Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkpiano committed Apr 12, 2024
1 parent 5133f85 commit 958cb25
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 16 deletions.
79 changes: 79 additions & 0 deletions src/adapters/openai.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type OpenAI from 'openai';
import {
AnyEventObject,
AnyMachineSnapshot,
Observer,
fromObservable,
fromPromise,
Expand All @@ -11,6 +12,7 @@ import { getAllTransitions } from '../utils';
import { ChatCompletionCreateParamsNonStreaming } from 'openai/resources';
import { ChatCompletionCreateParamsBase } from 'openai/resources/chat/completions';
import { StatelyAgentAdapter, Tool } from '../types';
import { ZodEventTypes } from '../schemas';

/**
* Creates [promise actor logic](https://stately.ai/docs/promise-actors) that uses the OpenAI API to generate a completion.
Expand Down Expand Up @@ -107,6 +109,83 @@ export function fromChatStream<TInput>(
);
}

export async function getToolCalls(
openai: OpenAI,
goal: string,
snapshot: AnyMachineSnapshot,
model: string,
eventSchemas: ZodEventTypes = {}
): Promise<
Array<{
type: `agent.${string}`;
params: Record<string, any>;
}>
> {
const eventSchemaMap = eventSchemas;
const transitions = getAllTransitions(snapshot);
const functionNameMapping: Record<string, string> = {};
const tools = transitions
.filter((t) => {
// return !t.eventType.startsWith('xstate.');
return t.eventType.startsWith('agent.');
})
.map((t) => {
const name = t.eventType.replace(/\./g, '_');
functionNameMapping[name] = t.eventType;
const eventSchema = eventSchemaMap[t.eventType];
const {
description,
properties: { type, ...properties },
} = (eventSchema as any) ?? {};

return {
type: 'function',
function: {
name,
description: t.description ?? description,
parameters: {
type: 'object',
properties: properties ?? {},
},
},
} as const;
});

if (!tools.length) {
return [];
}

const completionParams: ChatCompletionCreateParamsNonStreaming = {
model,
messages: [
{
role: 'user',
content: goal,
},
],
};

const completion = await openai.chat.completions.create({
...completionParams,
tools,
});

const toolCalls = completion.choices[0]?.message.tool_calls;

if (toolCalls?.length) {
const events = toolCalls.map((tc) => {
return {
type: functionNameMapping[tc.function.name],
...JSON.parse(tc.function.arguments),
};
});

return events;
}

return [];
}

/**
* Creates [promise actor logic](https://stately.ai/docs/promise-actors) that passes the next possible transitions as functions to [OpenAI tool calls](https://platform.openai.com/docs/guides/function-calling) and returns an array of potential next events.
*
Expand Down
57 changes: 41 additions & 16 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import {
SnapshotFrom,
createActor,
} from 'xstate';
import { getAllTransitions } from './utils';
import OpenAI from 'openai';
import { StatelyAgentAdapter } from './types';
import { getToolCalls } from './adapters/openai';
import { ZodEventTypes } from './schemas';

// export type AgentExperiences<TState, TReward> = Record<
// string, // serialized state
Expand Down Expand Up @@ -39,7 +44,7 @@ export type AgentPlan<TLogic extends AnyActorLogic> = Array<{
/**
* The expected next state
*/
nextState: SnapshotFrom<TLogic>;
nextState?: SnapshotFrom<TLogic>;
}>;

export interface AgentModel<
Expand Down Expand Up @@ -80,15 +85,6 @@ export interface AgentModel<
state: TState;
goal: string;
}) => Promise<Array<AgentPlan<TState>>>;
getNextPlan: ({
logic,
state,
goal,
}: {
logic: TLogic;
state: TState;
goal: string;
}) => AgentPlan<TState>;
getReward: ({
logic,
state,
Expand Down Expand Up @@ -118,17 +114,46 @@ export interface Agent<TLogic extends AnyActorLogic>
}

export function createAgent<TLogic extends AnyActorLogic>(
openai: OpenAI,
logic: TLogic,
input: InputFrom<TLogic>,
goal: string // TODO: () => string ?
goal: string, // TODO: () => string ?
schemas: ZodEventTypes
): Agent<TLogic> {
const experiences: Array<AgentExperience<any, any>> = [];

const agentModel: AgentModel<TLogic, any> = {
// addExperience: (experience) => {
// experiences.push(experience);
// },
} as unknown as AgentModel<TLogic, any>;
const agentModel = {
policy: async ({ logic, state, goal }) => {
const toolEvents = await getToolCalls(
openai,
goal,
state,
'gpt-3.5-turbo-16k-0613',
schemas
);

return toolEvents.map((te) => ({
state,
event: te as EventFromLogic<TLogic>,
}));
},
addExperience: (experience) => {
experiences.push(experience);
},
getExperiences: async () => experiences,
getLogic: async ({ experiences }) => {
return logic;
},
getNextEvents: async ({ logic, state }) => {
return [];
},
getReward: async ({ logic, state, goal, action }) => {
return 0;
},
getPlans: async ({ logic, state, goal }) => {
return [];
},
} satisfies AgentModel<TLogic, any>;

const actor = createActor(logic, {
input,
Expand Down

0 comments on commit 958cb25

Please sign in to comment.