Skip to content

Commit 2863624

Browse files
authored
Python snaphots: add workerd cli option to save memory snapshot to disk (#1878)
1 parent db44674 commit 2863624

12 files changed

+123
-48
lines changed

src/pyodide/internal/metadata.js

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { default as ArtifactBundler } from "pyodide-internal:artifacts";
55

66
export const IS_WORKERD = MetadataReader.isWorkerd();
77
export const IS_TRACING = MetadataReader.isTracing();
8+
export const SHOULD_SNAPSHOT_TO_DISK = MetadataReader.shouldSnapshotToDisk();
89
export const IS_CREATING_BASELINE_SNAPSHOT =
910
MetadataReader.isCreatingBaselineSnapshot();
1011
export const WORKERD_INDEX_URL = PYODIDE_BUCKET.PYODIDE_PACKAGE_BUCKET_URL;

src/pyodide/internal/snapshot.js

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { default as ArtifactBundler } from "pyodide-internal:artifacts";
22
import { default as UnsafeEval } from "internal:unsafe-eval";
3+
import { default as DiskCache } from "pyodide-internal:disk_cache";
34
import {
45
SITE_PACKAGES_INFO,
56
SITE_PACKAGES_SO_FILES,
@@ -8,6 +9,7 @@ import {
89
import { default as TarReader } from "pyodide-internal:packages_tar_reader";
910
import processScriptImports from "pyodide-internal:process_script_imports.py";
1011
import {
12+
SHOULD_SNAPSHOT_TO_DISK,
1113
IS_CREATING_BASELINE_SNAPSHOT,
1214
MEMORY_SNAPSHOT_READER,
1315
} from "pyodide-internal:metadata";
@@ -27,8 +29,8 @@ import { reportError, simpleRunPython } from "pyodide-internal:util";
2729
*/
2830
import { _createPyodideModule } from "pyodide-internal:generated/pyodide.asm";
2931

30-
const SHOULD_UPLOAD_SNAPSHOT =
31-
ArtifactBundler.isEnabled() || ArtifactBundler.isEwValidating();
32+
const TOP_LEVEL_SNAPSHOT = ArtifactBundler.isEwValidating() || SHOULD_SNAPSHOT_TO_DISK;
33+
const SHOULD_UPLOAD_SNAPSHOT = ArtifactBundler.isEnabled() || TOP_LEVEL_SNAPSHOT;
3234

3335
/**
3436
* Global variable for the memory snapshot. On the first run we stick a copy of
@@ -61,11 +63,13 @@ export async function uploadArtifacts() {
6163
* Used to hold the memory that needs to be uploaded for the validator.
6264
*/
6365
let MEMORY_TO_UPLOAD = undefined;
64-
export function getMemoryToUpload() {
66+
function getMemoryToUpload() {
6567
if (!MEMORY_TO_UPLOAD) {
6668
throw new TypeError("Expected MEMORY_TO_UPLOAD to be set");
6769
}
68-
return MEMORY_TO_UPLOAD;
70+
const tmp = MEMORY_TO_UPLOAD;
71+
MEMORY_TO_UPLOAD = undefined;
72+
return tmp;
6973
}
7074

7175
/**
@@ -281,8 +285,9 @@ function setUploadFunction(toUpload) {
281285
if (toUpload.constructor.name !== "Uint8Array") {
282286
throw new TypeError("Expected TO_UPLOAD to be a Uint8Array");
283287
}
284-
if (ArtifactBundler.isEwValidating()) {
288+
if (TOP_LEVEL_SNAPSHOT) {
285289
MEMORY_TO_UPLOAD = toUpload;
290+
return;
286291
}
287292
DEFERRED_UPLOAD_FUNCTION = async () => {
288293
try {
@@ -412,3 +417,11 @@ export function finishSnapshotSetup(pyodide) {
412417
});
413418
}
414419
}
420+
421+
export function maybeStoreMemorySnapshot() {
422+
if (ArtifactBundler.isEwValidating()) {
423+
ArtifactBundler.storeMemorySnapshot(getMemoryToUpload());
424+
} else if (SHOULD_SNAPSHOT_TO_DISK) {
425+
DiskCache.put("snapshot.bin", getMemoryToUpload());
426+
}
427+
}

src/pyodide/python-entrypoint-helper.js

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// python-entrypoint.js USER module.
33

44
import { loadPyodide } from "pyodide-internal:python";
5-
import { uploadArtifacts, getMemoryToUpload } from "pyodide-internal:snapshot";
5+
import { uploadArtifacts, maybeStoreMemorySnapshot } from "pyodide-internal:snapshot";
66
import { enterJaegerSpan } from "pyodide-internal:jaeger";
77
import {
88
REQUIREMENTS,
@@ -17,7 +17,6 @@ import {
1717
MAIN_MODULE_NAME,
1818
WORKERD_INDEX_URL,
1919
} from "pyodide-internal:metadata";
20-
import { default as ArtifactBundler } from "pyodide-internal:artifacts";
2120
import { reportError } from "pyodide-internal:util";
2221
import { default as Limiter } from "pyodide-internal:limiter";
2322

@@ -201,12 +200,13 @@ try {
201200
}
202201
}
203202
}
204-
205-
// Store the memory snapshot in the ArtifactBundler so that the validator can read it out.
206-
// Needs to happen at the top level because the validator does not perform requests.
207-
if (ArtifactBundler.isEwValidating()) {
208-
ArtifactBundler.storeMemorySnapshot(getMemoryToUpload());
209-
}
203+
/**
204+
* Store the memory snapshot in the ArtifactBundler so that the validator can
205+
* read it out. Needs to happen at the top level because the validator does
206+
* not perform requests. In workerd, this will save a snapshot to disk if the
207+
* `--python-save-snapshot` or `--python-save-baseline-snapshot` is passed.
208+
*/
209+
maybeStoreMemorySnapshot();
210210
} catch (e) {
211211
console.warn("Error in top level in python-entrypoint-helper.js");
212212
reportError(e);

src/workerd/api/pyodide/pyodide.c++

+17-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ int ArtifactBundler::readMemorySnapshot(int offset, kj::Array<kj::byte> buf) {
7070
return readToTarget(KJ_REQUIRE_NONNULL(existingSnapshot), offset, buf);
7171
}
7272

73-
jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf) {
73+
jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf, const PythonConfig& pythonConfig) {
7474
auto modules = conf.getModules();
7575
auto mainModule = kj::str(modules.begin()->getName());
7676
int numFiles = 0;
@@ -117,9 +117,22 @@ jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf) {
117117
}
118118
names.add(kj::str(module.getName()));
119119
}
120-
return jsg::alloc<PyodideMetadataReader>(kj::mv(mainModule), names.finish(), contents.finish(),
121-
requirements.finish(), true /* isWorkerd */,
122-
false /* isTracing */, false /* createBaselineSnapshot */, kj::none /* memorySnapshot */);
120+
bool createSnapshot = pythonConfig.createSnapshot;
121+
bool createBaselineSnapshot = pythonConfig.createBaselineSnapshot;
122+
bool snapshotToDisk = createSnapshot || createBaselineSnapshot;
123+
// clang-format off
124+
return jsg::alloc<PyodideMetadataReader>(
125+
kj::mv(mainModule),
126+
names.finish(),
127+
contents.finish(),
128+
requirements.finish(),
129+
true /* isWorkerd */,
130+
false /* isTracing */,
131+
snapshotToDisk,
132+
createBaselineSnapshot,
133+
kj::none /* memorySnapshot */
134+
);
135+
// clang-format on
123136
}
124137

125138
const kj::Maybe<kj::Own<const kj::Directory>> DiskCache::NULL_CACHE_ROOT = kj::none;

src/workerd/api/pyodide/pyodide.h

+28-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212

1313
namespace workerd::api::pyodide {
1414

15+
struct PythonConfig {
16+
kj::Maybe<kj::Own<const kj::Directory>> diskCacheRoot;
17+
bool createSnapshot;
18+
bool createBaselineSnapshot;
19+
};
20+
1521
// A function to read a segment of the tar file into a buffer
1622
// Set up this way to avoid copying files that aren't accessed.
1723
class PackagesTarReader : public jsg::Object {
@@ -36,17 +42,20 @@ class PyodideMetadataReader : public jsg::Object {
3642
kj::Array<kj::String> requirements;
3743
bool isWorkerdFlag;
3844
bool isTracingFlag;
45+
bool snapshotToDisk;
3946
bool createBaselineSnapshot;
4047
kj::Maybe<kj::Array<kj::byte>> memorySnapshot;
4148

4249
public:
4350
PyodideMetadataReader(kj::String mainModule, kj::Array<kj::String> names,
4451
kj::Array<kj::Array<kj::byte>> contents, kj::Array<kj::String> requirements,
4552
bool isWorkerd, bool isTracing,
53+
bool snapshotToDisk,
4654
bool createBaselineSnapshot,
4755
kj::Maybe<kj::Array<kj::byte>> memorySnapshot)
4856
: mainModule(kj::mv(mainModule)), names(kj::mv(names)), contents(kj::mv(contents)),
4957
requirements(kj::mv(requirements)), isWorkerdFlag(isWorkerd), isTracingFlag(isTracing),
58+
snapshotToDisk(snapshotToDisk),
5059
createBaselineSnapshot(createBaselineSnapshot),
5160
memorySnapshot(kj::mv(memorySnapshot)) {}
5261

@@ -58,6 +67,10 @@ class PyodideMetadataReader : public jsg::Object {
5867
return this->isTracingFlag;
5968
}
6069

70+
bool shouldSnapshotToDisk() {
71+
return snapshotToDisk;
72+
}
73+
6174
bool isCreatingBaselineSnapshot() {
6275
return createBaselineSnapshot;
6376
}
@@ -101,6 +114,7 @@ class PyodideMetadataReader : public jsg::Object {
101114
JSG_METHOD(getMemorySnapshotSize);
102115
JSG_METHOD(readMemorySnapshot);
103116
JSG_METHOD(disposeMemorySnapshot);
117+
JSG_METHOD(shouldSnapshotToDisk);
104118
JSG_METHOD(isCreatingBaselineSnapshot);
105119
}
106120

@@ -131,24 +145,21 @@ class ArtifactBundler : public jsg::Object {
131145
existingSnapshot(kj::mv(existingSnapshot)),
132146
uploadMemorySnapshotCb(kj::mv(uploadMemorySnapshotCb)),
133147
hasUploaded(false),
134-
isValidating(false),
135-
createBaselineSnapshot(false) {};
148+
isValidating(false) {};
136149

137150
ArtifactBundler(kj::Maybe<kj::Array<kj::byte>> existingSnapshot)
138151
: storedSnapshot(kj::none),
139152
existingSnapshot(kj::mv(existingSnapshot)),
140153
uploadMemorySnapshotCb(kj::none),
141154
hasUploaded(false),
142-
isValidating(false),
143-
createBaselineSnapshot(false) {};
155+
isValidating(false){};
144156

145157
ArtifactBundler(bool isValidating = false)
146158
: storedSnapshot(kj::none),
147159
existingSnapshot(kj::none),
148160
uploadMemorySnapshotCb(kj::none),
149161
hasUploaded(false),
150-
isValidating(isValidating),
151-
createBaselineSnapshot(false) {};
162+
isValidating(isValidating){};
152163

153164
jsg::Promise<bool> uploadMemorySnapshot(jsg::Lock& js, kj::Array<kj::byte> snapshot) {
154165
// Prevent multiple uploads.
@@ -228,7 +239,6 @@ class ArtifactBundler : public jsg::Object {
228239
kj::Maybe<kj::Function<kj::Promise<bool>(kj::Array<kj::byte> snapshot)>> uploadMemorySnapshotCb;
229240
bool hasUploaded;
230241
bool isValidating;
231-
bool createBaselineSnapshot;
232242
};
233243

234244

@@ -243,13 +253,19 @@ class DisabledInternalJaeger : public jsg::Object {
243253

244254
// This cache is used by Pyodide to store wheels fetched over the internet across workerd restarts in local dev only
245255
class DiskCache: public jsg::Object {
256+
private:
246257
static const kj::Maybe<kj::Own<const kj::Directory>> NULL_CACHE_ROOT; // always set to kj::none
247258

248259
const kj::Maybe<kj::Own<const kj::Directory>> &cacheRoot;
260+
249261
public:
250262
DiskCache(): cacheRoot(NULL_CACHE_ROOT) {}; // Disabled disk cache
251263
DiskCache(const kj::Maybe<kj::Own<const kj::Directory>> &cacheRoot): cacheRoot(cacheRoot) {};
252264

265+
static jsg::Ref<DiskCache> makeDisabled() {
266+
return jsg::alloc<DiskCache>();
267+
}
268+
253269
jsg::Optional<kj::Array<kj::byte>> get(jsg::Lock& js, kj::String key);
254270
void put(jsg::Lock& js, kj::String key, kj::Array<kj::byte> data);
255271

@@ -275,14 +291,14 @@ class SimplePythonLimiter : public jsg::Object {
275291

276292

277293
public:
278-
SimplePythonLimiter(int startupLimitMs, kj::Function<kj::TimePoint()> getTimeCb) :
279-
startupLimitMs(startupLimitMs),
280-
getTimeCb(kj::mv(getTimeCb)) {}
281-
282294
SimplePythonLimiter() :
283295
startupLimitMs(0),
284296
getTimeCb(kj::none) {}
285297

298+
SimplePythonLimiter(int startupLimitMs, kj::Function<kj::TimePoint()> getTimeCb) :
299+
startupLimitMs(startupLimitMs),
300+
getTimeCb(kj::mv(getTimeCb)) {}
301+
286302
static jsg::Ref<SimplePythonLimiter> makeDisabled() {
287303
return jsg::alloc<SimplePythonLimiter>();
288304
}
@@ -313,7 +329,7 @@ class SimplePythonLimiter : public jsg::Object {
313329

314330
using Worker = server::config::Worker;
315331

316-
jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf);
332+
jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf, const PythonConfig& pythonConfig);
317333

318334
bool hasPythonModules(capnp::List<server::config::Worker::Module>::Reader modules);
319335

src/workerd/server/server.c++

+2-1
Original file line numberDiff line numberDiff line change
@@ -2680,7 +2680,8 @@ kj::Own<Server::Service> Server::makeWorker(kj::StringPtr name, config::Worker::
26802680
*limitEnforcer,
26812681
kj::atomicAddRef(*observer),
26822682
*memoryCacheProvider,
2683-
diskCacheRoot);
2683+
pythonConfig
2684+
);
26842685
auto inspectorPolicy = Worker::Isolate::InspectorPolicy::DISALLOW;
26852686
if (inspectorOverride != kj::none) {
26862687
// For workerd, if the inspector is enabled, it is always fully trusted.

src/workerd/server/server.h

+16-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <kj/async-io.h>
1111
#include <workerd/io/worker.h>
1212
#include <workerd/api/memory-cache.h>
13+
#include <workerd/api/pyodide/pyodide.h>
1314
#include <workerd/server/workerd.capnp.h>
1415
#include <workerd/util/sqlite.h>
1516
#include <workerd/server/alarm-scheduler.h>
@@ -25,6 +26,8 @@ namespace workerd::jsg {
2526

2627
namespace workerd::server {
2728

29+
using api::pyodide::PythonConfig;
30+
2831
// Implements the single-tenant Workers Runtime server / CLI.
2932
//
3033
// The purpose of this class is to implement the core logic independently of the CLI itself,
@@ -58,8 +61,14 @@ class Server: private kj::TaskSet::ErrorHandler {
5861
void enableControl(uint fd) {
5962
controlOverride = kj::heap<kj::FdOutputStream>(fd);
6063
}
61-
void setDiskCacheRoot(kj::Maybe<kj::Own<const kj::Directory>> &&dkr) {
62-
diskCacheRoot = kj::mv(dkr);
64+
void setPythonDiskCacheRoot(kj::Maybe<kj::Own<const kj::Directory>> &&dkr) {
65+
pythonConfig.diskCacheRoot = kj::mv(dkr);
66+
}
67+
void setPythonCreateSnapshot() {
68+
pythonConfig.createSnapshot = true;
69+
}
70+
void setPythonCreateBaselineSnapshot() {
71+
pythonConfig.createBaselineSnapshot = true;
6372
}
6473

6574
// Runs the server using the given config.
@@ -93,7 +102,11 @@ class Server: private kj::TaskSet::ErrorHandler {
93102
kj::Network& network;
94103
kj::EntropySource& entropySource;
95104
kj::Function<void(kj::String)> reportConfigError;
96-
kj::Maybe<kj::Own<const kj::Directory>> diskCacheRoot;
105+
PythonConfig pythonConfig = PythonConfig {
106+
.diskCacheRoot = kj::none,
107+
.createSnapshot = false,
108+
.createBaselineSnapshot = false
109+
};
97110

98111
bool experimental = false;
99112

0 commit comments

Comments
 (0)