Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gemini thinking model support, with a default of gemini-2.0-flash-thinking-exp-01-21 #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ flowchart TB
- API keys for:
- Firecrawl API (for web search and content extraction)
- OpenAI API (for o3 mini model)
- Gemini API (for gemini thinking model)

## Setup

Expand All @@ -100,9 +101,11 @@ FIRECRAWL_KEY="your_firecrawl_key"
# FIRECRAWL_BASE_URL="http://localhost:3002"

OPENAI_KEY="your_openai_key"
GOOGLE_API_KEY="your_google_key" # Required for gemini thinking model
```

To use local LLM, comment out `OPENAI_KEY` and instead uncomment `OPENAI_ENDPOINT` and `OPENAI_MODEL`:

- Set `OPENAI_ENDPOINT` to the address of your local server (eg."http://localhost:1234/v1")
- Set `OPENAI_MODEL` to the name of the model loaded in your local server.

Expand All @@ -127,10 +130,13 @@ npm start

You'll be prompted to:

1. Enter your research query
2. Specify research breadth (recommended: 3-10, default: 4)
3. Specify research depth (recommended: 1-5, default: 2)
4. Answer follow-up questions to refine the research direction
1. Select an AI model:
- OpenAI (default) - Uses o3 mini model (Requires OPENAI_KEY)
- Gemini - Uses gemini thinking model (Requires GOOGLE_API_KEY)
2. Enter your research query
3. Specify research breadth (recommended: 3-10, default: 4)
4. Specify research depth (recommended: 1-5, default: 2)
5. Answer follow-up questions to refine the research direction

The system will then:

Expand All @@ -149,7 +155,7 @@ If you have a free version, you may sometime run into rate limit errors, you can

### Custom endpoints and models

There are 2 other optional env vars that lets you tweak the endpoint (for other OpenAI compatible APIs like OpenRouter or Gemini) as well as the model string.
There are 2 optional env vars that let you tweak the endpoint (for other OpenAI compatible APIs like OpenRouter) as well as the model string. Note that these are only applicable when using OpenAI models.

```bash
OPENAI_ENDPOINT="custom_endpoint"
Expand Down
10 changes: 10 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
},
"dependencies": {
"@ai-sdk/openai": "^1.1.9",
"@google/generative-ai": "^0.21.0",
"@mendable/firecrawl-js": "^1.16.0",
"ai": "^4.1.17",
"js-tiktoken": "^1.0.17",
Expand Down
23 changes: 23 additions & 0 deletions src/ai/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { z } from 'zod';
import { modelProvider } from './providers';

export type GenerateObjectParams<T extends z.ZodType> = {
system: string;
prompt: string;
schema: T;
abortSignal?: AbortSignal;
model?: any;
provider?: any;
};

export async function generateObject<T extends z.ZodType>(
params: GenerateObjectParams<T>
): Promise<{ object: z.infer<T> }> {
const enhancedParams = {
...params,
provider: params.provider || { id: 'openai' },
};
return modelProvider.getCurrentProvider().generateObject(enhancedParams);
}

export { modelProvider, trimPrompt } from './providers';
273 changes: 260 additions & 13 deletions src/ai/providers.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,277 @@
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai';
import { getEncoding } from 'js-tiktoken';

import { z } from 'zod';
import {
GoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
} from '@google/generative-ai';
import { RecursiveCharacterTextSplitter } from './text-splitter';

interface CustomOpenAIProviderSettings extends OpenAIProviderSettings {
baseURL?: string;
}

// Providers
const openai = createOpenAI({
apiKey: process.env.OPENAI_KEY!,
baseURL: process.env.OPENAI_ENDPOINT || 'https://api.openai.com/v1',
} as CustomOpenAIProviderSettings);
// Provider Interface
export interface AIProvider {
generateObject<T extends z.ZodType>(params: {
system: string;
prompt: string;
schema: T;
abortSignal?: AbortSignal;
provider?: any;
}): Promise<{ object: z.infer<T> }>;
}

// OpenAI Provider Implementation
class OpenAIProvider implements AIProvider {
private static openai = createOpenAI({
apiKey: process.env.OPENAI_KEY!,
baseURL: process.env.OPENAI_ENDPOINT || 'https://api.openai.com/v1',
} as CustomOpenAIProviderSettings);

const customModel = process.env.OPENAI_MODEL || 'o3-mini';
private static model = OpenAIProvider.openai(
process.env.OPENAI_MODEL || 'o3-mini',
{
reasoningEffort: (process.env.OPENAI_MODEL || 'o3-mini').startsWith('o')
? 'medium'
: undefined,
structuredOutputs: true,
}
) as any;

// Models
async generateObject<T extends z.ZodType>(params: {
system: string;
prompt: string;
schema: T;
abortSignal?: AbortSignal;
provider?: any;
}): Promise<{ object: z.infer<T> }> {
try {
const { provider: _, ...rest } = params;
return await OpenAIProvider.model(rest);
} catch (error) {
throw new Error(`OpenAI model error: ${error}`);
}
}
}

export const o3MiniModel = openai(customModel, {
reasoningEffort: customModel.startsWith('o') ? 'medium' : undefined,
structuredOutputs: true,
});
// Gemini Provider Implementation
class GeminiProvider implements AIProvider {
private model: ReturnType<GoogleGenerativeAI['getGenerativeModel']>;

constructor(apiKey: string) {
const googleAI = new GoogleGenerativeAI(apiKey);
this.model = googleAI.getGenerativeModel({
model: process.env.GEMINI_MODEL || 'gemini-2.0-flash-thinking-exp-01-21',
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
],
});
}

private generateExampleForSchema(schema: z.ZodType): string {
if (!(schema instanceof z.ZodObject)) return "{}";

const shape = schema._def.shape();
const example: any = {};

for (const [key, value] of Object.entries(shape)) {
if (value instanceof z.ZodArray) {
const elementType = value._def.type;
if (elementType instanceof z.ZodObject) {
// For arrays of objects, create a basic example
example[key] = [{
query: "example query",
researchGoal: "example goal"
}];
} else {
// For simple arrays (questions, learnings, etc)
example[key] = ["first example", "second example"];
}
} else if (value instanceof z.ZodString) {
example[key] = "sample content";
}
}

return JSON.stringify(example, null, 2);
}

async generateObject<T extends z.ZodType>(params: {
system: string;
prompt: string;
schema: T;
abortSignal?: AbortSignal;
provider?: any;
}): Promise<{ object: z.infer<T> }> {
// For report generation, use tagged format instead of JSON
const isReportGeneration = Object.keys((params.schema as any)._def.shape() || {}).includes('reportMarkdown');

const fullPrompt = isReportGeneration
? `${params.system}

Follow these steps:

1. First, think about this request and analyze it carefully:
${params.prompt}

2. Format your response using the tags shown below.
The reasoning section helps organize your thoughts.
The report section contains the actual content.

Format your response EXACTLY like this, using these exact tags:

<REASONING>
Step 1: [Initial analysis]
Step 2: [Key points considered]
Step 3: [Final reasoning]
</REASONING>

<REPORT>
[Your markdown report here]
</REPORT>

IMPORTANT INSTRUCTIONS:
- The <REPORT> section must contain properly formatted markdown content
- Format the report professionally with clear sections and headings
- Include all relevant information from the research
- Make the report as detailed and comprehensive as possible`
: `${params.system}

Follow these steps:

1. First, think about this request and analyze it carefully:
${params.prompt}

2. Format your complete response in TWO parts as shown below.
The reasoning section helps organize your thoughts.
The JSON section MUST match the schema structure EXACTLY.

Format your response EXACTLY like this, using these exact tags:

<REASONING>
Step 1: [Initial analysis]
Step 2: [Key points considered]
Step 3: [Final reasoning]
</REASONING>

<JSON>
${this.generateExampleForSchema(params.schema)}
</JSON>

IMPORTANT INSTRUCTIONS:
- The <JSON> section must contain ONLY valid JSON
- Your response must match this exact structure with fields: ${Object.keys((params.schema as any)._def.shape() || {}).join(', ')}
- Follow the example format precisely, replacing example values with real content
- No additional fields or different structure allowed
- No markdown formatting or code blocks
- Ensure all JSON syntax is valid`;

const result = await this.model.generateContent(fullPrompt);
const response = result.response;
const content = response.candidates?.[0]?.content?.parts?.[0]?.text;

if (!content) {
throw new Error('No content generated from Gemini');
}

try {
const isReportGeneration = Object.keys((params.schema as any)._def.shape() || {}).includes('reportMarkdown');

if (isReportGeneration) {
// Extract report content for report generation
const reportMatch = content?.match(/<REPORT>\s*([\s\S]*?)\s*<\/REPORT>/);
if (!reportMatch?.[1]) {
console.error('Full Gemini response:', content);
throw new Error('Could not extract valid report content');
}

const reportContent = reportMatch[1].trim();
return {
object: {
reportMarkdown: reportContent
}
} as any;
} else {
// Handle regular JSON responses
const jsonMatch = content?.match(/<JSON>\s*([\s\S]*?)\s*<\/JSON>/);
if (!jsonMatch?.[1]) {
console.error('Full Gemini response:', content);
throw new Error('Could not extract valid JSON section from response');
}

// Extract and clean up JSON content
let jsonContent = jsonMatch[1]
.trim()
// Remove code block markers
.replace(/```[a-z]*\n?/g, '')
// Clean up structural issues
.replace(/,(\s*[}\]])/g, '$1');

const parsedContent = JSON.parse(jsonContent);
const validatedObject = params.schema.parse(parsedContent);
return { object: validatedObject };
}
} catch (error) {
console.error('Full Gemini response:', content); // For debugging
throw new Error(`Failed to parse Gemini response: ${error}`);
}
}
}

class ModelProvider {
private static instance: ModelProvider;
private currentProvider: AIProvider;

private constructor() {
this.currentProvider = new OpenAIProvider();
}

static getInstance(): ModelProvider {
if (!ModelProvider.instance) {
ModelProvider.instance = new ModelProvider();
}
return ModelProvider.instance;
}

setProvider(type: 'openai' | 'gemini'): void {
if (type === 'openai') {
this.currentProvider = new OpenAIProvider();
} else {
const apiKey = process.env.GOOGLE_API_KEY;
if (!apiKey) {
throw new Error('GOOGLE_API_KEY environment variable is required for Gemini');
}
this.currentProvider = new GeminiProvider(apiKey);
}
}

getCurrentProvider(): AIProvider {
return this.currentProvider;
}
}

export const modelProvider = ModelProvider.getInstance();

const MinChunkSize = 140;
const encoder = getEncoding('o200k_base');

const MinChunkSize = 140;

// trim prompt to maximum context size
export function trimPrompt(
prompt: string,
Expand Down
Loading