diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 1f9ab3158e..0b2a396579 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -26,10 +26,7 @@ import { EvalKeyAugments, EvalRun, EvalRunKey, - GenerateRequest, - GenerateRequestSchema, GenerateResponseSchema, - MessageData, RunNewEvaluationRequest, SpanData, } from '../types'; @@ -37,6 +34,7 @@ import { evaluatorName, generateTestCaseId, getEvalExtractors, + getModelInput, hasAction, isEvaluator, logger, @@ -440,30 +438,3 @@ function isSupportedActionRef(actionRef: string) { actionRef.startsWith(`/${supportedType}`) ); } - -function getModelInput(data: any, modelConfig: any): GenerateRequest { - let message: MessageData; - if (typeof data === 'string') { - message = { - role: 'user', - content: [ - { - text: data, - }, - ], - } as MessageData; - return { - messages: message ? [message] : [], - config: modelConfig, - }; - } else { - const maybeRequest = GenerateRequestSchema.safeParse(data); - if (maybeRequest.success) { - return maybeRequest.data; - } else { - throw new Error( - `Unable to parse model input as MessageSchema as input. Details: ${maybeRequest.error}` - ); - } - } -} diff --git a/genkit-tools/common/src/eval/validate.ts b/genkit-tools/common/src/eval/validate.ts index 2cf0f0bf3d..911106ab40 100644 --- a/genkit-tools/common/src/eval/validate.ts +++ b/genkit-tools/common/src/eval/validate.ts @@ -25,6 +25,7 @@ import { ValidateDataRequest, ValidateDataResponse, } from '../types'; +import { getModelInput } from '../utils'; // Setup for AJV type JSONSchema = JSONSchemaType | any; @@ -61,8 +62,8 @@ export async function validateSchema( if (dataset.length === 0) { return { valid: true }; } - dataset.forEach((sample, index) => { - const response = validate(targetSchema, sample.input); + dataset.forEach((sample) => { + const response = validate(actionRef, targetSchema, sample.input); if (!response.valid) { errorsMap[sample.testCaseId] = response.errors ?? []; } @@ -74,7 +75,7 @@ export async function validateSchema( } else { const dataset = InferenceDatasetSchema.parse(data); dataset.forEach((sample, index) => { - const response = validate(targetSchema, sample.input); + const response = validate(actionRef, targetSchema, sample.input); if (!response.valid) { errorsMap[index.toString()] = response.errors ?? []; } @@ -86,11 +87,16 @@ export async function validateSchema( } function validate( + actionRef: string, jsonSchema: JSONSchema, data: unknown ): { valid: boolean; errors?: ErrorDetail[] } { + const isModelAction = actionRef.startsWith('/model'); + let input = isModelAction + ? getModelInput(data, /* modelConfig= */ undefined) + : data; const validator = ajv.compile(jsonSchema); - const valid = validator(data) as boolean; + const valid = validator(input) as boolean; const errors = validator.errors?.map((e) => e); return { valid, errors: errors?.map(toErrorDetail) }; } diff --git a/genkit-tools/common/src/utils/eval.ts b/genkit-tools/common/src/utils/eval.ts index 1c3aa4bc73..c8b4a108ae 100644 --- a/genkit-tools/common/src/utils/eval.ts +++ b/genkit-tools/common/src/utils/eval.ts @@ -37,9 +37,12 @@ import { EvaluationDatasetSchema, EvaluationSample, EvaluationSampleSchema, + GenerateRequest, + GenerateRequestSchema, InferenceDatasetSchema, InferenceSample, InferenceSampleSchema, + MessageData, } from '../types'; import { Action } from '../types/action'; import { DocumentData, RetrieverResponse } from '../types/retrievers'; @@ -330,3 +333,31 @@ export async function hasAction(params: { return actionsRecord.hasOwnProperty(actionRef); } + +/** Helper function that maps string data to GenerateRequest */ +export function getModelInput(data: any, modelConfig: any): GenerateRequest { + let message: MessageData; + if (typeof data === 'string') { + message = { + role: 'user', + content: [ + { + text: data, + }, + ], + } as MessageData; + return { + messages: message ? [message] : [], + config: modelConfig, + }; + } else { + const maybeRequest = GenerateRequestSchema.safeParse(data); + if (maybeRequest.success) { + return maybeRequest.data; + } else { + throw new Error( + `Unable to parse model input as MessageSchema as input. Details: ${maybeRequest.error}` + ); + } + } +}