diff --git a/__tests__/image-desc.test.ts b/__tests__/image-desc.test.ts new file mode 100644 index 0000000..2dcb6ae --- /dev/null +++ b/__tests__/image-desc.test.ts @@ -0,0 +1,253 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { ModelType, type Plugin } from '@elizaos/core'; +import { logger } from '@elizaos/core'; +import type { + Florence2ForConditionalGeneration, + Florence2Processor, + ModelOutput, + PreTrainedTokenizer, +} from '@huggingface/transformers'; +import { beforeAll, describe, expect, test, vi } from 'vitest'; +import { TEST_PATHS, createMockRuntime } from './test-utils'; + +// Mock the transformers library +vi.mock('@huggingface/transformers', () => { + logger.info('Setting up transformers mock'); + return { + env: { + allowLocalModels: false, + allowRemoteModels: true, + backends: { + onnx: { + logLevel: 'fatal', + }, + }, + }, + Florence2ForConditionalGeneration: { + from_pretrained: vi.fn().mockImplementation(async () => { + logger.info('Creating mock Florence2ForConditionalGeneration model'); + const mockModel = { + generate: async () => { + logger.info('Generating vision model output'); + return new Int32Array([1, 2, 3, 4, 5]); // Mock token IDs + }, + _merge_input_ids_with_image_features: vi.fn(), + _prepare_inputs_embeds: vi.fn(), + forward: vi.fn(), + main_input_name: 'pixel_values', + }; + return mockModel as unknown as Florence2ForConditionalGeneration; + }), + }, + AutoProcessor: { + from_pretrained: vi.fn().mockImplementation(async () => { + logger.info('Creating mock Florence2Processor'); + const mockProcessor = { + __call__: async () => ({ pixel_values: new Float32Array(10) }), + construct_prompts: () => [''], + post_process_generation: () => ({ + '': 'A detailed description of the test image.', + }), + tasks_answer_post_processing_type: 'string', + task_prompts_without_inputs: [], + task_prompts_with_input: [], + regexes: {}, + }; + return mockProcessor as unknown as Florence2Processor; + }), + }, + AutoTokenizer: { + from_pretrained: vi.fn().mockImplementation(async () => { + logger.info('Creating mock tokenizer'); + const mockTokenizer = { + __call__: async () => ({ input_ids: new Int32Array(5) }), + batch_decode: () => ['A detailed caption of the image.'], + encode: async () => new Int32Array(5), + decode: async () => 'Decoded text', + return_token_type_ids: true, + padding_side: 'right', + _tokenizer_config: {}, + normalizer: {}, + }; + return mockTokenizer as unknown as PreTrainedTokenizer; + }), + }, + RawImage: { + fromBlob: vi.fn().mockImplementation(async () => ({ + size: { width: 640, height: 480 }, + })), + }, + }; +}); + +// Set environment variables before importing the plugin +process.env.MODELS_DIR = TEST_PATHS.MODELS_DIR; +process.env.CACHE_DIR = TEST_PATHS.CACHE_DIR; + +// Import plugin after setting environment variables and mocks +import { localAiPlugin } from '../src/index'; + +// Type assertion for localAIPlugin +const plugin = localAiPlugin as Required; + +describe('LocalAI Image Description', () => { + const mockRuntime = createMockRuntime(); + + beforeAll(async () => { + logger.info('Starting image description test setup', { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + process_cwd: process.cwd(), + }); + + // Create necessary directories + const visionCacheDir = path.join(TEST_PATHS.CACHE_DIR, 'vision'); + if (!fs.existsSync(visionCacheDir)) { + logger.info('Creating vision cache directory:', visionCacheDir); + fs.mkdirSync(visionCacheDir, { recursive: true }); + } + + await plugin.init( + { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + }, + mockRuntime + ); + + logger.success('Test setup completed'); + }, 300000); + + test('should describe image from URL successfully', async () => { + logger.info('Starting successful image description test'); + + // Using a reliable test image URL + const imageUrl = 'https://picsum.photos/200/300'; + logger.info('Testing with image URL:', imageUrl); + + try { + const result = await mockRuntime.useModel(ModelType.IMAGE_DESCRIPTION, imageUrl); + + // if result is not an object, throw an error + if (typeof result !== 'object') { + throw new Error('Result is not an object'); + } + + logger.info('Image description result:', { + resultType: typeof result, + resultLength: result.description.length, + rawResult: result, + }); + + expect(result).toBeDefined(); + const parsed = result; + logger.info('Parsed result:', parsed); + + expect(parsed).toHaveProperty('title'); + expect(parsed).toHaveProperty('description'); + expect(typeof parsed.title).toBe('string'); + expect(typeof parsed.description).toBe('string'); + logger.success('Successful image description test completed', { + title: parsed.title, + descriptionLength: parsed.description.length, + }); + } catch (error) { + logger.error('Image description test failed:', { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + imageUrl, + }); + throw error; + } + }); + + test('should handle invalid image URL', async () => { + logger.info('Starting invalid URL test'); + const invalidUrl = 'https://picsum.photos/invalid/image.jpg'; + logger.info('Testing with invalid URL:', invalidUrl); + + try { + await mockRuntime.useModel(ModelType.IMAGE_DESCRIPTION, invalidUrl); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Invalid URL test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + errorType: error.constructor.name, + stack: error instanceof Error ? error.stack : undefined, + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Failed to fetch image'); + } + }); + + test('should handle non-string input', async () => { + logger.info('Starting non-string input test'); + const invalidInput = { url: 'not-a-string' }; + + try { + await mockRuntime.useModel(ModelType.IMAGE_DESCRIPTION, invalidInput as unknown); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Non-string input test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Invalid image URL'); + } + }); + + test('should handle vision model failure', async () => { + logger.info('Starting vision model failure test'); + + // Use a working URL for this test + const imageUrl = 'https://picsum.photos/200/300'; + logger.info('Testing with image URL:', imageUrl); + + // Mock the vision model to fail + const { Florence2ForConditionalGeneration } = await import('@huggingface/transformers'); + const modelMock = vi.mocked(Florence2ForConditionalGeneration); + + // Save the original implementation + const originalImpl = modelMock.from_pretrained; + + // Mock the implementation to fail + modelMock.from_pretrained.mockImplementationOnce(async () => { + logger.info('Simulating vision model failure'); + throw new Error('Vision model failed to load'); + }); + + try { + await mockRuntime.useModel(ModelType.IMAGE_DESCRIPTION, imageUrl); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Vision model failure test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + errorType: error.constructor.name, + stack: error instanceof Error ? error.stack : undefined, + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Vision model failed'); + } finally { + // Restore the original implementation + modelMock.from_pretrained = originalImpl; + } + }); + + test('should handle non-image content type', async () => { + logger.info('Starting non-image content test'); + const textUrl = 'https://raw.githubusercontent.com/microsoft/FLAML/main/README.md'; + + try { + await mockRuntime.useModel(ModelType.IMAGE_DESCRIPTION, textUrl); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Non-image content test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + // The error message might vary depending on how we want to handle this case + expect(error.message).toBeDefined(); + } + }); +}); diff --git a/__tests__/initialization.test.ts b/__tests__/initialization.test.ts new file mode 100644 index 0000000..8fde690 --- /dev/null +++ b/__tests__/initialization.test.ts @@ -0,0 +1,39 @@ +import { ModelType, type ModelTypeName } from '@elizaos/core'; +import { describe, expect, test } from 'vitest'; +import { localAiPlugin } from '../src/index'; + +describe('LocalAI Plugin Initialization', () => { + // Mock runtime for testing + const mockRuntime = { + useModel: async (modelType: ModelTypeName, _params: any) => { + if (modelType === ModelType.TEXT_SMALL) { + return 'Initialization successful'; + } + throw new Error(`Unexpected model class: ${modelType}`); + }, + }; + + test('should initialize plugin with default configuration', async () => { + try { + if (!localAiPlugin.init) { + throw new Error('Plugin initialization failed'); + } + // Initialize plugin + await localAiPlugin.init({}, mockRuntime as any); + + // Run initialization test + const result = await mockRuntime.useModel(ModelType.TEXT_SMALL, { + context: + "Debug Mode: Test initialization. Respond with 'Initialization successful' if you can read this.", + stopSequences: [], + }); + + expect(result).toBeDefined(); + expect(typeof result).toBe('string'); + expect(result).toContain('successful'); + } catch (error) { + console.error('Test failed:', error); + throw error; + } + }); +}); diff --git a/__tests__/test-utils.ts b/__tests__/test-utils.ts new file mode 100644 index 0000000..0e4d488 --- /dev/null +++ b/__tests__/test-utils.ts @@ -0,0 +1,336 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { Readable } from 'node:stream'; +import { + type Agent, + type Character, + type IAgentRuntime, + type ModelResultMap, + type ModelTypeName, + ModelType, + type State, + type UUID, + logger, +} from '@elizaos/core'; +import { vi } from 'vitest'; +import { MODEL_SPECS } from '../src/types'; + +// Get the workspace root by going up from the current file location +const WORKSPACE_ROOT = path.resolve(__dirname, '../../../'); + +// During tests, we need to set cwd to agent directory since that's where the plugin runs from in production +const AGENT_DIR = path.join(WORKSPACE_ROOT, 'packages/project-starter'); +process.chdir(AGENT_DIR); + +// Create shared mock for download manager +export const downloadModelMock = vi.fn().mockResolvedValue(undefined); + +export const TEST_PATHS = { + MODELS_DIR: path.join(AGENT_DIR, 'models'), + CACHE_DIR: path.join(AGENT_DIR, 'cache'), +} as const; + +export const createMockRuntime = (): IAgentRuntime => ({ + agentId: '12345678-1234-1234-1234-123456789012', + character: {} as Character, + providers: [], + actions: [], + evaluators: [], + plugins: [], + fetch: null, + routes: [], + getService: () => null, + getAllServices: () => new Map(), + initialize: async () => {}, + registerService: () => {}, + setSetting: () => {}, + getSetting: () => null, + getConversationLength: () => 0, + processActions: async () => {}, + evaluate: async () => null, + registerProvider: () => {}, + registerAction: () => {}, + ensureConnection: async () => {}, + ensureParticipantInRoom: async () => {}, + ensureRoomExists: async () => {}, + composeState: async () => ({}) as State, + useModel: async ( + modelType: T, + params: any + ): Promise => { + // Check if there are any pending mock rejections + const mockCalls = downloadModelMock.mock.calls; + if ( + mockCalls.length > 0 && + downloadModelMock.mock.results[mockCalls.length - 1].type === 'throw' + ) { + // Rethrow the error from the mock + throw downloadModelMock.mock.results[mockCalls.length - 1].value; + } + + // Call downloadModel based on the model class + if (modelType === ModelType.TEXT_SMALL) { + await downloadModelMock( + MODEL_SPECS.small, + path.join(TEST_PATHS.MODELS_DIR, MODEL_SPECS.small.name) + ); + return 'The small language model generated this response.' as R; + } + if (modelType === ModelType.TEXT_LARGE) { + await downloadModelMock( + MODEL_SPECS.medium, + path.join(TEST_PATHS.MODELS_DIR, MODEL_SPECS.medium.name) + ); + return 'Artificial intelligence is a transformative technology that continues to evolve.' as R; + } + if (modelType === ModelType.TRANSCRIPTION) { + // For transcription, we expect a Buffer as the parameter + const audioBuffer = params as unknown as Buffer; + if (!Buffer.isBuffer(audioBuffer)) { + throw new Error('Invalid audio buffer'); + } + if (audioBuffer.length === 0) { + throw new Error('Empty audio buffer'); + } + + // Mock the transcription process + const { nodewhisper } = await import('nodejs-whisper'); + const { exec } = await import('node:child_process'); + + // Create a temporary file path for testing + const tempPath = path.join(TEST_PATHS.CACHE_DIR, 'whisper', `temp_${Date.now()}.wav`); + + // Mock the file system operations + if (!fs.existsSync(path.dirname(tempPath))) { + fs.mkdirSync(path.dirname(tempPath), { recursive: true }); + } + fs.writeFileSync(tempPath, audioBuffer); + + try { + // Call the mocked exec for audio conversion + await new Promise((resolve, reject) => { + exec( + `ffmpeg -y -i "${tempPath}" -acodec pcm_s16le -ar 16000 -ac 1 "${tempPath}"`, + (error, stdout, stderr) => { + if (error) reject(error); + else resolve({ stdout, stderr }); + } + ); + }); + + // Call the mocked whisper for transcription + const result = await nodewhisper(tempPath, { + modelName: 'base.en', + autoDownloadModelName: 'base.en', + }); + + // Clean up + if (fs.existsSync(tempPath)) { + fs.unlinkSync(tempPath); + } + + return result as R; + } catch (error) { + // Clean up on error + if (fs.existsSync(tempPath)) { + fs.unlinkSync(tempPath); + } + throw error; + } + } + if (modelType === ModelType.IMAGE_DESCRIPTION) { + // For image description, we expect a URL as the parameter + const imageUrl = params as unknown as string; + if (typeof imageUrl !== 'string') { + throw new Error('Invalid image URL'); + } + + try { + logger.info('Attempting to fetch image:', imageUrl); + + // Mock the fetch and vision processing + const response = await fetch(imageUrl); + logger.info('Fetch response:', { + status: response.status, + statusText: response.statusText, + contentType: response.headers.get('content-type'), + ok: response.ok, + }); + + if (!response.ok) { + const error = new Error(`Failed to fetch image: ${response.statusText}`); + logger.error('Fetch failed:', { + error: error.message, + status: response.status, + statusText: response.statusText, + }); + throw error; + } + + // Import and initialize vision model + const { Florence2ForConditionalGeneration } = await import('@huggingface/transformers'); + try { + await Florence2ForConditionalGeneration.from_pretrained('mock-model'); + } catch (error) { + logger.error('Vision model initialization failed:', error); + throw new Error('Vision model failed to load'); + } + + // For successful responses, return mock description + const mockResult = { + title: 'A test image from Picsum', + description: + 'This is a detailed description of a randomly generated test image from Picsum Photos, showing various visual elements in high quality.', + }; + logger.info('Generated mock description:', mockResult); + + return mockResult as R; + } catch (error) { + logger.error('Image description failed:', { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + imageUrl, + }); + throw error; + } + } + if (modelType === ModelType.TEXT_TO_SPEECH) { + // For TTS, we expect a string as the parameter + const text = params as unknown as string; + if (typeof text !== 'string') { + throw new Error('invalid input: expected string'); + } + if (text.length === 0) { + throw new Error('empty text input'); + } + + try { + logger.info('Processing TTS request:', { textLength: text.length }); + + // Get the mock implementation to check for errors + const { getLlama } = await import('node-llama-cpp'); + const llamaMock = vi.mocked(getLlama); + + // Call getLlama to trigger any mock rejections + // We don't need to pass actual arguments since we're just testing error handling + await llamaMock('lastBuild'); + + // Create a mock audio stream + const mockAudioStream = new Readable({ + read() { + // Push some mock audio data + this.push( + Buffer.from([ + 0x52, + 0x49, + 0x46, + 0x46, // "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, // Chunk size + 0x57, + 0x41, + 0x56, + 0x45, // "WAVE" + 0x66, + 0x6d, + 0x74, + 0x20, // "fmt " + ]) + ); + this.push(null); // End of stream + }, + }); + + logger.success('TTS generation successful'); + return mockAudioStream as R; + } catch (error) { + logger.error('TTS generation failed:', { + error: error instanceof Error ? error.message : String(error), + textLength: text.length, + }); + throw error; + } + } + throw new Error(`Unexpected model class: ${modelType}`); + }, + registerModel: () => {}, + getModel: () => undefined, + registerEvent: () => {}, + getEvent: () => undefined, + emitEvent: () => Promise.resolve(), + createTask: () => Promise.resolve('12345678-1234-1234-1234-123456789012'), + getTasks: () => Promise.resolve([]), + getTask: () => Promise.resolve(null), + updateTask: () => Promise.resolve(), + deleteTask: () => Promise.resolve(), + stop: async () => {}, + services: new Map(), + events: new Map(), + registerPlugin: () => Promise.resolve(), + getKnowledge: () => Promise.resolve([]), + addKnowledge: () => Promise.resolve(), + registerDatabaseAdapter: () => {}, + registerEvaluator: () => {}, + ensureWorldExists: () => Promise.resolve(), + registerTaskWorker: () => {}, + getTaskWorker: () => undefined, + db: undefined, + init: () => Promise.resolve(), + close: () => Promise.resolve(), + getAgent: () => Promise.resolve(null), + getAgents: () => Promise.resolve([]), + createAgent: () => Promise.resolve(false), + updateAgent: (agentId: UUID, agent: Partial) => Promise.resolve(false), + deleteAgent: (agentId: UUID) => Promise.resolve(false), + ensureAgentExists: (agent: Partial) => Promise.resolve(), + ensureEmbeddingDimension: (dimension: number) => Promise.resolve(), + getEntityById: (entityId: UUID) => Promise.resolve(null), + getEntitiesForRoom: () => Promise.resolve([]), + createEntity: () => Promise.resolve(false), + updateEntity: () => Promise.resolve(), + getComponent: () => Promise.resolve(null), + getComponents: () => Promise.resolve([]), + createComponent: () => Promise.resolve(false), + updateComponent: () => Promise.resolve(), + deleteComponent: () => Promise.resolve(), + getMemories: () => Promise.resolve([]), + getMemoryById: () => Promise.resolve(null), + getMemoriesByIds: () => Promise.resolve([]), + getMemoriesByRoomIds: () => Promise.resolve([]), + getCachedEmbeddings: () => Promise.resolve([]), + log: () => Promise.resolve(), + searchMemories: () => Promise.resolve([]), + createMemory: () => Promise.resolve('12345678-1234-1234-1234-123456789012'), + deleteMemory: () => Promise.resolve(), + deleteAllMemories: () => Promise.resolve(), + countMemories: () => Promise.resolve(0), + createWorld: () => Promise.resolve('12345678-1234-1234-1234-123456789012'), + getWorld: () => Promise.resolve(null), + getAllWorlds: () => Promise.resolve([]), + updateWorld: () => Promise.resolve(), + getRoom: () => Promise.resolve(null), + createRoom: () => Promise.resolve('12345678-1234-1234-1234-123456789012'), + deleteRoom: () => Promise.resolve(), + updateRoom: () => Promise.resolve(), + getRoomsForParticipant: () => Promise.resolve([]), + getRoomsForParticipants: () => Promise.resolve([]), + getRooms: () => Promise.resolve([]), + addParticipant: () => Promise.resolve(false), + removeParticipant: () => Promise.resolve(false), + getParticipantsForEntity: () => Promise.resolve([]), + getParticipantsForRoom: () => Promise.resolve([]), + getParticipantUserState: () => Promise.resolve(null), + setParticipantUserState: () => Promise.resolve(), + createRelationship: () => Promise.resolve(false), + updateRelationship: () => Promise.resolve(), + getRelationship: () => Promise.resolve(null), + getRelationships: () => Promise.resolve([]), + getCache: () => Promise.resolve(undefined), + setCache: () => Promise.resolve(false), + deleteCache: () => Promise.resolve(false), + getTasksByName: () => Promise.resolve([]), + getLogs: () => Promise.resolve([]), + deleteLog: () => Promise.resolve(), +}); diff --git a/__tests__/text-gen.test.ts b/__tests__/text-gen.test.ts new file mode 100644 index 0000000..b7f2588 --- /dev/null +++ b/__tests__/text-gen.test.ts @@ -0,0 +1,166 @@ +import { type IAgentRuntime, ModelType, type Plugin, logger } from '@elizaos/core'; +import { beforeAll, beforeEach, describe, expect, test, vi } from 'vitest'; +import { MODEL_SPECS, type ModelSpec } from '../src/types'; +import { TEST_PATHS, createMockRuntime, downloadModelMock } from './test-utils'; + +// Set environment variables before importing the plugin +process.env.MODELS_DIR = TEST_PATHS.MODELS_DIR; +process.env.CACHE_DIR = TEST_PATHS.CACHE_DIR; + +// Mock the model download and initialization +vi.mock('../src/utils/downloadManager', () => ({ + DownloadManager: { + getInstance: () => ({ + downloadModel: async (modelSpec: ModelSpec, modelPath: string) => { + // Call the mock to track the call + await downloadModelMock(modelSpec, modelPath); + }, + getCacheDir: () => TEST_PATHS.CACHE_DIR, + ensureDirectoryExists: vi.fn(), + }), + }, +})); + +// Import plugin after setting environment variables and mocks +import { localAiPlugin } from '../src/index'; + +// Type assertion for localAIPlugin +const plugin = localAiPlugin as Required; + +describe('LocalAI Text Generation', () => { + const mockRuntime = createMockRuntime(); + + beforeEach(() => { + // Clear mock calls before each test + downloadModelMock.mockClear(); + }); + + beforeAll(async () => { + // Log the paths we're trying to use + logger.info('Test is using paths:', { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + process_cwd: process.cwd(), + }); + + // Initialize plugin with the same paths + await plugin.init( + { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + }, + mockRuntime as IAgentRuntime + ); + + // Log environment variables after initialization + logger.info('Environment variables after init:', { + MODELS_DIR: process.env.MODELS_DIR, + CACHE_DIR: process.env.CACHE_DIR, + }); + }, 300000); + + test('should attempt to download small model when using TEXT_SMALL', async () => { + const result = await mockRuntime.useModel(ModelType.TEXT_SMALL, { + context: 'Generate a test response.', + stopSequences: [], + runtime: mockRuntime, + modelClass: ModelType.TEXT_SMALL, + }); + + expect(downloadModelMock).toHaveBeenCalledTimes(1); + expect(downloadModelMock.mock.calls[0][0]).toMatchObject({ + name: MODEL_SPECS.small.name, + }); + expect(result).toBeDefined(); + expect(typeof result).toBe('string'); + expect(result.length).toBeGreaterThan(0); + }); + + test('should attempt to download large model when using TEXT_LARGE', async () => { + const result = await mockRuntime.useModel(ModelType.TEXT_LARGE, { + context: 'Debug Mode: Generate a one-sentence response about artificial intelligence.', + stopSequences: [], + runtime: mockRuntime, + modelClass: ModelType.TEXT_LARGE, + }); + + expect(downloadModelMock).toHaveBeenCalledTimes(1); + expect(downloadModelMock.mock.calls[0][0]).toMatchObject({ + name: MODEL_SPECS.medium.name, + }); + expect(result).toBeDefined(); + expect(typeof result).toBe('string'); + expect(result.length).toBeGreaterThan(10); + expect(result.toLowerCase()).toContain('artificial intelligence'); + }); + + test('should handle download failure gracefully', async () => { + // Mock a download failure + downloadModelMock.mockRejectedValueOnce(new Error('Download failed')); + + await expect( + mockRuntime.useModel(ModelType.TEXT_SMALL, { + context: 'This should fail due to download error', + stopSequences: [], + runtime: mockRuntime, + modelClass: ModelType.TEXT_SMALL, + }) + ).rejects.toThrow('Download failed'); + + expect(downloadModelMock).toHaveBeenCalledTimes(1); + }); + + test('should handle empty context', async () => { + await expect( + mockRuntime.useModel(ModelType.TEXT_SMALL, { + context: '', + stopSequences: [], + runtime: mockRuntime, + modelClass: ModelType.TEXT_SMALL, + }) + ).resolves.toBeDefined(); + }); + + test('should handle stop sequences', async () => { + const result = await mockRuntime.useModel(ModelType.TEXT_SMALL, { + context: 'Generate a response with stop sequence.', + stopSequences: ['STOP'], + runtime: mockRuntime, + modelClass: ModelType.TEXT_SMALL, + }); + + expect(result).toBeDefined(); + expect(typeof result).toBe('string'); + }); + + test('should handle model switching', async () => { + // First use TEXT_SMALL + const smallResult = await mockRuntime.useModel(ModelType.TEXT_SMALL, { + context: 'Small model test', + stopSequences: [], + runtime: mockRuntime, + modelClass: ModelType.TEXT_SMALL, + }); + + // Then use TEXT_LARGE + const largeResult = await mockRuntime.useModel(ModelType.TEXT_LARGE, { + context: 'Large model test', + stopSequences: [], + runtime: mockRuntime, + modelClass: ModelType.TEXT_LARGE, + }); + + expect(smallResult).toBeDefined(); + expect(largeResult).toBeDefined(); + expect(smallResult).not.toBe(largeResult); + + // Verify both models were attempted to be downloaded + expect(downloadModelMock).toHaveBeenCalledTimes(2); + expect(downloadModelMock.mock.calls[0][0]).toMatchObject({ + name: MODEL_SPECS.small.name, + }); + expect(downloadModelMock.mock.calls[1][0]).toMatchObject({ + name: MODEL_SPECS.medium.name, + }); + }); +}); diff --git a/__tests__/text-transcribe.test.ts b/__tests__/text-transcribe.test.ts new file mode 100644 index 0000000..b0d3d67 --- /dev/null +++ b/__tests__/text-transcribe.test.ts @@ -0,0 +1,323 @@ +import type { ChildProcess, ExecException, ExecOptions } from 'node:child_process'; +import fs from 'node:fs'; +import path from 'node:path'; +import { type IAgentRuntime, ModelType, type Plugin, logger } from '@elizaos/core'; +import type { IOptions } from 'nodejs-whisper'; +import { beforeAll, describe, expect, test, vi } from 'vitest'; +import { TEST_PATHS, createMockRuntime } from './test-utils'; + +// Mock the nodewhisper function +vi.mock('nodejs-whisper', () => { + logger.info('Setting up nodewhisper mock'); + return { + nodewhisper: vi.fn().mockImplementation(async (filePath: string, options: IOptions) => { + logger.info('nodewhisper mock called with:', { + filePath, + options, + fileExists: fs.existsSync(filePath), + }); + + // Mock successful transcription + if (fs.existsSync(filePath)) { + logger.success('Mock transcription successful'); + return 'This is a mock transcription of the audio file.'; + } + logger.error('Audio file not found in mock'); + throw new Error('Audio file not found'); + }), + }; +}); + +// Mock the exec function for audio conversion +vi.mock('node:child_process', () => { + logger.info('Setting up child_process exec mock'); + return { + exec: vi + .fn() + .mockImplementation( + ( + command: string, + options: + | ExecOptions + | undefined + | null + | ((error: ExecException | null, stdout: string, stderr: string) => void), + callback?: (error: ExecException | null, stdout: string, stderr: string) => void + ) => { + logger.info('exec mock called with:', { + command, + hasOptions: !!options, + optionsType: typeof options, + hasCallback: !!callback, + }); + + // Handle the case where options is the callback + const actualCallback = callback || (typeof options === 'function' ? options : undefined); + if (actualCallback) { + logger.info('Executing mock ffmpeg conversion'); + actualCallback(null, '', ''); + } + return { kill: () => {}, pid: 123 } as ChildProcess; + } + ), + }; +}); + +// Set environment variables before importing the plugin +process.env.MODELS_DIR = TEST_PATHS.MODELS_DIR; +process.env.CACHE_DIR = TEST_PATHS.CACHE_DIR; + +// Import plugin after setting environment variables and mocks +import { localAiPlugin } from '../src/index'; + +// Type assertion for localAIPlugin +const plugin = localAiPlugin as Required; + +describe('LocalAI Audio Transcription', () => { + const mockRuntime = createMockRuntime(); + + beforeAll(async () => { + logger.info('Starting transcription test setup', { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + process_cwd: process.cwd(), + }); + + // Create necessary directories + const whisperCacheDir = path.join(TEST_PATHS.CACHE_DIR, 'whisper'); + if (!fs.existsSync(whisperCacheDir)) { + logger.info('Creating whisper cache directory:', whisperCacheDir); + fs.mkdirSync(whisperCacheDir, { recursive: true }); + } + + await plugin.init( + { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + }, + mockRuntime as IAgentRuntime + ); + + logger.success('Test setup completed'); + }, 300000); + + test('should transcribe audio buffer successfully', async () => { + logger.info('Starting successful transcription test'); + + // Create a mock audio buffer (WAV header) + const audioBuffer = Buffer.from([ + 0x52, + 0x49, + 0x46, + 0x46, // "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, // Chunk size + 0x57, + 0x41, + 0x56, + 0x45, // "WAVE" + 0x66, + 0x6d, + 0x74, + 0x20, // "fmt " + ]); + + logger.info('Created test audio buffer', { + size: audioBuffer.length, + header: audioBuffer.toString('hex').substring(0, 32), + }); + + const result = await mockRuntime.useModel(ModelType.TRANSCRIPTION, audioBuffer); + + logger.info('Transcription result:', { + result, + type: typeof result, + length: result.length, + }); + + expect(result).toBeDefined(); + expect(typeof result).toBe('string'); + expect(result).toContain('transcription'); + logger.success('Successful transcription test completed'); + }); + + test('should handle empty audio buffer', async () => { + logger.info('Starting empty buffer test'); + const emptyBuffer = Buffer.from([]); + + logger.info('Created empty buffer', { + size: emptyBuffer.length, + }); + + try { + await mockRuntime.useModel(ModelType.TRANSCRIPTION, emptyBuffer); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Empty buffer test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + } + }); + + test('should handle invalid audio format', async () => { + logger.info('Starting invalid format test'); + const invalidBuffer = Buffer.from('not an audio file'); + + logger.info('Created invalid buffer', { + size: invalidBuffer.length, + content: invalidBuffer.toString(), + }); + + try { + await mockRuntime.useModel(ModelType.TRANSCRIPTION, invalidBuffer); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Invalid format test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + } + }); + + test('should handle audio conversion failure', async () => { + logger.info('Starting conversion failure test'); + + const { exec } = await import('node:child_process'); + const execMock = vi.mocked(exec); + + logger.info('Setting up failing exec mock'); + execMock.mockImplementationOnce( + ( + command: string, + options: + | ExecOptions + | undefined + | null + | ((error: ExecException | null, stdout: string, stderr: string) => void), + callback?: (error: ExecException | null, stdout: string, stderr: string) => void + ) => { + logger.info('Failing exec mock called with:', { + command, + hasOptions: !!options, + hasCallback: !!callback, + }); + + const actualCallback = callback || (typeof options === 'function' ? options : undefined); + if (actualCallback) { + const error = new Error('Failed to convert audio') as ExecException; + error.code = 1; + error.killed = false; + logger.info('Simulating ffmpeg failure'); + actualCallback(error, '', ''); + } + return { kill: () => {}, pid: 123 } as ChildProcess; + } + ); + + const audioBuffer = Buffer.from([ + 0x52, + 0x49, + 0x46, + 0x46, // "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, // Chunk size + ]); + + logger.info('Created test audio buffer for conversion failure', { + size: audioBuffer.length, + header: audioBuffer.toString('hex').substring(0, 16), + }); + + try { + await mockRuntime.useModel(ModelType.TRANSCRIPTION, audioBuffer); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Conversion failure test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Failed to convert audio'); + } + }); + + test('should handle whisper model failure', async () => { + logger.info('Starting whisper model failure test'); + + const { nodewhisper } = await import('nodejs-whisper'); + const whisperMock = vi.mocked(nodewhisper); + + logger.info('Setting up failing whisper mock'); + whisperMock.mockRejectedValueOnce(new Error('Whisper model failed')); + + const audioBuffer = Buffer.from([ + 0x52, + 0x49, + 0x46, + 0x46, // "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, // Chunk size + ]); + + logger.info('Created test audio buffer for whisper failure', { + size: audioBuffer.length, + header: audioBuffer.toString('hex').substring(0, 16), + }); + + try { + await mockRuntime.useModel(ModelType.TRANSCRIPTION, audioBuffer); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Whisper failure test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Whisper model failed'); + } + }); + + test('should clean up temporary files after transcription', async () => { + logger.info('Starting cleanup test'); + + const audioBuffer = Buffer.from([ + 0x52, + 0x49, + 0x46, + 0x46, // "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, // Chunk size + 0x57, + 0x41, + 0x56, + 0x45, // "WAVE" + ]); + + logger.info('Created test audio buffer for cleanup test', { + size: audioBuffer.length, + header: audioBuffer.toString('hex').substring(0, 24), + }); + + await mockRuntime.useModel(ModelType.TRANSCRIPTION, audioBuffer); + + // Check that no temporary files are left in the cache directory + const cacheDir = path.join(TEST_PATHS.CACHE_DIR, 'whisper'); + logger.info('Checking cache directory for temp files:', cacheDir); + + const files = fs.readdirSync(cacheDir); + logger.info('Found files in cache:', files); + + const tempFiles = files.filter((f) => f.startsWith('temp_')); + logger.info('Found temporary files:', tempFiles); + + expect(tempFiles.length).toBe(0); + logger.success('Cleanup test completed'); + }); +}); diff --git a/__tests__/tts.test.ts b/__tests__/tts.test.ts new file mode 100644 index 0000000..7eac6f4 --- /dev/null +++ b/__tests__/tts.test.ts @@ -0,0 +1,210 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { Readable } from 'node:stream'; +import { ModelType, type Plugin } from '@elizaos/core'; +import { logger } from '@elizaos/core'; +import type { LlamaContext, LlamaContextSequence, LlamaModel } from 'node-llama-cpp'; +import { beforeAll, describe, expect, test, vi } from 'vitest'; +import { TEST_PATHS, createMockRuntime } from './test-utils'; + +// Mock the node-llama-cpp +vi.mock('node-llama-cpp', () => { + logger.info('Setting up node-llama-cpp mock for TTS'); + return { + getLlama: vi.fn().mockImplementation(async () => { + logger.info('Creating mock Llama instance for TTS'); + return { + loadModel: vi.fn().mockImplementation(async () => { + logger.info('Creating mock LlamaModel for TTS'); + return { + createContext: vi.fn().mockImplementation(async () => { + logger.info('Creating mock LlamaContext for TTS'); + return { + getSequence: vi.fn().mockImplementation(() => { + logger.info('Creating mock LlamaContextSequence for TTS'); + return { + evaluate: vi.fn().mockImplementation(async function* () { + yield 1; + yield 2; + yield 3; + }), + } as unknown as LlamaContextSequence; + }), + } as unknown as LlamaContext; + }), + } as unknown as LlamaModel; + }), + }; + }), + }; +}); + +// Set environment variables before importing the plugin +process.env.MODELS_DIR = TEST_PATHS.MODELS_DIR; +process.env.CACHE_DIR = TEST_PATHS.CACHE_DIR; + +// Import plugin after setting environment variables and mocks +import { localAiPlugin } from '../src/index'; + +// Type assertion for localAIPlugin +const plugin = localAiPlugin as Required; + +describe('LocalAI Text-to-Speech', () => { + const mockRuntime = createMockRuntime(); + + beforeAll(async () => { + logger.info('Starting TTS test setup', { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + process_cwd: process.cwd(), + }); + + // Create necessary directories + const ttsCacheDir = path.join(TEST_PATHS.CACHE_DIR, 'tts'); + if (!fs.existsSync(ttsCacheDir)) { + logger.info('Creating TTS cache directory:', ttsCacheDir); + fs.mkdirSync(ttsCacheDir, { recursive: true }); + } + + await plugin.init( + { + MODELS_DIR: TEST_PATHS.MODELS_DIR, + CACHE_DIR: TEST_PATHS.CACHE_DIR, + }, + mockRuntime + ); + + logger.success('Test setup completed'); + }, 300000); + + test('should handle model initialization failure', async () => { + logger.info('Starting model initialization failure test'); + + const { getLlama } = await import('node-llama-cpp'); + const llamaMock = vi.mocked(getLlama); + + // Save original implementation + const originalImpl = llamaMock.getMockImplementation(); + + // Mock implementation to fail + llamaMock.mockRejectedValueOnce(new Error('Failed to initialize TTS model')); + + try { + await mockRuntime.useModel(ModelType.TEXT_TO_SPEECH, 'Test text'); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Model initialization failure test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Failed to initialize TTS model'); + } finally { + // Restore original implementation + if (originalImpl) { + llamaMock.mockImplementation(originalImpl); + } + } + }); + + test('should handle audio generation failure', async () => { + logger.info('Starting audio generation failure test'); + + const { getLlama } = await import('node-llama-cpp'); + const llamaMock = vi.mocked(getLlama); + + // Save original implementation + const originalImpl = llamaMock.getMockImplementation(); + + // Mock implementation to fail during audio generation + llamaMock.mockRejectedValueOnce(new Error('Audio generation failed')); + + try { + await mockRuntime.useModel(ModelType.TEXT_TO_SPEECH, 'Test text'); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Audio generation failure test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('Audio generation failed'); + } finally { + // Restore original implementation + if (originalImpl) { + llamaMock.mockImplementation(originalImpl); + } + } + }); + + test('should generate speech from text successfully', async () => { + logger.info('Starting successful TTS test'); + + const testText = 'This is a test of the text to speech system.'; + logger.info('Testing with text:', testText); + + try { + const result = await mockRuntime.useModel(ModelType.TEXT_TO_SPEECH, testText); + logger.info('TTS generation result type:', typeof result); + + expect(result).toBeDefined(); + expect(result).toBeInstanceOf(Readable); + + // Test stream readability + let dataReceived = false; + (result as Readable).on('data', (chunk) => { + logger.info('Received audio data chunk:', { size: chunk.length }); + dataReceived = true; + }); + + await new Promise((resolve, reject) => { + (result as Readable).on('end', () => { + if (!dataReceived) { + reject(new Error('No audio data received from stream')); + } else { + resolve(true); + } + }); + (result as Readable).on('error', reject); + }); + + logger.success('Successful TTS test completed'); + } catch (error) { + logger.error('TTS test failed:', { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + }); + throw error; + } + }); + + test('should handle empty text input', async () => { + logger.info('Starting empty text test'); + const emptyText = ''; + + try { + await mockRuntime.useModel(ModelType.TEXT_TO_SPEECH, emptyText); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Empty text test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('empty text'); + } + }); + + test('should handle non-string input', async () => { + logger.info('Starting non-string input test'); + const invalidInput = { text: 'not-a-string' }; + + try { + await mockRuntime.useModel(ModelType.TEXT_TO_SPEECH, invalidInput as unknown as string); + throw new Error("Should have failed but didn't"); + } catch (error) { + logger.info('Non-string input test failed as expected:', { + error: error instanceof Error ? error.message : String(error), + }); + expect(error).toBeDefined(); + expect(error.message).toContain('invalid input'); + } + }); +}); diff --git a/bun.lock b/bun.lock index 8022b4c..8d4b86d 100644 --- a/bun.lock +++ b/bun.lock @@ -1,5 +1,6 @@ { "lockfileVersion": 1, + "configVersion": 0, "workspaces": { "": { "name": "@elizaos/plugin-local-ai", diff --git a/package.json b/package.json index 6a1ca36..b911ef2 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,5 @@ { - "name": "@elizaos/plugin-local-embedding", + "name": "@elizaos/plugin-local-ai", "version": "1.2.1", "type": "module", "main": "dist/index.js", @@ -7,7 +7,7 @@ "types": "dist/index.d.ts", "repository": { "type": "git", - "url": "git+https://github.com/elizaos-plugins/plugin-local-embedding.git" + "url": "git+https://github.com/elizaos-plugins/plugin-local-ai.git" }, "exports": { "./package.json": "./package.json", @@ -22,7 +22,7 @@ "dist" ], "dependencies": { - "@elizaos/core": "^1.2.1", + "@elizaos/core": "workspace:*", "@huggingface/transformers": "^3.5.1", "node-llama-cpp": "3.10.0", "nodejs-whisper": "0.2.9", @@ -102,4 +102,4 @@ "prettier": "3.6.2", "typescript": "^5.8.2" } -} +} \ No newline at end of file diff --git a/src/environment.ts b/src/environment.ts index 4fabd91..a9699fd 100644 --- a/src/environment.ts +++ b/src/environment.ts @@ -2,6 +2,8 @@ import { logger } from '@elizaos/core'; import { z } from 'zod'; // Default model filenames +const DEFAULT_SMALL_MODEL = 'DeepHermes-3-Llama-3-3B-Preview-q4.gguf'; +const DEFAULT_LARGE_MODEL = 'DeepHermes-3-Llama-3-8B-q4.gguf'; const DEFAULT_EMBEDDING_MODEL = 'bge-small-en-v1.5.Q4_K_M.gguf'; // Configuration schema focused only on local AI settings @@ -10,6 +12,8 @@ const DEFAULT_EMBEDDING_MODEL = 'bge-small-en-v1.5.Q4_K_M.gguf'; * Allows overriding default model filenames via environment variables. */ export const configSchema = z.object({ + LOCAL_SMALL_MODEL: z.string().optional().default(DEFAULT_SMALL_MODEL), + LOCAL_LARGE_MODEL: z.string().optional().default(DEFAULT_LARGE_MODEL), LOCAL_EMBEDDING_MODEL: z.string().optional().default(DEFAULT_EMBEDDING_MODEL), MODELS_DIR: z.string().optional(), // Path for the models directory CACHE_DIR: z.string().optional(), // Path for the cache directory @@ -35,36 +39,29 @@ export function validateConfig(): Config { try { // Prepare the config for parsing, reading from process.env const configToParse = { + // Read model filenames from environment variables or use undefined (so zod defaults apply) + LOCAL_SMALL_MODEL: process.env.LOCAL_SMALL_MODEL, + LOCAL_LARGE_MODEL: process.env.LOCAL_LARGE_MODEL, LOCAL_EMBEDDING_MODEL: process.env.LOCAL_EMBEDDING_MODEL, MODELS_DIR: process.env.MODELS_DIR, // Read models directory path from env CACHE_DIR: process.env.CACHE_DIR, // Read cache directory path from env LOCAL_EMBEDDING_DIMENSIONS: process.env.LOCAL_EMBEDDING_DIMENSIONS, // Read embedding dimensions }; - logger.debug('Validating configuration for local AI plugin from env:', { - LOCAL_EMBEDDING_MODEL: configToParse.LOCAL_EMBEDDING_MODEL, - MODELS_DIR: configToParse.MODELS_DIR, - CACHE_DIR: configToParse.CACHE_DIR, - LOCAL_EMBEDDING_DIMENSIONS: configToParse.LOCAL_EMBEDDING_DIMENSIONS, - }); + logger.debug('Validating configuration for local AI plugin from env: ' + JSON.stringify(configToParse)); const validatedConfig = configSchema.parse(configToParse); - logger.info('Using local AI configuration:', validatedConfig); + logger.info('Using local AI configuration: ' + JSON.stringify(validatedConfig)); return validatedConfig; } catch (error) { if (error instanceof z.ZodError) { - const errorMessages = error.errors - .map((err) => `${err.path.join('.')}: ${err.message}`) - .join('\n'); - logger.error('Zod validation failed:', errorMessages); + const errorMessages = JSON.stringify(error.issues); + logger.error('Zod validation failed: ' + errorMessages); throw new Error(`Configuration validation failed:\n${errorMessages}`); } - logger.error('Configuration validation failed:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - }); + logger.error('Configuration validation failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } } diff --git a/src/index.ts b/src/index.ts index 638807e..ea6906e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,23 +1,104 @@ +import fs from 'node:fs'; +import os from 'node:os'; +import path from 'node:path'; +import { Readable } from 'node:stream'; import type { + GenerateTextParams, ModelTypeName, - TextEmbeddingParams + TextEmbeddingParams, + ObjectGenerationParams, + ImageDescriptionParams, + TranscriptionParams, + TextToSpeechParams, + TokenizeTextParams, + DetokenizeTextParams, } from '@elizaos/core'; import { type IAgentRuntime, ModelType, type Plugin, logger } from '@elizaos/core'; import { type Llama, + LlamaChatSession, + type LlamaContext, + type LlamaContextSequence, LlamaEmbeddingContext, type LlamaModel, - getLlama + getLlama, } from 'node-llama-cpp'; -import fs from 'node:fs'; -import os from 'node:os'; -import path from 'node:path'; -import { basename } from 'path'; -import { type Config, validateConfig } from './environment'; -import { type EmbeddingModelSpec, MODEL_SPECS, type ModelSpec } from './types'; +import { validateConfig, type Config } from './environment'; +import { MODEL_SPECS, type ModelSpec, type EmbeddingModelSpec } from './types'; import { DownloadManager } from './utils/downloadManager'; import { getPlatformManager } from './utils/platform'; import { TokenizerManager } from './utils/tokenizerManager'; +import { TranscribeManager } from './utils/transcribeManager'; +import { TTSManager } from './utils/ttsManager'; +import { VisionManager } from './utils/visionManager'; +import { basename } from 'path'; + +/** + * Local interface for internal text generation params. + * Extends core params with modelType which is used internally to select the model. + */ +interface LocalGenerateTextParams { + prompt: string; + stopSequences?: string[]; + modelType?: ModelTypeName; +} + +// Words to punish in LLM responses +/** + * Array containing words that should trigger a punishment when used in a message. + * This array includes words like "please", "feel", "free", punctuation marks, and various topic-related words. + * @type {string[]} + */ +const wordsToPunish = [ + ' please', + ' feel', + ' free', + '!', + '–', + '—', + '?', + '.', + ',', + '; ', + ' cosmos', + ' tapestry', + ' tapestries', + ' glitch', + ' matrix', + ' cyberspace', + ' troll', + ' questions', + ' topics', + ' discuss', + ' basically', + ' simulation', + ' simulate', + ' universe', + ' like', + ' debug', + ' debugging', + ' wild', + ' existential', + ' juicy', + ' circuits', + ' help', + ' ask', + ' happy', + ' just', + ' cosmic', + ' cool', + ' joke', + ' punchline', + ' fancy', + ' glad', + ' assist', + ' algorithm', + ' Indeed', + ' Furthermore', + ' However', + ' Notably', + ' Therefore', +]; /** * Class representing a LocalAIManager. @@ -33,24 +114,42 @@ import { TokenizerManager } from './utils/tokenizerManager'; class LocalAIManager { private static instance: LocalAIManager | null = null; private llama: Llama | undefined; + private smallModel: LlamaModel | undefined; + private mediumModel: LlamaModel | undefined; private embeddingModel: LlamaModel | undefined; private embeddingContext: LlamaEmbeddingContext | undefined; + private ctx: LlamaContext | undefined; + private sequence: LlamaContextSequence | undefined; + private chatSession: LlamaChatSession | undefined; private modelPath!: string; private mediumModelPath!: string; private embeddingModelPath!: string; private cacheDir!: string; private tokenizerManager!: TokenizerManager; private downloadManager!: DownloadManager; + private visionManager!: VisionManager; private activeModelConfig: ModelSpec; private embeddingModelConfig: EmbeddingModelSpec; + private transcribeManager!: TranscribeManager; + private ttsManager!: TTSManager; private config: Config | null = null; // Store validated config // Initialization state flag + private smallModelInitialized = false; + private mediumModelInitialized = false; private embeddingInitialized = false; + private visionInitialized = false; + private transcriptionInitialized = false; + private ttsInitialized = false; private environmentInitialized = false; // Add flag for environment initialization // Initialization promises to prevent duplicate initialization + private smallModelInitializingPromise: Promise | null = null; + private mediumModelInitializingPromise: Promise | null = null; private embeddingInitializingPromise: Promise | null = null; + private visionInitializingPromise: Promise | null = null; + private transcriptionInitializingPromise: Promise | null = null; + private ttsInitializingPromise: Promise | null = null; private environmentInitializingPromise: Promise | null = null; // Add promise for environment private modelsDir!: string; @@ -80,6 +179,9 @@ class LocalAIManager { // Initialize managers that depend on modelsDir this.downloadManager = DownloadManager.getInstance(this.cacheDir, this.modelsDir); this.tokenizerManager = TokenizerManager.getInstance(this.cacheDir, this.modelsDir); + this.visionManager = VisionManager.getInstance(this.cacheDir); + this.transcribeManager = TranscribeManager.getInstance(this.cacheDir); + this.ttsManager = TTSManager.getInstance(this.cacheDir); } /** @@ -177,8 +279,12 @@ class LocalAIManager { this._postValidateInit(); // Set model paths based on validated config + this.modelPath = path.join(this.modelsDir, this.config.LOCAL_SMALL_MODEL); + this.mediumModelPath = path.join(this.modelsDir, this.config.LOCAL_LARGE_MODEL); this.embeddingModelPath = path.join(this.modelsDir, this.config.LOCAL_EMBEDDING_MODEL); // Set embedding path + logger.info('Using small model path:', basename(this.modelPath)); + logger.info('Using medium model path:', basename(this.mediumModelPath)); logger.info('Using embedding model path:', basename(this.embeddingModelPath)); logger.info('Environment configuration validated and model paths set'); @@ -186,10 +292,7 @@ class LocalAIManager { this.environmentInitialized = true; logger.success('Environment initialization complete'); } catch (error) { - logger.error('Environment validation failed:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - }); + logger.error('Environment validation failed: ' + (error instanceof Error ? error.message : String(error))); this.environmentInitializingPromise = null; // Allow retry on failure throw error; } @@ -238,11 +341,7 @@ class LocalAIManager { // Pass the determined path to the download manager return await this.downloadManager.downloadModel(modelSpec, modelPathToDownload); } catch (error) { - logger.error('Model download failed:', { - error: error instanceof Error ? error.message : String(error), - modelType, - modelPath: modelPathToDownload, - }); + logger.error('Model download failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } } @@ -258,14 +357,13 @@ class LocalAIManager { await platformManager.initialize(); const capabilities = platformManager.getCapabilities(); - logger.info('Platform capabilities detected:', { + logger.info('Platform capabilities detected: ' + JSON.stringify({ platform: capabilities.platform, gpu: capabilities.gpu?.type || 'none', recommendedModel: capabilities.recommendedModelSize, - supportedBackends: capabilities.supportedBackends, - }); + })); } catch (error) { - logger.warn('Platform detection failed:', error); + logger.warn('Platform detection failed: ' + (error instanceof Error ? error.message : String(error))); } } @@ -277,6 +375,11 @@ class LocalAIManager { */ async initialize(modelType: ModelTypeName = ModelType.TEXT_SMALL): Promise { await this.initializeEnvironment(); // Ensure environment is initialized first + if (modelType === ModelType.TEXT_LARGE) { + await this.lazyInitMediumModel(); + } else { + await this.lazyInitSmallModel(); + } } /** @@ -324,12 +427,7 @@ class LocalAIManager { logger.success('Embedding model initialized successfully'); } } catch (error) { - logger.error('Embedding initialization failed with details:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - modelsDir: this.modelsDir, - embeddingModelPath: this.embeddingModelPath, // Log the path being used - }); + logger.error('Embedding initialization failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } } @@ -346,7 +444,7 @@ class LocalAIManager { throw new Error('Failed to initialize embedding model'); } - logger.info('Generating embedding for text', { textLength: text.length }); + logger.info('Generating embedding for text, length: ' + text.length); // Use the native getEmbedding method const embeddingResult = await this.embeddingContext.getEmbeddingFor(text); @@ -357,16 +455,10 @@ class LocalAIManager { // Normalize the embedding if needed (may already be normalized) const normalizedEmbedding = this.normalizeEmbedding(mutableEmbedding); - logger.info('Embedding generation complete', { - dimensions: normalizedEmbedding.length, - }); + logger.info('Embedding generation complete, dimensions: ' + normalizedEmbedding.length); return normalizedEmbedding; } catch (error) { - logger.error('Embedding generation failed:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - textLength: text?.length ?? 'text is null', - }); + logger.error('Embedding generation failed: ' + (error instanceof Error ? error.message : String(error))); // Return zero vector with correct dimensions as fallback const zeroDimensions = this.config?.LOCAL_EMBEDDING_DIMENSIONS // Use validated config @@ -433,7 +525,7 @@ class LocalAIManager { this.embeddingInitialized = true; logger.info('Embedding model initialized successfully'); } catch (error) { - logger.error('Failed to initialize embedding model:', error); + logger.error('Failed to initialize embedding model: ' + (error instanceof Error ? error.message : String(error))); this.embeddingInitializingPromise = null; throw error; } @@ -443,6 +535,163 @@ class LocalAIManager { await this.embeddingInitializingPromise; } + /** + * Asynchronously generates text based on the provided parameters. + * Now uses lazy initialization for models + */ + async generateText(params: LocalGenerateTextParams): Promise { + try { + // Call LlamaContext.dispose() to free GPU memory. + if (this.ctx) { + this.ctx.dispose(); + this.ctx = undefined; + } + await this.initializeEnvironment(); // Ensure environment is initialized + logger.info('Generating text with model: ' + (params.modelType || 'default')); + // Lazy initialize the appropriate model + if (params.modelType === ModelType.TEXT_LARGE) { + await this.lazyInitMediumModel(); + + if (!this.mediumModel) { + throw new Error('Medium model initialization failed'); + } + + this.activeModelConfig = MODEL_SPECS.medium; + const mediumModel = this.mediumModel; + + // Create fresh context + this.ctx = await mediumModel.createContext({ + contextSize: MODEL_SPECS.medium.contextSize, + }); + } else { + await this.lazyInitSmallModel(); + + if (!this.smallModel) { + throw new Error('Small model initialization failed'); + } + + this.activeModelConfig = MODEL_SPECS.small; + const smallModel = this.smallModel; + + // Create fresh context + this.ctx = await smallModel.createContext({ + contextSize: MODEL_SPECS.small.contextSize, + }); + } + + if (!this.ctx) { + throw new Error('Failed to create prompt'); + } + + // QUICK TEST FIX: Always get fresh sequence + this.sequence = this.ctx.getSequence(); + + // QUICK TEST FIX: Create new session each time without maintaining state + // Only use valid options for LlamaChatSession + this.chatSession = new LlamaChatSession({ + contextSequence: this.sequence, + }); + + if (!this.chatSession) { + throw new Error('Failed to create chat session'); + } + logger.info('Created new chat session for model: ' + (params.modelType || 'default')); + // Log incoming prompt for debugging + logger.info('Incoming prompt length: ' + params.prompt.length); + + const tokens = await this.tokenizerManager.encode(params.prompt, this.activeModelConfig); + logger.info('Input tokens count: ' + tokens.length); + + // QUICK TEST FIX: Add system message to reset prompt + const systemMessage = 'You are a helpful AI assistant. Respond to the current request only.'; + await this.chatSession.prompt(systemMessage, { + maxTokens: 1, // Minimal tokens for system message + temperature: 0.0, + }); + + let response = await this.chatSession.prompt(params.prompt, { + maxTokens: 8192, + temperature: 0.7, + topP: 0.9, + repeatPenalty: { + punishTokensFilter: () => + this.smallModel ? this.smallModel.tokenize(wordsToPunish.join(' ')) : [], + penalty: 1.2, + frequencyPenalty: 0.7, + presencePenalty: 0.7, + }, + }); + + // Log raw response for debugging + logger.info('Raw response length: ' + response.length); + + // Clean think tags if present + if (response.includes('')) { + logger.info('Cleaning think tags from response'); + response = response.replace(/[\s\S]*?<\/think>\n?/g, ''); + logger.info('Think tags removed from response'); + } + + // Return the raw response and let the framework handle JSON parsing and action validation + return response; + } catch (error) { + logger.error('Text generation failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + + /** + * Describe image with lazy vision model initialization + */ + public async describeImage( + imageData: Buffer, + mimeType: string + ): Promise<{ title: string; description: string }> { + try { + // Lazy initialize vision model + await this.lazyInitVision(); + + // Convert buffer to data URL + const base64 = imageData.toString('base64'); + const dataUrl = `data:${mimeType};base64,${base64}`; + return await this.visionManager.processImage(dataUrl); + } catch (error) { + logger.error('Image description failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + + /** + * Transcribe audio with lazy transcription model initialization + */ + public async transcribeAudio(audioBuffer: Buffer): Promise { + try { + // Lazy initialize transcription model + await this.lazyInitTranscription(); + + const result = await this.transcribeManager.transcribe(audioBuffer); + return result.text; + } catch (error) { + logger.error('Audio transcription failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + + /** + * Generate speech with lazy TTS model initialization + */ + public async generateSpeech(text: string): Promise { + try { + // Lazy initialize TTS model + await this.lazyInitTTS(); + + return await this.ttsManager.generateSpeech(text); + } catch (error) { + logger.error('Speech generation failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + /** * Returns the TokenizerManager associated with this object. * @@ -460,6 +709,197 @@ class LocalAIManager { return this.activeModelConfig; } + /** + * Lazy initialize the small text model + */ + private async lazyInitSmallModel(): Promise { + if (this.smallModelInitialized) return; + + if (!this.smallModelInitializingPromise) { + this.smallModelInitializingPromise = (async () => { + await this.initializeEnvironment(); // Ensure environment is initialized first + await this.checkPlatformCapabilities(); + + // Download model if needed + // Pass the correct model path determined during environment init + await this.downloadModel(ModelType.TEXT_SMALL); + + // Initialize Llama and small model + try { + // Use getLlama helper instead of directly creating + this.llama = await getLlama(); + + const smallModel = await this.llama.loadModel({ + gpuLayers: 43, + modelPath: this.modelPath, // Use the potentially overridden path + vocabOnly: false, + }); + + this.smallModel = smallModel; + + const ctx = await smallModel.createContext({ + contextSize: MODEL_SPECS.small.contextSize, + }); + + this.ctx = ctx; + this.sequence = undefined; // Reset sequence to create a new one + this.smallModelInitialized = true; + logger.info('Small model initialized successfully'); + } catch (error) { + logger.error('Failed to initialize small model: ' + (error instanceof Error ? error.message : String(error))); + this.smallModelInitializingPromise = null; + throw error; + } + })(); + } + + await this.smallModelInitializingPromise; + } + + /** + * Lazy initialize the medium text model + */ + private async lazyInitMediumModel(): Promise { + if (this.mediumModelInitialized) return; + + if (!this.mediumModelInitializingPromise) { + this.mediumModelInitializingPromise = (async () => { + await this.initializeEnvironment(); // Ensure environment is initialized first + // Make sure llama is initialized first (implicitly done by small model init if needed) + if (!this.llama) { + // Attempt to initialize small model first to get llama instance + // This might download the small model even if only medium is requested, + // but ensures llama is ready. + await this.lazyInitSmallModel(); + } + + // Download model if needed + // Pass the correct model path determined during environment init + await this.downloadModel(ModelType.TEXT_LARGE); + + // Initialize medium model + try { + const mediumModel = await this.llama!.loadModel({ + gpuLayers: 43, + modelPath: this.mediumModelPath, // Use the potentially overridden path + vocabOnly: false, + }); + + this.mediumModel = mediumModel; + this.mediumModelInitialized = true; + logger.info('Medium model initialized successfully'); + } catch (error) { + logger.error('Failed to initialize medium model: ' + (error instanceof Error ? error.message : String(error))); + this.mediumModelInitializingPromise = null; + throw error; + } + })(); + } + + await this.mediumModelInitializingPromise; + } + + /** + * Lazy initialize the vision model + */ + private async lazyInitVision(): Promise { + if (this.visionInitialized) return; + + if (!this.visionInitializingPromise) { + this.visionInitializingPromise = (async () => { + try { + // Initialize vision model directly + // Use existing initialization code from the file + // ... + this.visionInitialized = true; + logger.info('Vision model initialized successfully'); + } catch (error) { + logger.error('Failed to initialize vision model: ' + (error instanceof Error ? error.message : String(error))); + this.visionInitializingPromise = null; + throw error; + } + })(); + } + + await this.visionInitializingPromise; + } + + /** + * Lazy initialize the transcription model + */ + private async lazyInitTranscription(): Promise { + if (this.transcriptionInitialized) return; + + if (!this.transcriptionInitializingPromise) { + this.transcriptionInitializingPromise = (async () => { + try { + // Ensure environment is initialized first + await this.initializeEnvironment(); + + // Initialize TranscribeManager if not already done + if (!this.transcribeManager) { + this.transcribeManager = TranscribeManager.getInstance(this.cacheDir); + } + + // Ensure FFmpeg is available + const ffmpegReady = await this.transcribeManager.ensureFFmpeg(); + if (!ffmpegReady) { + // FFmpeg is not available, log instructions and throw + // The TranscribeManager's ensureFFmpeg or initializeFFmpeg would have already logged instructions. + logger.error( + 'FFmpeg is not available or not configured correctly. Cannot proceed with transcription.' + ); + // No need to call logFFmpegInstallInstructions here as ensureFFmpeg/initializeFFmpeg already does. + throw new Error( + 'FFmpeg is required for transcription but is not available. Please see server logs for installation instructions.' + ); + } + + // Proceed with transcription model initialization if FFmpeg is ready + // (Assuming TranscribeManager handles its own specific model init if any, + // or that nodewhisper handles it internally) + this.transcriptionInitialized = true; + logger.info('Transcription prerequisites (FFmpeg) checked and ready.'); + logger.info('Transcription model initialized successfully'); + } catch (error) { + logger.error('Failed to initialize transcription model: ' + (error instanceof Error ? error.message : String(error))); + this.transcriptionInitializingPromise = null; + throw error; + } + })(); + } + + await this.transcriptionInitializingPromise; + } + + /** + * Lazy initialize the TTS model + */ + private async lazyInitTTS(): Promise { + if (this.ttsInitialized) return; + + if (!this.ttsInitializingPromise) { + this.ttsInitializingPromise = (async () => { + try { + // Initialize TTS model directly + // Use existing initialization code from the file + // Get the TTSManager instance (ensure environment is initialized for cacheDir) + await this.initializeEnvironment(); + this.ttsManager = TTSManager.getInstance(this.cacheDir); + // Note: The internal pipeline initialization within TTSManager happens + // when generateSpeech calls its own initialize method. + this.ttsInitialized = true; + logger.info('TTS model initialized successfully'); + } catch (error) { + logger.error('Failed to lazy initialize TTS components: ' + (error instanceof Error ? error.message : String(error))); + this.ttsInitializingPromise = null; // Allow retry + throw error; + } + })(); + } + + await this.ttsInitializingPromise; + } } // Create manager instance @@ -475,12 +915,22 @@ export const localAiPlugin: Plugin = { async init(_config: any, runtime: IAgentRuntime) { logger.info('🚀 Initializing Local AI plugin...'); - + try { // Initialize environment and validate configuration await localAIManager.initializeEnvironment(); const config = validateConfig(); - + + // Check for critical configuration + if (!config.LOCAL_SMALL_MODEL || !config.LOCAL_LARGE_MODEL || !config.LOCAL_EMBEDDING_MODEL) { + logger.warn('⚠️ Local AI plugin: Model configuration is incomplete'); + logger.warn('Please ensure the following environment variables are set:'); + logger.warn('- LOCAL_SMALL_MODEL: Path to small language model file'); + logger.warn('- LOCAL_LARGE_MODEL: Path to large language model file'); + logger.warn('- LOCAL_EMBEDDING_MODEL: Path to embedding model file'); + logger.warn('Example: LOCAL_SMALL_MODEL=llama-3.2-1b-instruct-q8_0.gguf'); + } + // Check if models directory is accessible const modelsDir = config.MODELS_DIR || path.join(os.homedir(), '.eliza', 'models'); if (!fs.existsSync(modelsDir)) { @@ -488,14 +938,14 @@ export const localAiPlugin: Plugin = { logger.warn('The directory will be created, but you need to download model files'); logger.warn('Visit https://huggingface.co/models to download compatible GGUF models'); } - + // Perform a basic initialization test logger.info('🔍 Testing Local AI initialization...'); - + try { // Check platform capabilities await localAIManager.checkPlatformCapabilities(); - + // Test if we can get the llama instance const llamaInstance = await getLlama(); if (llamaInstance) { @@ -503,29 +953,31 @@ export const localAiPlugin: Plugin = { } else { throw new Error('Failed to load llama.cpp library'); } - + // Check if at least one model file exists + const smallModelPath = path.join(modelsDir, config.LOCAL_SMALL_MODEL); + const largeModelPath = path.join(modelsDir, config.LOCAL_LARGE_MODEL); const embeddingModelPath = path.join(modelsDir, config.LOCAL_EMBEDDING_MODEL); - + const modelsExist = { + small: fs.existsSync(smallModelPath), + large: fs.existsSync(largeModelPath), embedding: fs.existsSync(embeddingModelPath) }; - - if (!modelsExist.embedding) { + + if (!modelsExist.small && !modelsExist.large && !modelsExist.embedding) { logger.warn('⚠️ No model files found in models directory'); logger.warn('Models will be downloaded on first use, which may take time'); logger.warn('To pre-download models, run the plugin and it will fetch them automatically'); } else { - logger.info('📦 Found model files:', { - embedding: modelsExist.embedding ? '✓' : '✗' - }); + logger.info('📦 Found model files: small=' + (modelsExist.small ? '✓' : '✗') + ', large=' + (modelsExist.large ? '✓' : '✗') + ', embedding=' + (modelsExist.embedding ? '✓' : '✗')); } - + logger.success('✅ Local AI plugin initialized successfully'); logger.info('💡 Models will be loaded on-demand when first used'); - + } catch (testError) { - logger.error('❌ Local AI initialization test failed:', testError); + logger.error('❌ Local AI initialization test failed: ' + (testError instanceof Error ? testError.message : String(testError))); logger.warn('The plugin may not function correctly'); logger.warn('Please check:'); logger.warn('1. Your system has sufficient memory (8GB+ recommended)'); @@ -533,13 +985,10 @@ export const localAiPlugin: Plugin = { logger.warn('3. Your CPU supports the required instruction sets'); // Don't throw here - allow the plugin to load even if the test fails } - + } catch (error) { - logger.error('❌ Failed to initialize Local AI plugin:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - }); - + logger.error('❌ Failed to initialize Local AI plugin: ' + (error instanceof Error ? error.message : String(error))); + // Provide helpful guidance based on common errors if (error instanceof Error) { if (error.message.includes('Cannot find module')) { @@ -553,14 +1002,51 @@ export const localAiPlugin: Plugin = { logger.error('- Linux: Install build-essential package'); } } - + // Don't throw - allow the system to continue without this plugin logger.warn('⚠️ Local AI plugin will not be available'); } }, models: { - [ModelType.TEXT_EMBEDDING]: async (_runtime: IAgentRuntime, params: TextEmbeddingParams) => { - const text = params?.text; + [ModelType.TEXT_SMALL]: async ( + _runtime: IAgentRuntime, + { prompt, stopSequences = [] }: GenerateTextParams + ) => { + try { + // Ensure environment is initialized before generating text (now public) + await localAIManager.initializeEnvironment(); + return await localAIManager.generateText({ + prompt, + stopSequences, + modelType: ModelType.TEXT_SMALL, + }); + } catch (error) { + logger.error('Error in TEXT_SMALL handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + + [ModelType.TEXT_LARGE]: async ( + _runtime: IAgentRuntime, + { prompt, stopSequences = [] }: GenerateTextParams + ) => { + try { + // Ensure environment is initialized before generating text (now public) + await localAIManager.initializeEnvironment(); + return await localAIManager.generateText({ + prompt, + stopSequences, + modelType: ModelType.TEXT_LARGE, + }); + } catch (error) { + logger.error('Error in TEXT_LARGE handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + + [ModelType.TEXT_EMBEDDING]: async (_runtime: IAgentRuntime, params: string | TextEmbeddingParams | null) => { + // Handle different input types: string, TextEmbeddingParams, or null + const text = params === null ? null : (typeof params === 'string' ? params : params?.text); try { // Handle null/undefined/empty text if (!text) { @@ -571,40 +1057,307 @@ export const localAiPlugin: Plugin = { // Pass the raw text directly to the framework without any manipulation return await localAIManager.generateEmbedding(text); } catch (error) { - logger.error('Error in TEXT_EMBEDDING handler:', { - error: error instanceof Error ? error.message : String(error), - fullText: text, - textType: typeof text, - textStructure: text !== null ? JSON.stringify(text, null, 2) : 'null', - }); + logger.error('Error in TEXT_EMBEDDING handler: ' + (error instanceof Error ? error.message : String(error))); return new Array(384).fill(0); } }, + [ModelType.OBJECT_SMALL]: async (_runtime: IAgentRuntime, params: ObjectGenerationParams) => { + try { + // Ensure environment is initialized (now public) + await localAIManager.initializeEnvironment(); + logger.info('OBJECT_SMALL handler - Processing request with prompt length: ' + params.prompt.length); + + // Enhance the prompt to request JSON output + let jsonPrompt = params.prompt; + if (!jsonPrompt.includes('```json') && !jsonPrompt.includes('respond with valid JSON')) { + jsonPrompt += + '\nPlease respond with valid JSON only, without any explanations, markdown formatting, or additional text.'; + } + + // Directly generate text using the local small model + const textResponse = await localAIManager.generateText({ + prompt: jsonPrompt, + stopSequences: params.stopSequences, + modelType: ModelType.TEXT_SMALL, + }); + + // Extract and parse JSON from the text response + try { + // Function to extract JSON content from text + const extractJSON = (text: string): string => { + // Try to find content between JSON codeblocks or markdown blocks + const jsonBlockRegex = /```(?:json)?\s*([\s\S]*?)\s*```/; + const match = text.match(jsonBlockRegex); + + if (match && match[1]) { + return match[1].trim(); + } + + // If no code blocks, try to find JSON-like content + // This regex looks for content that starts with { and ends with } + const jsonContentRegex = /\s*(\{[\s\S]*\})\s*$/; + const contentMatch = text.match(jsonContentRegex); + + if (contentMatch && contentMatch[1]) { + return contentMatch[1].trim(); + } + + // If no JSON-like content found, return the original text + return text.trim(); + }; + + const extractedJsonText = extractJSON(textResponse); + logger.debug('Extracted JSON text:', extractedJsonText); + + let jsonObject; + try { + jsonObject = JSON.parse(extractedJsonText); + } catch (parseError) { + // Try fixing common JSON issues + logger.debug('Initial JSON parse failed, attempting to fix common issues'); + + // Replace any unescaped newlines in string values + const fixedJson = extractedJsonText + .replace(/:\s*"([^"]*)(?:\n)([^"]*)"/g, ': "$1\\n$2"') + // Remove any non-JSON text that might have gotten mixed into string values + .replace(/"([^"]*?)[^a-zA-Z0-9\s\.,;:\-_\(\)"'\[\]{}]([^"]*?)"/g, '"$1$2"') + // Fix missing quotes around property names + .replace(/(\s*)(\w+)(\s*):/g, '$1"$2"$3:') + // Fix trailing commas in arrays and objects + .replace(/,(\s*[\]}])/g, '$1'); + + try { + jsonObject = JSON.parse(fixedJson); + } catch (finalError) { + logger.error('Failed to parse JSON after fixing: ' + (finalError instanceof Error ? finalError.message : String(finalError))); + throw new Error('Invalid JSON returned from model'); + } + } + + // Validate against schema if provided + if (params.schema) { + try { + // Simplistic schema validation - check if all required properties exist + for (const key of Object.keys(params.schema)) { + if (!(key in jsonObject)) { + jsonObject[key] = null; // Add missing properties with null value + } + } + } catch (schemaError) { + logger.error('Schema validation failed: ' + (schemaError instanceof Error ? schemaError.message : String(schemaError))); + } + } + + return jsonObject; + } catch (parseError) { + logger.error('Failed to parse JSON: ' + (parseError instanceof Error ? parseError.message : String(parseError))); + logger.error('Raw response length: ' + textResponse.length); + throw new Error('Invalid JSON returned from model'); + } + } catch (error) { + logger.error('Error in OBJECT_SMALL handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + + [ModelType.OBJECT_LARGE]: async (_runtime: IAgentRuntime, params: ObjectGenerationParams) => { + try { + // Ensure environment is initialized (now public) + await localAIManager.initializeEnvironment(); + logger.info('OBJECT_LARGE handler - Processing request with prompt length: ' + params.prompt.length); + + // Enhance the prompt to request JSON output + let jsonPrompt = params.prompt; + if (!jsonPrompt.includes('```json') && !jsonPrompt.includes('respond with valid JSON')) { + jsonPrompt += + '\nPlease respond with valid JSON only, without any explanations, markdown formatting, or additional text.'; + } + + // Directly generate text using the local large model + const textResponse = await localAIManager.generateText({ + prompt: jsonPrompt, + stopSequences: params.stopSequences, + modelType: ModelType.TEXT_LARGE, + }); + + // Extract and parse JSON from the text response + try { + // Function to extract JSON content from text + const extractJSON = (text: string): string => { + // Try to find content between JSON codeblocks or markdown blocks + const jsonBlockRegex = /```(?:json)?\s*([\s\S]*?)\s*```/; + const match = text.match(jsonBlockRegex); + + if (match && match[1]) { + return match[1].trim(); + } + + // If no code blocks, try to find JSON-like content + // This regex looks for content that starts with { and ends with } + const jsonContentRegex = /\s*(\{[\s\S]*\})\s*$/; + const contentMatch = text.match(jsonContentRegex); + + if (contentMatch && contentMatch[1]) { + return contentMatch[1].trim(); + } + + // If no JSON-like content found, return the original text + return text.trim(); + }; + + // Clean up the extracted JSON to handle common formatting issues + const cleanupJSON = (jsonText: string): string => { + // Remove common logging/debugging patterns that might get mixed into the JSON + return ( + jsonText + // Remove any lines that look like log statements + .replace(/\[DEBUG\].*?(\n|$)/g, '\n') + .replace(/\[LOG\].*?(\n|$)/g, '\n') + .replace(/console\.log.*?(\n|$)/g, '\n') + ); + }; + + const extractedJsonText = extractJSON(textResponse); + const cleanedJsonText = cleanupJSON(extractedJsonText); + logger.debug('Extracted JSON text:', cleanedJsonText); + + let jsonObject; + try { + jsonObject = JSON.parse(cleanedJsonText); + } catch (parseError) { + // Try fixing common JSON issues + logger.debug('Initial JSON parse failed, attempting to fix common issues'); + + // Replace any unescaped newlines in string values + const fixedJson = cleanedJsonText + .replace(/:\s*"([^"]*)(?:\n)([^"]*)"/g, ': "$1\\n$2"') + // Remove any non-JSON text that might have gotten mixed into string values + .replace(/"([^"]*?)[^a-zA-Z0-9\s\.,;:\-_\(\)"'\[\]{}]([^"]*?)"/g, '"$1$2"') + // Fix missing quotes around property names + .replace(/(\s*)(\w+)(\s*):/g, '$1"$2"$3:') + // Fix trailing commas in arrays and objects + .replace(/,(\s*[\]}])/g, '$1'); + + try { + jsonObject = JSON.parse(fixedJson); + } catch (finalError) { + logger.error('Failed to parse JSON after fixing: ' + (finalError instanceof Error ? finalError.message : String(finalError))); + throw new Error('Invalid JSON returned from model'); + } + } + + // Validate against schema if provided + if (params.schema) { + try { + // Simplistic schema validation - check if all required properties exist + for (const key of Object.keys(params.schema)) { + if (!(key in jsonObject)) { + jsonObject[key] = null; // Add missing properties with null value + } + } + } catch (schemaError) { + logger.error('Schema validation failed: ' + (schemaError instanceof Error ? schemaError.message : String(schemaError))); + } + } + + return jsonObject; + } catch (parseError) { + logger.error('Failed to parse JSON: ' + (parseError instanceof Error ? parseError.message : String(parseError))); + logger.error('Raw response length: ' + textResponse.length); + throw new Error('Invalid JSON returned from model'); + } + } catch (error) { + logger.error('Error in OBJECT_LARGE handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + [ModelType.TEXT_TOKENIZER_ENCODE]: async ( _runtime: IAgentRuntime, - { text }: { text: string } + params: TokenizeTextParams ) => { try { const manager = localAIManager.getTokenizerManager(); const config = localAIManager.getActiveModelConfig(); - return await manager.encode(text, config); + // TokenizeTextParams has 'prompt' property per core API + const textToEncode = params.prompt || ''; + return await manager.encode(textToEncode, config); } catch (error) { - logger.error('Error in TEXT_TOKENIZER_ENCODE handler:', error); + logger.error('Error in TEXT_TOKENIZER_ENCODE handler: ' + (error instanceof Error ? error.message : String(error))); throw error; } }, [ModelType.TEXT_TOKENIZER_DECODE]: async ( _runtime: IAgentRuntime, - { tokens }: { tokens: number[] } + params: DetokenizeTextParams ) => { try { const manager = localAIManager.getTokenizerManager(); const config = localAIManager.getActiveModelConfig(); - return await manager.decode(tokens, config); + return await manager.decode(params.tokens, config); + } catch (error) { + logger.error('Error in TEXT_TOKENIZER_DECODE handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + + [ModelType.IMAGE_DESCRIPTION]: async (_runtime: IAgentRuntime, params: string | ImageDescriptionParams) => { + try { + const imageUrl = typeof params === 'string' ? params : params.imageUrl; + logger.info('Processing image from URL: ' + imageUrl); + + // Fetch the image from URL + const response = await fetch(imageUrl); + if (!response.ok) { + throw new Error(`Failed to fetch image: ${response.statusText}`); + } + + const buffer = Buffer.from(await response.arrayBuffer()); + const mimeType = response.headers.get('content-type') || 'image/jpeg'; + + return await localAIManager.describeImage(buffer, mimeType); + } catch (error) { + logger.error('Error in IMAGE_DESCRIPTION handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + + [ModelType.TRANSCRIPTION]: async (_runtime: IAgentRuntime, params: string | Buffer | TranscriptionParams) => { + try { + let audioBuffer: Buffer; + if (Buffer.isBuffer(params)) { + audioBuffer = params; + } else if (typeof params === 'string') { + // If string, fetch the audio + const response = await fetch(params); + audioBuffer = Buffer.from(await response.arrayBuffer()); + } else { + // TranscriptionParams - fetch from audioUrl + const response = await fetch(params.audioUrl); + audioBuffer = Buffer.from(await response.arrayBuffer()); + } + logger.info('Processing audio transcription, buffer size: ' + audioBuffer.length); + return await localAIManager.transcribeAudio(audioBuffer); } catch (error) { - logger.error('Error in TEXT_TOKENIZER_DECODE handler:', error); + logger.error('Error in TRANSCRIPTION handler: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + + [ModelType.TEXT_TO_SPEECH]: async (_runtime: IAgentRuntime, params: string | TextToSpeechParams): Promise => { + try { + const text = typeof params === 'string' ? params : params.text; + const readable = await localAIManager.generateSpeech(text); + // Convert Readable stream to Buffer for new API + const chunks: Buffer[] = []; + for await (const chunk of readable) { + chunks.push(Buffer.from(chunk)); + } + return Buffer.concat(chunks); + } catch (error) { + logger.error('Error in TEXT_TO_SPEECH handler: ' + (error instanceof Error ? error.message : String(error))); throw error; } }, @@ -613,6 +1366,65 @@ export const localAiPlugin: Plugin = { { name: 'local_ai_plugin_tests', tests: [ + { + name: 'local_ai_test_initialization', + fn: async (runtime) => { + try { + logger.info('Starting initialization test'); + + // Test TEXT_SMALL model initialization + const result = await runtime.useModel(ModelType.TEXT_SMALL, { + prompt: + "Debug Mode: Test initialization. Respond with 'Initialization successful' if you can read this.", + stopSequences: [], + }); + + logger.info('Model response:', result); + + if (!result || typeof result !== 'string') { + throw new Error('Invalid response from model'); + } + + if (!result.includes('successful')) { + throw new Error('Model response does not indicate success'); + } + + logger.success('Initialization test completed successfully'); + } catch (error) { + logger.error('Initialization test failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + }, + { + name: 'local_ai_test_text_large', + fn: async (runtime) => { + try { + logger.info('Starting TEXT_LARGE model test'); + + const result = await runtime.useModel(ModelType.TEXT_LARGE, { + prompt: + 'Debug Mode: Generate a one-sentence response about artificial intelligence.', + stopSequences: [], + }); + + logger.info('Large model response: ' + (typeof result === 'string' ? result.substring(0, 100) : 'non-string')); + + if (!result || typeof result !== 'string') { + throw new Error('Invalid response from large model'); + } + + if (result.length < 10) { + throw new Error('Response too short, possible model failure'); + } + + logger.success('TEXT_LARGE test completed successfully'); + } catch (error) { + logger.error('TEXT_LARGE test failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + }, { name: 'local_ai_test_text_embedding', fn: async (runtime) => { @@ -624,7 +1436,7 @@ export const localAiPlugin: Plugin = { text: 'This is a test of the text embedding model.', }); - logger.info('Embedding generated with dimensions:', embedding.length); + logger.info('Embedding generated with dimensions: ' + (Array.isArray(embedding) ? embedding.length : 'not-array')); if (!Array.isArray(embedding)) { throw new Error('Embedding is not an array'); @@ -646,10 +1458,7 @@ export const localAiPlugin: Plugin = { logger.success('TEXT_EMBEDDING test completed successfully'); } catch (error) { - logger.error('TEXT_EMBEDDING test failed:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - }); + logger.error('TEXT_EMBEDDING test failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } }, @@ -661,8 +1470,11 @@ export const localAiPlugin: Plugin = { logger.info('Starting TEXT_TOKENIZER_ENCODE test'); const text = 'Hello tokenizer test!'; - const tokens = await runtime.useModel(ModelType.TEXT_TOKENIZER_ENCODE, { text }); - logger.info('Encoded tokens:', { count: tokens.length }); + const tokens = await runtime.useModel(ModelType.TEXT_TOKENIZER_ENCODE, { + prompt: text, + modelType: ModelType.TEXT_SMALL, + }); + logger.info('Encoded tokens count: ' + (Array.isArray(tokens) ? tokens.length : 'not-array')); if (!Array.isArray(tokens)) { throw new Error('Tokens output is not an array'); @@ -678,10 +1490,7 @@ export const localAiPlugin: Plugin = { logger.success('TEXT_TOKENIZER_ENCODE test completed successfully'); } catch (error) { - logger.error('TEXT_TOKENIZER_ENCODE test failed:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, - }); + logger.error('TEXT_TOKENIZER_ENCODE test failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } }, @@ -695,17 +1504,16 @@ export const localAiPlugin: Plugin = { // First encode some text const originalText = 'Hello tokenizer test!'; const tokens = await runtime.useModel(ModelType.TEXT_TOKENIZER_ENCODE, { - text: originalText, + prompt: originalText, + modelType: ModelType.TEXT_SMALL, }); // Then decode it back const decodedText = await runtime.useModel(ModelType.TEXT_TOKENIZER_DECODE, { - tokens, - }); - logger.info('Round trip tokenization:', { - original: originalText, - decoded: decodedText, + tokens: tokens as number[], + modelType: ModelType.TEXT_SMALL, }); + logger.info('Round trip tokenization - original: ' + originalText + ', decoded: ' + decodedText); if (typeof decodedText !== 'string') { throw new Error('Decoded output is not a string'); @@ -713,10 +1521,137 @@ export const localAiPlugin: Plugin = { logger.success('TEXT_TOKENIZER_DECODE test completed successfully'); } catch (error) { - logger.error('TEXT_TOKENIZER_DECODE test failed:', { - error: error instanceof Error ? error.message : String(error), - stack: error instanceof Error ? error.stack : undefined, + logger.error('TEXT_TOKENIZER_DECODE test failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + }, + { + name: 'local_ai_test_image_description', + fn: async (runtime) => { + try { + logger.info('Starting IMAGE_DESCRIPTION test'); + + // Use a more stable test image URL + const imageUrl = + 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/320px-Cat03.jpg'; + const result = await runtime.useModel(ModelType.IMAGE_DESCRIPTION, imageUrl); + + logger.info('Image description result: ' + JSON.stringify(result)); + + if (!result || typeof result !== 'object') { + throw new Error('Invalid response format'); + } + + if (!result.title || !result.description) { + throw new Error('Missing title or description in response'); + } + + if (typeof result.title !== 'string' || typeof result.description !== 'string') { + throw new Error('Title or description is not a string'); + } + + logger.success('IMAGE_DESCRIPTION test completed successfully'); + } catch (error) { + logger.error('IMAGE_DESCRIPTION test failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + }, + { + name: 'local_ai_test_transcription', + fn: async (runtime) => { + try { + logger.info('Starting TRANSCRIPTION test'); + + // Create a proper WAV file header and minimal audio data + // WAV file format: RIFF header + fmt chunk + data chunk + const channels = 1; + const sampleRate = 16000; + const bitsPerSample = 16; + const duration = 0.5; // 500ms for better transcription + const numSamples = Math.floor(sampleRate * duration); + const dataSize = numSamples * channels * (bitsPerSample / 8); + + // Create the WAV header + const buffer = Buffer.alloc(44 + dataSize); + + // RIFF header + buffer.write('RIFF', 0); + buffer.writeUInt32LE(36 + dataSize, 4); // File size - 8 + buffer.write('WAVE', 8); + + // fmt chunk + buffer.write('fmt ', 12); + buffer.writeUInt32LE(16, 16); // fmt chunk size + buffer.writeUInt16LE(1, 20); // Audio format (1 = PCM) + buffer.writeUInt16LE(channels, 22); // Number of channels + buffer.writeUInt32LE(sampleRate, 24); // Sample rate + buffer.writeUInt32LE(sampleRate * channels * (bitsPerSample / 8), 28); // Byte rate + buffer.writeUInt16LE(channels * (bitsPerSample / 8), 32); // Block align + buffer.writeUInt16LE(bitsPerSample, 34); // Bits per sample + + // data chunk + buffer.write('data', 36); + buffer.writeUInt32LE(dataSize, 40); // Data size + + // Generate a simple sine wave tone (440Hz) instead of silence + const frequency = 440; // A4 note + for (let i = 0; i < numSamples; i++) { + const sample = Math.sin((2 * Math.PI * frequency * i) / sampleRate) * 0.1 * 32767; + buffer.writeInt16LE(Math.floor(sample), 44 + i * 2); + } + + const transcription = await runtime.useModel(ModelType.TRANSCRIPTION, buffer); + logger.info('Transcription result:', transcription); + + if (typeof transcription !== 'string') { + throw new Error('Transcription result is not a string'); + } + + // Accept empty string as valid result (for non-speech audio) + logger.info('Transcription completed (may be empty for non-speech audio)'); + + logger.success('TRANSCRIPTION test completed successfully'); + } catch (error) { + logger.error('TRANSCRIPTION test failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + }, + }, + { + name: 'local_ai_test_text_to_speech', + fn: async (runtime) => { + try { + logger.info('Starting TEXT_TO_SPEECH test'); + + const testText = 'This is a test of the text to speech system.'; + const audioStream = await runtime.useModel(ModelType.TEXT_TO_SPEECH, testText); + + if (!(audioStream instanceof Readable)) { + throw new Error('TTS output is not a readable stream'); + } + + // Test stream readability + let dataReceived = false; + audioStream.on('data', () => { + dataReceived = true; + }); + + await new Promise((resolve, reject) => { + audioStream.on('end', () => { + if (!dataReceived) { + reject(new Error('No audio data received from stream')); + } else { + resolve(true); + } + }); + audioStream.on('error', reject); }); + + logger.success('TEXT_TO_SPEECH test completed successfully'); + } catch (error) { + logger.error('TEXT_TO_SPEECH test failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } }, diff --git a/src/types/whisper-node.d.ts b/src/types/whisper-node.d.ts new file mode 100644 index 0000000..b66f143 --- /dev/null +++ b/src/types/whisper-node.d.ts @@ -0,0 +1,31 @@ +declare module 'whisper-node' { + interface WhisperOptions { + language?: string; + gen_file_txt?: boolean; + gen_file_subtitle?: boolean; + gen_file_vtt?: boolean; + word_timestamps?: boolean; + timestamp_size?: number; + } + + interface WhisperConfig { + modelName?: string; + modelPath?: string; + whisperOptions?: WhisperOptions; + } + + interface TranscriptSegment { + start: string; + end: string; + speech: string; + } + + function whisper(filePath: string, options?: WhisperConfig): Promise; + + const exports: { + whisper: typeof whisper; + default: typeof whisper; + }; + + export = exports; +} diff --git a/src/utils/downloadManager.ts b/src/utils/downloadManager.ts index 51314f3..c013875 100644 --- a/src/utils/downloadManager.ts +++ b/src/utils/downloadManager.ts @@ -382,11 +382,7 @@ export class DownloadManager { for (const attempt of attempts) { try { - logger.info('Attempting model download:', { - description: attempt.description, - url: attempt.url, - timestamp: new Date().toISOString(), - }); + logger.info('Attempting model download: ' + attempt.description + ' from ' + attempt.url); // The downloadFile method now handles the progress bar display await this.downloadFile(attempt.url, modelPath); @@ -398,11 +394,7 @@ export class DownloadManager { break; } catch (error) { lastError = error; - logger.warn('Model download attempt failed:', { - description: attempt.description, - error: error instanceof Error ? error.message : String(error), - timestamp: new Date().toISOString(), - }); + logger.warn('Model download attempt failed: ' + attempt.description + ' - ' + (error instanceof Error ? error.message : String(error))); } } @@ -419,11 +411,7 @@ export class DownloadManager { // Return false to indicate the model already existed return false; } catch (error) { - logger.error('Model download failed:', { - error: error instanceof Error ? error.message : String(error), - modelPath: modelPath, - model: modelSpec.name, - }); + logger.error('Model download failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } } diff --git a/src/utils/platform.ts b/src/utils/platform.ts index fe9424c..7edd158 100644 --- a/src/utils/platform.ts +++ b/src/utils/platform.ts @@ -75,7 +75,7 @@ export class PlatformManager { /** * Private constructor method. */ - private constructor() {} + private constructor() { } /** * Get the singleton instance of the PlatformManager class @@ -103,7 +103,7 @@ export class PlatformManager { // recommendedModel: this.capabilities.recommendedModelSize, // }); } catch (error) { - logger.error('Platform detection failed', { error }); + logger.error('Platform detection failed: ' + (error instanceof Error ? error.message : String(error))); throw error; } } @@ -171,7 +171,7 @@ export class PlatformManager { return null; } } catch (error) { - logger.error('GPU detection failed', { error }); + logger.error('GPU detection failed: ' + (error instanceof Error ? error.message : String(error))); return null; } } @@ -201,7 +201,7 @@ export class PlatformManager { isAppleSilicon: false, }; } catch (error) { - logger.error('Mac GPU detection failed', { error }); + logger.error('Mac GPU detection failed: ' + (error instanceof Error ? error.message : String(error))); return { name: 'Unknown Mac GPU', type: 'metal', @@ -242,7 +242,7 @@ export class PlatformManager { type: 'directml', }; } catch (error) { - logger.error('Windows GPU detection failed', { error }); + logger.error('Windows GPU detection failed: ' + (error instanceof Error ? error.message : String(error))); return null; } } @@ -282,7 +282,7 @@ export class PlatformManager { type: 'none', }; } catch (error) { - logger.error('Linux GPU detection failed', { error }); + logger.error('Linux GPU detection failed: ' + (error instanceof Error ? error.message : String(error))); return null; } } diff --git a/src/utils/transcribeManager.ts b/src/utils/transcribeManager.ts new file mode 100644 index 0000000..8fd5c35 --- /dev/null +++ b/src/utils/transcribeManager.ts @@ -0,0 +1,436 @@ +import { exec } from 'node:child_process'; +import fs from 'node:fs'; +import path from 'node:path'; +import { promisify } from 'node:util'; +import { logger } from '@elizaos/core'; + +const execAsync = promisify(exec); + +// Lazy load whisper-node to avoid ESM/CommonJS issues +let whisperModule: any = null; +async function getWhisper() { + if (!whisperModule) { + // Dynamic import for CommonJS module + const module = await import('whisper-node'); + // The module exports an object with a whisper property + whisperModule = (module as any).whisper; + } + return whisperModule; +} + +/** + * Interface representing the result of a transcription process. + * @interface + * @property {string} text - The transcribed text. + */ +interface TranscriptionResult { + text: string; +} + +/** + * Class representing a TranscribeManager. + * + * @property {TranscribeManager | null} instance - The singleton instance of the TranscribeManager class. + * @property {string} cacheDir - The directory path for caching transcribed files. + * @property {boolean} ffmpegAvailable - Flag indicating if ffmpeg is available for audio processing. + * @property {string | null} ffmpegVersion - The version of ffmpeg if available. + * @property {string | null} ffmpegPath - The path to the ffmpeg executable. + * @property {boolean} ffmpegInitialized - Flag indicating if ffmpeg has been initialized. + * + * @constructor + * Creates an instance of TranscribeManager with the specified cache directory. + */ +export class TranscribeManager { + private static instance: TranscribeManager | null = null; + private cacheDir: string; + private ffmpegAvailable = false; + private ffmpegVersion: string | null = null; + private ffmpegPath: string | null = null; + private ffmpegInitialized = false; + + /** + * Constructor for TranscribeManager class. + * + * @param {string} cacheDir - The directory path for storing cached files. + */ + private constructor(cacheDir: string) { + this.cacheDir = path.join(cacheDir, 'whisper'); + logger.debug('Initializing TranscribeManager at: ' + this.cacheDir); + this.ensureCacheDirectory(); + } + + /** + * Ensures that FFmpeg is initialized and available for use. + * @returns {Promise} A promise that resolves to a boolean value indicating if FFmpeg is available. + */ + public async ensureFFmpeg(): Promise { + if (!this.ffmpegInitialized) { + try { + await this.initializeFFmpeg(); + this.ffmpegInitialized = true; + } catch (error) { + logger.error('FFmpeg initialization failed: ' + (error instanceof Error ? error.message : String(error))); + return false; + } + } + return this.ffmpegAvailable; + } + + /** + * Checks if FFmpeg is available. + * @returns {boolean} True if FFmpeg is available, false otherwise. + */ + public isFFmpegAvailable(): boolean { + return this.ffmpegAvailable; + } + + /** + * Asynchronously retrieves the FFmpeg version if it hasn't been fetched yet. + * If the FFmpeg version has already been fetched, it will return the stored version. + * @returns A Promise that resolves with the FFmpeg version as a string, or null if the version is not available. + */ + public async getFFmpegVersion(): Promise { + if (!this.ffmpegVersion) { + await this.fetchFFmpegVersion(); + } + return this.ffmpegVersion; + } + + /** + * Fetches the FFmpeg version by executing the command "ffmpeg -version". + * Updates the class property ffmpegVersion with the retrieved version. + * Logs the FFmpeg version information or error message. + * @returns {Promise} A Promise that resolves once the FFmpeg version is fetched and logged. + */ + private async fetchFFmpegVersion(): Promise { + try { + const { stdout } = await execAsync('ffmpeg -version'); + this.ffmpegVersion = stdout.split('\n')[0]; + logger.info('FFmpeg version: ' + this.ffmpegVersion); + } catch (error) { + this.ffmpegVersion = null; + logger.error('Failed to get FFmpeg version: ' + (error instanceof Error ? error.message : String(error))); + } + } + + /** + * Initializes FFmpeg by performing the following steps: + * 1. Checks for FFmpeg availability in PATH + * 2. Retrieves FFmpeg version information + * 3. Verifies FFmpeg capabilities + * + * If FFmpeg is available, logs a success message with version, path, and timestamp. + * If FFmpeg is not available, logs installation instructions. + * + * @returns A Promise that resolves once FFmpeg has been successfully initialized + */ + private async initializeFFmpeg(): Promise { + try { + // First check if ffmpeg exists in PATH + await this.checkFFmpegAvailability(); + + if (this.ffmpegAvailable) { + // Get FFmpeg version info + await this.fetchFFmpegVersion(); + + // Verify FFmpeg capabilities + await this.verifyFFmpegCapabilities(); + + logger.success('FFmpeg initialized successfully: version=' + this.ffmpegVersion + ' path=' + this.ffmpegPath); + } else { + this.logFFmpegInstallInstructions(); + } + } catch (error) { + this.ffmpegAvailable = false; + logger.error('FFmpeg initialization failed: ' + (error instanceof Error ? error.message : String(error))); + this.logFFmpegInstallInstructions(); + } + } + + /** + * Asynchronously checks for the availability of FFmpeg in the system by executing a command to find the FFmpeg location. + * Updates the class properties `ffmpegPath` and `ffmpegAvailable` accordingly. + * Logs relevant information such as FFmpeg location and potential errors using the logger. + * + * @returns A Promise that resolves with no value upon completion. + */ + private async checkFFmpegAvailability(): Promise { + try { + const { stdout, stderr } = await execAsync('which ffmpeg || where ffmpeg'); + this.ffmpegPath = stdout.trim(); + this.ffmpegAvailable = true; + logger.info('FFmpeg found at: ' + this.ffmpegPath); + } catch (error) { + this.ffmpegAvailable = false; + this.ffmpegPath = null; + logger.error('FFmpeg not found in PATH: ' + (error instanceof Error ? error.message : String(error))); + } + } + + /** + * Verifies the FFmpeg capabilities by checking if FFmpeg supports the required codecs and formats. + * + * @returns {Promise} A Promise that resolves if FFmpeg has the required codecs, otherwise rejects with an error message. + */ + private async verifyFFmpegCapabilities(): Promise { + try { + // Check if FFmpeg supports required codecs and formats + const { stdout } = await execAsync('ffmpeg -codecs'); + const hasRequiredCodecs = stdout.includes('pcm_s16le') && stdout.includes('wav'); + + if (!hasRequiredCodecs) { + throw new Error('FFmpeg installation missing required codecs (pcm_s16le, wav)'); + } + + // logger.info("FFmpeg capabilities verified", { + // hasRequiredCodecs, + // timestamp: new Date().toISOString() + // }); + } catch (error) { + logger.error('FFmpeg capabilities verification failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + + /** + * Logs instructions on how to install FFmpeg if it is not properly installed. + */ + private logFFmpegInstallInstructions(): void { + logger.warn('FFmpeg is required but not properly installed. Please install FFmpeg: mac=brew install ffmpeg, ubuntu=sudo apt-get install ffmpeg, windows=choco install ffmpeg, manual=https://ffmpeg.org/download.html'); + } + + /** + * Gets the singleton instance of TranscribeManager, creates a new instance if it doesn't exist. + * + * @param {string} cacheDir - The directory path for caching transcriptions. + * @returns {TranscribeManager} The singleton instance of TranscribeManager. + */ + public static getInstance(cacheDir: string): TranscribeManager { + if (!TranscribeManager.instance) { + TranscribeManager.instance = new TranscribeManager(cacheDir); + } + return TranscribeManager.instance; + } + + /** + * Ensures that the cache directory exists. If it doesn't exist, + * creates the directory using fs.mkdirSync with recursive set to true. + * @returns {void} + */ + private ensureCacheDirectory(): void { + if (!fs.existsSync(this.cacheDir)) { + fs.mkdirSync(this.cacheDir, { recursive: true }); + // logger.info("Created whisper cache directory:", this.cacheDir); + } + } + + /** + * Converts an audio file to WAV format using FFmpeg. + * + * @param {string} inputPath - The input path of the audio file to convert. + * @param {string} outputPath - The output path where the converted WAV file will be saved. + * @returns {Promise} A Promise that resolves when the conversion is completed. + * @throws {Error} If FFmpeg is not installed or not properly configured, or if the audio conversion fails. + */ + private async convertToWav(inputPath: string, outputPath: string): Promise { + if (!this.ffmpegAvailable) { + throw new Error( + 'FFmpeg is not installed or not properly configured. Please install FFmpeg to use audio transcription.' + ); + } + + try { + // Add -loglevel error to suppress FFmpeg output unless there's an error + const { stderr } = await execAsync( + `ffmpeg -y -loglevel error -i "${inputPath}" -acodec pcm_s16le -ar 16000 -ac 1 "${outputPath}"` + ); + + if (stderr) { + logger.warn('FFmpeg conversion error: ' + stderr); + } + + if (!fs.existsSync(outputPath)) { + throw new Error('WAV file was not created successfully'); + } + } catch (error) { + logger.error('Audio conversion failed: ' + (error instanceof Error ? error.message : String(error))); + throw new Error( + `Failed to convert audio to WAV format: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + /** + * Asynchronously preprocesses the audio by converting the provided audio buffer into a WAV file. + * If FFmpeg is not installed, an error is thrown. + * + * @param {Buffer} audioBuffer The audio buffer to preprocess + * @returns {Promise} The path to the preprocessed WAV file + * @throws {Error} If FFmpeg is not installed or if audio preprocessing fails + */ + private async preprocessAudio(audioBuffer: Buffer): Promise { + if (!this.ffmpegAvailable) { + throw new Error('FFmpeg is not installed. Please install FFmpeg to use audio transcription.'); + } + + try { + // Check if the buffer is already a WAV file + const isWav = + audioBuffer.length > 4 && + audioBuffer.toString('ascii', 0, 4) === 'RIFF' && + audioBuffer.length > 12 && + audioBuffer.toString('ascii', 8, 12) === 'WAVE'; + + // Use appropriate extension based on format detection + const extension = isWav ? '.wav' : ''; + const tempInputFile = path.join(this.cacheDir, `temp_input_${Date.now()}${extension}`); + const tempWavFile = path.join(this.cacheDir, `temp_${Date.now()}.wav`); + + // logger.info("Creating temporary files", { + // inputFile: tempInputFile, + // wavFile: tempWavFile, + // bufferSize: audioBuffer.length, + // timestamp: new Date().toISOString() + // }); + + // Write buffer to temporary file + fs.writeFileSync(tempInputFile, audioBuffer); + // logger.info("Temporary input file created", { + // path: tempInputFile, + // size: audioBuffer.length, + // timestamp: new Date().toISOString() + // }); + + // If already WAV with correct format, skip conversion + if (isWav) { + // Check if it's already in the correct format (16kHz, mono, 16-bit) + try { + const { stdout } = await execAsync( + `ffprobe -v error -show_entries stream=sample_rate,channels,bits_per_raw_sample -of json "${tempInputFile}"` + ); + const probeResult = JSON.parse(stdout); + const stream = probeResult.streams?.[0]; + + if ( + stream?.sample_rate === '16000' && + stream?.channels === 1 && + (stream?.bits_per_raw_sample === 16 || stream?.bits_per_raw_sample === undefined) + ) { + // Already in correct format, just rename + fs.renameSync(tempInputFile, tempWavFile); + return tempWavFile; + } + } catch (probeError) { + // If probe fails, continue with conversion + logger.debug('FFprobe failed, continuing with conversion: ' + (probeError instanceof Error ? probeError.message : String(probeError))); + } + } + + // Convert to WAV format + await this.convertToWav(tempInputFile, tempWavFile); + + // Clean up the input file + if (fs.existsSync(tempInputFile)) { + fs.unlinkSync(tempInputFile); + // logger.info("Temporary input file cleaned up", { + // path: tempInputFile, + // timestamp: new Date().toISOString() + // }); + } + + return tempWavFile; + } catch (error) { + logger.error('Audio preprocessing failed: ' + (error instanceof Error ? error.message : String(error))); + throw new Error( + `Failed to preprocess audio: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + /** + * Transcribes the audio buffer to text using whisper. + * + * @param {Buffer} audioBuffer The audio buffer to transcribe. + * @returns {Promise} A promise that resolves with the transcription result. + * @throws {Error} If FFmpeg is not installed or properly configured. + */ + + public async transcribe(audioBuffer: Buffer): Promise { + await this.ensureFFmpeg(); + + if (!this.ffmpegAvailable) { + throw new Error( + 'FFmpeg is not installed or not properly configured. Please install FFmpeg to use audio transcription.' + ); + } + + try { + // Preprocess audio to WAV format + const wavFile = await this.preprocessAudio(audioBuffer); + + logger.info('Starting transcription with whisper...'); + + let segments; + try { + // Get the whisper function + const whisper = await getWhisper(); + + // Transcribe using whisper-node + segments = await whisper(wavFile, { + modelName: 'tiny', + modelPath: path.join(this.cacheDir, 'models'), // Specify where to store models + whisperOptions: { + language: 'en', + word_timestamps: false, // We don't need word-level timestamps + }, + }); + } catch (whisperError) { + // Check if it's a model download issue + const errorMessage = + whisperError instanceof Error ? whisperError.message : String(whisperError); + if (errorMessage.includes('not found') || errorMessage.includes('download')) { + logger.error('Whisper model not found. Please run: npx whisper-node download'); + throw new Error( + 'Whisper model not found. Please install it with: npx whisper-node download' + ); + } + + // For other errors, log and rethrow + logger.error('Whisper transcription error: ' + (whisperError instanceof Error ? whisperError.message : String(whisperError))); + throw whisperError; + } + + // Clean up temporary WAV file + if (fs.existsSync(wavFile)) { + fs.unlinkSync(wavFile); + logger.info('Temporary WAV file cleaned up'); + } + + // Check if segments is valid + if (!segments || !Array.isArray(segments)) { + logger.warn('Whisper returned no segments (likely silence or very short audio)'); + // Return empty transcription for silent/empty audio + return { text: '' }; + } + + // Handle empty segments array + if (segments.length === 0) { + logger.warn('No speech detected in audio'); + return { text: '' }; + } + + // Combine all segments into a single text + const cleanText = segments + .map((segment: any) => segment.speech?.trim() || '') + .filter((text: string) => text) // Remove empty segments + .join(' '); + + logger.success('Transcription complete: textLength=' + cleanText.length + ' segmentCount=' + segments.length); + + return { text: cleanText }; + } catch (error) { + logger.error('Transcription failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } +} diff --git a/src/utils/ttsManager.ts b/src/utils/ttsManager.ts new file mode 100644 index 0000000..4aa6faf --- /dev/null +++ b/src/utils/ttsManager.ts @@ -0,0 +1,254 @@ +import { logger } from '@elizaos/core'; +import { pipeline, type TextToAudioPipeline } from '@huggingface/transformers'; +import fs from 'node:fs'; +import path from 'node:path'; +import { fetch } from 'undici'; +import { MODEL_SPECS } from '../types'; +import { PassThrough, Readable } from 'node:stream'; + +// Audio Utils + +/** + * Generates a WAV file header based on the provided audio parameters. + * @param {number} audioLength - The length of the audio data in bytes. + * @param {number} sampleRate - The sample rate of the audio. + * @param {number} [channelCount=1] - The number of channels (default is 1). + * @param {number} [bitsPerSample=16] - The number of bits per sample (default is 16). + * @returns {Buffer} The WAV file header as a Buffer object. + */ +function getWavHeader( + audioLength: number, + sampleRate: number, + channelCount = 1, + bitsPerSample = 16 +): Buffer { + const wavHeader = Buffer.alloc(44); + wavHeader.write('RIFF', 0); + wavHeader.writeUInt32LE(36 + audioLength, 4); // Length of entire file in bytes minus 8 + wavHeader.write('WAVE', 8); + wavHeader.write('fmt ', 12); + wavHeader.writeUInt32LE(16, 16); // Length of format data + wavHeader.writeUInt16LE(1, 20); // Type of format (1 is PCM) + wavHeader.writeUInt16LE(channelCount, 22); // Number of channels + wavHeader.writeUInt32LE(sampleRate, 24); // Sample rate + wavHeader.writeUInt32LE((sampleRate * bitsPerSample * channelCount) / 8, 28); // Byte rate + wavHeader.writeUInt16LE((bitsPerSample * channelCount) / 8, 32); // Block align ((BitsPerSample * Channels) / 8) + wavHeader.writeUInt16LE(bitsPerSample, 34); // Bits per sample + wavHeader.write('data', 36); // Data chunk header + wavHeader.writeUInt32LE(audioLength, 40); // Data chunk size + return wavHeader; +} + +/** + * Prepends a WAV header to a readable stream of audio data. + * + * @param {Readable} readable - The readable stream containing the audio data. + * @param {number} audioLength - The length of the audio data in bytes. + * @param {number} sampleRate - The sample rate of the audio data. + * @param {number} [channelCount=1] - The number of channels in the audio data (default is 1). + * @param {number} [bitsPerSample=16] - The number of bits per sample in the audio data (default is 16). + * @returns {PassThrough} A new pass-through stream with the WAV header prepended to the audio data. + */ +function prependWavHeader( + readable: Readable, + audioLength: number, + sampleRate: number, + channelCount = 1, + bitsPerSample = 16 +): PassThrough { + const wavHeader = getWavHeader(audioLength, sampleRate, channelCount, bitsPerSample); + let pushedHeader = false; + const passThrough = new PassThrough(); + readable.on('data', (data: Buffer) => { + if (!pushedHeader) { + passThrough.push(wavHeader); + pushedHeader = true; + } + passThrough.push(data); + }); + readable.on('end', () => { + passThrough.end(); + }); + return passThrough; +} + +/** + * Class representing a Text-to-Speech Manager using Transformers.js + */ +export class TTSManager { + private static instance: TTSManager | null = null; + private cacheDir: string; + private synthesizer: TextToAudioPipeline | null = null; + private defaultSpeakerEmbedding: Float32Array | null = null; + private initialized = false; + private initializingPromise: Promise | null = null; + + private constructor(cacheDir: string) { + this.cacheDir = path.join(cacheDir, 'tts'); + this.ensureCacheDirectory(); + logger.debug('TTSManager using Transformers.js initialized'); + } + + public static getInstance(cacheDir: string): TTSManager { + if (!TTSManager.instance) { + TTSManager.instance = new TTSManager(cacheDir); + } + return TTSManager.instance; + } + + private ensureCacheDirectory(): void { + if (!fs.existsSync(this.cacheDir)) { + fs.mkdirSync(this.cacheDir, { recursive: true }); + logger.debug('Created TTS cache directory:', this.cacheDir); + } + } + + private async initialize(): Promise { + // Guard against concurrent calls: if an initialization is already in progress, return its promise. + if (this.initializingPromise) { + logger.debug('TTS initialization already in progress, awaiting existing promise.'); + return this.initializingPromise; + } + + // If already initialized, no need to do anything further. + if (this.initialized) { + logger.debug('TTS already initialized.'); + return; + } + + // Start the initialization process. + // The promise is stored in this.initializingPromise and cleared in the finally block. + this.initializingPromise = (async () => { + try { + logger.info('Initializing TTS with Transformers.js backend...'); + + const ttsModelSpec = MODEL_SPECS.tts.default; + if (!ttsModelSpec) { + throw new Error('Default TTS model specification not found in MODEL_SPECS.'); + } + const modelName = ttsModelSpec.modelId; + const speakerEmbeddingUrl = ttsModelSpec.defaultSpeakerEmbeddingUrl; + + // 1. Load the TTS Pipeline + logger.info(`Loading TTS pipeline for model: ${modelName}`); + this.synthesizer = await pipeline('text-to-audio', modelName) as TextToAudioPipeline; + logger.success(`TTS pipeline loaded successfully for model: ${modelName}`); + + // 2. Load Default Speaker Embedding (if specified) + if (speakerEmbeddingUrl) { + const embeddingFilename = path.basename(new URL(speakerEmbeddingUrl).pathname); + const embeddingPath = path.join(this.cacheDir, embeddingFilename); + + if (fs.existsSync(embeddingPath)) { + logger.info('Loading default speaker embedding from cache...'); + const buffer = fs.readFileSync(embeddingPath); + this.defaultSpeakerEmbedding = new Float32Array( + buffer.buffer, + buffer.byteOffset, + buffer.length / Float32Array.BYTES_PER_ELEMENT + ); + logger.success('Default speaker embedding loaded from cache.'); + } else { + logger.info(`Downloading default speaker embedding from: ${speakerEmbeddingUrl}`); + const response = await fetch(speakerEmbeddingUrl); + if (!response.ok) { + throw new Error(`Failed to download speaker embedding: ${response.statusText}`); + } + const buffer = await response.arrayBuffer(); + this.defaultSpeakerEmbedding = new Float32Array(buffer); + fs.writeFileSync(embeddingPath, Buffer.from(buffer)); + logger.success('Default speaker embedding downloaded and cached.'); + } + } else { + logger.warn( + `No default speaker embedding URL specified for model ${modelName}. Speaker control may be limited.` + ); + this.defaultSpeakerEmbedding = null; + } + + // Check synthesizer as embedding might be optional for some models + if (!this.synthesizer) { + throw new Error('TTS initialization failed: Pipeline not loaded.'); + } + + logger.success('TTS initialization complete (Transformers.js)'); + this.initialized = true; + } catch (error) { + logger.error('TTS (Transformers.js) initialization failed: ' + (error instanceof Error ? error.message : String(error))); + this.initialized = false; + this.synthesizer = null; + this.defaultSpeakerEmbedding = null; + throw error; // Propagate error to reject the initializingPromise + } finally { + // Clear the promise once initialization is complete (successfully or not) + this.initializingPromise = null; + logger.debug('TTS initializingPromise cleared after completion/failure.'); + } + })(); + + return this.initializingPromise; + } + + /** + * Asynchronously generates speech from a given text using the Transformers.js pipeline. + * @param {string} text - The text to generate speech from. + * @returns {Promise} A promise that resolves to a Readable stream containing the generated WAV audio data. + * @throws {Error} If the TTS model is not initialized or if generation fails. + */ + public async generateSpeech(text: string): Promise { + try { + await this.initialize(); + + // Check synthesizer is initialized (embedding might be null but handled in synthesizer call) + if (!this.synthesizer) { + throw new Error('TTS Manager not properly initialized.'); + } + + logger.info('Starting speech generation with Transformers.js for text: ' + text.substring(0, 50) + '...'); + + // Generate audio using the pipeline + const output = await this.synthesizer(text, { + // Pass embedding only if it was loaded + ...(this.defaultSpeakerEmbedding && { + speaker_embeddings: this.defaultSpeakerEmbedding, + }), + }); + + // output is { audio: Float32Array, sampling_rate: number } + const audioFloat32 = output.audio; + const samplingRate = output.sampling_rate; + + logger.info('Raw audio data received from pipeline: samplingRate=' + samplingRate + ' length=' + audioFloat32.length); + + if (!audioFloat32 || audioFloat32.length === 0) { + throw new Error('TTS pipeline generated empty audio output.'); + } + + // Convert Float32Array to Int16 Buffer (standard PCM for WAV) + const pcmData = new Int16Array(audioFloat32.length); + for (let i = 0; i < audioFloat32.length; i++) { + const s = Math.max(-1, Math.min(1, audioFloat32[i])); // Clamp to [-1, 1] + pcmData[i] = s < 0 ? s * 0x8000 : s * 0x7fff; // Convert to 16-bit [-32768, 32767] + } + const audioBuffer = Buffer.from(pcmData.buffer); + + logger.info('Audio data converted to 16-bit PCM Buffer: byteLength=' + audioBuffer.length); + + // Create WAV format stream + // Use samplingRate from the pipeline output + const audioStream = prependWavHeader( + Readable.from(audioBuffer), + audioBuffer.length, // Pass buffer length in bytes + samplingRate, + 1, // Number of channels (assuming mono) + 16 // Bit depth + ); + + logger.success('Speech generation complete (Transformers.js)'); + return audioStream; + } catch (error) { + logger.error('Transformers.js speech generation failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } +} diff --git a/src/utils/visionManager.ts b/src/utils/visionManager.ts new file mode 100644 index 0000000..4cd7fcd --- /dev/null +++ b/src/utils/visionManager.ts @@ -0,0 +1,475 @@ +import { existsSync } from 'node:fs'; +import fs from 'node:fs'; +import os from 'node:os'; +import path from 'node:path'; +import process from 'node:process'; +import { logger } from '@elizaos/core'; +import { + AutoProcessor, + AutoTokenizer, + Florence2ForConditionalGeneration, + type Florence2Processor, + type PreTrainedTokenizer, + type ProgressCallback, + type ProgressInfo, + RawImage, + type Tensor, + env, +} from '@huggingface/transformers'; +import { MODEL_SPECS } from '../types'; +import { DownloadManager } from './downloadManager'; + +// Define valid types based on HF transformers types +/** + * Defines the type 'DeviceType' which can take one of the three string values: 'cpu', 'gpu', or 'auto' + */ +type DeviceType = 'cpu' | 'gpu' | 'auto'; +/** + * Represents the available data types options. + */ +type DTypeType = 'fp32' | 'fp16' | 'auto'; + +/** + * Interface for platform configuration options. + * @typedef {Object} PlatformConfig + * @property {DeviceType} device - The type of device to use. + * @property {DTypeType} dtype - The data type to use. + * @property {boolean} useOnnx - Flag indicating whether to use ONNX for processing. + */ +interface PlatformConfig { + device: DeviceType; + dtype: DTypeType; + useOnnx: boolean; +} + +/** + * Represents a model component with a name, type, and optionally a data type. + * @interface ModelComponent + * @property { string } name - The name of the model component. + * @property { string } type - The type of the model component. + * @property { DTypeType } [dtype] - The data type of the model component (optional). + */ +interface ModelComponent { + name: string; + type: string; + dtype?: DTypeType; +} + +/** + * Class representing a VisionManager. + * @property {VisionManager | null} instance - The static instance of VisionManager. + * @property {Florence2ForConditionalGeneration | null} model - The model for conditional generation. + * @property {Florence2Processor | null} processor - The processor for Florence2. + * @property {PreTrainedTokenizer | null} tokenizer - The pre-trained tokenizer. + * @property {string} modelsDir - The directory for models. + * @property {string} cacheDir - The directory for caching. + * @property {boolean} initialized - Flag indicating if the VisionManager has been initialized. + * @property {DownloadManager} downloadManager - The manager for downloading. + */ +export class VisionManager { + private static instance: VisionManager | null = null; + private model: Florence2ForConditionalGeneration | null = null; + private processor: Florence2Processor | null = null; + private tokenizer: PreTrainedTokenizer | null = null; + private modelsDir: string; + private cacheDir: string; + private initialized = false; + private downloadManager: DownloadManager; + private modelDownloaded = false; + private tokenizerDownloaded = false; + private processorDownloaded = false; + private platformConfig: PlatformConfig; + private modelComponents: ModelComponent[] = [ + { name: 'embed_tokens', type: 'embeddings' }, + { name: 'vision_encoder', type: 'encoder' }, + { name: 'decoder_model_merged', type: 'decoder' }, + { name: 'encoder_model', type: 'encoder' }, + ]; + + /** + * Constructor for VisionManager class. + * + * @param {string} cacheDir - The directory path for caching vision models. + */ + private constructor(cacheDir: string) { + this.modelsDir = path.join(path.dirname(cacheDir), 'models', 'vision'); + this.cacheDir = cacheDir; + this.ensureModelsDirExists(); + this.downloadManager = DownloadManager.getInstance(this.cacheDir, this.modelsDir); + this.platformConfig = this.getPlatformConfig(); + logger.debug('VisionManager initialized'); + } + + /** + * Retrieves the platform configuration based on the operating system and architecture. + * @returns {PlatformConfig} The platform configuration object with device, dtype, and useOnnx properties. + */ + private getPlatformConfig(): PlatformConfig { + const platform = os.platform(); + const arch = os.arch(); + + // Default configuration + let config: PlatformConfig = { + device: 'cpu', + dtype: 'fp32', + useOnnx: true, + }; + + if (platform === 'darwin' && arch === 'arm64') { + // Apple Silicon + config = { + device: 'gpu', + dtype: 'fp16', + useOnnx: true, + }; + } else if (platform === 'win32' || platform === 'linux') { + // Windows or Linux with CUDA + const hasCuda = process.env.CUDA_VISIBLE_DEVICES !== undefined; + if (hasCuda) { + config = { + device: 'gpu', + dtype: 'fp16', + useOnnx: true, + }; + } + } + return config; + } + + /** + * Ensures that the models directory exists. If it does not exist, it creates the directory. + */ + private ensureModelsDirExists(): void { + if (!existsSync(this.modelsDir)) { + logger.debug(`Creating models directory at: ${this.modelsDir}`); + fs.mkdirSync(this.modelsDir, { recursive: true }); + } + } + + /** + * Returns the singleton instance of VisionManager. + * If an instance does not already exist, a new instance is created with the specified cache directory. + * + * @param {string} cacheDir - The directory where cache files will be stored. + * + * @returns {VisionManager} The singleton instance of VisionManager. + */ + public static getInstance(cacheDir: string): VisionManager { + if (!VisionManager.instance) { + VisionManager.instance = new VisionManager(cacheDir); + } + return VisionManager.instance; + } + + /** + * Check if the cache exists for the specified model or tokenizer or processor. + * @param {string} modelId - The ID of the model. + * @param {"model" | "tokenizer" | "processor"} type - The type of the cache ("model", "tokenizer", or "processor"). + * @returns {boolean} - Returns true if cache exists, otherwise returns false. + */ + private checkCacheExists(modelId: string, type: 'model' | 'tokenizer' | 'processor'): boolean { + const modelPath = path.join(this.modelsDir, modelId.replace('/', '--'), type); + if (existsSync(modelPath)) { + logger.info(`${type} found at: ${modelPath}`); + return true; + } + return false; + } + + /** + * Configures the model components based on the platform and architecture. + * Sets the default data type (dtype) for components based on platform capabilities. + * Updates all component dtypes to match the default dtype. + */ + private configureModelComponents(): void { + const platform = os.platform(); + const arch = os.arch(); + + // Set dtype based on platform capabilities + let defaultDtype: DTypeType = 'fp32'; + + if (platform === 'darwin' && arch === 'arm64') { + // Apple Silicon can handle fp16 + defaultDtype = 'fp16'; + } else if ( + (platform === 'win32' || platform === 'linux') && + process.env.CUDA_VISIBLE_DEVICES !== undefined + ) { + // CUDA-enabled systems can handle fp16 + defaultDtype = 'fp16'; + } + + // Update all component dtypes + this.modelComponents = this.modelComponents.map((component) => ({ + ...component, + dtype: defaultDtype, + })); + + logger.info('Model components configured: platform=' + platform + ' arch=' + arch + ' dtype=' + defaultDtype); + } + + /** + * Get the model configuration based on the input component name. + * @param {string} componentName - The name of the component to retrieve the configuration for. + * @returns {object} The model configuration object containing device, dtype, and cache_dir. + */ + private getModelConfig(componentName: string) { + const component = this.modelComponents.find((c) => c.name === componentName); + return { + device: this.platformConfig.device, + dtype: component?.dtype || 'fp32', + cache_dir: this.modelsDir, + }; + } + + /** + * Asynchronous method to initialize the vision model by loading Florence2 model, vision tokenizer, and vision processor. + * + * @returns {Promise} - Promise that resolves once the initialization process is completed. + * @throws {Error} - If there is an error during the initialization process. + */ + private async initialize() { + try { + if (this.initialized) { + logger.info('Vision model already initialized, skipping initialization'); + return; + } + + logger.info('Starting vision model initialization...'); + const modelSpec = MODEL_SPECS.vision; + + // Configure environment + logger.info('Configuring environment for vision model...'); + env.allowLocalModels = true; + env.allowRemoteModels = true; + + // Configure ONNX backend + if (this.platformConfig.useOnnx) { + env.backends.onnx.enabled = true; + env.backends.onnx.logLevel = 'info'; + } + + // logger.info("Vision model configuration:", { + // modelId: modelSpec.modelId, + // modelsDir: this.modelsDir, + // allowLocalModels: env.allowLocalModels, + // allowRemoteModels: env.allowRemoteModels, + // platform: this.platformConfig + // }); + + // Initialize model with detailed logging + logger.info('Loading Florence2 model...'); + try { + let lastProgress = -1; + const modelCached = this.checkCacheExists(modelSpec.modelId, 'model'); + + const model = await Florence2ForConditionalGeneration.from_pretrained(modelSpec.modelId, { + device: 'cpu', + cache_dir: this.modelsDir, + local_files_only: modelCached, + revision: 'main', + progress_callback: ((progressInfo: ProgressInfo) => { + if (modelCached || this.modelDownloaded) return; + const progress = + 'progress' in progressInfo ? Math.max(0, Math.min(1, progressInfo.progress)) : 0; + const currentProgress = Math.round(progress * 100); + if (currentProgress > lastProgress + 9 || currentProgress === 100) { + lastProgress = currentProgress; + const barLength = 30; + const filledLength = Math.floor((currentProgress / 100) * barLength); + const progressBar = '▰'.repeat(filledLength) + '▱'.repeat(barLength - filledLength); + logger.info(`Downloading vision model: ${progressBar} ${currentProgress}%`); + if (currentProgress === 100) this.modelDownloaded = true; + } + }) as ProgressCallback, + }); + + this.model = model as unknown as Florence2ForConditionalGeneration; + logger.success('Florence2 model loaded successfully'); + } catch (error) { + logger.error('Failed to load Florence2 model: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + + // Initialize tokenizer with detailed logging + logger.info('Loading vision tokenizer...'); + try { + const tokenizerCached = this.checkCacheExists(modelSpec.modelId, 'tokenizer'); + let tokenizerProgress = -1; + + this.tokenizer = await AutoTokenizer.from_pretrained(modelSpec.modelId, { + cache_dir: this.modelsDir, + local_files_only: tokenizerCached, + progress_callback: ((progressInfo: ProgressInfo) => { + if (tokenizerCached || this.tokenizerDownloaded) return; + const progress = + 'progress' in progressInfo ? Math.max(0, Math.min(1, progressInfo.progress)) : 0; + const currentProgress = Math.round(progress * 100); + if (currentProgress !== tokenizerProgress) { + tokenizerProgress = currentProgress; + const barLength = 30; + const filledLength = Math.floor((currentProgress / 100) * barLength); + const progressBar = '▰'.repeat(filledLength) + '▱'.repeat(barLength - filledLength); + logger.info(`Downloading vision tokenizer: ${progressBar} ${currentProgress}%`); + if (currentProgress === 100) this.tokenizerDownloaded = true; + } + }) as ProgressCallback, + }); + logger.success('Vision tokenizer loaded successfully'); + } catch (error) { + logger.error('Failed to load tokenizer: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + + // Initialize processor with detailed logging + logger.info('Loading vision processor...'); + try { + const processorCached = this.checkCacheExists(modelSpec.modelId, 'processor'); + let processorProgress = -1; + + this.processor = (await AutoProcessor.from_pretrained(modelSpec.modelId, { + device: 'cpu', + cache_dir: this.modelsDir, + local_files_only: processorCached, + progress_callback: ((progressInfo: ProgressInfo) => { + if (processorCached || this.processorDownloaded) return; + const progress = + 'progress' in progressInfo ? Math.max(0, Math.min(1, progressInfo.progress)) : 0; + const currentProgress = Math.round(progress * 100); + if (currentProgress !== processorProgress) { + processorProgress = currentProgress; + const barLength = 30; + const filledLength = Math.floor((currentProgress / 100) * barLength); + const progressBar = '▰'.repeat(filledLength) + '▱'.repeat(barLength - filledLength); + logger.info(`Downloading vision processor: ${progressBar} ${currentProgress}%`); + if (currentProgress === 100) this.processorDownloaded = true; + } + }) as ProgressCallback, + })) as Florence2Processor; + logger.success('Vision processor loaded successfully'); + } catch (error) { + logger.error('Failed to load vision processor: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + + this.initialized = true; + logger.success('Vision model initialization complete'); + } catch (error) { + logger.error('Vision model initialization failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + + /** + * Fetches an image from a given URL and returns the image data as a Buffer along with its MIME type. + * + * @param {string} url - The URL of the image to fetch. + * @returns {Promise<{ buffer: Buffer; mimeType: string }>} Object containing the image data as a Buffer and its MIME type. + */ + private async fetchImage(url: string): Promise<{ buffer: Buffer; mimeType: string }> { + try { + logger.info(`Fetching image from URL: ${url.slice(0, 100)}...`); + + // Handle data URLs differently + if (url.startsWith('data:')) { + logger.info('Processing data URL...'); + const [header, base64Data] = url.split(','); + const mimeType = header.split(';')[0].split(':')[1]; + const buffer = Buffer.from(base64Data, 'base64'); + logger.info('Data URL processed successfully'); + // logger.info("Data URL processed successfully:", { + // mimeType, + // bufferSize: buffer.length + // }); + return { buffer, mimeType }; + } + + // Handle regular URLs + const response = await fetch(url); + if (!response.ok) { + throw new Error(`Failed to fetch image: ${response.statusText}`); + } + const buffer = Buffer.from(await response.arrayBuffer()); + const mimeType = response.headers.get('content-type') || 'image/jpeg'; + + logger.info('Image fetched successfully: mimeType=' + mimeType + ' size=' + buffer.length); + + return { buffer, mimeType }; + } catch (error) { + logger.error('Failed to fetch image: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } + + /** + * Processes the image from the provided URL using the initialized vision model components. + * @param {string} imageUrl - The URL of the image to process. + * @returns {Promise<{ title: string; description: string }>} An object containing the title and description of the processed image. + */ + public async processImage(imageUrl: string): Promise<{ title: string; description: string }> { + try { + logger.info('Starting image processing...'); + + // Ensure model is initialized + if (!this.initialized) { + logger.info('Vision model not initialized, initializing now...'); + await this.initialize(); + } + + if (!this.model || !this.processor || !this.tokenizer) { + throw new Error('Vision model components not properly initialized'); + } + + // Fetch and process image + logger.info('Fetching image...'); + const { buffer, mimeType } = await this.fetchImage(imageUrl); + + // Process image + logger.info('Creating image blob...'); + const blob = new Blob([new Uint8Array(buffer)], { type: mimeType }); + logger.info('Converting blob to RawImage...'); + // @ts-ignore - RawImage.fromBlob expects web Blob but works with node Blob + const image = await RawImage.fromBlob(blob); + + logger.info('Processing image with vision processor...'); + const visionInputs = await this.processor(image); + logger.info('Constructing prompts...'); + const prompts = this.processor.construct_prompts(''); + logger.info('Tokenizing prompts...'); + const textInputs = this.tokenizer(prompts); + + // Generate description + logger.info('Generating image description...'); + const generatedIds = (await this.model.generate({ + ...textInputs, + ...visionInputs, + max_new_tokens: MODEL_SPECS.vision.maxTokens, + })) as Tensor; + + logger.info('Decoding generated text...'); + const generatedText = this.tokenizer.batch_decode(generatedIds, { + skip_special_tokens: false, + })[0]; + + logger.info('Post-processing generation...'); + const result = this.processor.post_process_generation( + generatedText, + '', + image.size + ); + + const detailedCaption = result[''] as string; + const response = { + title: `${detailedCaption.split('.')[0]}.`, + description: detailedCaption, + }; + + logger.success('Image processing complete: title=' + response.title.length + ' desc=' + response.description.length); + + return response; + } catch (error) { + logger.error('Image processing failed: ' + (error instanceof Error ? error.message : String(error))); + throw error; + } + } +}