Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use weak refs between WebSockets in a WebSocketPair #2161

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ WebSocket::WebSocket(jsg::Lock& js,
IoContext& ioContext,
kj::WebSocket& ws,
HibernationPackage package)
: url(kj::mv(package.url)),
: weakRef(kj::refcounted<WeakRef<WebSocket>>(kj::Badge<WebSocket> {}, *this)),
url(kj::mv(package.url)),
protocol(kj::mv(package.protocol)),
extensions(kj::mv(package.extensions)),
serializedAttachment(kj::mv(package.serializedAttachment)),
Expand All @@ -71,7 +72,8 @@ jsg::Ref<WebSocket> WebSocket::hibernatableFromNative(
}

WebSocket::WebSocket(kj::Own<kj::WebSocket> native, Locality locality)
: url(kj::none),
: weakRef(kj::refcounted<WeakRef<WebSocket>>(kj::Badge<WebSocket> {}, *this)),
url(kj::none),
farNative(nullptr),
outgoingMessages(IoContext::current().addObject(kj::heap<OutgoingMessagesMap>())),
locality(locality) {
Expand All @@ -81,7 +83,8 @@ WebSocket::WebSocket(kj::Own<kj::WebSocket> native, Locality locality)
}

WebSocket::WebSocket(kj::String url, Locality locality)
: url(kj::mv(url)),
: weakRef(kj::refcounted<WeakRef<WebSocket>>(kj::Badge<WebSocket> {}, *this)),
url(kj::mv(url)),
farNative(nullptr),
outgoingMessages(IoContext::current().addObject(kj::heap<OutgoingMessagesMap>())),
locality(locality) {
Expand Down Expand Up @@ -968,8 +971,8 @@ jsg::Ref<WebSocketPair> WebSocketPair::constructor() {
auto first = pair->getFirst();
auto second = pair->getSecond();

first->setMaybePair(second.addRef());
second->setMaybePair(first.addRef());
first->setMaybePair(second->addWeakRef());
second->setMaybePair(first->addWeakRef());
return kj::mv(pair);
}

Expand Down Expand Up @@ -1015,8 +1018,8 @@ void WebSocket::assertNoError(jsg::Lock& js) {
}
}

void WebSocket::setMaybePair(jsg::Ref<WebSocket> other) {
maybePair = other.addRef();
void WebSocket::setMaybePair(kj::Own<WeakRef<WebSocket>> other) {
maybePair = kj::mv(other);
}

kj::Own<kj::WebSocket> WebSocket::acceptAsHibernatable(kj::Array<kj::StringPtr> tags) {
Expand Down Expand Up @@ -1068,15 +1071,19 @@ bool WebSocket::awaitingHibernatableRelease() {
}

void WebSocket::setRemoteOnPair() {
JSG_REQUIRE_NONNULL(maybePair, Error,
"this WebSocket is not one end of a WebSocketPair")->locality = REMOTE;
auto& ref = JSG_REQUIRE_NONNULL(maybePair, Error,
"this WebSocket is not one end of a WebSocketPair");
ref->runIfAlive([](WebSocket& ref) { ref.locality = REMOTE; });
}

bool WebSocket::pairIsAwaitingCoupling() {
bool answer = false;
KJ_IF_SOME(pair, maybePair) {
return pair->farNative->state.is<AwaitingAcceptanceOrCoupling>();
pair->runIfAlive([&answer](WebSocket& pair) {
answer = pair.farNative->state.is<AwaitingAcceptanceOrCoupling>();
});
}
return false;
return answer;
}

WebSocket::HibernationPackage WebSocket::buildPackageForHibernation() {
Expand Down Expand Up @@ -1199,7 +1206,6 @@ void WebSocket::visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
tracker.trackField("error", error);
tracker.trackFieldWithSize("IoOwn<OutgoingMessagesMap>", sizeof(IoOwn<OutgoingMessagesMap>));
tracker.trackField("autoResponseStatus", autoResponseStatus);
tracker.trackField("maybePair", maybePair);
}

} // namespace workerd::api
22 changes: 20 additions & 2 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <kj/compat/http.h>
#include "basics.h"
#include <workerd/io/io-context.h>
#include <workerd/util/weak-refs.h>
#include <stdlib.h>

namespace workerd {
Expand Down Expand Up @@ -230,6 +231,10 @@ class WebSocket: public EventTarget {
bool closedOutgoingConnection = false;
};

~WebSocket() noexcept(false) {
weakRef->invalidate();
}

// This WebSocket constructor is only used when WebSockets wake up from hibernation.
// It will immediately set the `state` to `Accepted`, but it limits the behavior by specifying it
// as `Hibernatable` -- thereby making most api::WebSocket methods inaccessible.
Expand Down Expand Up @@ -409,7 +414,12 @@ class WebSocket: public EventTarget {

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const;

kj::Own<WeakRef<WebSocket>> addWeakRef() {
return weakRef->addRef();
}

private:
kj::Own<WeakRef<WebSocket>> weakRef;
kj::Maybe<kj::String> url;
kj::Maybe<kj::String> protocol = kj::String();
kj::Maybe<kj::String> extensions = kj::String();
Expand Down Expand Up @@ -599,9 +609,17 @@ class WebSocket: public EventTarget {
};

// So that each end of a WebSocketPair can keep track of its pair.
kj::Maybe<jsg::Ref<WebSocket>> maybePair;
// We use a weak ref to track the pair to avoid having a strong ref cycle
// between the two WebSocket instances that would cause them to leak. This
// can mean, however, that it's possible for one side of the pair to be garbage
// collected while the other still exists. This should be fairly unusual tho.
kj::Maybe<kj::Own<WeakRef<WebSocket>>> maybePair;

void visitForGc(jsg::GcVisitor& visitor) {
visitor.visit(error);
}

void setMaybePair(jsg::Ref<WebSocket> other);
void setMaybePair(kj::Own<WeakRef<WebSocket>> other);

friend jsg::Ref<WebSocketPair> WebSocketPair::constructor();

Expand Down
Loading