-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #245 from flatironinstitute/takeout-scripts
Add the option to download 'takeout scripts'
- Loading branch information
Showing
20 changed files
with
683 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
tryCatch({ | ||
cmdstanr::cmdstan_path() | ||
}, error = function(e) { | ||
if ("--install-cmdstan" %in% args) { | ||
cmdstanr::install_cmdstan() | ||
} else { | ||
stop("cmdstan not found, use --install-cmdstan to install") | ||
} | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
.option_names_map = c( | ||
init_radius="init", | ||
num_warmup="iter_warmup", | ||
num_samples="iter_sampling", | ||
num_chains="chains" | ||
) | ||
|
||
sampling_opts <- list() | ||
|
||
if (file.exists("./sampling_opts.json")) { | ||
opts <- jsonlite::fromJSON("./sampling_opts.json") | ||
for (key in names(opts)) { | ||
out_key <- key | ||
if (key %in% names(.option_names_map)) { | ||
out_key <- .option_names_map[[key]] | ||
} | ||
sampling_opts[[out_key]] <- opts[[key]] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
options(repos = c('https://cloud.r-project.org/')) | ||
if (!require("posterior")) { | ||
install.packages("posterior") | ||
} | ||
if (!require("cmdstanr")) { | ||
install.packages("cmdstanr", repos = c('https://stan-dev.r-universe.dev', getOption("repos"))) | ||
} | ||
if (!require("jsonlite")) { | ||
install.packages("jsonlite") | ||
} | ||
|
||
library(cmdstanr) | ||
library(jsonlite) | ||
|
||
args <- commandArgs(trailingOnly = TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
draws <- fit$draws(format="draws_array") | ||
|
||
grDevices::pdf(onefile=FALSE) | ||
source("analysis.R", local=TRUE, print.eval=TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
source("data.R", local=TRUE, print.eval=TRUE) | ||
if (typeof(data) != "list") { | ||
stop("[stan-playground] data must be a list") | ||
} | ||
data <- list(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
print("compiling model") | ||
model = cmdstanr::cmdstan_model("./main.stan") | ||
|
||
print("sampling") | ||
fit = do.call(model$sample, as.list(c(data=data, sampling_opts))) | ||
|
||
print(fit$summary()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import { ProjectDataModel } from "@SpCore/Project/ProjectDataModel"; | ||
import indent from "@SpUtil/indent"; | ||
|
||
const makeRuntimeScript = async ( | ||
project: ProjectDataModel, | ||
type: "R" | "py", | ||
) => { | ||
const { | ||
spPreamble, | ||
spRunData, | ||
spLoadConfig, | ||
spCmdStan, | ||
spRunSampling, | ||
spLoadDraws, | ||
spRunAnalysis, | ||
} = await loadParts(type); | ||
|
||
const keyType = type === "R" ? "R" : "Py"; | ||
|
||
const hasDataJson = project.dataFileContent.length > 0; | ||
const hasDataScript = project[`data${keyType}FileContent`].length > 0; | ||
const hasAnalysisScript = project[`analysis${keyType}FileContent`].length > 0; | ||
|
||
let script = | ||
`TITLE = ${JSON.stringify(project.meta.title)}\n` + spPreamble + "\n"; | ||
|
||
// data | ||
if (hasDataJson && hasDataScript) { | ||
const { start, middle, end } = checkIfIgnoreData(type); | ||
script += start; | ||
script += indent(spRunData); | ||
script += middle; | ||
script += ` print("Loading data from data.json, pass --ignore-saved-data to run data.${type} instead")\n`; | ||
script += indent(loadDataFromFile(type)); | ||
script += end; | ||
} else if (hasDataJson) { | ||
script += loadDataFromFile(type); | ||
} else if (hasDataScript) { | ||
script += spRunData; | ||
script += "\n"; | ||
} else { | ||
script += `data = ""\n\n`; | ||
} | ||
|
||
// running sampler | ||
script += spLoadConfig; | ||
script += "\n"; | ||
script += spCmdStan; | ||
script += "\n"; | ||
script += spRunSampling; | ||
|
||
// analysis | ||
if (hasAnalysisScript) { | ||
script += "\n"; | ||
script += spLoadDraws; | ||
script += "\n"; | ||
script += spRunAnalysis; | ||
} | ||
|
||
return script; | ||
}; | ||
|
||
const loadParts = async (type: "R" | "py") => { | ||
const [ | ||
spPreamble, | ||
spRunData, | ||
spLoadConfig, | ||
spCmdStan, | ||
spRunSampling, | ||
spLoadDraws, | ||
spRunAnalysis, | ||
] = await Promise.all([ | ||
import(`./${type}/preamble.${type}?raw`).then((m) => m.default), | ||
import(`./${type}/run_data.${type}?raw`).then((m) => m.default), | ||
import(`./${type}/load_args.${type}?raw`).then((m) => m.default), | ||
import(`./${type}/cmdstan.${type}?raw`).then((m) => m.default), | ||
import(`./${type}/sample.${type}?raw`).then((m) => m.default), | ||
type === "py" | ||
? import(`../pyodide/sp_load_draws.py?raw`).then((m) => m.default) | ||
: Promise.resolve("") /* R uses posterior, no custom script needed */, | ||
import(`./${type}/run_analysis.${type}?raw`).then((m) => m.default), | ||
]); | ||
|
||
return { | ||
spPreamble, | ||
spRunData, | ||
spLoadConfig, | ||
spCmdStan, | ||
spRunSampling, | ||
spLoadDraws, | ||
spRunAnalysis, | ||
}; | ||
}; | ||
|
||
const loadDataFromFile = (type: "R" | "py") => { | ||
if (type === "R") { | ||
return `data <- "./data.json"\n`; | ||
} | ||
return `data = os.path.join(HERE, 'data.json')\n`; | ||
}; | ||
|
||
const checkIfIgnoreData = (type: "R" | "py") => { | ||
if (type === "R") { | ||
const start = `if ("--ignore-saved-data" %in% args) {\n`; | ||
const middle = `\n} else {\n`; | ||
const end = `\n}\n\n`; | ||
return { start, middle, end }; | ||
} | ||
return { | ||
start: `if args.ignore_saved_data:\n`, | ||
middle: "\nelse:\n", | ||
end: "\n\n", | ||
}; | ||
}; | ||
|
||
export default makeRuntimeScript; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
try: | ||
cmdstanpy.cmdstan_path() | ||
except Exception: | ||
if args.install_cmdstan: | ||
cmdstanpy.install_cmdstan() | ||
else: | ||
raise ValueError("cmdstan not found, use --install-cmdstan to install") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
_option_names_map = { | ||
"init_radius": "inits", | ||
"num_warmup": "iter_warmup", | ||
"num_samples": "iter_sampling", | ||
"num_chains": "chains", | ||
} | ||
|
||
if os.path.isfile(os.path.join(HERE, "sampling_opts.json")): | ||
print("loading sampling_opts.json") | ||
with open(os.path.join(HERE, "sampling_opts.json")) as f: | ||
s = json.load(f) | ||
sampling_opts = {_option_names_map.get(k, k): v for k, v in s.items()} | ||
else: | ||
sampling_opts = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import argparse | ||
import json | ||
import os | ||
|
||
import cmdstanpy | ||
|
||
HERE = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}") | ||
argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing") | ||
argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files") | ||
args, _ = argparser.parse_known_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import matplotlib.pyplot as plt | ||
|
||
print("executing analysis.py") | ||
|
||
sp_data = { | ||
"draws": fit.draws(concat_chains=True).transpose(), | ||
"paramNames": fit.metadata.cmdstan_config["raw_header"].split(","), | ||
"numChains": fit.chains, | ||
} | ||
|
||
draws = sp_load_draws(sp_data) | ||
del sp_data | ||
|
||
with open(os.path.join(HERE, "analysis.py")) as f: | ||
exec(f.read()) | ||
|
||
if len(plt.gcf().get_children()) > 1: | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
print("executing data.py") | ||
with open(os.path.join(HERE, "data.py")) as f: | ||
exec(f.read()) | ||
if "data" not in locals(): | ||
raise ValueError("data variable not defined in data.py") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
print("compiling model") | ||
model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan")) | ||
|
||
print("sampling") | ||
fit = model.sample(data=data, **sampling_opts) | ||
|
||
print(fit.summary()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
const indent = (s: string) => { | ||
return s | ||
.trim() | ||
.split("\n") | ||
.map((x) => " " + x) | ||
.join("\n"); | ||
}; | ||
|
||
export default indent; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import JSZip from "jszip"; | ||
|
||
export const serializeAsZip = async ( | ||
folderName: string, | ||
files: { [key: string]: string } = {}, | ||
): Promise<[Blob, string]> => { | ||
const zip = new JSZip(); | ||
const folder = zip.folder(folderName); | ||
if (!folder) { | ||
throw new Error("Error creating folder in zip file"); | ||
} | ||
|
||
Object.entries(files).forEach(([name, content]) => { | ||
if (content.trim() !== "") { | ||
folder.file(name, content); | ||
} | ||
}); | ||
const zipBlob = await zip.generateAsync({ type: "blob" }); | ||
|
||
return [zipBlob, folderName]; | ||
}; |
Oops, something went wrong.