diff --git a/lib/lib-storage/src/Upload.ts b/lib/lib-storage/src/Upload.ts index 0ff6b1468ab3..fe8f11c773d2 100644 --- a/lib/lib-storage/src/Upload.ts +++ b/lib/lib-storage/src/Upload.ts @@ -16,7 +16,6 @@ import { UploadPartCommand, UploadPartCommandInput, } from "@aws-sdk/client-s3"; -import { AbortController } from "@smithy/abort-controller"; import { EndpointParameterInstructionsSupplier, getEndpointFromInstructions, @@ -24,12 +23,13 @@ import { } from "@smithy/middleware-endpoint"; import { HttpRequest } from "@smithy/protocol-http"; import { extendedEncodeURIComponent } from "@smithy/smithy-client"; -import type { AbortController as IAbortController, AbortSignal as IAbortSignal, Endpoint } from "@smithy/types"; +import type { Endpoint } from "@smithy/types"; import { EventEmitter } from "events"; import { byteLength } from "./byteLength"; import { BYTE_LENGTH_SOURCE, byteLengthSource } from "./byteLengthSource"; import { getChunk } from "./chunker"; +import { wireSignal } from "./signal"; import { BodyDataTypes, Options, Progress } from "./types"; export interface RawDataPart { @@ -65,8 +65,7 @@ export class Upload extends EventEmitter { private bytesUploadedSoFar: number; // used in the upload. - private abortController: IAbortController; - private concurrentUploaders: Promise[] = []; + private abortController = new AbortController(); private createMultiPartPromise?: Promise; private abortMultipartUploadCommand: AbortMultipartUploadCommand | null = null; @@ -98,12 +97,13 @@ export class Upload extends EventEmitter { throw new Error(`InputError: Upload requires params to be passed to upload.`); } + wireSignal(this.abortController, options.abortSignal); + wireSignal(this.abortController, options.abortController?.signal); + // set progress defaults this.totalBytes = this.params.ContentLength ?? byteLength(this.params.Body); this.totalBytesSource = byteLengthSource(this.params.Body, this.params.ContentLength); this.bytesUploadedSoFar = 0; - this.abortController = options.abortController ?? new AbortController(); - this.partSize = options.partSize || Math.max(Upload.MIN_PART_SIZE, Math.floor((this.totalBytes || 0) / this.MAX_PARTS)); @@ -129,7 +129,12 @@ export class Upload extends EventEmitter { ); } this.sent = true; - return await Promise.race([this.__doMultipartUpload(), this.__abortTimeout(this.abortController.signal)]); + + try { + return await this.__doMultipartUpload(); + } finally { + this.abortController.abort(); + } } public on(event: "httpUploadProgress", listener: (progress: Progress) => void): this { @@ -161,7 +166,12 @@ export class Upload extends EventEmitter { eventEmitter.on("xhr.upload.progress", uploadEventListener); } - const resolved = await Promise.all([this.client.send(new PutObjectCommand(params)), clientConfig?.endpoint?.()]); + const resolved = await Promise.all([ + this.client.send(new PutObjectCommand(params), { + abortSignal: this.abortController.signal, + }), + clientConfig?.endpoint?.(), + ]); const putResult = resolved[0]; let endpoint: Endpoint | undefined = resolved[1]; @@ -311,7 +321,10 @@ export class Upload extends EventEmitter { UploadId: this.uploadId, Body: dataPart.data, PartNumber: dataPart.partNumber, - }) + }), + { + abortSignal: this.abortController.signal, + } ); if (eventEmitter !== null) { @@ -353,28 +366,27 @@ export class Upload extends EventEmitter { private async __doMultipartUpload(): Promise { const dataFeeder = getChunk(this.params.Body, this.partSize); - const concurrentUploaderFailures: Error[] = []; + const concurrentUploads: Promise[] = []; for (let index = 0; index < this.queueSize; index++) { - const currentUpload = this.__doConcurrentUpload(dataFeeder).catch((err) => { - concurrentUploaderFailures.push(err); - }); - this.concurrentUploaders.push(currentUpload); + const currentUpload = this.__doConcurrentUpload(dataFeeder); + concurrentUploads.push(currentUpload); } - await Promise.all(this.concurrentUploaders); - if (concurrentUploaderFailures.length >= 1) { + /** + * Previously, each promise in concurrentUploads could potentially throw + * and immediately return control to user code. However, we want to wait for + * all uploaders to finish before calling AbortMultipartUpload to avoid + * stranding uploaded parts. + * + * We throw only the first error to be consistent with prior behavior, + * but may consider combining the errors into a report in the future. + */ + const results = await Promise.allSettled(concurrentUploads); + const firstFailure = results.find((result) => result.status === "rejected"); + if (firstFailure) { await this.markUploadAsAborted(); - /** - * Previously, each promise in concurrentUploaders could potentially throw - * and immediately return control to user code. However, we want to wait for - * all uploaders to finish before calling AbortMultipartUpload to avoid - * stranding uploaded parts. - * - * We throw only the first error to be consistent with prior behavior, - * but may consider combining the errors into a report in the future. - */ - throw concurrentUploaderFailures[0]; + throw firstFailure.reason; } if (this.abortController.signal.aborted) { @@ -447,16 +459,6 @@ to input.params.ContentLength in bytes. } } - private async __abortTimeout(abortSignal: IAbortSignal): Promise { - return new Promise((resolve, reject) => { - abortSignal.onabort = () => { - const abortError = new Error("Upload aborted."); - abortError.name = "AbortError"; - reject(abortError); - }; - }); - } - private __validateUploadPart(dataPart: RawDataPart): void { const actualPartSize = byteLength(dataPart.data); diff --git a/lib/lib-storage/src/signal.ts b/lib/lib-storage/src/signal.ts new file mode 100644 index 000000000000..8e241e5a47dc --- /dev/null +++ b/lib/lib-storage/src/signal.ts @@ -0,0 +1,51 @@ +import { IAbortSignal } from "./types"; + +/** + * This function wires an external abort signal to an internal abort controller. + * The internal abort controller will be aborted when the external signal is + * aborted. + * + * Every callback created will be removed as soon as either the internal or + * external signal is aborted. This allows to avoid memory leaks, especially if + * the external signal has a (significantly) longer lifespan than the internal + * one. + * + * In order to ensure that any references are removed, make sure to always + * `abort()` the internal controller when you are done with it. + */ +export function wireSignal(internalController: AbortController, externalSignal?: IAbortSignal): void { + if (!externalSignal || internalController.signal.aborted) { + return; + } + if (externalSignal.aborted) { + internalController.abort(); + return; + } + + if (isNativeSignal(externalSignal)) { + externalSignal.addEventListener("abort", () => internalController.abort(), { + once: true, + signal: internalController.signal, + }); + } else { + // backwards compatibility + const origOnabort = externalSignal.onabort; + const restore = () => { + externalSignal.onabort = origOnabort; + }; + + externalSignal.onabort = function () { + internalController.abort(); + restore(); + origOnabort?.call(this); + }; + + // Let's clear any reference to the internal controller when it is aborted, + // avoiding potential memory leaks. + internalController.signal.addEventListener("abort", restore, { once: true }); + } +} + +export function isNativeSignal(signal: IAbortSignal): signal is globalThis.AbortSignal { + return "addEventListener" in signal && typeof signal.addEventListener === "function"; +} diff --git a/lib/lib-storage/src/types.ts b/lib/lib-storage/src/types.ts index 61c87af0dc9c..a2856a38600b 100644 --- a/lib/lib-storage/src/types.ts +++ b/lib/lib-storage/src/types.ts @@ -6,7 +6,7 @@ import { Tag, UploadPartCommandInput, } from "@aws-sdk/client-s3"; -import type { AbortController } from "@smithy/types"; +import type { AbortController, AbortSignal } from "@smithy/types"; export interface Progress { loaded?: number; @@ -19,6 +19,9 @@ export interface Progress { // string | Uint8Array | Buffer | Readable | ReadableStream | Blob. export type BodyDataTypes = PutObjectCommandInput["Body"]; +export type IAbortController = AbortController | globalThis.AbortController; +export type IAbortSignal = AbortSignal | globalThis.AbortSignal; + /** * @deprecated redundant, use {@link S3Client} directly. */ @@ -51,8 +54,15 @@ export interface Configuration { /** * Optional abort controller for controlling this upload's abort signal externally. + * + * @deprecated use `abortSignal` instead. + */ + abortController?: IAbortController; + + /** + * Optional abort signal for controlling this upload's abort signal externally. */ - abortController?: AbortController; + abortSignal?: globalThis.AbortSignal; } export interface Options extends Partial {