Skip to content

Commit 6af64b9

Browse files
committed
Schedule suspensions only when the promises are awaited (then-ed)
1 parent 7d0ddd0 commit 6af64b9

File tree

3 files changed

+86
-10
lines changed

3 files changed

+86
-10
lines changed

src/journal.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ export class Journal<I, O> {
160160
}
161161
}
162162

163+
public isUnResolved(index: number): boolean {
164+
const journalEntry = this.pendingJournalEntries.get(index);
165+
return journalEntry !== undefined;
166+
}
167+
163168
public handleRuntimeCompletionMessage(m: CompletionMessage) {
164169
// Get message at that entryIndex in pendingJournalEntries
165170
const journalEntry = this.pendingJournalEntries.get(m.entryIndex);

src/restate_context_impl.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext {
311311
const msg = AwakeableEntryMessage.create();
312312
const promise = this.stateMachine
313313
.handleUserCodeMessage<Buffer>(AWAKEABLE_ENTRY_MESSAGE_TYPE, msg)
314-
.then((result: Buffer | void) => {
314+
.transform((result: Buffer | void) => {
315315
if (!(result instanceof Buffer)) {
316316
// This should either be a filled buffer or an empty buffer but never anything else.
317317
throw RetryableError.internal(

src/state_machine.ts

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,15 @@ export class StateMachine<I, O> implements RestateStreamConsumer {
114114
completedFlag?: boolean,
115115
protocolVersion?: number,
116116
requiresAckFlag?: boolean
117-
): Promise<T | void> {
117+
): WrappedPromise<T | void> {
118118
// if the state machine is already closed, return a promise that never
119119
// completes, so that the user code does not resume
120120
if (this.stateMachineClosed) {
121-
return new CompletablePromise<T>().promise;
121+
return wrapDeeply(new CompletablePromise<T>().promise);
122122
}
123123

124124
const promise = this.journal.handleUserSideMessage(messageType, message);
125+
const journalIndex = this.journal.getUserCodeJournalIndex();
125126

126127
// Only send the message to restate if we are not in replaying mode
127128
if (this.journal.isProcessing()) {
@@ -150,13 +151,14 @@ export class StateMachine<I, O> implements RestateStreamConsumer {
150151
);
151152
}
152153

153-
if (
154-
p.SUSPENSION_TRIGGERS.includes(messageType) &&
155-
this.journal.getCompletableIndices().length > 0
156-
) {
157-
this.scheduleSuspension();
158-
}
159-
return promise;
154+
return wrapDeeply(promise, () => {
155+
if (!p.SUSPENSION_TRIGGERS.includes(messageType)) {
156+
return;
157+
}
158+
if (this.journal.isUnResolved(journalIndex)) {
159+
this.scheduleSuspension();
160+
}
161+
});
160162
}
161163

162164
/**
@@ -490,3 +492,72 @@ export class StateMachine<I, O> implements RestateStreamConsumer {
490492
}
491493
}
492494
}
495+
/**
496+
* Returns a promise that wraps the original promise and calls cb() at the first time
497+
* this promise or any nested promise that is chained to it is awaited. (then-ed)
498+
*/
499+
500+
/* eslint-disable @typescript-eslint/no-explicit-any */
501+
export type WrappedPromise<T> = Promise<T> & {
502+
transform: <TResult1 = T, TResult2 = never>(
503+
onfulfilled?:
504+
| ((value: T) => TResult1 | PromiseLike<TResult1>)
505+
| null
506+
| undefined,
507+
onrejected?:
508+
| ((reason: any) => TResult2 | PromiseLike<TResult2>)
509+
| null
510+
| undefined
511+
) => Promise<TResult1 | TResult2>;
512+
};
513+
514+
const wrapDeeply = <T>(
515+
promise: Promise<T>,
516+
cb?: () => void
517+
): WrappedPromise<T> => {
518+
/* eslint-disable @typescript-eslint/no-explicit-any */
519+
return {
520+
transform: function <TResult1 = T, TResult2 = never>(
521+
onfulfilled?:
522+
| ((value: T) => TResult1 | PromiseLike<TResult1>)
523+
| null
524+
| undefined,
525+
onrejected?:
526+
| ((reason: any) => TResult2 | PromiseLike<TResult2>)
527+
| null
528+
| undefined
529+
): Promise<TResult1 | TResult2> {
530+
return wrapDeeply(promise.then(onfulfilled, onrejected), cb);
531+
},
532+
533+
then: function <TResult1 = T, TResult2 = never>(
534+
onfulfilled?:
535+
| ((value: T) => TResult1 | PromiseLike<TResult1>)
536+
| null
537+
| undefined,
538+
onrejected?:
539+
| ((reason: any) => TResult2 | PromiseLike<TResult2>)
540+
| null
541+
| undefined
542+
): Promise<TResult1 | TResult2> {
543+
if (cb !== undefined) {
544+
cb();
545+
}
546+
return promise.then(onfulfilled, onrejected);
547+
},
548+
catch: function <TResult = never>(
549+
onrejected?:
550+
| ((reason: any) => TResult | PromiseLike<TResult>)
551+
| null
552+
| undefined
553+
): Promise<T | TResult> {
554+
return wrapDeeply(promise.catch(onrejected), cb);
555+
},
556+
finally: function (
557+
onfinally?: (() => void) | null | undefined
558+
): Promise<T> {
559+
return wrapDeeply(promise.finally(onfinally), cb);
560+
},
561+
[Symbol.toStringTag]: "",
562+
};
563+
};

0 commit comments

Comments
 (0)