Skip to content

Commit f35d5a8

Browse files
authored
Merge pull request #1532 from cloudflare/dominik/flush-on-socket-close-2
Await connection and flush before closing socket.
2 parents 20b9eb5 + 9b89c8a commit f35d5a8

8 files changed

Lines changed: 134 additions & 6 deletions

File tree

src/workerd/api/sockets.c++

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "system-streams.h"
77
#include <workerd/io/worker-interface.h>
88
#include "url-standard.h"
9+
#include <workerd/util/autogate.h>
910

1011

1112
namespace workerd::api {
@@ -262,7 +263,11 @@ jsg::Ref<Socket> connectImpl(
262263
return connectImplNoOutputLock(js, kj::mv(fetcher), kj::mv(address), kj::mv(options));
263264
}
264265

265-
jsg::Promise<void> Socket::close(jsg::Lock& js) {
266+
// Closes the underlying socket connection. This is an old implementation and will be removed soon.
267+
// See closeImplNew below for the new implementation.
268+
//
269+
// TODO(later): remove once safe
270+
jsg::Promise<void> Socket::closeImplOld(jsg::Lock& js) {
266271
// Forcibly close the readable/writable streams.
267272
auto cancelPromise = readable->getController().cancel(js, kj::none);
268273
auto abortPromise = writable->getController().abort(js, kj::none);
@@ -271,8 +276,58 @@ jsg::Promise<void> Socket::close(jsg::Lock& js) {
271276
return abortPromise.then(js, [this](jsg::Lock& js) {
272277
resolveFulfiller(js, kj::none);
273278
return js.resolvedPromise();
274-
}, [this](jsg::Lock& js, jsg::Value err) { return errorHandler(js, kj::mv(err)); });
275-
}, [this](jsg::Lock& js, jsg::Value err) { return errorHandler(js, kj::mv(err)); });
279+
}, [this](jsg::Lock& js, jsg::Value err) {
280+
errorHandler(js, kj::mv(err));
281+
return js.resolvedPromise();
282+
});
283+
}, [this](jsg::Lock& js, jsg::Value err) {
284+
errorHandler(js, kj::mv(err));
285+
return js.resolvedPromise();
286+
});
287+
}
288+
289+
// Closes the underlying socket connection, but only after the socket connection is properly
290+
// established through any configured proxy. This method also flushes the writable stream prior to
291+
// closing.
292+
jsg::Promise<void> Socket::closeImplNew(jsg::Lock& js) {
293+
if (isClosing) {
294+
return closedPromiseCopy.whenResolved(js);
295+
}
296+
297+
isClosing = true;
298+
writable->getController().setPendingClosure();
299+
readable->getController().setPendingClosure();
300+
301+
// Wait until the socket connects (successfully or otherwise)
302+
return openedPromiseCopy.whenResolved(js).then(js, [this](jsg::Lock& js) {
303+
if (!writable->getController().isClosedOrClosing()) {
304+
return writable->getController().flush(js);
305+
} else {
306+
return js.resolvedPromise();
307+
}
308+
}).then(js, [this](jsg::Lock& js) {
309+
// Forcibly abort the readable/writable streams.
310+
auto cancelPromise = readable->getController().cancel(js, kj::none);
311+
auto abortPromise = writable->getController().abort(js, kj::none);
312+
// The below is effectively `Promise.all(cancelPromise, abortPromise)`
313+
return cancelPromise.then(js,
314+
[abortPromise = kj::mv(abortPromise)](jsg::Lock& js) mutable {
315+
return kj::mv(abortPromise);
316+
});
317+
}).then(js, [this](jsg::Lock& js) {
318+
resolveFulfiller(js, kj::none);
319+
return js.resolvedPromise();
320+
}).catch_(js, [this](jsg::Lock& js, jsg::Value err) {
321+
errorHandler(js, kj::mv(err));
322+
});
323+
}
324+
325+
jsg::Promise<void> Socket::close(jsg::Lock& js) {
326+
if (util::Autogate::isEnabled(util::AutogateKey::SOCKETS_AWAIT_PROXY_BEFORE_CLOSE)) {
327+
return closeImplNew(js);
328+
} else {
329+
return closeImplOld(js);
330+
}
276331
}
277332

278333
jsg::Ref<Socket> Socket::startTls(jsg::Lock& js, jsg::Optional<TlsOptions> tlsOptions) {

src/workerd/api/sockets.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class Socket: public jsg::Object {
5858
: connectionStream(context.addObject(kj::mv(connectionStream))),
5959
readable(kj::mv(readableParam)), writable(kj::mv(writable)),
6060
closedResolver(kj::mv(closedPrPair.resolver)),
61+
closedPromiseCopy(closedPrPair.promise.whenResolved(js)),
6162
closedPromise(kj::mv(closedPrPair.promise)),
6263
watchForDisconnectTask(context.addObject(kj::heap(kj::mv(watchForDisconnectTask)))),
6364
options(kj::mv(options)),
@@ -66,7 +67,9 @@ class Socket: public jsg::Object {
6667
domain(kj::mv(domain)),
6768
isDefaultFetchPort(isDefaultFetchPort),
6869
openedResolver(kj::mv(openedPrPair.resolver)),
69-
openedPromise(kj::mv(openedPrPair.promise)) { };
70+
openedPromiseCopy(openedPrPair.promise.whenResolved(js)),
71+
openedPromise(kj::mv(openedPrPair.promise)),
72+
isClosing(false) { };
7073

7174
jsg::Ref<ReadableStream> getReadable() { return readable.addRef(); }
7275
jsg::Ref<WritableStream> getWritable() { return writable.addRef(); }
@@ -119,6 +122,9 @@ class Socket: public jsg::Object {
119122
jsg::Ref<WritableStream> writable;
120123
// This fulfiller is used to resolve the `closedPromise` below.
121124
jsg::Promise<void>::Resolver closedResolver;
125+
// Copy kept so that it can be returned from `close`.
126+
jsg::Promise<void> closedPromiseCopy;
127+
// Memoized copy that is returned by the `closed` attribute.
122128
jsg::MemoizedIdentity<jsg::Promise<void>> closedPromise;
123129
IoOwn<kj::Promise<void>> watchForDisconnectTask;
124130
jsg::Optional<SocketOptions> options;
@@ -133,10 +139,16 @@ class Socket: public jsg::Object {
133139
bool isDefaultFetchPort;
134140
// This fulfiller is used to resolve the `openedPromise` below.
135141
jsg::Promise<SocketInfo>::Resolver openedResolver;
142+
// Copy kept so that it can be used in `close`.
143+
jsg::Promise<void> openedPromiseCopy;
136144
jsg::MemoizedIdentity<jsg::Promise<SocketInfo>> openedPromise;
145+
// Used to keep track of a pending `close` operation on the socket.
146+
bool isClosing;
137147

138148
kj::Promise<kj::Own<kj::AsyncIoStream>> processConnection();
139149
jsg::Promise<void> maybeCloseWriteSide(jsg::Lock& js);
150+
jsg::Promise<void> closeImplOld(jsg::Lock& js);
151+
jsg::Promise<void> closeImplNew(jsg::Lock& js);
140152

141153
// Helper method for handleProxyStatus implementations.
142154
void handleProxyError(jsg::Lock& js, kj::Exception e);
@@ -149,10 +161,9 @@ class Socket: public jsg::Object {
149161
}
150162
};
151163

152-
jsg::Promise<void> errorHandler(jsg::Lock& js, jsg::Value err) {
164+
void errorHandler(jsg::Lock& js, jsg::Value err) {
153165
auto jsException = err.getHandle(js);
154166
resolveFulfiller(js, jsg::createTunneledException(js.v8Isolate, jsException));
155-
return js.resolvedPromise();
156167
};
157168

158169
void visitForGc(jsg::GcVisitor& visitor) {

src/workerd/api/streams/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,10 @@ class ReadableStreamController {
510510
jsg::Lock& js, kj::Own<WritableStreamSink> sink, bool end) = 0;
511511

512512
virtual kj::Own<ReadableStreamController> detach(jsg::Lock& js, bool ignoreDisturbed) = 0;
513+
514+
// Used by sockets to signal that the ReadableStream shouldn't allow reads due to pending
515+
// closure.
516+
virtual void setPendingClosure() = 0;
513517
};
514518

515519
kj::Own<ReadableStreamController> newReadableStreamJsController();
@@ -679,6 +683,10 @@ class WritableStreamController {
679683

680684
// True is this controller requires ArrayBuffer(Views) to be written to it.
681685
virtual bool isByteOriented() const = 0;
686+
687+
// Used by sockets to signal that the WritableStream shouldn't allow writes due to pending
688+
// closure.
689+
virtual void setPendingClosure() = 0;
682690
};
683691

684692
kj::Own<WritableStreamController> newWritableStreamJsController();

src/workerd/api/streams/internal.c++

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ jsg::Ref<ReadableStream> ReadableStreamInternalController::addRef() {
485485
kj::Maybe<jsg::Promise<ReadResult>> ReadableStreamInternalController::read(
486486
jsg::Lock& js,
487487
kj::Maybe<ByobOptions> maybeByobOptions) {
488+
489+
if (isPendingClosure) {
490+
return js.rejectedPromise<ReadResult>(
491+
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
492+
}
493+
488494
std::shared_ptr<v8::BackingStore> store;
489495
size_t byteLength = 0;
490496
size_t byteOffset = 0;
@@ -596,6 +602,11 @@ jsg::Promise<void> ReadableStreamInternalController::pipeTo(
596602
KJ_DASSERT(!isLockedToReader());
597603
KJ_DASSERT(!destination.isLockedToWriter());
598604

605+
if (isPendingClosure) {
606+
return js.rejectedPromise<void>(
607+
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
608+
}
609+
599610
disturbed = true;
600611
KJ_IF_SOME(promise, destination.tryPipeFrom(js,
601612
KJ_ASSERT_NONNULL(owner).addRef(),
@@ -655,6 +666,8 @@ void ReadableStreamInternalController::doError(jsg::Lock& js, v8::Local<v8::Valu
655666
ReadableStreamController::Tee ReadableStreamInternalController::tee(jsg::Lock& js) {
656667
JSG_REQUIRE(!isLockedToReader(), TypeError,
657668
"This ReadableStream is currently locked to a reader.");
669+
JSG_REQUIRE(!isPendingClosure, TypeError,
670+
"This ReadableStream belongs to an object that is closing.");
658671
readState.init<Locked>();
659672
disturbed = true;
660673
KJ_SWITCH_ONEOF(state) {
@@ -815,6 +828,10 @@ jsg::Ref<WritableStream> WritableStreamInternalController::addRef() {
815828
jsg::Promise<void> WritableStreamInternalController::write(
816829
jsg::Lock& js,
817830
jsg::Optional<v8::Local<v8::Value>> value) {
831+
if (isPendingClosure) {
832+
return js.rejectedPromise<void>(
833+
js.v8TypeError("This WritableStream belongs to an object that is closing."_kj));
834+
}
818835
if (isClosedOrClosing()) {
819836
return js.rejectedPromise<void>(
820837
js.v8TypeError("This WritableStream has been closed."_kj));
@@ -1916,6 +1933,10 @@ jsg::Promise<kj::Array<byte>> ReadableStreamInternalController::readAllBytes(
19161933
return js.rejectedPromise<kj::Array<byte>>(KJ_EXCEPTION(FAILED,
19171934
"jsg.TypeError: This ReadableStream is currently locked to a reader."));
19181935
}
1936+
if (isPendingClosure) {
1937+
return js.rejectedPromise<kj::Array<byte>>(
1938+
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
1939+
}
19191940
KJ_SWITCH_ONEOF(state) {
19201941
KJ_CASE_ONEOF(closed, StreamStates::Closed) {
19211942
return js.resolvedPromise(kj::Array<byte>());
@@ -1939,6 +1960,10 @@ jsg::Promise<kj::String> ReadableStreamInternalController::readAllText(
19391960
return js.rejectedPromise<kj::String>(KJ_EXCEPTION(FAILED,
19401961
"jsg.TypeError: This ReadableStream is currently locked to a reader."));
19411962
}
1963+
if (isPendingClosure) {
1964+
return js.rejectedPromise<kj::String>(
1965+
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
1966+
}
19421967
KJ_SWITCH_ONEOF(state) {
19431968
KJ_CASE_ONEOF(closed, StreamStates::Closed) {
19441969
return js.resolvedPromise(kj::String());

src/workerd/api/streams/internal.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class ReadableStreamInternalController: public ReadableStreamController {
101101

102102
kj::Own<ReadableStreamController> detach(jsg::Lock& js, bool ignoreDisturbed) override;
103103

104+
void setPendingClosure() override {
105+
isPendingClosure = true;
106+
}
107+
104108
private:
105109
void doCancel(jsg::Lock& js, jsg::Optional<v8::Local<v8::Value>> reason);
106110
void doClose(jsg::Lock& js);
@@ -143,6 +147,10 @@ class ReadableStreamInternalController: public ReadableStreamController {
143147
bool disturbed = false;
144148
bool readPending = false;
145149

150+
// Used by Sockets code to signal to the ReadableStream that it should error when read from
151+
// because the socket is currently being closed.
152+
bool isPendingClosure = false;
153+
146154
friend class ReadableStream;
147155
friend class WritableStreamInternalController;
148156
friend class PipeLocked;
@@ -212,6 +220,10 @@ class WritableStreamInternalController: public WritableStreamController {
212220
bool isErrored() override;
213221

214222
inline bool isByteOriented() const override { return true; }
223+
224+
void setPendingClosure() override {
225+
isPendingClosure = true;
226+
}
215227
private:
216228

217229
struct AbortOptions {
@@ -259,6 +271,10 @@ class WritableStreamInternalController: public WritableStreamController {
259271
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable;
260272
bool waitingOnClosureWritableAlready = false;
261273

274+
// Used by Sockets code to signal to the WritableStream that it should error when written to
275+
// because the socket is currently being closed.
276+
bool isPendingClosure = false;
277+
262278
void increaseCurrentWriteBufferSize(jsg::Lock& js, uint64_t amount);
263279
void decreaseCurrentWriteBufferSize(jsg::Lock& js, uint64_t amount);
264280
void updateBackpressure(jsg::Lock& js, bool backpressure);

src/workerd/api/streams/standard.c++

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// https://opensource.org/licenses/Apache-2.0
44

55
#include "standard.h"
6+
#include <kj/debug.h>
67
#include "readable.h"
78
#include "writable.h"
89
#include <workerd/jsg/buffersource.h>
@@ -684,6 +685,10 @@ public:
684685

685686
kj::Own<ReadableStreamController> detach(jsg::Lock& js, bool ignoreDisturbed) override;
686687

688+
void setPendingClosure() override {
689+
KJ_UNIMPLEMENTED("only implemented for WritableStreamInternalController");
690+
}
691+
687692
private:
688693
bool hasPendingReadRequests();
689694

@@ -816,6 +821,10 @@ public:
816821

817822
inline bool isByteOriented() const override { return false; }
818823

824+
void setPendingClosure() override {
825+
KJ_UNIMPLEMENTED("only implemented for WritableStreamInternalController");
826+
}
827+
819828
private:
820829
jsg::Promise<void> pipeLoop(jsg::Lock& js);
821830

src/workerd/util/autogate.c++

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ kj::StringPtr KJ_STRINGIFY(AutogateKey key) {
1515
return "test-workerd"_kj;
1616
case AutogateKey::BUILTIN_WASM_MODULES:
1717
return "builtin-wasm-modules"_kj;
18+
case AutogateKey::SOCKETS_AWAIT_PROXY_BEFORE_CLOSE:
19+
return "sockets-await-proxy-before-close"_kj;
1820
case AutogateKey::NumOfKeys:
1921
KJ_FAIL_ASSERT("NumOfKeys should not be used in getName");
2022
}

src/workerd/util/autogate.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ enum class AutogateKey {
1616
// Allow builtin modules to be wasm modules. Used for Python project.
1717
// Gates code in jsg/modules.h
1818
BUILTIN_WASM_MODULES,
19+
// Enable new behaviour of Socket::close (specifically waiting for proxy result before closing).
20+
SOCKETS_AWAIT_PROXY_BEFORE_CLOSE,
1921
NumOfKeys // Reserved for iteration.
2022
};
2123

0 commit comments

Comments
 (0)