diff --git a/gui/src/app/core/Project/ProjectSerialization.ts b/gui/src/app/core/Project/ProjectSerialization.ts index 68cd7a9a..d7b1e2fd 100644 --- a/gui/src/app/core/Project/ProjectSerialization.ts +++ b/gui/src/app/core/Project/ProjectSerialization.ts @@ -4,7 +4,6 @@ import { FileRegistry, ProjectFileMap, mapFileContentsToModel, - mapModelToFileManifest, } from "@SpCore/Project/FileMapping"; import { ProjectDataModel, @@ -16,7 +15,6 @@ import { parseSamplingOpts, persistStateToEphemera, } from "@SpCore/Project/ProjectDataModel"; -import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces"; import JSZip from "jszip"; export const serializeProjectToLocalStorage = ( @@ -49,26 +47,6 @@ export const deserializeProjectFromLocalStorage = ( } }; -export const serializeAsZip = async ( - data: ProjectDataModel, -): Promise<[Blob, string]> => { - const fileManifest = mapModelToFileManifest(data); - const folderName = replaceSpacesWithUnderscores(data.meta.title); - const zip = new JSZip(); - const folder = zip.folder(folderName); - if (!folder) { - throw new Error("Error creating folder in zip file"); - } - Object.entries(fileManifest).forEach(([name, content]) => { - if (content.trim() !== "") { - folder.file(name, content); - } - }); - const zipBlob = await zip.generateAsync({ type: "blob" }); - - return [zipBlob, folderName]; -}; - export const parseFile = (fileBuffer: ArrayBuffer) => { const content = new TextDecoder().decode(fileBuffer); return content; @@ -98,7 +76,7 @@ export const deserializeZipToFiles = async (zipBuffer: ArrayBuffer) => { const content = await file.async("arraybuffer"); const decoded = new TextDecoder().decode(content); files[basename] = decoded; - } else { + } else if (!["run.R", "run.py"].includes(basename)) { throw new Error( `Unrecognized file in zip: ${file.name} (basename ${basename})`, ); diff --git a/gui/src/app/core/Scripting/Takeout/R/cmdstan.R b/gui/src/app/core/Scripting/Takeout/R/cmdstan.R new file mode 100644 index 00000000..7e46109c --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/R/cmdstan.R @@ -0,0 +1,9 @@ +tryCatch({ + cmdstanr::cmdstan_path() +}, error = function(e) { + if ("--install-cmdstan" %in% args) { + cmdstanr::install_cmdstan() + } else { + stop("cmdstan not found, use --install-cmdstan to install") + } +}) diff --git a/gui/src/app/core/Scripting/Takeout/R/load_args.R b/gui/src/app/core/Scripting/Takeout/R/load_args.R new file mode 100644 index 00000000..1593d7b7 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/R/load_args.R @@ -0,0 +1,19 @@ +.option_names_map = c( + init_radius="init", + num_warmup="iter_warmup", + num_samples="iter_sampling", + num_chains="chains" +) + +sampling_opts <- list() + +if (file.exists("./sampling_opts.json")) { + opts <- jsonlite::fromJSON("./sampling_opts.json") + for (key in names(opts)) { + out_key <- key + if (key %in% names(.option_names_map)) { + out_key <- .option_names_map[[key]] + } + sampling_opts[[out_key]] <- opts[[key]] + } +} diff --git a/gui/src/app/core/Scripting/Takeout/R/preamble.R b/gui/src/app/core/Scripting/Takeout/R/preamble.R new file mode 100644 index 00000000..102aa12a --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/R/preamble.R @@ -0,0 +1,15 @@ +options(repos = c('https://cloud.r-project.org/')) +if (!require("posterior")) { + install.packages("posterior") +} +if (!require("cmdstanr")) { + install.packages("cmdstanr", repos = c('https://stan-dev.r-universe.dev', getOption("repos"))) +} +if (!require("jsonlite")) { + install.packages("jsonlite") +} + +library(cmdstanr) +library(jsonlite) + +args <- commandArgs(trailingOnly = TRUE) diff --git a/gui/src/app/core/Scripting/Takeout/R/run_analysis.R b/gui/src/app/core/Scripting/Takeout/R/run_analysis.R new file mode 100644 index 00000000..beff1cf0 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/R/run_analysis.R @@ -0,0 +1,4 @@ +draws <- fit$draws(format="draws_array") + +grDevices::pdf(onefile=FALSE) +source("analysis.R", local=TRUE, print.eval=TRUE) diff --git a/gui/src/app/core/Scripting/Takeout/R/run_data.R b/gui/src/app/core/Scripting/Takeout/R/run_data.R new file mode 100644 index 00000000..3c16e0d5 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/R/run_data.R @@ -0,0 +1,5 @@ +source("data.R", local=TRUE, print.eval=TRUE) +if (typeof(data) != "list") { + stop("[stan-playground] data must be a list") +} +data <- list(data) diff --git a/gui/src/app/core/Scripting/Takeout/R/sample.R b/gui/src/app/core/Scripting/Takeout/R/sample.R new file mode 100644 index 00000000..634103b5 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/R/sample.R @@ -0,0 +1,7 @@ +print("compiling model") +model = cmdstanr::cmdstan_model("./main.stan") + +print("sampling") +fit = do.call(model$sample, as.list(c(data=data, sampling_opts))) + +print(fit$summary()) diff --git a/gui/src/app/core/Scripting/Takeout/makeRuntime.ts b/gui/src/app/core/Scripting/Takeout/makeRuntime.ts new file mode 100644 index 00000000..950a6517 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/makeRuntime.ts @@ -0,0 +1,116 @@ +import { ProjectDataModel } from "@SpCore/Project/ProjectDataModel"; +import indent from "@SpUtil/indent"; + +const makeRuntimeScript = async ( + project: ProjectDataModel, + type: "R" | "py", +) => { + const { + spPreamble, + spRunData, + spLoadConfig, + spCmdStan, + spRunSampling, + spLoadDraws, + spRunAnalysis, + } = await loadParts(type); + + const keyType = type === "R" ? "R" : "Py"; + + const hasDataJson = project.dataFileContent.length > 0; + const hasDataScript = project[`data${keyType}FileContent`].length > 0; + const hasAnalysisScript = project[`analysis${keyType}FileContent`].length > 0; + + let script = + `TITLE = ${JSON.stringify(project.meta.title)}\n` + spPreamble + "\n"; + + // data + if (hasDataJson && hasDataScript) { + const { start, middle, end } = checkIfIgnoreData(type); + script += start; + script += indent(spRunData); + script += middle; + script += ` print("Loading data from data.json, pass --ignore-saved-data to run data.${type} instead")\n`; + script += indent(loadDataFromFile(type)); + script += end; + } else if (hasDataJson) { + script += loadDataFromFile(type); + } else if (hasDataScript) { + script += spRunData; + script += "\n"; + } else { + script += `data = ""\n\n`; + } + + // running sampler + script += spLoadConfig; + script += "\n"; + script += spCmdStan; + script += "\n"; + script += spRunSampling; + + // analysis + if (hasAnalysisScript) { + script += "\n"; + script += spLoadDraws; + script += "\n"; + script += spRunAnalysis; + } + + return script; +}; + +const loadParts = async (type: "R" | "py") => { + const [ + spPreamble, + spRunData, + spLoadConfig, + spCmdStan, + spRunSampling, + spLoadDraws, + spRunAnalysis, + ] = await Promise.all([ + import(`./${type}/preamble.${type}?raw`).then((m) => m.default), + import(`./${type}/run_data.${type}?raw`).then((m) => m.default), + import(`./${type}/load_args.${type}?raw`).then((m) => m.default), + import(`./${type}/cmdstan.${type}?raw`).then((m) => m.default), + import(`./${type}/sample.${type}?raw`).then((m) => m.default), + type === "py" + ? import(`../pyodide/sp_load_draws.py?raw`).then((m) => m.default) + : Promise.resolve("") /* R uses posterior, no custom script needed */, + import(`./${type}/run_analysis.${type}?raw`).then((m) => m.default), + ]); + + return { + spPreamble, + spRunData, + spLoadConfig, + spCmdStan, + spRunSampling, + spLoadDraws, + spRunAnalysis, + }; +}; + +const loadDataFromFile = (type: "R" | "py") => { + if (type === "R") { + return `data <- "./data.json"\n`; + } + return `data = os.path.join(HERE, 'data.json')\n`; +}; + +const checkIfIgnoreData = (type: "R" | "py") => { + if (type === "R") { + const start = `if ("--ignore-saved-data" %in% args) {\n`; + const middle = `\n} else {\n`; + const end = `\n}\n\n`; + return { start, middle, end }; + } + return { + start: `if args.ignore_saved_data:\n`, + middle: "\nelse:\n", + end: "\n\n", + }; +}; + +export default makeRuntimeScript; diff --git a/gui/src/app/core/Scripting/Takeout/py/cmdstan.py b/gui/src/app/core/Scripting/Takeout/py/cmdstan.py new file mode 100644 index 00000000..4d23c586 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/py/cmdstan.py @@ -0,0 +1,7 @@ +try: + cmdstanpy.cmdstan_path() +except Exception: + if args.install_cmdstan: + cmdstanpy.install_cmdstan() + else: + raise ValueError("cmdstan not found, use --install-cmdstan to install") diff --git a/gui/src/app/core/Scripting/Takeout/py/load_args.py b/gui/src/app/core/Scripting/Takeout/py/load_args.py new file mode 100644 index 00000000..cc0dca34 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/py/load_args.py @@ -0,0 +1,14 @@ +_option_names_map = { + "init_radius": "inits", + "num_warmup": "iter_warmup", + "num_samples": "iter_sampling", + "num_chains": "chains", +} + +if os.path.isfile(os.path.join(HERE, "sampling_opts.json")): + print("loading sampling_opts.json") + with open(os.path.join(HERE, "sampling_opts.json")) as f: + s = json.load(f) + sampling_opts = {_option_names_map.get(k, k): v for k, v in s.items()} +else: + sampling_opts = {} diff --git a/gui/src/app/core/Scripting/Takeout/py/preamble.py b/gui/src/app/core/Scripting/Takeout/py/preamble.py new file mode 100644 index 00000000..015290f5 --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/py/preamble.py @@ -0,0 +1,12 @@ +import argparse +import json +import os + +import cmdstanpy + +HERE = os.path.dirname(os.path.abspath(__file__)) + +argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}") +argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing") +argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files") +args, _ = argparser.parse_known_args() diff --git a/gui/src/app/core/Scripting/Takeout/py/run_analysis.py b/gui/src/app/core/Scripting/Takeout/py/run_analysis.py new file mode 100644 index 00000000..3ccfbc0d --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/py/run_analysis.py @@ -0,0 +1,18 @@ +import matplotlib.pyplot as plt + +print("executing analysis.py") + +sp_data = { + "draws": fit.draws(concat_chains=True).transpose(), + "paramNames": fit.metadata.cmdstan_config["raw_header"].split(","), + "numChains": fit.chains, +} + +draws = sp_load_draws(sp_data) +del sp_data + +with open(os.path.join(HERE, "analysis.py")) as f: + exec(f.read()) + +if len(plt.gcf().get_children()) > 1: + plt.show() diff --git a/gui/src/app/core/Scripting/Takeout/py/run_data.py b/gui/src/app/core/Scripting/Takeout/py/run_data.py new file mode 100644 index 00000000..e6e65a6d --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/py/run_data.py @@ -0,0 +1,5 @@ +print("executing data.py") +with open(os.path.join(HERE, "data.py")) as f: + exec(f.read()) +if "data" not in locals(): + raise ValueError("data variable not defined in data.py") diff --git a/gui/src/app/core/Scripting/Takeout/py/sample.py b/gui/src/app/core/Scripting/Takeout/py/sample.py new file mode 100644 index 00000000..54d23f0a --- /dev/null +++ b/gui/src/app/core/Scripting/Takeout/py/sample.py @@ -0,0 +1,7 @@ +print("compiling model") +model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan")) + +print("sampling") +fit = model.sample(data=data, **sampling_opts) + +print(fit.summary()) diff --git a/gui/src/app/util/indent.ts b/gui/src/app/util/indent.ts new file mode 100644 index 00000000..1feec892 --- /dev/null +++ b/gui/src/app/util/indent.ts @@ -0,0 +1,9 @@ +const indent = (s: string) => { + return s + .trim() + .split("\n") + .map((x) => " " + x) + .join("\n"); +}; + +export default indent; diff --git a/gui/src/app/util/serializeAsZip.ts b/gui/src/app/util/serializeAsZip.ts new file mode 100644 index 00000000..f6698629 --- /dev/null +++ b/gui/src/app/util/serializeAsZip.ts @@ -0,0 +1,21 @@ +import JSZip from "jszip"; + +export const serializeAsZip = async ( + folderName: string, + files: { [key: string]: string } = {}, +): Promise<[Blob, string]> => { + const zip = new JSZip(); + const folder = zip.folder(folderName); + if (!folder) { + throw new Error("Error creating folder in zip file"); + } + + Object.entries(files).forEach(([name, content]) => { + if (content.trim() !== "") { + folder.file(name, content); + } + }); + const zipBlob = await zip.generateAsync({ type: "blob" }); + + return [zipBlob, folderName]; +}; diff --git a/gui/src/app/windows/ExportProjectWindow/ExportProjectPanel.tsx b/gui/src/app/windows/ExportProjectWindow/ExportProjectPanel.tsx index f5082aff..7d7dc0bd 100644 --- a/gui/src/app/windows/ExportProjectWindow/ExportProjectPanel.tsx +++ b/gui/src/app/windows/ExportProjectWindow/ExportProjectPanel.tsx @@ -1,15 +1,19 @@ -import { FunctionComponent, use, useState } from "react"; +import { FunctionComponent, use, useEffect, useState } from "react"; import Table from "@mui/material/Table"; import TableBody from "@mui/material/TableBody"; import TableCell from "@mui/material/TableCell"; import TableContainer from "@mui/material/TableContainer"; +import Button from "@mui/material/Button"; +import TextField from "@mui/material/TextField"; + import { AlternatingTableRow } from "@SpComponents/StyledTables"; import { mapModelToFileManifest } from "@SpCore/Project/FileMapping"; import { ProjectContext } from "@SpCore/Project/ProjectContextProvider"; +import makeRuntimeScript from "@SpCore/Scripting/Takeout/makeRuntime"; import { triggerDownload } from "@SpUtil/triggerDownload"; -import Button from "@mui/material/Button"; -import { serializeAsZip } from "@SpCore/Project/ProjectSerialization"; -import TextField from "@mui/material/TextField"; +import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces"; +import { serializeAsZip } from "@SpUtil/serializeAsZip"; + import GistExportPanel from "./GistExportPanel"; import GistUpdatePanel from "./GistUpdatePanel"; @@ -26,6 +30,32 @@ const ExportProjectPanel: FunctionComponent = ({ const [exportingToGist, setExportingToGist] = useState(false); const [updatingExistingGist, setUpdatingExistingGist] = useState(false); + const [includeRunPy, setIncludeRunPy] = useState( + data.analysisPyFileContent.length > 0 || data.dataPyFileContent.length > 0, + ); + const [runPy, setRunPy] = useState(""); + + const [includeRunR, setIncludeRunR] = useState( + data.analysisRFileContent.length > 0 || data.dataRFileContent.length > 0, + ); + const [runR, setRunR] = useState(""); + + useEffect(() => { + if (includeRunPy) { + makeRuntimeScript(data, "py").then(setRunPy); + } else { + setRunPy(""); + } + }, [includeRunPy, data]); + + useEffect(() => { + if (includeRunR) { + makeRuntimeScript(data, "R").then(setRunR); + } else { + setRunR(""); + } + }, [includeRunR, data]); + return (
@@ -59,6 +89,38 @@ const ExportProjectPanel: FunctionComponent = ({ ), )} + + + run.py + + + setIncludeRunPy(e.target.checked)} + /> +   {runPy.length} bytes + + + + + run.R + + + setIncludeRunR(e.target.checked)} + /> +   {runR.length} bytes + + @@ -67,7 +129,16 @@ const ExportProjectPanel: FunctionComponent = ({