Skip to content

Commit

Permalink
chore: moved generate tests into shareable spec file (for go and py) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Feb 4, 2025
1 parent 220d9b6 commit 175bcb4
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 257 deletions.
2 changes: 1 addition & 1 deletion go/tests/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type test struct {
const hostPort = "http://localhost:3100"

func TestReflectionAPI(t *testing.T) {
filenames, err := filepath.Glob(filepath.FromSlash("../../tests/reflection_api_tests.yaml"))
filenames, err := filepath.Glob(filepath.FromSlash("../../tests/specs/reflection_api.yaml"))
if err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion js/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
"rimraf": "^6.0.1",
"tsup": "^8.3.5",
"tsx": "^4.19.2",
"typescript": "^4.9.0"
"typescript": "^4.9.0",
"yaml": "^2.7.0"
},
"types": "lib/index.d.ts",
"exports": {
Expand Down
282 changes: 71 additions & 211 deletions js/ai/tests/generate/action_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,239 +14,99 @@
* limitations under the License.
*/

import { stripUndefinedProps, z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { readFileSync } from 'fs';
import { beforeEach, describe, it } from 'node:test';
import { parse } from 'yaml';
import {
GenerateAction,
defineGenerateAction,
} from '../../src/generate/action.js';
import { GenerateResponseChunkData } from '../../src/model.js';
import { defineTool } from '../../src/tool.js';
import {
ProgrammableModel,
defineEchoModel,
defineProgrammableModel,
} from '../helpers.js';

describe('generate', () => {
GenerateActionOptionsSchema,
GenerateResponseChunkData,
GenerateResponseChunkSchema,
GenerateResponseSchema,
} from '../../src/model.js';
import { defineTool } from '../../src/tool.js';
import { ProgrammableModel, defineProgrammableModel } from '../helpers.js';

const SpecSuiteSchema = z
.object({
tests: z.array(
z
.object({
name: z.string(),
input: GenerateActionOptionsSchema,
streamChunks: z
.array(z.array(GenerateResponseChunkSchema))
.optional(),
modelResponses: z.array(GenerateResponseSchema),
expectResponse: GenerateResponseSchema.optional(),
stream: z.boolean().optional(),
expectChunks: z.array(GenerateResponseChunkSchema).optional(),
})
.strict()
),
})
.strict();

describe('spec', () => {
let registry: Registry;
let pm: ProgrammableModel;

beforeEach(() => {
registry = new Registry();
defineGenerateAction(registry);
defineEchoModel(registry);
pm = defineProgrammableModel(registry);
});

it('registers the action', async () => {
const action = await registry.lookupAction('/util/generate');
assert.ok(action);
});

it('generate simple response', async () => {
const action = (await registry.lookupAction(
'/util/generate'
)) as GenerateAction;

const response = await action({
model: 'echoModel',
messages: [{ role: 'user', content: [{ text: 'hi' }] }],
config: { temperature: 11 },
});

assert.deepStrictEqual(response, {
custom: {},
finishReason: 'stop',
message: {
role: 'model',
content: [
{ text: 'Echo: hi' },
{ text: '; config: {"temperature":11}' },
],
},
request: {
messages: [
{
role: 'user',
content: [{ text: 'hi' }],
},
],
output: {},
tools: [],
config: {
temperature: 11,
},
docs: undefined,
},
usage: {},
});
});

it('should call tools', async () => {
const action = (await registry.lookupAction(
'/util/generate'
)) as GenerateAction;

defineTool(
registry,
{ name: 'testTool', description: 'description' },
async () => 'tool called'
);

// first response be tools call, the subsequent just text response from agent b.
let reqCounter = 0;
pm.handleResponse = async (req, sc) => {
return {
message: {
role: 'model',
content: [
reqCounter++ === 0
? {
toolRequest: {
name: 'testTool',
input: {},
ref: 'ref123',
},
}
: {
text: req.messages
.map((m) =>
m.content
.map(
(c) =>
c.text || JSON.stringify(c.toolResponse?.output)
)
.join()
)
.join(),
},
],
},
};
};

const response = await action({
model: 'programmableModel',
messages: [{ role: 'user', content: [{ text: 'hi' }] }],
tools: ['testTool'],
config: { temperature: 11 },
});

assert.deepStrictEqual(response, {
custom: {},
finishReason: undefined,
message: {
role: 'model',
content: [{ text: 'hi,,"tool called"' }],
},
request: {
messages: [
{
role: 'user',
content: [{ text: 'hi' }],
},
{
content: [
{
toolRequest: {
input: {},
name: 'testTool',
ref: 'ref123',
},
},
],
role: 'model',
},
{
content: [
{
toolResponse: {
name: 'testTool',
output: 'tool called',
ref: 'ref123',
},
},
],
role: 'tool',
},
],
output: {},
tools: [
{
description: 'description',
inputSchema: {
$schema: 'http://json-schema.org/draft-07/schema#',
},
name: 'testTool',
outputSchema: {
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
],
config: {
temperature: 11,
},
docs: undefined,
},
usage: {},
});
});

it('streams simple response', async () => {
const action = (await registry.lookupAction(
'/util/generate'
)) as GenerateAction;

const { output, stream } = action.stream({
model: 'echoModel',
messages: [{ role: 'user', content: [{ text: 'hi' }] }],
});

const chunks = [] as GenerateResponseChunkData[];
for await (const chunk of stream) {
chunks.push(chunk);
}

assert.deepStrictEqual(chunks, [
{
index: 0,
role: 'model',
content: [{ text: '3' }],
},
{
index: 0,
role: 'model',
content: [{ text: '2' }],
},
{
index: 0,
role: 'model',
content: [{ text: '1' }],
},
]);

assert.deepStrictEqual(await output, {
custom: {},
finishReason: 'stop',
message: {
role: 'model',
content: [{ text: 'Echo: hi' }, { text: '; config: undefined' }],
},
request: {
messages: [
{
role: 'user',
content: [{ text: 'hi' }],
},
],
output: {},
tools: [],
config: undefined,
docs: undefined,
},
usage: {},
SpecSuiteSchema.parse(
parse(readFileSync('../../tests/specs/generate.yaml', 'utf-8'))
).tests.forEach((test) => {
it(test.name, async () => {
if (test.modelResponses || test.streamChunks) {
let reqCounter = 0;
pm.handleResponse = async (req, sc) => {
if (test.streamChunks && sc) {
test.streamChunks[reqCounter].forEach(sc);
}
return test.modelResponses?.[reqCounter++]!;
};
}
const action = (await registry.lookupAction(
'/util/generate'
)) as GenerateAction;

if (test.stream) {
const { output, stream } = action.stream(test.input);

const chunks = [] as GenerateResponseChunkData[];
for await (const chunk of stream) {
chunks.push(stripUndefinedProps(chunk));
}

assert.deepStrictEqual(chunks, test.expectChunks);

assert.deepStrictEqual(
stripUndefinedProps(await output),
test.expectResponse
);
} else {
const response = await action(test.input);

assert.deepStrictEqual(
stripUndefinedProps(response),
test.expectResponse
);
}
});
});
});
Loading

0 comments on commit 175bcb4

Please sign in to comment.