Skip to content

Commit

Permalink
feat: o1 model support (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdjastrzebski authored Dec 16, 2024
1 parent d6b3e3a commit 3770dfb
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 52 deletions.
5 changes: 5 additions & 0 deletions .changeset/rude-points-lick.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@callstack/byorg-core': minor
---

core: chatModel is customizable using RequestContext, removed default maxTokens and maxSteps values
5 changes: 4 additions & 1 deletion docs/src/docs/core/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ export type RequestContext = {
/** Ids of users who are a part of conversation */
resolvedEntities: EntityInfo;

/** Function for generating a system prompt */
/** Chat model instance */
chatModel: ChatModel;

/** Function generating a system prompt (bound to the context) */
systemPrompt: () => string | null;

/**
Expand Down
4 changes: 2 additions & 2 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
"zod": "^3.23.8"
},
"peerDependencies": {
"ai": "^4.0.3"
"ai": "^4.0.18"
},
"devDependencies": {
"@microsoft/api-extractor": "catalog:",
"@rslib/core": "catalog:",
"ai": "^4.0.13",
"ai": "^4.0.18",
"vitest": "catalog:"
}
}
26 changes: 25 additions & 1 deletion packages/core/src/__tests__/application.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { expect, test, vitest } from 'vitest';
import { createApp } from '../application.js';
import { Message } from '../domain.js';
import { Message, MessageResponse, RequestContext } from '../domain.js';
import { createMockChatModel } from '../mock/mock-model.js';

const messages: Message[] = [{ role: 'user', content: 'Hello' }];
Expand Down Expand Up @@ -69,3 +69,27 @@ test('basic streaming test', async () => {
});
await expect(result.pendingEffects).resolves.toEqual([]);
});

test('uses chat model from context', async () => {
const baseModel = createMockChatModel({ delay: 0, seed: 3 });
const altModel = createMockChatModel({ delay: 0, seed: 3 });

async function altModelMiddleware(context: RequestContext, next: () => Promise<MessageResponse>) {
context.chatModel = altModel;
return await next();
}

const app = createApp({
chatModel: baseModel,
plugins: [
{
name: 'middleware',
middleware: altModelMiddleware,
},
],
});

await app.processMessages(messages);
expect(altModel.calls.length).toBe(1);
expect(baseModel.calls.length).toBe(0);
});
11 changes: 4 additions & 7 deletions packages/core/src/ai/vercel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ import { RequestContext, Message } from '../domain.js';
import { ApplicationTool } from '../tools.js';
import type { ChatModel, AssistantResponse, ModelUsage } from './types.js';

const DEFAULT_MAX_TOKENS = 1024;
const DEFAULT_MAX_STEPS = 5;

// Workaround for memory issue happening when sending image attachment. The attachments get inefficiently serialised causing a memory spike.
const VERCEL_AI_SHARED_OPTIONS = {
experimental_telemetry: {
Expand Down Expand Up @@ -127,9 +124,9 @@ export class VercelChatModelAdapter implements ChatModel {
const result = await streamText({
...VERCEL_AI_SHARED_OPTIONS,
model: this._options.languageModel,
maxTokens: this._options.maxTokens ?? DEFAULT_MAX_TOKENS,
maxSteps: this._options.maxSteps ?? DEFAULT_MAX_STEPS,
messages: context.messages,
maxTokens: this._options.maxTokens,
maxSteps: this._options.maxSteps,
tools: context.tools,
});

Expand Down Expand Up @@ -160,9 +157,9 @@ export class VercelChatModelAdapter implements ChatModel {
const result = await generateText({
...VERCEL_AI_SHARED_OPTIONS,
model: this._options.languageModel,
maxTokens: this._options.maxTokens ?? DEFAULT_MAX_TOKENS,
maxSteps: this._options.maxSteps ?? DEFAULT_MAX_STEPS,
messages: context.messages,
maxTokens: this._options.maxTokens,
maxSteps: this._options.maxSteps,
tools: context.tools,
});
const responseTime = performance.now() - startTime;
Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export type ErrorHandler = (
) => Promise<MessageResponse> | MessageResponse;

export type ApplicationConfig = {
chatModel: ChatModel | ((context: RequestContext) => ChatModel);
chatModel: ChatModel;
systemPrompt?: ((context: RequestContext) => string | null) | string;
plugins?: ApplicationPlugin[];
errorHandler?: ErrorHandler;
Expand All @@ -53,7 +53,7 @@ export type ProcessMessageOptions = {
};

export function createApp(config: ApplicationConfig): Application {
const { plugins = [], chatModel, errorHandler = defaultErrorHandler } = config;
const { plugins = [], errorHandler = defaultErrorHandler } = config;

plugins.forEach((plugin) => {
logger.debug(`Plugin "${plugin.name}" registered`);
Expand Down Expand Up @@ -100,6 +100,7 @@ export function createApp(config: ApplicationConfig): Application {

return lastMessage;
},
chatModel: config.chatModel,
systemPrompt: () =>
typeof config.systemPrompt === 'function'
? config.systemPrompt(context)
Expand All @@ -117,8 +118,7 @@ export function createApp(config: ApplicationConfig): Application {
performance.markEnd(PerformanceMarks.middlewareBeforeHandler);

performance.markStart(PerformanceMarks.chatModel);
const resolvedChatModel = typeof chatModel === 'function' ? chatModel(context) : chatModel;
const response = await resolvedChatModel.generateResponse(context);
const response = await context.chatModel.generateResponse(context);
performance.markEnd(PerformanceMarks.chatModel);

// Opens the 'middleware:afterHandler' mark that will be closed after middlewareExecutor has run
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/domain.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ApplicationTool } from './tools.js';
import type { AssistantResponse } from './ai/types.js';
import type { AssistantResponse, ChatModel } from './ai/types.js';
import { ReferenceStorage } from './references.js';
import { PerformanceTimeline } from './performance.js';

Expand Down Expand Up @@ -30,6 +30,7 @@ export type RequestContext = {
tools: ApplicationTool[];
references: ReferenceStorage;
resolvedEntities: EntityInfo;
chatModel: ChatModel;
systemPrompt: () => string | null;
onPartialResponse?: (text: string) => void;
extras: MessageRequestExtras;
Expand Down
10 changes: 9 additions & 1 deletion packages/core/src/mock/mock-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@ export type MockChatModelConfig = {
seed?: number;
};

export function createMockChatModel(config?: MockChatModelConfig): ChatModel {
export type MockChatModel = ChatModel & {
calls: Parameters<ChatModel['generateResponse']>[];
};

export function createMockChatModel(config?: MockChatModelConfig): MockChatModel {
const responses = config?.responses ?? LOREM_IPSUM_RESPONSES;
const delay = config?.delay ?? 100;

const calls: Parameters<ChatModel['generateResponse']>[] = [];

let lastRandom = config?.seed ?? Date.now();
return {
calls,
generateResponse: async (context: RequestContext): Promise<AssistantResponse> => {
calls.push([context]);
lastRandom = random(lastRandom);

const response = responses[lastRandom % responses.length];
Expand Down
63 changes: 28 additions & 35 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3770dfb

Please sign in to comment.