diff --git a/gui/src/app/SamplerOutputView/ConsoleOutput.tsx b/gui/src/app/SamplerOutputView/ConsoleOutput.tsx new file mode 100644 index 00000000..2aff92a6 --- /dev/null +++ b/gui/src/app/SamplerOutputView/ConsoleOutput.tsx @@ -0,0 +1,17 @@ +import Box from "@mui/material/Box"; + +import { FunctionComponent } from "react"; + +type ConsoleOutputProps = { + text: string; +}; + +const ConsoleOutput: FunctionComponent = ({ text }) => { + return ( + +
{text}
+
+ ); +}; + +export default ConsoleOutput; diff --git a/gui/src/app/SamplerOutputView/DrawsView.tsx b/gui/src/app/SamplerOutputView/DrawsView.tsx new file mode 100644 index 00000000..50592abf --- /dev/null +++ b/gui/src/app/SamplerOutputView/DrawsView.tsx @@ -0,0 +1,198 @@ +import { Download } from "@mui/icons-material"; +import Button from "@mui/material/Button"; +import IconButton from "@mui/material/IconButton"; +import Table from "@mui/material/Table"; +import TableBody from "@mui/material/TableBody"; +import TableCell from "@mui/material/TableCell"; +import TableContainer from "@mui/material/TableContainer"; +import { + SuccessBorderedTableRow, + SuccessColoredTableHead, +} from "@SpComponents/StyledTables"; +import { SamplingOpts } from "@SpCore/ProjectDataModel"; +import { triggerDownload } from "@SpUtil/triggerDownload"; +import JSZip from "jszip"; +import { FunctionComponent, useCallback, useMemo, useState } from "react"; + +type DrawsViewProps = { + draws: number[][]; + paramNames: string[]; + drawChainIds: number[]; + drawNumbers: number[]; + samplingOpts: SamplingOpts; // for including in exported zip +}; + +const DrawsView: FunctionComponent = ({ + draws, + paramNames, + drawChainIds, + drawNumbers, + samplingOpts, +}) => { + const [abbreviatedToNumRows, setAbbreviatedToNumRows] = useState< + number | undefined + >(300); + const formattedDraws = useMemo(() => { + if (abbreviatedToNumRows === undefined) return draws; + return draws.map((draw) => + formatDraws(draw.slice(0, abbreviatedToNumRows)), + ); + }, [draws, abbreviatedToNumRows]); + const handleExportToCsv = useCallback(() => { + const csvText = prepareCsvText( + draws, + paramNames, + drawChainIds, + drawNumbers, + ); + downloadTextFile(csvText, "draws.csv"); + }, [draws, paramNames, drawChainIds, drawNumbers]); + const handleExportToMultipleCsvs = useCallback(async () => { + const uniqueChainIds = Array.from(new Set(drawChainIds)); + const csvTexts = prepareMultipleCsvsText( + draws, + paramNames, + drawChainIds, + uniqueChainIds, + ); + const blob = await createZipBlobForMultipleCsvs( + csvTexts, + uniqueChainIds, + samplingOpts, + ); + const fileName = "SP-draws.zip"; + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = fileName; + a.click(); + URL.revokeObjectURL(url); + }, [draws, paramNames, drawChainIds, samplingOpts]); + return ( + <> +
+ + +  Export to single .csv + +   + + +  Export to multiple .csv + +
+
+ + + + + Chain + Draw + {paramNames.map((name, i) => ( + {name} + ))} + + + + {formattedDraws[0].map((_, i) => ( + + {drawChainIds[i]} + {drawNumbers[i]} + {formattedDraws.map((draw, j) => ( + {draw[i]} + ))} + + ))} + +
+ {abbreviatedToNumRows !== undefined && + abbreviatedToNumRows < draws[0].length && ( +
+ +
+ )} +
+ + ); +}; + +const formatDraws = (draws: number[]) => { + if (draws.every((x) => Number.isInteger(x))) return draws; + return draws.map((x) => x.toPrecision(6)); +}; + +const prepareCsvText = ( + draws: number[][], + paramNames: string[], + drawChainIds: number[], + drawNumbers: number[], +) => { + // draws: Each element of draws is a column corresponding to a parameter, across all chains + // paramNames: The paramNames array contains the names of the parameters in the same order that they appear in the draws array + // drawChainIds: The drawChainIds array contains the chain id for each row in the draws array + // uniqueChainIds: The uniqueChainIds array contains the unique chain ids + const lines = draws[0].map((_, i) => { + return [ + `${drawChainIds[i]}`, + `${drawNumbers[i]}`, + ...paramNames.map((_, j) => draws[j][i]), + ].join(","); + }); + return [["Chain", "Draw", ...paramNames].join(","), ...lines].join("\n"); +}; + +const prepareMultipleCsvsText = ( + draws: number[][], + paramNames: string[], + drawChainIds: number[], + uniqueChainIds: number[], +) => { + // See the comments in prepareCsvText for the meaning of the arguments. + // Whereas prepareCsvText returns a CSV that represents a long-form table, + // this function returns multiple CSVs, one for each chain. + return uniqueChainIds.map((chainId) => { + const drawIndicesForChain = drawChainIds + .map((id, i) => (id === chainId ? i : -1)) + .filter((i) => i >= 0); + const lines = drawIndicesForChain.map((i) => { + return paramNames.map((_, j) => draws[j][i]).join(","); + }); + + return [paramNames.join(","), ...lines].join("\n"); + }); +}; + +const createZipBlobForMultipleCsvs = async ( + csvTexts: string[], + uniqueChainIds: number[], + samplingOpts: SamplingOpts, +) => { + const zip = new JSZip(); + // put them all in a folder called 'draws' + const folder = zip.folder("draws"); + if (!folder) throw new Error("Failed to create folder"); + csvTexts.forEach((text, i) => { + folder.file(`chain_${uniqueChainIds[i]}.csv`, text); + }); + const samplingOptsText = JSON.stringify(samplingOpts, null, 2); + folder.file("sampling_opts.json", samplingOptsText); + const blob = await zip.generateAsync({ type: "blob" }); + return blob; +}; + +const downloadTextFile = (text: string, filename: string) => { + const blob = new Blob([text], { type: "text/plain" }); + triggerDownload(blob, filename, () => {}); +}; + +export default DrawsView; diff --git a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx index da35e0e4..c4b9200b 100644 --- a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx +++ b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx @@ -1,26 +1,13 @@ -import Button from "@mui/material/Button"; -import IconButton from "@mui/material/IconButton"; -import { Download } from "@mui/icons-material"; -import Table from "@mui/material/Table"; -import TableBody from "@mui/material/TableBody"; -import TableCell from "@mui/material/TableCell"; -import TableContainer from "@mui/material/TableContainer"; +import { FunctionComponent, useMemo } from "react"; -import { FunctionComponent, useCallback, useMemo, useState } from "react"; - -import JSZip from "jszip"; - -import { - SuccessBorderedTableRow, - SuccessColoredTableHead, -} from "@SpComponents/StyledTables"; +import DrawsView from "@SpComponents/DrawsView"; import HistsView from "@SpComponents/HistsView"; import SummaryView from "@SpComponents/SummaryView"; import TabWidget from "@SpComponents/TabWidget"; +import TracePlotsView from "@SpComponents/TracePlotsView"; +import ConsoleOutput from "@SpComponents/ConsoleOutput"; import { SamplingOpts } from "@SpCore/ProjectDataModel"; import { StanRun } from "@SpStanSampler/useStanSampler"; -import { triggerDownload } from "@SpUtil/triggerDownload"; -import TracePlotsView from "./TracePlotsView"; type SamplerOutputViewProps = { latestRun: StanRun; @@ -29,15 +16,20 @@ type SamplerOutputViewProps = { const SamplerOutputView: FunctionComponent = ({ latestRun, }) => { - const { draws, paramNames, computeTimeSec, samplingOpts } = latestRun; + if (!latestRun.runResult || !latestRun.samplingOpts) return ; + + const { + samplingOpts, + runResult: { draws, paramNames, computeTimeSec, consoleText }, + } = latestRun; - if (!draws || !paramNames || !samplingOpts) return ; return ( ); }; @@ -47,6 +39,7 @@ type DrawsDisplayProps = { paramNames: string[]; computeTimeSec: number | undefined; samplingOpts: SamplingOpts; + consoleText: string; }; const DrawsDisplay: FunctionComponent = ({ @@ -54,6 +47,7 @@ const DrawsDisplay: FunctionComponent = ({ paramNames, computeTimeSec, samplingOpts, + consoleText, }) => { const numChains = samplingOpts.num_chains; @@ -71,7 +65,9 @@ const DrawsDisplay: FunctionComponent = ({ }, [draws, numChains]); return ( - + = ({ paramNames={paramNames} drawChainIds={drawChainIds} /> + ); }; -type DrawsViewProps = { - draws: number[][]; - paramNames: string[]; - drawChainIds: number[]; - drawNumbers: number[]; - samplingOpts: SamplingOpts; // for including in exported zip -}; - -const DrawsView: FunctionComponent = ({ - draws, - paramNames, - drawChainIds, - drawNumbers, - samplingOpts, -}) => { - const [abbreviatedToNumRows, setAbbreviatedToNumRows] = useState< - number | undefined - >(300); - const formattedDraws = useMemo(() => { - if (abbreviatedToNumRows === undefined) return draws; - return draws.map((draw) => - formatDraws(draw.slice(0, abbreviatedToNumRows)), - ); - }, [draws, abbreviatedToNumRows]); - const handleExportToCsv = useCallback(() => { - const csvText = prepareCsvText( - draws, - paramNames, - drawChainIds, - drawNumbers, - ); - downloadTextFile(csvText, "draws.csv"); - }, [draws, paramNames, drawChainIds, drawNumbers]); - const handleExportToMultipleCsvs = useCallback(async () => { - const uniqueChainIds = Array.from(new Set(drawChainIds)); - const csvTexts = prepareMultipleCsvsText( - draws, - paramNames, - drawChainIds, - uniqueChainIds, - ); - const blob = await createZipBlobForMultipleCsvs( - csvTexts, - uniqueChainIds, - samplingOpts, - ); - const fileName = "SP-draws.zip"; - const url = URL.createObjectURL(blob); - const a = document.createElement("a"); - a.href = url; - a.download = fileName; - a.click(); - URL.revokeObjectURL(url); - }, [draws, paramNames, drawChainIds, samplingOpts]); - return ( - <> -
- - -  Export to single .csv - -   - - -  Export to multiple .csv - -
-
- - - - - Chain - Draw - {paramNames.map((name, i) => ( - {name} - ))} - - - - {formattedDraws[0].map((_, i) => ( - - {drawChainIds[i]} - {drawNumbers[i]} - {formattedDraws.map((draw, j) => ( - {draw[i]} - ))} - - ))} - -
- {abbreviatedToNumRows !== undefined && - abbreviatedToNumRows < draws[0].length && ( -
- -
- )} -
- - ); -}; - -const formatDraws = (draws: number[]) => { - if (draws.every((x) => Number.isInteger(x))) return draws; - return draws.map((x) => x.toPrecision(6)); -}; - -const prepareCsvText = ( - draws: number[][], - paramNames: string[], - drawChainIds: number[], - drawNumbers: number[], -) => { - // draws: Each element of draws is a column corresponding to a parameter, across all chains - // paramNames: The paramNames array contains the names of the parameters in the same order that they appear in the draws array - // drawChainIds: The drawChainIds array contains the chain id for each row in the draws array - // uniqueChainIds: The uniqueChainIds array contains the unique chain ids - const lines = draws[0].map((_, i) => { - return [ - `${drawChainIds[i]}`, - `${drawNumbers[i]}`, - ...paramNames.map((_, j) => draws[j][i]), - ].join(","); - }); - return [["Chain", "Draw", ...paramNames].join(","), ...lines].join("\n"); -}; - -const prepareMultipleCsvsText = ( - draws: number[][], - paramNames: string[], - drawChainIds: number[], - uniqueChainIds: number[], -) => { - // See the comments in prepareCsvText for the meaning of the arguments. - // Whereas prepareCsvText returns a CSV that represents a long-form table, - // this function returns multiple CSVs, one for each chain. - return uniqueChainIds.map((chainId) => { - const drawIndicesForChain = drawChainIds - .map((id, i) => (id === chainId ? i : -1)) - .filter((i) => i >= 0); - const lines = drawIndicesForChain.map((i) => { - return paramNames.map((_, j) => draws[j][i]).join(","); - }); - - return [paramNames.join(","), ...lines].join("\n"); - }); -}; - -const createZipBlobForMultipleCsvs = async ( - csvTexts: string[], - uniqueChainIds: number[], - samplingOpts: SamplingOpts, -) => { - const zip = new JSZip(); - // put them all in a folder called 'draws' - const folder = zip.folder("draws"); - if (!folder) throw new Error("Failed to create folder"); - csvTexts.forEach((text, i) => { - folder.file(`chain_${uniqueChainIds[i]}.csv`, text); - }); - const samplingOptsText = JSON.stringify(samplingOpts, null, 2); - folder.file("sampling_opts.json", samplingOptsText); - const blob = await zip.generateAsync({ type: "blob" }); - return blob; -}; - -const downloadTextFile = (text: string, filename: string) => { - const blob = new Blob([text], { type: "text/plain" }); - triggerDownload(blob, filename, () => {}); -}; - export default SamplerOutputView; diff --git a/gui/src/app/Scripting/Analysis/useAnalysisState.ts b/gui/src/app/Scripting/Analysis/useAnalysisState.ts index 74495674..26de83cb 100644 --- a/gui/src/app/Scripting/Analysis/useAnalysisState.ts +++ b/gui/src/app/Scripting/Analysis/useAnalysisState.ts @@ -22,9 +22,10 @@ const useAnalysisState = (latestRun: StanRun) => { useEffect(() => { clearOutputDivs(consoleRef, imagesRef); - }, [latestRun.draws]); + }, [latestRun.runResult?.draws]); - const { draws, paramNames, samplingOpts, status: samplerStatus } = latestRun; + const { runResult, samplingOpts, status: samplerStatus } = latestRun; + const { draws, paramNames } = runResult || {}; const numChains = samplingOpts?.num_chains; const spData = useMemo(() => { if (samplerStatus === "completed" && draws && numChains && paramNames) { diff --git a/gui/src/app/StanSampler/StanModelWorker.ts b/gui/src/app/StanSampler/StanModelWorker.ts index b71df3bb..1515a795 100644 --- a/gui/src/app/StanSampler/StanModelWorker.ts +++ b/gui/src/app/StanSampler/StanModelWorker.ts @@ -41,6 +41,7 @@ export type StanModelReplyMessage = draws: number[][]; paramNames: string[]; error: null; + consoleText: string; } | { purpose: Replies.Progress; @@ -77,8 +78,16 @@ const parseProgress = (msg: string): Progress => { return { chain, iteration, totalIterations, percent, warmup }; }; +let consoleText = ""; const progressPrintCallback = (msg: string) => { + if (!msg) { + return; + } if (!msg.startsWith("Chain") && !msg.startsWith("Iteration:")) { + // storing this has a not-insignificant overhead when a model + // has print statements, but is much faster than posting + // every single line to the main thread + consoleText += msg + "\n"; console.log(msg); return; } @@ -116,6 +125,7 @@ self.onmessage = (e: MessageEvent) => { return; } try { + consoleText = ""; const { paramNames, draws } = model.sample(e.data.sampleConfig); // TODO? use an ArrayBuffer so we can transfer without serialization cost postReply({ @@ -123,6 +133,7 @@ self.onmessage = (e: MessageEvent) => { draws, paramNames, error: null, + consoleText, }); } catch (e: any) { postReply({ purpose: Replies.StanReturn, error: e.toString() }); @@ -138,6 +149,7 @@ self.onmessage = (e: MessageEvent) => { return; } try { + consoleText = ""; const { draws, paramNames } = model.pathfinder(e.data.pathfinderConfig); // TODO? use an ArrayBuffer so we can transfer without serialization cost postReply({ @@ -145,6 +157,7 @@ self.onmessage = (e: MessageEvent) => { draws, paramNames, error: null, + consoleText, }); } catch (e: any) { postReply({ purpose: Replies.StanReturn, error: e.toString() }); diff --git a/gui/src/app/StanSampler/StanSampler.ts b/gui/src/app/StanSampler/StanSampler.ts index 69004520..29fa6ffa 100644 --- a/gui/src/app/StanSampler/StanSampler.ts +++ b/gui/src/app/StanSampler/StanSampler.ts @@ -78,6 +78,7 @@ class StanSampler { draws: e.data.draws, paramNames: e.data.paramNames, computeTimeSec: Date.now() / 1000 - this.#samplingStartTimeSec, + consoleText: e.data.consoleText, }); } break; diff --git a/gui/src/app/StanSampler/useStanSampler.ts b/gui/src/app/StanSampler/useStanSampler.ts index a79d56a4..428863db 100644 --- a/gui/src/app/StanSampler/useStanSampler.ts +++ b/gui/src/app/StanSampler/useStanSampler.ts @@ -12,9 +12,12 @@ export type StanRun = { progress?: Progress; samplingOpts?: SamplingOpts; data?: string; - draws?: number[][]; - paramNames?: string[]; - computeTimeSec?: number; + runResult?: { + consoleText: string; + draws: number[][]; + paramNames: string[]; + computeTimeSec: number; + }; }; const initialStanRun: StanRun = { @@ -43,6 +46,7 @@ export type StanRunAction = draws: number[][]; paramNames: string[]; computeTimeSec: number; + consoleText: string; }; export const StanRunReducer = ( @@ -73,9 +77,12 @@ export const StanRunReducer = ( return { ...state, status: "completed", - draws: action.draws, - paramNames: action.paramNames, - computeTimeSec: action.computeTimeSec, + runResult: { + draws: action.draws, + paramNames: action.paramNames, + computeTimeSec: action.computeTimeSec, + consoleText: action.consoleText, + }, }; default: return unreachable(action); diff --git a/gui/test/app/StanSampler/useStanSampler.test.ts b/gui/test/app/StanSampler/useStanSampler.test.ts index 5ddd94be..966509e7 100644 --- a/gui/test/app/StanSampler/useStanSampler.test.ts +++ b/gui/test/app/StanSampler/useStanSampler.test.ts @@ -131,7 +131,9 @@ describe("useStanSampler", () => { await waitFor(() => { expect(result.current.latestRun.status).toBe("completed"); - expect(result.current.latestRun.paramNames).toEqual(mockedParamNames); + expect(result.current.latestRun.runResult?.paramNames).toEqual( + mockedParamNames, + ); }); expect(mockedStderr).not.toHaveBeenCalled(); }); @@ -178,7 +180,9 @@ describe("useStanSampler", () => { await waitFor(() => { expect(result.current.latestRun.status).toBe("completed"); - expect(result.current.latestRun.paramNames).toEqual(mockedParamNames); + expect(result.current.latestRun.runResult?.paramNames).toEqual( + mockedParamNames, + ); }); expect(mockedStderr).not.toHaveBeenCalled(); @@ -206,17 +210,21 @@ describe("useStanSampler", () => { describe("outputs", () => { test("undefined sampler returns undefined", () => { const { result } = renderHook(() => useStanSampler(undefined)); - expect(result.current.latestRun.draws).toBeUndefined(); - expect(result.current.latestRun.paramNames).toBeUndefined(); - expect(result.current.latestRun.computeTimeSec).toBeUndefined(); + expect(result.current.latestRun.runResult?.draws).toBeUndefined(); + expect(result.current.latestRun.runResult?.paramNames).toBeUndefined(); + expect( + result.current.latestRun.runResult?.computeTimeSec, + ).toBeUndefined(); }); test("sampling changes output", async () => { const { result } = await loadedSampler(); - expect(result.current.latestRun.draws).toBeUndefined(); - expect(result.current.latestRun.paramNames).toBeUndefined(); - expect(result.current.latestRun.computeTimeSec).toBeUndefined(); + expect(result.current.latestRun.runResult?.draws).toBeUndefined(); + expect(result.current.latestRun.runResult?.paramNames).toBeUndefined(); + expect( + result.current.latestRun.runResult?.computeTimeSec, + ).toBeUndefined(); expect(result.current.latestRun.samplingOpts).toBeUndefined(); const testingSamplingOpts = { @@ -229,10 +237,15 @@ describe("useStanSampler", () => { }); await waitFor(() => { - expect(result.current.latestRun.draws).toEqual(mockedDraws); - expect(result.current.latestRun.paramNames).toEqual(mockedParamNames); expect(result.current.latestRun.samplingOpts).toBe(testingSamplingOpts); - expect(result.current.latestRun.computeTimeSec).toBeDefined(); + expect(result.current.latestRun.runResult).toBeDefined(); + expect(result.current.latestRun.runResult?.draws).toEqual(mockedDraws); + expect(result.current.latestRun.runResult?.paramNames).toEqual( + mockedParamNames, + ); + expect( + result.current.latestRun.runResult?.computeTimeSec, + ).toBeDefined(); }); expect(result.current.latestRun.status).toBe("completed");