Skip to content

Commit

Permalink
Merge pull request #245 from flatironinstitute/takeout-scripts
Browse files Browse the repository at this point in the history
Add the option to download 'takeout scripts'
  • Loading branch information
WardBrian authored Feb 24, 2025
2 parents f05874f + 4bd4294 commit 6e9286f
Show file tree
Hide file tree
Showing 20 changed files with 683 additions and 28 deletions.
24 changes: 1 addition & 23 deletions gui/src/app/core/Project/ProjectSerialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
FileRegistry,
ProjectFileMap,
mapFileContentsToModel,
mapModelToFileManifest,
} from "@SpCore/Project/FileMapping";
import {
ProjectDataModel,
Expand All @@ -16,7 +15,6 @@ import {
parseSamplingOpts,
persistStateToEphemera,
} from "@SpCore/Project/ProjectDataModel";
import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces";
import JSZip from "jszip";

export const serializeProjectToLocalStorage = (
Expand Down Expand Up @@ -49,26 +47,6 @@ export const deserializeProjectFromLocalStorage = (
}
};

export const serializeAsZip = async (
data: ProjectDataModel,
): Promise<[Blob, string]> => {
const fileManifest = mapModelToFileManifest(data);
const folderName = replaceSpacesWithUnderscores(data.meta.title);
const zip = new JSZip();
const folder = zip.folder(folderName);
if (!folder) {
throw new Error("Error creating folder in zip file");
}
Object.entries(fileManifest).forEach(([name, content]) => {
if (content.trim() !== "") {
folder.file(name, content);
}
});
const zipBlob = await zip.generateAsync({ type: "blob" });

return [zipBlob, folderName];
};

export const parseFile = (fileBuffer: ArrayBuffer) => {
const content = new TextDecoder().decode(fileBuffer);
return content;
Expand Down Expand Up @@ -98,7 +76,7 @@ export const deserializeZipToFiles = async (zipBuffer: ArrayBuffer) => {
const content = await file.async("arraybuffer");
const decoded = new TextDecoder().decode(content);
files[basename] = decoded;
} else {
} else if (!["run.R", "run.py"].includes(basename)) {
throw new Error(
`Unrecognized file in zip: ${file.name} (basename ${basename})`,
);
Expand Down
9 changes: 9 additions & 0 deletions gui/src/app/core/Scripting/Takeout/R/cmdstan.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
tryCatch({
cmdstanr::cmdstan_path()
}, error = function(e) {
if ("--install-cmdstan" %in% args) {
cmdstanr::install_cmdstan()
} else {
stop("cmdstan not found, use --install-cmdstan to install")
}
})
19 changes: 19 additions & 0 deletions gui/src/app/core/Scripting/Takeout/R/load_args.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.option_names_map = c(
init_radius="init",
num_warmup="iter_warmup",
num_samples="iter_sampling",
num_chains="chains"
)

sampling_opts <- list()

if (file.exists("./sampling_opts.json")) {
opts <- jsonlite::fromJSON("./sampling_opts.json")
for (key in names(opts)) {
out_key <- key
if (key %in% names(.option_names_map)) {
out_key <- .option_names_map[[key]]
}
sampling_opts[[out_key]] <- opts[[key]]
}
}
15 changes: 15 additions & 0 deletions gui/src/app/core/Scripting/Takeout/R/preamble.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
options(repos = c('https://cloud.r-project.org/'))
if (!require("posterior")) {
install.packages("posterior")
}
if (!require("cmdstanr")) {
install.packages("cmdstanr", repos = c('https://stan-dev.r-universe.dev', getOption("repos")))
}
if (!require("jsonlite")) {
install.packages("jsonlite")
}

library(cmdstanr)
library(jsonlite)

args <- commandArgs(trailingOnly = TRUE)
4 changes: 4 additions & 0 deletions gui/src/app/core/Scripting/Takeout/R/run_analysis.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
draws <- fit$draws(format="draws_array")

grDevices::pdf(onefile=FALSE)
source("analysis.R", local=TRUE, print.eval=TRUE)
5 changes: 5 additions & 0 deletions gui/src/app/core/Scripting/Takeout/R/run_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
source("data.R", local=TRUE, print.eval=TRUE)
if (typeof(data) != "list") {
stop("[stan-playground] data must be a list")
}
data <- list(data)
7 changes: 7 additions & 0 deletions gui/src/app/core/Scripting/Takeout/R/sample.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
print("compiling model")
model = cmdstanr::cmdstan_model("./main.stan")

print("sampling")
fit = do.call(model$sample, as.list(c(data=data, sampling_opts)))

print(fit$summary())
116 changes: 116 additions & 0 deletions gui/src/app/core/Scripting/Takeout/makeRuntime.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import { ProjectDataModel } from "@SpCore/Project/ProjectDataModel";
import indent from "@SpUtil/indent";

const makeRuntimeScript = async (
project: ProjectDataModel,
type: "R" | "py",
) => {
const {
spPreamble,
spRunData,
spLoadConfig,
spCmdStan,
spRunSampling,
spLoadDraws,
spRunAnalysis,
} = await loadParts(type);

const keyType = type === "R" ? "R" : "Py";

const hasDataJson = project.dataFileContent.length > 0;
const hasDataScript = project[`data${keyType}FileContent`].length > 0;
const hasAnalysisScript = project[`analysis${keyType}FileContent`].length > 0;

let script =
`TITLE = ${JSON.stringify(project.meta.title)}\n` + spPreamble + "\n";

// data
if (hasDataJson && hasDataScript) {
const { start, middle, end } = checkIfIgnoreData(type);
script += start;
script += indent(spRunData);
script += middle;
script += ` print("Loading data from data.json, pass --ignore-saved-data to run data.${type} instead")\n`;
script += indent(loadDataFromFile(type));
script += end;
} else if (hasDataJson) {
script += loadDataFromFile(type);
} else if (hasDataScript) {
script += spRunData;
script += "\n";
} else {
script += `data = ""\n\n`;
}

// running sampler
script += spLoadConfig;
script += "\n";
script += spCmdStan;
script += "\n";
script += spRunSampling;

// analysis
if (hasAnalysisScript) {
script += "\n";
script += spLoadDraws;
script += "\n";
script += spRunAnalysis;
}

return script;
};

const loadParts = async (type: "R" | "py") => {
const [
spPreamble,
spRunData,
spLoadConfig,
spCmdStan,
spRunSampling,
spLoadDraws,
spRunAnalysis,
] = await Promise.all([
import(`./${type}/preamble.${type}?raw`).then((m) => m.default),
import(`./${type}/run_data.${type}?raw`).then((m) => m.default),
import(`./${type}/load_args.${type}?raw`).then((m) => m.default),
import(`./${type}/cmdstan.${type}?raw`).then((m) => m.default),
import(`./${type}/sample.${type}?raw`).then((m) => m.default),
type === "py"
? import(`../pyodide/sp_load_draws.py?raw`).then((m) => m.default)
: Promise.resolve("") /* R uses posterior, no custom script needed */,
import(`./${type}/run_analysis.${type}?raw`).then((m) => m.default),
]);

return {
spPreamble,
spRunData,
spLoadConfig,
spCmdStan,
spRunSampling,
spLoadDraws,
spRunAnalysis,
};
};

const loadDataFromFile = (type: "R" | "py") => {
if (type === "R") {
return `data <- "./data.json"\n`;
}
return `data = os.path.join(HERE, 'data.json')\n`;
};

const checkIfIgnoreData = (type: "R" | "py") => {
if (type === "R") {
const start = `if ("--ignore-saved-data" %in% args) {\n`;
const middle = `\n} else {\n`;
const end = `\n}\n\n`;
return { start, middle, end };
}
return {
start: `if args.ignore_saved_data:\n`,
middle: "\nelse:\n",
end: "\n\n",
};
};

export default makeRuntimeScript;
7 changes: 7 additions & 0 deletions gui/src/app/core/Scripting/Takeout/py/cmdstan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
try:
cmdstanpy.cmdstan_path()
except Exception:
if args.install_cmdstan:
cmdstanpy.install_cmdstan()
else:
raise ValueError("cmdstan not found, use --install-cmdstan to install")
14 changes: 14 additions & 0 deletions gui/src/app/core/Scripting/Takeout/py/load_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_option_names_map = {
"init_radius": "inits",
"num_warmup": "iter_warmup",
"num_samples": "iter_sampling",
"num_chains": "chains",
}

if os.path.isfile(os.path.join(HERE, "sampling_opts.json")):
print("loading sampling_opts.json")
with open(os.path.join(HERE, "sampling_opts.json")) as f:
s = json.load(f)
sampling_opts = {_option_names_map.get(k, k): v for k, v in s.items()}
else:
sampling_opts = {}
12 changes: 12 additions & 0 deletions gui/src/app/core/Scripting/Takeout/py/preamble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import argparse
import json
import os

import cmdstanpy

HERE = os.path.dirname(os.path.abspath(__file__))

argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}")
argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing")
argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files")
args, _ = argparser.parse_known_args()
18 changes: 18 additions & 0 deletions gui/src/app/core/Scripting/Takeout/py/run_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import matplotlib.pyplot as plt

print("executing analysis.py")

sp_data = {
"draws": fit.draws(concat_chains=True).transpose(),
"paramNames": fit.metadata.cmdstan_config["raw_header"].split(","),
"numChains": fit.chains,
}

draws = sp_load_draws(sp_data)
del sp_data

with open(os.path.join(HERE, "analysis.py")) as f:
exec(f.read())

if len(plt.gcf().get_children()) > 1:
plt.show()
5 changes: 5 additions & 0 deletions gui/src/app/core/Scripting/Takeout/py/run_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
print("executing data.py")
with open(os.path.join(HERE, "data.py")) as f:
exec(f.read())
if "data" not in locals():
raise ValueError("data variable not defined in data.py")
7 changes: 7 additions & 0 deletions gui/src/app/core/Scripting/Takeout/py/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
print("compiling model")
model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan"))

print("sampling")
fit = model.sample(data=data, **sampling_opts)

print(fit.summary())
9 changes: 9 additions & 0 deletions gui/src/app/util/indent.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
const indent = (s: string) => {
return s
.trim()
.split("\n")
.map((x) => " " + x)
.join("\n");
};

export default indent;
21 changes: 21 additions & 0 deletions gui/src/app/util/serializeAsZip.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import JSZip from "jszip";

export const serializeAsZip = async (
folderName: string,
files: { [key: string]: string } = {},
): Promise<[Blob, string]> => {
const zip = new JSZip();
const folder = zip.folder(folderName);
if (!folder) {
throw new Error("Error creating folder in zip file");
}

Object.entries(files).forEach(([name, content]) => {
if (content.trim() !== "") {
folder.file(name, content);
}
});
const zipBlob = await zip.generateAsync({ type: "blob" });

return [zipBlob, folderName];
};
Loading

0 comments on commit 6e9286f

Please sign in to comment.