Skip to content

Commit

Permalink
Merge pull request #1591 from cloudflare/milan/cs-own-ws
Browse files Browse the repository at this point in the history
Give WS a strong ref to `InputGate::CriticalSection`
  • Loading branch information
MellowYarker authored Feb 1, 2024
2 parents 2f37146 + 649c127 commit f0d3478
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
17 changes: 9 additions & 8 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void WebSocket::initConnection(jsg::Lock& js, kj::Promise<PackedWebSocket> prom)
}

// Fire open event.
internalAccept(js);
internalAccept(js, IoContext::current().getCriticalSection());
dispatchOpen(js);
}).catch_(js, [this, self = JSG_THIS](jsg::Lock& js, jsg::Value&& e) mutable {
// Fire error event.
Expand Down Expand Up @@ -375,14 +375,14 @@ void WebSocket::accept(jsg::Lock& js) {
return;
}

internalAccept(js);
internalAccept(js, IoContext::current().getCriticalSection());
}

void WebSocket::internalAccept(jsg::Lock& js) {
void WebSocket::internalAccept(jsg::Lock& js, kj::Maybe<kj::Own<InputGate::CriticalSection>> cs) {
auto& native = *farNative;
auto nativeWs = kj::mv(KJ_ASSERT_NONNULL(native.state.tryGet<AwaitingAcceptanceOrCoupling>()).ws);
native.state.init<Accepted>(kj::mv(nativeWs), native, IoContext::current());
return startReadLoop(js);
return startReadLoop(js, kj::mv(cs));
}

WebSocket::Accepted::Accepted(kj::Own<kj::WebSocket> wsParam, Native& native, IoContext& context)
Expand Down Expand Up @@ -448,15 +448,15 @@ WebSocket::Accepted::~Accepted() noexcept(false) {
}
}

void WebSocket::startReadLoop(jsg::Lock& js) {
void WebSocket::startReadLoop(jsg::Lock& js, kj::Maybe<kj::Own<InputGate::CriticalSection>> cs) {
// If the kj::WebSocket happens to be an AbortableWebSocket (see util/abortable.h), then
// calling readLoop here could throw synchronously if the canceler has already been tripped.
// Using kj::evalNow() here let's us capture that and handle correctly.
//
// We catch exceptions and return Maybe<Exception> instead since we want to handle the exceptions
// in awaitIo() below, but we don't want the KJ exception converted to JavaScript before we can
// examine it.
kj::Promise<kj::Maybe<kj::Exception>> promise = readLoop();
kj::Promise<kj::Maybe<kj::Exception>> promise = readLoop(kj::mv(cs));

auto& context = IoContext::current();

Expand Down Expand Up @@ -901,7 +901,8 @@ kj::Array<kj::StringPtr> WebSocket::getHibernatableTags() {
return accepted.ws.getHibernatableTags();
}

kj::Promise<kj::Maybe<kj::Exception>> WebSocket::readLoop() {
kj::Promise<kj::Maybe<kj::Exception>> WebSocket::readLoop(
kj::Maybe<kj::Own<InputGate::CriticalSection>> cs) {
try {
// Note that we'll throw if the websocket has enabled hibernation.
auto& ws = *KJ_REQUIRE_NONNULL(
Expand Down Expand Up @@ -947,7 +948,7 @@ kj::Promise<kj::Maybe<kj::Exception>> WebSocket::readLoop() {
}

return true;
});
}, mapAddRef(cs));

if (!result) co_return kj::none;
}
Expand Down
6 changes: 3 additions & 3 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,11 @@ class WebSocket: public EventTarget {
// Same as accept(), but websockets that are created with `new WebSocket()` in JS cannot call
// accept(). Instead, we only permit the C++ constructor to call this "internal" version of accept()
// so that the websocket can start processing messages once the connection has been established.
void internalAccept(jsg::Lock& js);
void internalAccept(jsg::Lock& js, kj::Maybe<kj::Own<InputGate::CriticalSection>> cs);

// We defer the actual logic of accept() and internalAccept() to this method, since they largely
// share code.
void startReadLoop(jsg::Lock& js);
void startReadLoop(jsg::Lock& js, kj::Maybe<kj::Own<InputGate::CriticalSection>> cs);

void send(jsg::Lock& js, kj::OneOf<kj::Array<byte>, kj::String> message);
void close(jsg::Lock& js, jsg::Optional<int> code, jsg::Optional<kj::String> reason);
Expand Down Expand Up @@ -630,7 +630,7 @@ class WebSocket: public EventTarget {
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse);

kj::Promise<kj::Maybe<kj::Exception>> readLoop();
kj::Promise<kj::Maybe<kj::Exception>> readLoop(kj::Maybe<kj::Own<InputGate::CriticalSection>> cs);

void reportError(jsg::Lock& js, kj::Exception&& e);
void reportError(jsg::Lock& js, jsg::JsRef<jsg::JsValue> err);
Expand Down

0 comments on commit f0d3478

Please sign in to comment.