diff --git a/.changeset/rude-points-lick.md b/.changeset/rude-points-lick.md new file mode 100644 index 0000000..29665ed --- /dev/null +++ b/.changeset/rude-points-lick.md @@ -0,0 +1,5 @@ +--- +'@callstack/byorg-core': minor +--- + +core: chatModel is customizable using RequestContext, removed default maxTokens and maxSteps values diff --git a/docs/src/docs/core/context.md b/docs/src/docs/core/context.md index 7774419..8e729d2 100644 --- a/docs/src/docs/core/context.md +++ b/docs/src/docs/core/context.md @@ -23,7 +23,10 @@ export type RequestContext = { /** Ids of users who are a part of conversation */ resolvedEntities: EntityInfo; - /** Function for generating a system prompt */ + /** Chat model instance */ + chatModel: ChatModel; + + /** Function generating a system prompt (bound to the context) */ systemPrompt: () => string | null; /** diff --git a/packages/core/package.json b/packages/core/package.json index 92eea09..1fc1119 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -34,12 +34,12 @@ "zod": "^3.23.8" }, "peerDependencies": { - "ai": "^4.0.3" + "ai": "^4.0.18" }, "devDependencies": { "@microsoft/api-extractor": "catalog:", "@rslib/core": "catalog:", - "ai": "^4.0.13", + "ai": "^4.0.18", "vitest": "catalog:" } } diff --git a/packages/core/src/__tests__/application.test.ts b/packages/core/src/__tests__/application.test.ts index 6191209..5b71ff5 100644 --- a/packages/core/src/__tests__/application.test.ts +++ b/packages/core/src/__tests__/application.test.ts @@ -1,6 +1,6 @@ import { expect, test, vitest } from 'vitest'; import { createApp } from '../application.js'; -import { Message } from '../domain.js'; +import { Message, MessageResponse, RequestContext } from '../domain.js'; import { createMockChatModel } from '../mock/mock-model.js'; const messages: Message[] = [{ role: 'user', content: 'Hello' }]; @@ -69,3 +69,27 @@ test('basic streaming test', async () => { }); await expect(result.pendingEffects).resolves.toEqual([]); }); + +test('uses chat model from context', async () => { + const baseModel = createMockChatModel({ delay: 0, seed: 3 }); + const altModel = createMockChatModel({ delay: 0, seed: 3 }); + + async function altModelMiddleware(context: RequestContext, next: () => Promise) { + context.chatModel = altModel; + return await next(); + } + + const app = createApp({ + chatModel: baseModel, + plugins: [ + { + name: 'middleware', + middleware: altModelMiddleware, + }, + ], + }); + + await app.processMessages(messages); + expect(altModel.calls.length).toBe(1); + expect(baseModel.calls.length).toBe(0); +}); diff --git a/packages/core/src/ai/vercel.ts b/packages/core/src/ai/vercel.ts index f50dee0..c2822a9 100644 --- a/packages/core/src/ai/vercel.ts +++ b/packages/core/src/ai/vercel.ts @@ -16,9 +16,6 @@ import { RequestContext, Message } from '../domain.js'; import { ApplicationTool } from '../tools.js'; import type { ChatModel, AssistantResponse, ModelUsage } from './types.js'; -const DEFAULT_MAX_TOKENS = 1024; -const DEFAULT_MAX_STEPS = 5; - // Workaround for memory issue happening when sending image attachment. The attachments get inefficiently serialised causing a memory spike. const VERCEL_AI_SHARED_OPTIONS = { experimental_telemetry: { @@ -127,9 +124,9 @@ export class VercelChatModelAdapter implements ChatModel { const result = await streamText({ ...VERCEL_AI_SHARED_OPTIONS, model: this._options.languageModel, - maxTokens: this._options.maxTokens ?? DEFAULT_MAX_TOKENS, - maxSteps: this._options.maxSteps ?? DEFAULT_MAX_STEPS, messages: context.messages, + maxTokens: this._options.maxTokens, + maxSteps: this._options.maxSteps, tools: context.tools, }); @@ -160,9 +157,9 @@ export class VercelChatModelAdapter implements ChatModel { const result = await generateText({ ...VERCEL_AI_SHARED_OPTIONS, model: this._options.languageModel, - maxTokens: this._options.maxTokens ?? DEFAULT_MAX_TOKENS, - maxSteps: this._options.maxSteps ?? DEFAULT_MAX_STEPS, messages: context.messages, + maxTokens: this._options.maxTokens, + maxSteps: this._options.maxSteps, tools: context.tools, }); const responseTime = performance.now() - startTime; diff --git a/packages/core/src/application.ts b/packages/core/src/application.ts index e422cd2..a25523b 100644 --- a/packages/core/src/application.ts +++ b/packages/core/src/application.ts @@ -32,7 +32,7 @@ export type ErrorHandler = ( ) => Promise | MessageResponse; export type ApplicationConfig = { - chatModel: ChatModel | ((context: RequestContext) => ChatModel); + chatModel: ChatModel; systemPrompt?: ((context: RequestContext) => string | null) | string; plugins?: ApplicationPlugin[]; errorHandler?: ErrorHandler; @@ -53,7 +53,7 @@ export type ProcessMessageOptions = { }; export function createApp(config: ApplicationConfig): Application { - const { plugins = [], chatModel, errorHandler = defaultErrorHandler } = config; + const { plugins = [], errorHandler = defaultErrorHandler } = config; plugins.forEach((plugin) => { logger.debug(`Plugin "${plugin.name}" registered`); @@ -100,6 +100,7 @@ export function createApp(config: ApplicationConfig): Application { return lastMessage; }, + chatModel: config.chatModel, systemPrompt: () => typeof config.systemPrompt === 'function' ? config.systemPrompt(context) @@ -117,8 +118,7 @@ export function createApp(config: ApplicationConfig): Application { performance.markEnd(PerformanceMarks.middlewareBeforeHandler); performance.markStart(PerformanceMarks.chatModel); - const resolvedChatModel = typeof chatModel === 'function' ? chatModel(context) : chatModel; - const response = await resolvedChatModel.generateResponse(context); + const response = await context.chatModel.generateResponse(context); performance.markEnd(PerformanceMarks.chatModel); // Opens the 'middleware:afterHandler' mark that will be closed after middlewareExecutor has run diff --git a/packages/core/src/domain.ts b/packages/core/src/domain.ts index 8ed652c..48a4920 100644 --- a/packages/core/src/domain.ts +++ b/packages/core/src/domain.ts @@ -1,5 +1,5 @@ import { ApplicationTool } from './tools.js'; -import type { AssistantResponse } from './ai/types.js'; +import type { AssistantResponse, ChatModel } from './ai/types.js'; import { ReferenceStorage } from './references.js'; import { PerformanceTimeline } from './performance.js'; @@ -30,6 +30,7 @@ export type RequestContext = { tools: ApplicationTool[]; references: ReferenceStorage; resolvedEntities: EntityInfo; + chatModel: ChatModel; systemPrompt: () => string | null; onPartialResponse?: (text: string) => void; extras: MessageRequestExtras; diff --git a/packages/core/src/mock/mock-model.ts b/packages/core/src/mock/mock-model.ts index 4eda6fb..3acc6e1 100644 --- a/packages/core/src/mock/mock-model.ts +++ b/packages/core/src/mock/mock-model.ts @@ -19,13 +19,21 @@ export type MockChatModelConfig = { seed?: number; }; -export function createMockChatModel(config?: MockChatModelConfig): ChatModel { +export type MockChatModel = ChatModel & { + calls: Parameters[]; +}; + +export function createMockChatModel(config?: MockChatModelConfig): MockChatModel { const responses = config?.responses ?? LOREM_IPSUM_RESPONSES; const delay = config?.delay ?? 100; + const calls: Parameters[] = []; + let lastRandom = config?.seed ?? Date.now(); return { + calls, generateResponse: async (context: RequestContext): Promise => { + calls.push([context]); lastRandom = random(lastRandom); const response = responses[lastRandom % responses.length]; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9ed4764..f3cba6c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -203,8 +203,8 @@ importers: specifier: 'catalog:' version: 0.1.3(@microsoft/api-extractor@7.48.0(@types/node@22.9.1))(typescript@5.7.2) ai: - specifier: ^4.0.13 - version: 4.0.13(react@18.3.1)(zod@3.23.8) + specifier: ^4.0.18 + version: 4.0.18(react@18.3.1)(zod@3.23.8) vitest: specifier: 'catalog:' version: 2.1.8(@types/node@22.9.1)(sass-embedded@1.80.6)(terser@5.37.0) @@ -362,8 +362,8 @@ packages: zod: optional: true - '@ai-sdk/provider-utils@2.0.3': - resolution: {integrity: sha512-Cyk7GlFEse2jQ4I3FWYuZ1Zhr5w1mD9SHMJTYm/in1rd7r89nmEoQiOy3h8YV2ZvTa2/6aR10xZ4M0k4B3BluA==} + '@ai-sdk/provider-utils@2.0.4': + resolution: {integrity: sha512-GMhcQCZbwM6RoZCri0MWeEWXRt/T+uCxsmHEsTwNvEH3GDjNzchfX25C8ftry2MeEOOn6KfqCLSKomcgK6RoOg==} engines: {node: '>=18'} peerDependencies: zod: ^3.0.0 @@ -375,12 +375,12 @@ packages: resolution: {integrity: sha512-Sj29AzooJ7SYvhPd+AAWt/E7j63E9+AzRnoMHUaJPRYzOd/WDrVNxxv85prF9gDcQ7XPVlSk9j6oAZV9/DXYpA==} engines: {node: '>=18'} - '@ai-sdk/provider@1.0.1': - resolution: {integrity: sha512-mV+3iNDkzUsZ0pR2jG0sVzU6xtQY5DtSCBy3JFycLp6PwjyLw/iodfL3MwdmMCRJWgs3dadcHejRnMvF9nGTBg==} + '@ai-sdk/provider@1.0.2': + resolution: {integrity: sha512-YYtP6xWQyaAf5LiWLJ+ycGTOeBLWrED7LUrvc+SQIWhGaneylqbaGsyQL7VouQUeQ4JZ1qKYZuhmi3W56HADPA==} engines: {node: '>=18'} - '@ai-sdk/react@1.0.5': - resolution: {integrity: sha512-OPqYhltJE9dceWxw5pTXdYtAhs1Ca6Ly8xR7z/T+JZ0lrcgembFIMvnJ0dMBkba07P4GQBmuvd5DVTeAqPM9SQ==} + '@ai-sdk/react@1.0.6': + resolution: {integrity: sha512-8Hkserq0Ge6AEi7N4hlv2FkfglAGbkoAXEZ8YSp255c3PbnZz6+/5fppw+aROmZMOfNwallSRuy1i/iPa2rBpQ==} engines: {node: '>=18'} peerDependencies: react: ^18 || ^19 || ^19.0.0-rc @@ -391,8 +391,8 @@ packages: zod: optional: true - '@ai-sdk/ui-utils@1.0.4': - resolution: {integrity: sha512-P2vDvASaGsD+lmbsQ5WYjELxJBQgse3CpxyLSA+usZiZxspwYbLFsSWiYz3zhIemcnS0T6/OwQdU6UlMB4N5BQ==} + '@ai-sdk/ui-utils@1.0.5': + resolution: {integrity: sha512-DGJSbDf+vJyWmFNexSPUsS1AAy7gtsmFmoSyNbNbJjwl9hRIf2dknfA1V0ahx6pg3NNklNYFm53L8Nphjovfvg==} engines: {node: '>=18'} peerDependencies: zod: ^3.0.0 @@ -1976,8 +1976,8 @@ packages: resolution: {integrity: sha512-H0TSyFNDMomMNJQBn8wFV5YC/2eJ+VXECwOadZJT554xP6cODZHPX3H9QMQECxvrgiSOP1pHjy1sMWQVYJOUOA==} engines: {node: '>= 14'} - ai@4.0.13: - resolution: {integrity: sha512-ic+qEVPQhfLpGPnZ2M55ErofeuKaD/TQebeh0qSPwv2PF+dQwsPr2Pw+JNYXahezAOaxFNdrDPz0EF1kKcSFSw==} + ai@4.0.18: + resolution: {integrity: sha512-BTWzalLNE1LQphEka5xzJXDs5v4xXy1Uzr7dAVk+C/CnO3WNpuMBgrCymwUv0VrWaWc8xMQuh+OqsT7P7JyekQ==} engines: {node: '>=18'} peerDependencies: react: ^18 || ^19 || ^19.0.0-rc @@ -4352,11 +4352,6 @@ packages: nan@2.22.0: resolution: {integrity: sha512-nbajikzWTMwsW+eSsNm3QwlOs7het9gGJU5dDZzRTQGk03vyBOauxgI4VakDzE0PtsGTmXPsXTbbjVhRwR5mpw==} - nanoid@3.3.7: - resolution: {integrity: sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==} - engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} - hasBin: true - nanoid@3.3.8: resolution: {integrity: sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==} engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} @@ -5920,14 +5915,14 @@ snapshots: dependencies: '@ai-sdk/provider': 1.0.0 eventsource-parser: 3.0.0 - nanoid: 3.3.7 + nanoid: 3.3.8 secure-json-parse: 2.7.0 optionalDependencies: zod: 3.23.8 - '@ai-sdk/provider-utils@2.0.3(zod@3.23.8)': + '@ai-sdk/provider-utils@2.0.4(zod@3.23.8)': dependencies: - '@ai-sdk/provider': 1.0.1 + '@ai-sdk/provider': 1.0.2 eventsource-parser: 3.0.0 nanoid: 3.3.8 secure-json-parse: 2.7.0 @@ -5938,24 +5933,24 @@ snapshots: dependencies: json-schema: 0.4.0 - '@ai-sdk/provider@1.0.1': + '@ai-sdk/provider@1.0.2': dependencies: json-schema: 0.4.0 - '@ai-sdk/react@1.0.5(react@18.3.1)(zod@3.23.8)': + '@ai-sdk/react@1.0.6(react@18.3.1)(zod@3.23.8)': dependencies: - '@ai-sdk/provider-utils': 2.0.3(zod@3.23.8) - '@ai-sdk/ui-utils': 1.0.4(zod@3.23.8) + '@ai-sdk/provider-utils': 2.0.4(zod@3.23.8) + '@ai-sdk/ui-utils': 1.0.5(zod@3.23.8) swr: 2.2.5(react@18.3.1) throttleit: 2.1.0 optionalDependencies: react: 18.3.1 zod: 3.23.8 - '@ai-sdk/ui-utils@1.0.4(zod@3.23.8)': + '@ai-sdk/ui-utils@1.0.5(zod@3.23.8)': dependencies: - '@ai-sdk/provider': 1.0.1 - '@ai-sdk/provider-utils': 2.0.3(zod@3.23.8) + '@ai-sdk/provider': 1.0.2 + '@ai-sdk/provider-utils': 2.0.4(zod@3.23.8) zod-to-json-schema: 3.23.5(zod@3.23.8) optionalDependencies: zod: 3.23.8 @@ -8049,12 +8044,12 @@ snapshots: transitivePeerDependencies: - supports-color - ai@4.0.13(react@18.3.1)(zod@3.23.8): + ai@4.0.18(react@18.3.1)(zod@3.23.8): dependencies: - '@ai-sdk/provider': 1.0.1 - '@ai-sdk/provider-utils': 2.0.3(zod@3.23.8) - '@ai-sdk/react': 1.0.5(react@18.3.1)(zod@3.23.8) - '@ai-sdk/ui-utils': 1.0.4(zod@3.23.8) + '@ai-sdk/provider': 1.0.2 + '@ai-sdk/provider-utils': 2.0.4(zod@3.23.8) + '@ai-sdk/react': 1.0.6(react@18.3.1)(zod@3.23.8) + '@ai-sdk/ui-utils': 1.0.5(zod@3.23.8) '@opentelemetry/api': 1.9.0 jsondiffpatch: 0.6.0 zod-to-json-schema: 3.23.5(zod@3.23.8) @@ -11266,8 +11261,6 @@ snapshots: nan@2.22.0: {} - nanoid@3.3.7: {} - nanoid@3.3.8: {} natural-compare@1.4.0: {} @@ -11526,7 +11519,7 @@ snapshots: postcss@8.4.47: dependencies: - nanoid: 3.3.7 + nanoid: 3.3.8 picocolors: 1.1.1 source-map-js: 1.2.1