diff --git a/src/adapters/openai.ts b/src/adapters/openai.ts index 7da2481..7a825d1 100644 --- a/src/adapters/openai.ts +++ b/src/adapters/openai.ts @@ -1,6 +1,7 @@ import type OpenAI from 'openai'; import { AnyEventObject, + AnyMachineSnapshot, Observer, fromObservable, fromPromise, @@ -11,6 +12,7 @@ import { getAllTransitions } from '../utils'; import { ChatCompletionCreateParamsNonStreaming } from 'openai/resources'; import { ChatCompletionCreateParamsBase } from 'openai/resources/chat/completions'; import { StatelyAgentAdapter, Tool } from '../types'; +import { ZodEventTypes } from '../schemas'; /** * Creates [promise actor logic](https://stately.ai/docs/promise-actors) that uses the OpenAI API to generate a completion. @@ -107,6 +109,83 @@ export function fromChatStream( ); } +export async function getToolCalls( + openai: OpenAI, + goal: string, + snapshot: AnyMachineSnapshot, + model: string, + eventSchemas: ZodEventTypes = {} +): Promise< + Array<{ + type: `agent.${string}`; + params: Record; + }> +> { + const eventSchemaMap = eventSchemas; + const transitions = getAllTransitions(snapshot); + const functionNameMapping: Record = {}; + const tools = transitions + .filter((t) => { + // return !t.eventType.startsWith('xstate.'); + return t.eventType.startsWith('agent.'); + }) + .map((t) => { + const name = t.eventType.replace(/\./g, '_'); + functionNameMapping[name] = t.eventType; + const eventSchema = eventSchemaMap[t.eventType]; + const { + description, + properties: { type, ...properties }, + } = (eventSchema as any) ?? {}; + + return { + type: 'function', + function: { + name, + description: t.description ?? description, + parameters: { + type: 'object', + properties: properties ?? {}, + }, + }, + } as const; + }); + + if (!tools.length) { + return []; + } + + const completionParams: ChatCompletionCreateParamsNonStreaming = { + model, + messages: [ + { + role: 'user', + content: goal, + }, + ], + }; + + const completion = await openai.chat.completions.create({ + ...completionParams, + tools, + }); + + const toolCalls = completion.choices[0]?.message.tool_calls; + + if (toolCalls?.length) { + const events = toolCalls.map((tc) => { + return { + type: functionNameMapping[tc.function.name], + ...JSON.parse(tc.function.arguments), + }; + }); + + return events; + } + + return []; +} + /** * Creates [promise actor logic](https://stately.ai/docs/promise-actors) that passes the next possible transitions as functions to [OpenAI tool calls](https://platform.openai.com/docs/guides/function-calling) and returns an array of potential next events. * diff --git a/src/agent.ts b/src/agent.ts index 29f00f3..1a0b0bc 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -10,6 +10,11 @@ import { SnapshotFrom, createActor, } from 'xstate'; +import { getAllTransitions } from './utils'; +import OpenAI from 'openai'; +import { StatelyAgentAdapter } from './types'; +import { getToolCalls } from './adapters/openai'; +import { ZodEventTypes } from './schemas'; // export type AgentExperiences = Record< // string, // serialized state @@ -39,7 +44,7 @@ export type AgentPlan = Array<{ /** * The expected next state */ - nextState: SnapshotFrom; + nextState?: SnapshotFrom; }>; export interface AgentModel< @@ -80,15 +85,6 @@ export interface AgentModel< state: TState; goal: string; }) => Promise>>; - getNextPlan: ({ - logic, - state, - goal, - }: { - logic: TLogic; - state: TState; - goal: string; - }) => AgentPlan; getReward: ({ logic, state, @@ -118,17 +114,46 @@ export interface Agent } export function createAgent( + openai: OpenAI, logic: TLogic, input: InputFrom, - goal: string // TODO: () => string ? + goal: string, // TODO: () => string ? + schemas: ZodEventTypes ): Agent { const experiences: Array> = []; - const agentModel: AgentModel = { - // addExperience: (experience) => { - // experiences.push(experience); - // }, - } as unknown as AgentModel; + const agentModel = { + policy: async ({ logic, state, goal }) => { + const toolEvents = await getToolCalls( + openai, + goal, + state, + 'gpt-3.5-turbo-16k-0613', + schemas + ); + + return toolEvents.map((te) => ({ + state, + event: te as EventFromLogic, + })); + }, + addExperience: (experience) => { + experiences.push(experience); + }, + getExperiences: async () => experiences, + getLogic: async ({ experiences }) => { + return logic; + }, + getNextEvents: async ({ logic, state }) => { + return []; + }, + getReward: async ({ logic, state, goal, action }) => { + return 0; + }, + getPlans: async ({ logic, state, goal }) => { + return []; + }, + } satisfies AgentModel; const actor = createActor(logic, { input,