Skip to content

Commit 39a17bb

Browse files
pavelgjinlined
authored andcommitted
feat(js/ai): allow explicit constrained gen option, allow simulation with tools (#1872)
1 parent c3b91ab commit 39a17bb

12 files changed

Lines changed: 195 additions & 41 deletions

File tree

genkit-tools/common/src/types/model.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ export const ModelInfoSchema = z.object({
121121
/** Model can natively support document-based context grounding. */
122122
context: z.boolean().optional(),
123123
/** Model can natively support constrained generation. */
124-
constrained: z.boolean().optional(),
124+
constrained: z.enum(['none', 'all', 'no-tools']).optional(),
125125
})
126126
.optional(),
127127
});
@@ -295,6 +295,7 @@ export const GenerateActionOptionsSchema = z.object({
295295
contentType: z.string().optional(),
296296
instructions: z.union([z.boolean(), z.string()]).optional(),
297297
jsonSchema: z.any().optional(),
298+
constrained: z.boolean().optional(),
298299
})
299300
.optional(),
300301
/** When true, return tool calls for manual processing instead of automatically resolving them. */

genkit-tools/genkit-schema.json

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,10 @@
452452
"string"
453453
]
454454
},
455-
"jsonSchema": {}
455+
"jsonSchema": {},
456+
"constrained": {
457+
"type": "boolean"
458+
}
456459
},
457460
"additionalProperties": false
458461
},
@@ -757,7 +760,12 @@
757760
"type": "boolean"
758761
},
759762
"constrained": {
760-
"type": "boolean"
763+
"type": "string",
764+
"enum": [
765+
"none",
766+
"all",
767+
"no-tools"
768+
]
761769
}
762770
},
763771
"additionalProperties": false

js/ai/src/generate.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ export interface OutputOptions<O extends z.ZodTypeAny = z.ZodTypeAny> {
6262
instructions?: boolean | string;
6363
schema?: O;
6464
jsonSchema?: any;
65+
constrained?: boolean;
6566
}
6667

6768
/** ResumeOptions configure how to resume generation after an interrupt. */
@@ -217,9 +218,10 @@ export async function toGenerateRequest(
217218
output: {
218219
...(resolvedFormat?.config || {}),
219220
schema: resolvedSchema,
221+
...options.output,
220222
},
221-
};
222-
if (!out.output.schema) delete out.output.schema;
223+
} as GenerateRequest;
224+
if (!out?.output?.schema) delete out?.output?.schema;
223225
return out;
224226
}
225227

@@ -359,6 +361,7 @@ export async function generate<
359361
...stripUndefinedOptions(resolvedOptions.config),
360362
},
361363
output: resolvedOptions.output && {
364+
...resolvedOptions.output,
362365
format: resolvedOptions.output.format,
363366
jsonSchema: resolvedSchema,
364367
},

js/ai/src/generate/action.ts

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import {
2525
} from '@genkit-ai/core';
2626
import { logger } from '@genkit-ai/core/logging';
2727
import { Registry } from '@genkit-ai/core/registry';
28-
import { toJsonSchema } from '@genkit-ai/core/schema';
2928
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
3029
import {
3130
injectInstructions,
@@ -400,12 +399,12 @@ async function actionToGenerateRequest(
400399
config: options.config,
401400
docs: options.docs,
402401
tools: resolvedTools?.map(toToolDefinition) || [],
403-
output: {
404-
...(resolvedFormat?.config || {}),
405-
schema: toJsonSchema({
406-
jsonSchema: options.output?.jsonSchema,
407-
}),
408-
},
402+
output: stripUndefinedProps({
403+
constrained: options.output?.constrained,
404+
contentType: options.output?.contentType,
405+
format: options.output?.format,
406+
schema: options.output?.jsonSchema,
407+
}),
409408
};
410409
if (options.toolChoice) {
411410
out.toolChoice = options.toolChoice;

js/ai/src/model.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ export const ModelInfoSchema = z.object({
212212
/** Model can natively support document-based context grounding. */
213213
context: z.boolean().optional(),
214214
/** Model can natively support constrained generation. */
215-
constrained: z.boolean().optional(),
215+
constrained: z.enum(['none', 'all', 'no-tools']).optional(),
216216
/** Model supports controlling tool choice, e.g. forced tool calling. */
217217
toolChoice: z.boolean().optional(),
218218
})
@@ -487,8 +487,18 @@ export function defineModel<
487487
validateSupport(options),
488488
];
489489
if (!options?.supports?.context) middleware.push(augmentWithContext());
490-
if (!options?.supports?.constrained)
491-
middleware.push(simulateConstrainedGeneration());
490+
const constratedSimulator = simulateConstrainedGeneration();
491+
middleware.push((req, next) => {
492+
if (
493+
!options?.supports?.constrained ||
494+
options?.supports?.constrained === 'none' ||
495+
(options?.supports?.constrained === 'no-tools' &&
496+
(req.tools?.length ?? 0) > 0)
497+
) {
498+
return constratedSimulator(req, next);
499+
}
500+
return next(req);
501+
});
492502
const act = defineAction(
493503
registry,
494504
{
@@ -709,6 +719,7 @@ export const GenerateActionOptionsSchema = z.object({
709719
contentType: z.string().optional(),
710720
instructions: z.union([z.boolean(), z.string()]).optional(),
711721
jsonSchema: z.any().optional(),
722+
constrained: z.boolean().optional(),
712723
})
713724
.optional(),
714725
/** Options for resuming an interrupted generation. */

js/ai/src/model/middleware.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ export function simulateConstrainedGeneration(
270270
...req.output,
271271
// we're simulating it, so to the underlying model it's unconstrained.
272272
constrained: false,
273+
format: undefined,
274+
contentType: undefined,
273275
schema: undefined,
274276
},
275277
};

js/ai/tests/generate/generate_test.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,29 @@ describe('toGenerateRequest', () => {
279279
},
280280
throws: 'FAILED_PRECONDITION',
281281
},
282+
{
283+
should: 'passes through output options',
284+
prompt: {
285+
model: 'vertexai/gemini-1.0-pro',
286+
prompt: 'Tell a joke about dogs.',
287+
output: {
288+
constrained: true,
289+
format: 'banana',
290+
},
291+
},
292+
expectedOutput: {
293+
messages: [
294+
{ role: 'user', content: [{ text: 'Tell a joke about dogs.' }] },
295+
],
296+
config: undefined,
297+
docs: undefined,
298+
tools: [],
299+
output: {
300+
constrained: true,
301+
format: 'banana',
302+
},
303+
},
304+
},
282305
];
283306
for (const test of testCases) {
284307
it(test.should, async () => {

js/ai/tests/model/middleware_test.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,6 @@ describe('simulateConstrainedGeneration', () => {
449449
],
450450
output: {
451451
constrained: false,
452-
contentType: 'application/json',
453-
format: 'json',
454452
},
455453
tools: [],
456454
});
@@ -502,16 +500,14 @@ describe('simulateConstrainedGeneration', () => {
502500
],
503501
output: {
504502
constrained: false,
505-
contentType: 'application/json',
506-
format: 'json',
507503
},
508504
tools: [],
509505
});
510506
});
511507

512508
it('relies on native support -- no instructions', async () => {
513509
let pm = defineProgrammableModel(registry, {
514-
supports: { constrained: true },
510+
supports: { constrained: 'all' },
515511
});
516512
pm.handleResponse = async (req, sc) => {
517513
return {

js/genkit/tests/formats_test.ts

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17+
import { stripUndefinedProps } from '@genkit-ai/core';
1718
import * as assert from 'assert';
1819
import { beforeEach, describe, it } from 'node:test';
1920
import { GenkitBeta, genkit } from '../src/beta';
@@ -24,14 +25,11 @@ describe('formats', () => {
2425

2526
beforeEach(() => {
2627
ai = genkit({});
27-
defineEchoModel(ai);
28-
});
29-
30-
it('lets you define and use a custom output format', async () => {
3128
ai.defineFormat(
3229
{
3330
name: 'banana',
3431
format: 'banana',
32+
constrained: true,
3533
},
3634
(schema) => {
3735
let instructions: string | undefined;
@@ -53,6 +51,10 @@ describe('formats', () => {
5351
};
5452
}
5553
);
54+
});
55+
56+
it('lets you define and use a custom output format with native constrained generation', async () => {
57+
defineEchoModel(ai, { supports: { constrained: 'all' } });
5658

5759
const { output } = await ai.generate({
5860
model: 'echoModel',
@@ -73,5 +75,110 @@ describe('formats', () => {
7375
}
7476
assert.deepStrictEqual(chunks, ['banana: 3', 'banana: 2', 'banana: 1']);
7577
assert.strictEqual((await response).output, 'banana: Echo: hi');
78+
assert.deepStrictEqual(stripUndefinedProps((await response).request), {
79+
config: {},
80+
messages: [
81+
{
82+
content: [{ text: 'hi' }],
83+
role: 'user',
84+
},
85+
],
86+
output: {
87+
constrained: true,
88+
format: 'banana',
89+
schema: {},
90+
},
91+
tools: [],
92+
});
93+
});
94+
95+
it('lets you define and use a custom output format with simulated constrained generation', async () => {
96+
defineEchoModel(ai, { supports: { constrained: false } });
97+
98+
const { output } = await ai.generate({
99+
model: 'echoModel',
100+
prompt: 'hi',
101+
output: { format: 'banana' },
102+
});
103+
104+
assert.strictEqual(
105+
output,
106+
'banana: Echo: hi,Output should be in JSON format and conform to the following schema:\n' +
107+
'\n' +
108+
'```\n' +
109+
'{}\n' +
110+
'```\n'
111+
);
112+
113+
const { response, stream } = await ai.generateStream({
114+
model: 'echoModel',
115+
prompt: 'hi',
116+
output: { format: 'banana' },
117+
});
118+
const chunks: string[] = [];
119+
for await (const chunk of stream) {
120+
chunks.push(`${chunk.output}`);
121+
}
122+
assert.deepStrictEqual(chunks, ['banana: 3', 'banana: 2', 'banana: 1']);
123+
assert.strictEqual(
124+
(await response).output,
125+
'banana: Echo: hi,Output should be in JSON format and conform to the following schema:\n' +
126+
'\n' +
127+
'```\n' +
128+
'{}\n' +
129+
'```\n'
130+
);
131+
assert.deepStrictEqual(stripUndefinedProps((await response).request), {
132+
config: {},
133+
messages: [
134+
{
135+
content: [{ text: 'hi' }],
136+
role: 'user',
137+
},
138+
],
139+
output: {
140+
constrained: true,
141+
format: 'banana',
142+
schema: {},
143+
},
144+
tools: [],
145+
});
146+
});
147+
148+
it('used explicitly specified output options overriding format options', async () => {
149+
defineEchoModel(ai, { supports: { constrained: 'all' } });
150+
const response = await ai.generate({
151+
model: 'echoModel',
152+
prompt: 'hi',
153+
output: {
154+
format: 'banana',
155+
// Explicitly specified, should ignore whatever format sets
156+
constrained: false,
157+
jsonSchema: { type: 'string' },
158+
},
159+
});
160+
assert.deepStrictEqual(stripUndefinedProps(response.request), {
161+
config: {},
162+
messages: [
163+
{
164+
content: [
165+
{ text: 'hi' },
166+
{
167+
text: 'Output should be in banana format',
168+
metadata: { purpose: 'output' },
169+
},
170+
],
171+
role: 'user',
172+
},
173+
],
174+
output: {
175+
constrained: false,
176+
format: 'banana',
177+
schema: {
178+
type: 'string',
179+
},
180+
},
181+
tools: [],
182+
});
76183
});
77184
});

js/genkit/tests/helpers.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import { MessageData } from '@genkit-ai/ai';
1818
import { BaseEvalDataPoint } from '@genkit-ai/ai/evaluator';
19-
import { ModelAction } from '@genkit-ai/ai/model';
19+
import { ModelAction, ModelInfo } from '@genkit-ai/ai/model';
2020
import { SessionData, SessionStore } from '@genkit-ai/ai/session';
2121
import { StreamingCallback } from '@genkit-ai/core';
2222
import { Genkit } from '../src/genkit';
@@ -26,10 +26,14 @@ import {
2626
GenerateResponseData,
2727
} from '../src/model';
2828

29-
export function defineEchoModel(ai: Genkit): ModelAction {
29+
export function defineEchoModel(
30+
ai: Genkit,
31+
modelInfo?: ModelInfo
32+
): ModelAction {
3033
const model = ai.defineModel(
3134
{
3235
name: 'echoModel',
36+
...modelInfo,
3337
},
3438
async (request, sendChunk) => {
3539
(model as any).__test__lastRequest = request;

0 commit comments

Comments
 (0)