Skip to content

Commit f89e2f7

Browse files
Fix service handler options propagation (#405)
* Patch wrong propagation of ServiceHandlerOpts * Add test for the service handler with raw input. Also refactored the internal unit test infra to decouple the execution of the state machine from the service definition.
1 parent d7717ae commit f89e2f7

File tree

3 files changed

+180
-91
lines changed

3 files changed

+180
-91
lines changed

packages/restate-sdk/src/types/rpc.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,10 @@ export const service = <P extends string, M>(service: {
438438
throw new Error("service must be defined");
439439
}
440440
const handlers = Object.entries(service.handlers).map(([name, handler]) => {
441-
if (handler instanceof HandlerWrapper) {
442-
return [name, handler.transpose()];
443-
}
444441
if (handler instanceof Function) {
442+
if (HandlerWrapper.fromHandler(handler) !== undefined) {
443+
return [name, handler];
444+
}
445445
return [
446446
name,
447447
HandlerWrapper.from(HandlerKind.SERVICE, handler).transpose(),

packages/restate-sdk/test/service_bind.test.ts

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010
*/
1111

1212
import type { TestGreeter, TestRequest } from "./testdriver.js";
13-
import { TestDriver, TestResponse } from "./testdriver.js";
14-
import type * as restate from "../src/public_api.js";
15-
import { greetRequest, inputMessage, startMessage } from "./protoutils.js";
16-
import { describe, it } from "vitest";
13+
import { TestDriver, TestResponse, testService } from "./testdriver.js";
14+
import * as restate from "../src/public_api.js";
15+
import {
16+
END_MESSAGE,
17+
greetRequest,
18+
inputMessage,
19+
outputMessage,
20+
startMessage,
21+
} from "./protoutils.js";
22+
import { describe, expect, it } from "vitest";
1723

1824
const greeter: TestGreeter = {
1925
// eslint-disable-next-line @typescript-eslint/require-await
@@ -57,3 +63,32 @@ describe("BindService", () => {
5763
]).run();
5864
});
5965
});
66+
67+
const acceptBytes = restate.service({
68+
name: "acceptBytes",
69+
handlers: {
70+
greeter: restate.handlers.handler(
71+
{
72+
accept: "application/octet-stream",
73+
contentType: "application/json",
74+
},
75+
// eslint-disable-next-line @typescript-eslint/require-await
76+
async (_ctx: restate.Context, audio: Uint8Array) => {
77+
return { length: audio.length };
78+
}
79+
),
80+
},
81+
});
82+
83+
describe("AcceptBytes", () => {
84+
it("should accept bytes", async () => {
85+
const result = await testService(acceptBytes).run({
86+
input: [startMessage(), inputMessage(new Uint8Array([0, 1, 2, 3, 4]))],
87+
});
88+
89+
expect(result).toStrictEqual([
90+
outputMessage(Buffer.from(JSON.stringify({ length: 5 }))),
91+
END_MESSAGE,
92+
]);
93+
});
94+
});

packages/restate-sdk/test/testdriver.ts

Lines changed: 138 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ import {
1616
StartMessage,
1717
} from "../src/types/protocol.js";
1818
import type { Connection } from "../src/connection/connection.js";
19-
import { formatMessageAsJson } from "../src/utils/utils.js";
2019
import { Message } from "../src/types/types.js";
2120
import { StateMachine } from "../src/state_machine.js";
2221
import { InvocationBuilder } from "../src/invocation.js";
2322
import type { ObjectContext } from "../src/context.js";
24-
import type { VirtualObjectDefinition } from "../src/public_api.js";
23+
import type {
24+
ServiceDefinition,
25+
VirtualObjectDefinition,
26+
WorkflowDefinition,
27+
} from "../src/public_api.js";
2528
import { object } from "../src/public_api.js";
26-
import { HandlerKind } from "../src/types/rpc.js";
2729
import { NodeEndpoint } from "../src/endpoint/node_endpoint.js";
2830
import type { EndpointBuilder } from "../src/endpoint/endpoint_builder.js";
2931

@@ -47,53 +49,115 @@ export interface TestGreeter {
4749
greet(ctx: ObjectContext, message: TestRequest): Promise<TestResponse>;
4850
}
4951

50-
export class TestDriver implements Connection {
51-
private readonly result: Message[] = [];
52-
53-
private restateServer: TestRestateServer;
54-
private stateMachine: StateMachine;
55-
private completionMessages: Message[];
52+
export class TestDriver {
53+
private readonly uut: UUT<string, unknown>;
54+
private readonly input: Message[];
5655

56+
// Deprecated, please use testService below
5757
constructor(instance: TestGreeter, entries: Message[]) {
58-
this.restateServer = new TestRestateServer();
59-
60-
const svc = object({
61-
name: "greeter",
62-
handlers: {
63-
greet: async (ctx: ObjectContext, arg: TestRequest) => {
64-
return instance.greet(ctx, arg);
58+
this.uut = testService(
59+
object({
60+
name: "greeter",
61+
handlers: {
62+
greet: async (ctx: ObjectContext, arg: TestRequest) => {
63+
return instance.greet(ctx, arg);
64+
},
6565
},
66-
},
66+
})
67+
);
68+
this.input = entries;
69+
}
70+
71+
async run(): Promise<Message[]> {
72+
return await this.uut.run({
73+
input: this.input,
6774
});
75+
}
76+
}
6877

69-
this.restateServer.bind(svc);
78+
/**
79+
* This class' only purpose is to make certain methods accessible in tests.
80+
* Those methods are otherwise protected, to reduce the public interface and
81+
* make it simpler for users to understand what methods are relevant for them,
82+
* and which ones are not.
83+
*/
84+
class TestRestateServer extends NodeEndpoint {}
85+
86+
interface RunOptions {
87+
/// If not provided, will call the first service
88+
service?: string;
89+
/// If not provided, will call the first handler
90+
handler?: string;
91+
input: Message[];
92+
}
93+
94+
export class UUT<N extends string, T> {
95+
private readonly defaultService: string;
96+
private readonly defaultHandler: string;
97+
98+
// eslint-disable-next-line @typescript-eslint/no-redundant-type-constituents
99+
constructor(
100+
private readonly definition:
101+
| ServiceDefinition<N, T>
102+
| VirtualObjectDefinition<N, T>
103+
| WorkflowDefinition<N, T>
104+
) {
105+
// Infer service name and handler
106+
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access,@typescript-eslint/no-unsafe-assignment
107+
this.defaultService = definition.name;
108+
109+
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
110+
const definitionRecord: Record<string, never> =
111+
definition as unknown as Record<string, never>;
112+
if (definitionRecord && definitionRecord.service != undefined) {
113+
this.defaultHandler = Object.keys(
114+
definitionRecord.service as { [s: string]: unknown }
115+
)[0];
116+
} else if (definitionRecord && definitionRecord.object != undefined) {
117+
this.defaultHandler = Object.keys(
118+
definitionRecord.object as { [s: string]: unknown }
119+
)[0];
120+
} else if (definitionRecord && definitionRecord.workflow != undefined) {
121+
this.defaultHandler = Object.keys(
122+
definitionRecord.workflow as { [s: string]: unknown }
123+
)[0];
124+
} else {
125+
throw new TypeError(
126+
"supports only a service or a virtual object or a workflow definition"
127+
);
128+
}
129+
}
70130

71-
if (entries.length < 2) {
131+
public async run(options: RunOptions): Promise<Message[]> {
132+
const restateServer = new TestRestateServer();
133+
restateServer.bind(this.definition);
134+
135+
// Sanity check on input messages
136+
if (options.input.length < 2) {
72137
throw new Error(
73138
"Less than two runtime messages supplied for test. Need to have at least start message and input message."
74139
);
75140
}
76-
77-
if (entries[0].messageType !== START_MESSAGE_TYPE) {
141+
if (options.input[0].messageType !== START_MESSAGE_TYPE) {
78142
throw new Error("First message has to be start message.");
79143
}
80144

81145
// Get the index of where the completion messages start in the entries list
82-
const firstCompletionIndex = entries.findIndex(
146+
const firstCompletionIndex = options.input.findIndex(
83147
(value) =>
84148
value.messageType === COMPLETION_MESSAGE_TYPE ||
85149
value.messageType === ENTRY_ACK_MESSAGE_TYPE
86150
);
87151

88152
// The last message of the replay is the one right before the first completion
89153
const endOfReplay =
90-
firstCompletionIndex !== -1 ? firstCompletionIndex : entries.length;
154+
firstCompletionIndex !== -1 ? firstCompletionIndex : options.input.length;
91155

92-
const msg = entries[0];
93-
// We need to set the right number for known entries. Copy the rest
94-
const startEntry = msg.message as StartMessage;
95-
entries[0] = new Message(
96-
msg.messageType,
156+
// --- Patch StartMessage with the right number of entries
157+
const startMsg = options.input[0];
158+
const startEntry = startMsg.message as StartMessage;
159+
options.input[0] = new Message(
160+
startMsg.messageType,
97161
new StartMessage({
98162
id: startEntry.id,
99163
debugId: startEntry.debugId,
@@ -102,13 +166,18 @@ export class TestDriver implements Connection {
102166
partialState: startEntry.partialState,
103167
key: startEntry.key,
104168
}),
105-
msg.completed,
106-
msg.requiresAck
169+
startMsg.completed,
170+
startMsg.requiresAck
107171
);
108172

109-
const replayMessages = entries.slice(0, endOfReplay);
110-
this.completionMessages = entries.slice(endOfReplay);
111-
173+
// TODO the production code here is doing some bad assumption,
174+
// by assuming that during the initial replay phase no CompletionMessages are sent.
175+
// Although this is currently correct, it is correct only due to how the runtime is implemented,
176+
// and might not be generally true if we change the runtime.
177+
// This should probably be fixed in the production code, and subsequently the test should
178+
// stop splitting the input messages here.
179+
const replayMessages = options.input.slice(0, endOfReplay);
180+
const completionMessages = options.input.slice(endOfReplay);
112181
if (
113182
replayMessages.filter(
114183
(value) =>
@@ -120,9 +189,8 @@ export class TestDriver implements Connection {
120189
"You cannot interleave replay messages with completion or ack messages. First define the replay messages, then the completion messages."
121190
);
122191
}
123-
124192
if (
125-
this.completionMessages.filter(
193+
completionMessages.filter(
126194
(value) =>
127195
value.messageType !== COMPLETION_MESSAGE_TYPE &&
128196
value.messageType !== ENTRY_ACK_MESSAGE_TYPE
@@ -133,90 +201,76 @@ export class TestDriver implements Connection {
133201
);
134202
}
135203

136-
const method = this.restateServer
137-
.componentByName("greeter")
204+
const method = restateServer
205+
.componentByName(options.service ? options.service : this.defaultService)
138206
?.handlerMatching({
139-
componentName: "greeter",
140-
handlerName: "greet",
207+
componentName: options.service ? options.service : this.defaultService,
208+
handlerName: options.handler ? options.handler : this.defaultHandler,
141209
});
142-
143210
if (!method) {
144-
throw new Error("Something is wrong with the test setup");
211+
throw new Error("Can't find the handler to execute");
145212
}
146213

147214
const invocationBuilder = new InvocationBuilder(method);
148215
replayMessages.forEach((el) => invocationBuilder.handleMessage(el));
149216
const invocation = invocationBuilder.build();
150217

151-
this.stateMachine = new StateMachine(
152-
this,
218+
const testConnection = new TestConnection();
219+
const stateMachine = new StateMachine(
220+
testConnection,
153221
invocation,
154-
HandlerKind.EXCLUSIVE,
222+
method.kind(),
155223
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
156-
(
157-
this.restateServer as unknown as { builder: EndpointBuilder }
158-
).builder.logger,
224+
(restateServer as unknown as { builder: EndpointBuilder }).builder.logger,
159225
invocation.inferLoggerContext()
160226
);
161-
}
162227

163-
headers(): ReadonlyMap<string, string | string[] | undefined> {
164-
return new Map();
165-
}
166-
167-
async run(): Promise<Message[]> {
168-
const completed = this.stateMachine.invoke();
228+
const completed = stateMachine.invoke();
169229

170230
// we send the completions here. Because we don't await the messages that we send the completions for,
171231
// we enqueue those completions in the event loop, so they get processed when everything else is done.
172232
// This is highly fragile!!!
173-
this.completionMessages.forEach((el) => {
174-
setTimeout(() => this.stateMachine.handleMessage(el));
233+
completionMessages.forEach((el) => {
234+
setTimeout(() => stateMachine.handleMessage(el));
175235
});
176236
// Set the input channel to closed a bit after sending the messages
177237
// to make the service finish up the work it can do and suspend or send back a response.
178-
setTimeout(() => this.stateMachine.handleInputClosed());
238+
setTimeout(() => stateMachine.handleInputClosed());
179239

180240
await completed;
181241

182-
return Promise.resolve(this.result);
242+
return Promise.resolve(testConnection.sentMessages());
243+
}
244+
}
245+
246+
// eslint-disable-next-line @typescript-eslint/no-redundant-type-constituents
247+
export function testService<N extends string, T>(
248+
definition:
249+
| ServiceDefinition<N, T>
250+
| VirtualObjectDefinition<N, T>
251+
| WorkflowDefinition<N, T>
252+
): UUT<N, T> {
253+
return new UUT<N, T>(definition);
254+
}
255+
256+
class TestConnection implements Connection {
257+
private result: Message[] = [];
258+
259+
headers(): ReadonlyMap<string, string | string[] | undefined> {
260+
return new Map();
183261
}
184262

185263
send(msg: Message): Promise<void> {
186264
this.result.push(msg);
187-
(
188-
this.restateServer as unknown as { builder: EndpointBuilder }
189-
).builder.rlog.debug(
190-
`Adding result to the result array. Message type: ${
191-
msg.messageType
192-
}, message:
193-
${
194-
msg.message instanceof Uint8Array
195-
? (msg.message as Uint8Array).toString()
196-
: formatMessageAsJson(msg.message)
197-
}`
198-
);
199265
return Promise.resolve();
200266
}
201267

202-
onClose() {
203-
// nothing to do
204-
}
205-
206268
async end(): Promise<void> {
207269
// nothing to do
208270
return Promise.resolve();
209271
}
210272

211-
onError() {
212-
// nothing to do
273+
sentMessages(): Message[] {
274+
return this.result;
213275
}
214276
}
215-
216-
/**
217-
* This class' only purpose is to make certain methods accessible in tests.
218-
* Those methods are otherwise protected, to reduce the public interface and
219-
* make it simpler for users to understand what methods are relevant for them,
220-
* and which ones are not.
221-
*/
222-
class TestRestateServer extends NodeEndpoint {}

0 commit comments

Comments
 (0)