Skip to content

Commit

Permalink
feat: model aliases (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdjastrzebski authored Apr 2, 2024
1 parent 5a9acd5 commit 1592834
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 36 deletions.
22 changes: 20 additions & 2 deletions src/commands/chat/state/init.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { type ConfigFile } from '../../../config-file.js';
import { DEFAULT_SYSTEM_PROMPT } from '../../../default-config.js';
import type { ResponseStyle } from '../../../engine/providers/config.js';
import type { Message } from '../../../engine/inference.js';
import type { PromptOptions } from '../prompt-options.js';
Expand All @@ -20,16 +21,33 @@ export function initChatState(
throw new Error(`Provider config not found: ${provider.name}.`);
}

const modelOrAlias = options.model ?? providerFileConfig.model;
const model = modelOrAlias
? provider.modelAliases[modelOrAlias] ?? modelOrAlias
: provider.defaultModel;

const systemPrompt = !provider.skipSystemPrompt?.includes(model)
? providerFileConfig.systemPrompt ?? DEFAULT_SYSTEM_PROMPT
: undefined;

const providerConfig = {
model: options.model ?? providerFileConfig.model,
apiKey: providerFileConfig.apiKey,
systemPrompt: providerFileConfig.systemPrompt,
model,
systemPrompt,
responseStyle: getResponseStyle(options),
};

const contextMessages: Message[] = [];
const outputMessages: ChatMessage[] = [];

if (modelOrAlias != null && modelOrAlias !== model) {
outputMessages.push({
type: 'program',
level: 'debug',
text: `Resolved model alias "${modelOrAlias}" to "${model}".`,
});
}

outputMessages.push({
type: 'program',
level: 'debug',
Expand Down
2 changes: 1 addition & 1 deletion src/commands/chat/ui/StatusBar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export function StatusBar() {
const totalUsage = useMemo(() => calculateTotalUsage(items), [items]);
const totalTime = useMemo(() => calculateTotalResponseTime(items), [items]);

const modelPricing = provider.pricing[providerConfig.model];
const modelPricing = provider.modelPricing[providerConfig.model];
const totalCost = calculateUsageCost(totalUsage, modelPricing) ?? 0;

return (
Expand Down
2 changes: 1 addition & 1 deletion src/commands/chat/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export function handleInputFile(
const fileContent = fs.readFileSync(filePath).toString();
const fileTokens = getTokensCount(fileContent);

const pricing = provider.pricing[config.model];
const pricing = provider.modelPricing[config.model];
const fileCost = calculateUsageCost({ inputTokens: fileTokens }, pricing);

let costWarning = null;
Expand Down
21 changes: 10 additions & 11 deletions src/config-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { z } from 'zod';
import {
DEFAULT_OPEN_AI_MODEL,
DEFAULT_PERPLEXITY_MODEL,
DEFAULT_SYSTEM_PROMPT,
} from './default-config.js';

const LEGACY_CONFIG_FILENAME = '.airc';
const CONFIG_FILENAME = '.airc.json';
Expand All @@ -15,15 +10,15 @@ const ProvidersSchema = z.object({
openAi: z.optional(
z.object({
apiKey: z.string(),
model: z.string().default(DEFAULT_OPEN_AI_MODEL),
systemPrompt: z.string().default(DEFAULT_SYSTEM_PROMPT),
model: z.string().optional(),
systemPrompt: z.string().optional(),
}),
),
perplexity: z.optional(
z.object({
apiKey: z.string(),
model: z.string().default(DEFAULT_PERPLEXITY_MODEL),
systemPrompt: z.string().default(DEFAULT_SYSTEM_PROMPT),
model: z.string().optional(),
systemPrompt: z.string().optional(),
}),
),
});
Expand All @@ -42,10 +37,14 @@ export function parseConfigFile() {

const typedConfig = ConfigFileSchema.parse(json);
if (!typedConfig.providers.openAi?.apiKey && !typedConfig.providers.perplexity?.apiKey) {
throw new Error('Add your OpenAI or Perplexity API key to "~/.airc.json" and try again.');
throw new Error(
`Add your OpenAI or Perplexity API key to "~/${CONFIG_FILENAME}" and try again.`,
);
}

return typedConfig;
// Note: we return original json object, and not `typedConfig` because we want to preserve
// the original order of providers in the config file.
return json;
}

export function writeConfigFile(configContents: ConfigFile) {
Expand Down
6 changes: 0 additions & 6 deletions src/default-config.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
// OpenAI models: https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
export const DEFAULT_OPEN_AI_MODEL = 'gpt-4-turbo-preview';

// Perplexity models: https://docs.perplexity.ai/docs/model-cards
export const DEFAULT_PERPLEXITY_MODEL = 'sonar-medium-chat';

export const DEFAULT_SYSTEM_PROMPT =
'You are a helpful assistant responding in a concise manner to user questions.';

Expand Down
2 changes: 1 addition & 1 deletion src/engine/providers/config.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export interface ProviderConfig {
apiKey: string;
model: string;
systemPrompt: string;
systemPrompt?: string;
responseStyle: ResponseStyle;
}

Expand Down
10 changes: 9 additions & 1 deletion src/engine/providers/openAi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ const OpenAi: Provider = {
name: 'openAi',
apiKeyUrl: 'https://platform.openai.com/api-keys',

// OpenAI models: https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
defaultModel: 'gpt-4-turbo-preview',

// Price per 1k tokens [input, output].
// Source: https://openai.com/pricing
pricing: {
modelPricing: {
'gpt-4-turbo-preview': { inputTokensCost: 0.01, outputTokensCost: 0.03 },
'gpt-4-0125-preview': { inputTokensCost: 0.01, outputTokensCost: 0.03 },
'gpt-4-1106-preview': { inputTokensCost: 0.01, outputTokensCost: 0.03 },
Expand All @@ -23,6 +26,11 @@ const OpenAi: Provider = {
'gpt-3.5-turbo-0125': { inputTokensCost: 0.0005, outputTokensCost: 0.0015 },
},

modelAliases: {
'gpt-4-turbo': 'gpt-4-turbo-preview',
'gpt-3.5': 'gpt-3.5-turbo',
},

getChatCompletion: async (config: ProviderConfig, messages: Message[]) => {
const api = new OpenAI({
apiKey: config.apiKey,
Expand Down
14 changes: 13 additions & 1 deletion src/engine/providers/perplexity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ const Perplexity: Provider = {
name: 'perplexity',
apiKeyUrl: 'https://perplexity.ai/settings/api',

// Perplexity models: https://docs.perplexity.ai/docs/model-cards
defaultModel: 'sonar-medium-chat',

// Price per 1k tokens [input, output].
// Source: https://docs.perplexity.ai/docs/model-cards
// Source: https://docs.perplexity.ai/docs/pricing
pricing: {
modelPricing: {
'sonar-small-chat': { inputTokensCost: 0.2 / 1000, outputTokensCost: 0.2 / 1000 },
'sonar-medium-chat': { inputTokensCost: 0.6 / 1000, outputTokensCost: 0.6 / 1000 },
'sonar-small-online': {
Expand All @@ -30,6 +33,15 @@ const Perplexity: Provider = {
'mixtral-8x7b-instruct': { inputTokensCost: 0.6 / 1000, outputTokensCost: 0.6 / 1000 },
},

modelAliases: {
online: 'sonar-medium-online',
codellama: 'codellama-70b-instruct',
mistral: 'mistral-7b-instruct',
mixtral: 'mixtral-8x7b-instruct',
},

skipSystemPrompt: ['sonar-small-online', 'sonar-medium-online'],

getChatCompletion: async (config: ProviderConfig, messages: Message[]) => {
const api = new OpenAI({
apiKey: config.apiKey,
Expand Down
7 changes: 6 additions & 1 deletion src/engine/providers/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ export interface Provider {
name: ProviderName;
label: string;
apiKeyUrl: string;
pricing: Record<string, ModelPricing>;

defaultModel: string;
modelPricing: Record<string, ModelPricing>;
modelAliases: Record<string, string>;

skipSystemPrompt?: string[];

getChatCompletion: (config: ProviderConfig, messages: Message[]) => Promise<ModelResponse>;
getChatCompletionStream?: (
Expand Down
26 changes: 15 additions & 11 deletions src/engine/providers/utils/open-ai-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@ export async function getChatCompletion(
config: ProviderConfig,
messages: Message[],
): Promise<ModelResponse> {
const systemMessage: Message = {
role: 'system',
content: config.systemPrompt,
};
const allMessages = getMessages(config, messages);

const startTime = performance.now();
const response = await api.chat.completions.create({
messages: [systemMessage, ...messages],
messages: allMessages,
model: config.model,
...responseStyles[config.responseStyle],
});
Expand Down Expand Up @@ -53,12 +50,7 @@ export async function getChatCompletionStream(
messages: Message[],
onResponseUpdate: (response: ModelResponseUpdate) => void,
): Promise<ModelResponse> {
const systemMessage: Message = {
role: 'system',
content: config.systemPrompt,
};

const allMessages = [systemMessage, ...messages];
const allMessages = getMessages(config, messages);

const startTime = performance.now();
const stream = await api.chat.completions.create({
Expand Down Expand Up @@ -96,3 +88,15 @@ export async function getChatCompletionStream(
data: lastChunk,
};
}

function getMessages(config: ProviderConfig, messages: Message[]): Message[] {
if (!config.systemPrompt) {
return messages;
}

const systemMessage: Message = {
role: 'system',
content: config.systemPrompt,
};
return [systemMessage, ...messages];
}

0 comments on commit 1592834

Please sign in to comment.