Skip to content

Commit

Permalink
Convert DigestStream into JS-backed stream
Browse files Browse the repository at this point in the history
The original implementation of `DigestStream` used a WritableStreamSink
and the old "internal" streams implementation. This is unnecessary and
makes the implementation less efficient that it otherwise could be due
to the operations having to bounce between the kj event loop and the
JS isolate lock. This change converts `DigestStream` into a JS-backed
stream where the operations occur within the isolate lock, avoiding
the bounce and avoiding having to reacquire the isolate lock after
each chunk, etc. Also avoids the need for IoContext addObject / IoOwn
dereferencing. Implementation should be safer and more efficient.

Also allow DigestStream to accept string inputs. If a string is written
to a DigestStream, it will be converted to UTF-8 bytes and included in
the digest.

Also allow DigestStream to support `Symbol.dispose`.

Improves tests a bit also.
  • Loading branch information
jasnell committed Apr 24, 2024
1 parent 4089c39 commit 618ba24
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 103 deletions.
193 changes: 133 additions & 60 deletions src/workerd/api/crypto.c++
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "crypto.h"
#include "crypto-impl.h"
#include "streams/standard.h"
#include <array>
#include <openssl/crypto.h>
#include <openssl/err.h>
Expand Down Expand Up @@ -655,100 +656,172 @@ kj::String Crypto::randomUUID() {
// =======================================================================================
// Crypto Streams implementation

namespace {
DigestStreamSink::DigestContextPtr initContext(DigestStreamSink::HashAlgorithm& algorithm) {
DigestStream::DigestContextPtr DigestStream::initContext(SubtleCrypto::HashAlgorithm& algorithm) {
auto checkErrorsOnFinish = webCryptoOperationBegin(__func__, algorithm.name);
auto type = lookupDigestAlgorithm(algorithm.name).second;
auto context = makeDigestContext();
KJ_ASSERT(context != nullptr);
OSSLCALL(EVP_DigestInit_ex(context.get(), type, nullptr));
return kj::mv(context);
}
} // namespace

DigestStreamSink::DigestStreamSink(
HashAlgorithm algorithm,
kj::Own<kj::PromiseFulfiller<kj::Array<kj::byte>>> fulfiller)
: algorithm(kj::mv(algorithm)),
state(initContext(this->algorithm)),
fulfiller(kj::mv(fulfiller)) {}

DigestStreamSink::~DigestStreamSink() {
if (fulfiller && fulfiller->isWaiting()) {
fulfiller->reject(JSG_KJ_EXCEPTION(FAILED, Error,
"The digest was never completed. The DigestStream was created but possibly never "
"used or finished."));
DigestStream::DigestStream(
kj::Own<WritableStreamController> controller,
SubtleCrypto::HashAlgorithm algorithm,
jsg::Promise<kj::Array<kj::byte>>::Resolver resolver,
jsg::Promise<kj::Array<kj::byte>> promise)
: WritableStream(kj::mv(controller)),
promise(kj::mv(promise)),
state(Ready(kj::mv(algorithm), kj::mv(resolver))) {}

void DigestStream::dispose(jsg::Lock& js) {
js.tryCatch([&] {
KJ_IF_SOME(ready, state.tryGet<Ready>()) {
auto reason = js.typeError("The DigestStream was disposed.");
ready.resolver.reject(js, reason);
state.init<StreamStates::Errored>(js.v8Ref<v8::Value>(reason));
}
}, [&](jsg::Value exception) {
js.throwException(kj::mv(exception));
});
}

void DigestStream::visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
tracker.trackField("promise", promise);
KJ_IF_SOME(ready, state.tryGet<Ready>()) {
tracker.trackField("resolver", ready.resolver);
}
}

void DigestStream::visitForGc(jsg::GcVisitor& visitor) {
visitor.visit(promise);
KJ_IF_SOME(ready, state.tryGet<Ready>()) {
visitor.visit(ready.resolver);
}
}

kj::Promise<void> DigestStreamSink::write(const void* buffer, size_t size) {
kj::Maybe<StreamStates::Errored> DigestStream::write(jsg::Lock& js, kj::ArrayPtr<kj::byte> buffer) {
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(closed, Closed) {
return kj::READY_NOW;
KJ_CASE_ONEOF(closed, StreamStates::Closed) {
return kj::none;
}
KJ_CASE_ONEOF(errored, Errored) {
return kj::cp(errored);
KJ_CASE_ONEOF(errored, StreamStates::Errored) {
return errored.addRef(js);
}
KJ_CASE_ONEOF(context, DigestContextPtr) {
auto checkErrorsOnFinish = webCryptoOperationBegin(__func__, algorithm.name);
OSSLCALL(EVP_DigestUpdate(context.get(), buffer, size));
return kj::READY_NOW;
KJ_CASE_ONEOF(ready, Ready) {
auto checkErrorsOnFinish = webCryptoOperationBegin(__func__, ready.algorithm.name);
OSSLCALL(EVP_DigestUpdate(ready.context.get(), buffer.begin(), buffer.size()));
return kj::none;
}
}
KJ_UNREACHABLE;
}

kj::Promise<void> DigestStreamSink::write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
for (auto& piece : pieces) {
co_await write(piece.begin(), piece.size());
}
}

kj::Promise<void> DigestStreamSink::end() {
kj::Maybe<StreamStates::Errored> DigestStream::close(jsg::Lock& js) {
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(closed, Closed) {
return kj::READY_NOW;
KJ_CASE_ONEOF(closed, StreamStates::Closed) {
return kj::none;
}
KJ_CASE_ONEOF(errored, Errored) {
return kj::cp(errored);
KJ_CASE_ONEOF(errored, StreamStates::Errored) {
return errored.addRef(js);
}
KJ_CASE_ONEOF(context, DigestContextPtr) {
auto checkErrorsOnFinish = webCryptoOperationBegin(__func__, algorithm.name);
KJ_CASE_ONEOF(ready, Ready) {
auto checkErrorsOnFinish = webCryptoOperationBegin(__func__, ready.algorithm.name);
uint size = 0;
auto digest = kj::heapArray<kj::byte>(EVP_MD_CTX_size(context.get()));
OSSLCALL(EVP_DigestFinal_ex(context.get(), digest.begin(), &size));
auto digest = kj::heapArray<kj::byte>(EVP_MD_CTX_size(ready.context.get()));
OSSLCALL(EVP_DigestFinal_ex(ready.context.get(), digest.begin(), &size));
KJ_ASSERT(size, digest.size());
state.init<Closed>();
fulfiller->fulfill(kj::mv(digest));
return kj::READY_NOW;
ready.resolver.resolve(js, kj::mv(digest));
state.init<StreamStates::Closed>();
return kj::none;
}
}
KJ_UNREACHABLE;
}

void DigestStreamSink::abort(kj::Exception reason) {
fulfiller->reject(kj::cp(reason));
state.init<Errored>(kj::mv(reason));
void DigestStream::abort(jsg::Lock& js, jsg::JsValue reason) {
// If the state is already closed or errored, then this is a non-op
KJ_IF_SOME(ready, state.tryGet<Ready>()) {
ready.resolver.reject(js, reason);
state.init<StreamStates::Errored>(js.v8Ref<v8::Value>(reason));
}
}

DigestStream::DigestStream(
HashAlgorithm algorithm,
kj::Own<kj::PromiseFulfiller<kj::Array<kj::byte>>> fulfiller,
jsg::Promise<kj::Array<kj::byte>> promise)
: WritableStream(IoContext::current(),
kj::heap<DigestStreamSink>(kj::mv(algorithm), kj::mv(fulfiller))),
promise(kj::mv(promise)) {}

jsg::Ref<DigestStream> DigestStream::constructor(jsg::Lock& js, Algorithm algorithm) {
auto paf = kj::newPromiseAndFulfiller<kj::Array<kj::byte>>();
auto paf = js.newPromiseAndResolver<kj::Array<kj::byte>>();

auto jsPromise = IoContext::current().awaitIoLegacy(js, kj::mv(paf.promise));
jsPromise.markAsHandled(js);

return jsg::alloc<DigestStream>(
auto stream = jsg::alloc<DigestStream>(
newWritableStreamJsController(),
interpretAlgorithmParam(kj::mv(algorithm)),
kj::mv(paf.fulfiller),
kj::mv(jsPromise));
kj::mv(paf.resolver),
kj::mv(paf.promise));

stream->getController().setup(js, UnderlyingSink {
.write = [&stream=*stream](jsg::Lock& js, v8::Local<v8::Value> chunk, auto c) mutable {
return js.tryCatch([&] {
// Make sure what we got can be interpreted as bytes...
std::shared_ptr<v8::BackingStore> backing;
size_t byteLength = 0;
size_t byteOffset = 0;
if (chunk->IsArrayBuffer()) {
auto ab = chunk.As<v8::ArrayBuffer>();
backing = ab->GetBackingStore();
byteLength = ab->ByteLength();
} else if (chunk->IsArrayBufferView()) {
auto abv = chunk.As<v8::ArrayBufferView>();
backing = abv->Buffer()->GetBackingStore();
byteLength = abv->ByteLength();
byteOffset = abv->ByteOffset();
} else if (chunk->IsString()) {
// If we receive a string, we'll convert that to UTF-8 bytes and digest that.
auto str = js.toString(chunk);
if (str.size() == 0) return js.resolvedPromise();
KJ_IF_SOME(error, stream.write(js, str.asBytes())) {
return js.rejectedPromise<void>(kj::mv(error));
}
stream.bytesWritten += str.size();
return js.resolvedPromise();
} else {
return js.rejectedPromise<void>(js.typeError(
"DigestStream is a byte stream but received an object of "
"non-ArrayBuffer/ArrayBufferView/string type on its writable side."));
}

if (byteLength == 0) return js.resolvedPromise();
kj::ArrayPtr<kj::byte> ptr(static_cast<kj::byte*>(backing->Data()) + byteOffset, byteLength);
// If sink.write returns a non kj::none value, that means the sink was errored
// and we return a rejected promise here. Otherwise, we return resolved.
KJ_IF_SOME(error, stream.write(js, ptr)) {
return js.rejectedPromise<void>(kj::mv(error));
} else {} // Here to silence a compiler warning
stream.bytesWritten += byteLength;
return js.resolvedPromise();
}, [&](jsg::Value exception) {
return js.rejectedPromise<void>(kj::mv(exception));
});
},
.abort = [&stream=*stream](jsg::Lock& js, auto reason) mutable {
return js.tryCatch([&] {
stream.abort(js, jsg::JsValue(reason));
return js.resolvedPromise();
}, [&](jsg::Value exception) {
return js.rejectedPromise<void>(kj::mv(exception));
});
},
.close = [&stream=*stream](jsg::Lock& js) mutable {
return js.tryCatch([&] {
// If sink.close returns a non kj::none value, that means the sink was errored
// and we return a rejected promise here. Otherwise, we return resolved.
KJ_IF_SOME(error, stream.close(js)) {
return js.rejectedPromise<void>(kj::mv(error));
} else {} // Here to silence a compiler warning
return js.resolvedPromise();
}, [&](jsg::Value exception) {
return js.rejectedPromise<void>(kj::mv(exception));
});
}}, kj::none);

return kj::mv(stream);
}

} // namespace workerd::api
71 changes: 29 additions & 42 deletions src/workerd/api/crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,54 +600,25 @@ class SubtleCrypto: public jsg::Object {
};

// =======================================================================================
class DigestStreamSink: public WritableStreamSink {
public:
using HashAlgorithm = SubtleCrypto::HashAlgorithm;
using DigestContextPtr = std::unique_ptr<EVP_MD_CTX, void(*)(EVP_MD_CTX*)>;

explicit DigestStreamSink(
HashAlgorithm algorithm,
kj::Own<kj::PromiseFulfiller<kj::Array<kj::byte>>> fulfiller);

virtual ~DigestStreamSink();

kj::Promise<void> write(const void* buffer, size_t size) override;

kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override;

kj::Promise<void> end() override;

void abort(kj::Exception reason) override;

private:
struct Closed {};
using Errored = kj::Exception;

SubtleCrypto::HashAlgorithm algorithm;
kj::OneOf<DigestContextPtr, Closed, Errored> state;
kj::Own<kj::PromiseFulfiller<kj::Array<kj::byte>>> fulfiller;
};

// DigestStream is a non-standard extension that provides a way of generating
// a hash digest from streaming data. It combines Web Crypto concepts into a
// WritableStream and is compatible with both APIs.
class DigestStream: public WritableStream {
public:
using HashAlgorithm = DigestStreamSink::HashAlgorithm;
using Algorithm = kj::OneOf<kj::String, HashAlgorithm>;
using DigestContextPtr = std::unique_ptr<EVP_MD_CTX, void(*)(EVP_MD_CTX*)>;
using Algorithm = kj::OneOf<kj::String, SubtleCrypto::HashAlgorithm>;

explicit DigestStream(
HashAlgorithm algorithm,
kj::Own<kj::PromiseFulfiller<kj::Array<kj::byte>>> fulfiller,
kj::Own<WritableStreamController> controller,
SubtleCrypto::HashAlgorithm algorithm,
jsg::Promise<kj::Array<kj::byte>>::Resolver resolver,
jsg::Promise<kj::Array<kj::byte>> promise);

static jsg::Ref<DigestStream> constructor(jsg::Lock& js, Algorithm algorithm);

jsg::MemoizedIdentity<jsg::Promise<kj::Array<kj::byte>>>& getDigest() { return promise; }

kj::Own<WritableStreamSink> removeSink(jsg::Lock& js) override {
KJ_UNIMPLEMENTED("DigestStream::removeSink is not implemented");
}
void dispose(jsg::Lock& js);
uint64_t getBytesWritten() const { return bytesWritten; }

JSG_RESOURCE_TYPE(DigestStream, CompatibilityFlags::Reader flags) {
JSG_INHERIT(WritableStream);
Expand All @@ -656,20 +627,36 @@ class DigestStream: public WritableStream {
} else {
JSG_READONLY_INSTANCE_PROPERTY(digest, getDigest);
}
JSG_READONLY_PROTOTYPE_PROPERTY(bytesWritten, getBytesWritten);
JSG_DISPOSE(dispose);

JSG_TS_OVERRIDE(extends WritableStream<ArrayBuffer | ArrayBufferView>);
}

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
tracker.trackField("promise", promise);
}
void visitForMemoryInfo(jsg::MemoryTracker& tracker) const;

private:
static DigestContextPtr initContext(SubtleCrypto::HashAlgorithm& algorithm);

struct Ready {
SubtleCrypto::HashAlgorithm algorithm;
jsg::Promise<kj::Array<kj::byte>>::Resolver resolver;
DigestContextPtr context;
Ready(SubtleCrypto::HashAlgorithm algorithm,
jsg::Promise<kj::Array<kj::byte>>::Resolver resolver)
: algorithm(kj::mv(algorithm)),
resolver(kj::mv(resolver)),
context(initContext(this->algorithm)) {}
};
jsg::MemoizedIdentity<jsg::Promise<kj::Array<kj::byte>>> promise;
kj::OneOf<Ready, StreamStates::Closed, StreamStates::Errored> state;
uint64_t bytesWritten = 0;

void visitForGc(jsg::GcVisitor& visitor) {
visitor.visit(promise);
}
kj::Maybe<StreamStates::Errored> write(jsg::Lock& js, kj::ArrayPtr<kj::byte> buffer);
kj::Maybe<StreamStates::Errored> close(jsg::Lock& js);
void abort(jsg::Lock& js, jsg::JsValue reason);

void visitForGc(jsg::GcVisitor& visitor);
};

// =======================================================================================
Expand Down
Loading

0 comments on commit 618ba24

Please sign in to comment.