Skip to content

Streaming Change #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/tts-demo/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
"typescript-eslint": "^8.18.2",
"vite": "^6.0.5"
}
}
}
288 changes: 244 additions & 44 deletions examples/tts-demo/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useState } from 'react';
import { useState, useRef, useEffect } from 'react';
import { BrowserAI } from '@browserai/browserai';
import styled from '@emotion/styled';

Expand Down Expand Up @@ -182,9 +182,88 @@ function App() {
const [isLoading, setIsLoading] = useState(false);
const [ttsAI] = useState(new BrowserAI());
const [isModelLoaded, setIsModelLoaded] = useState(false);
const [audioBlob, setAudioBlob] = useState<Blob | null>(null);
const [selectedVoice, setSelectedVoice] = useState('af_bella');
const [speed, setSpeed] = useState(1.0);
const [audioBlob, setAudioBlob] = useState<Blob | null>(null);

// Audio streaming references
const audioContextRef = useRef<AudioContext | null>(null);
const nextPlayTimeRef = useRef<number>(0);
const isPlayingRef = useRef<boolean>(false);
const accumulatedAudioChunksRef = useRef<Float32Array[]>([]);
const sampleRateRef = useRef<number>(24000);

// Clean up audio context on unmount
useEffect(() => {
return () => {
if (audioContextRef.current) {
audioContextRef.current.close();
}
};
}, []);

// Create WAV header
const createWAVHeader = (numChannels: number, sampleRate: number, numSamples: number): ArrayBuffer => {
const buffer = new ArrayBuffer(44);
const view = new DataView(buffer);

// "RIFF" chunk descriptor
writeString(view, 0, 'RIFF');
// File size (data size + 36 bytes of header)
view.setUint32(4, 36 + numSamples * 2, true);
writeString(view, 8, 'WAVE');

// "fmt " sub-chunk
writeString(view, 12, 'fmt ');
view.setUint32(16, 16, true); // fmt chunk size
view.setUint16(20, 1, true); // audio format (1 for PCM)
view.setUint16(22, numChannels, true);
view.setUint32(24, sampleRate, true);
view.setUint32(28, sampleRate * numChannels * 2, true); // byte rate
view.setUint16(32, numChannels * 2, true); // block align
view.setUint16(34, 16, true); // bits per sample

// "data" sub-chunk
writeString(view, 36, 'data');
view.setUint32(40, numSamples * 2, true); // data size

return buffer;
};

// Helper function to write string to DataView
const writeString = (view: DataView, offset: number, string: string) => {
for (let i = 0; i < string.length; i++) {
view.setUint8(offset + i, string.charCodeAt(i));
}
};

const initializeAudioContext = () => {
if (!audioContextRef.current || audioContextRef.current.state === 'closed') {
const context = new (window.AudioContext || (window as any).webkitAudioContext)();
audioContextRef.current = context;
nextPlayTimeRef.current = context.currentTime; // Initialize play time
return context;
}
return audioContextRef.current;
};

const playAudioChunk = (context: AudioContext, chunk: Float32Array, sampleRate: number) => {
const buffer = context.createBuffer(1, chunk.length, sampleRate);
buffer.copyToChannel(chunk, 0);

const node = context.createBufferSource();
node.buffer = buffer;
node.connect(context.destination);

// Schedule playback precisely
const scheduledTime = Math.max(context.currentTime, nextPlayTimeRef.current);
node.start(scheduledTime);

// Update the time for the next chunk
nextPlayTimeRef.current = scheduledTime + buffer.duration;

return node;
};

const loadModel = async () => {
try {
Expand All @@ -204,58 +283,171 @@ function App() {
setStatus('Please enter some text first');
return;
}
if (!isModelLoaded || isLoading) return;

setIsLoading(true);
setStatus('Generating speech stream...');

// Reset any previous audio state
accumulatedAudioChunksRef.current = [];
isPlayingRef.current = true;

const currentAudioContext = initializeAudioContext();
if (!currentAudioContext) {
setStatus('Failed to initialize Audio Context');
setIsLoading(false);
return;
}

// Ensure audio context is running (required after user interaction)
if (currentAudioContext.state === 'suspended') {
await currentAudioContext.resume();
}

// Reset nextPlayTime for new playback
nextPlayTimeRef.current = currentAudioContext.currentTime;

try {
setIsLoading(true);
setStatus('Generating speech...');
const audioData = await ttsAI.textToSpeech(text, {
// Get language from selected voice
const selectedVoiceData = VOICE_OPTIONS.find(v => v.id === selectedVoice);
if (!selectedVoiceData) {
throw new Error("Selected voice data not found.");
}
const language = selectedVoiceData.language;

const result = await ttsAI.textToSpeech(text, {
voice: selectedVoice,
speed: speed
speed: speed,
language: language // Pass explicit language code
});

if (audioData) {
// Create a blob with WAV MIME type
const blob = new Blob([audioData], { type: 'audio/wav' });
setAudioBlob(blob); // Store the blob for download
const audioUrl = URL.createObjectURL(blob);

// Create and play audio element
const audio = new Audio(audioUrl);
// Extract stream and sampleRate from the result
const { stream, sampleRate } = result;

// Store sample rate for WAV generation
sampleRateRef.current = sampleRate;

// Reset accumulated chunks
accumulatedAudioChunksRef.current = [];

// Clear any previous audio blob
setAudioBlob(null);

setStatus('Streaming audio...');
let chunksProcessed = 0;

// Process each chunk from the stream
for await (const chunk of stream) {
if (!isPlayingRef.current) break; // Allow stopping

audio.onended = () => {
setStatus('Finished playing');
setIsLoading(false);
URL.revokeObjectURL(audioUrl); // Clean up
};
// Store the chunk for potential download later
accumulatedAudioChunksRef.current.push(chunk);

audio.onerror = (e) => {
console.error('Audio playback error:', e);
setStatus('Error playing audio');
setIsLoading(false);
URL.revokeObjectURL(audioUrl);
};
// Play this chunk
playAudioChunk(currentAudioContext, chunk, sampleRate);

setStatus('Playing audio...');
await audio.play();
// Update status occasionally to show progress
chunksProcessed++;
if (chunksProcessed % 10 === 0) {
setStatus('Streaming audio...');
}
}

// Calculate when all audio will finish playing
const estimatedDuration = nextPlayTimeRef.current - currentAudioContext.currentTime;
const finishingDelay = Math.max(estimatedDuration * 1000, 100); // At least 100ms

setTimeout(() => {
if (isPlayingRef.current) {
// Create blob for download
if (accumulatedAudioChunksRef.current.length > 0) {
// Calculate total length of all chunks
const totalLength = accumulatedAudioChunksRef.current.reduce((total, chunk) => total + chunk.length, 0);

// Create a combined Float32Array
const combinedFloat32 = new Float32Array(totalLength);
let offset = 0;

// Copy all chunks into the combined array
for (const chunk of accumulatedAudioChunksRef.current) {
combinedFloat32.set(chunk, offset);
offset += chunk.length;
}

// Normalize if needed - skip this as chunks are already normalized
// const maxValue = combinedFloat32.reduce((max, val) => Math.max(max, Math.abs(val)), 0);
// const normalizedData = maxValue > 0 ? new Float32Array(combinedFloat32.length) : combinedFloat32;

// if (maxValue > 0) {
// for (let i = 0; i < combinedFloat32.length; i++) {
// normalizedData[i] = combinedFloat32[i] / maxValue;
// }
// }

// Convert to Int16Array for WAV
const int16Array = new Int16Array(combinedFloat32.length);
const int16Factor = 0x7FFF;

for (let i = 0; i < combinedFloat32.length; i++) {
const s = combinedFloat32[i];
int16Array[i] = s < 0 ? Math.max(-0x8000, s * 0x8000) : Math.min(0x7FFF, s * int16Factor);
}

// Create WAV header
const wavHeader = createWAVHeader(1, sampleRateRef.current, int16Array.length);

// Combine header with audio data
const wavBytes = new Uint8Array(44 + int16Array.byteLength);
wavBytes.set(new Uint8Array(wavHeader), 0);
wavBytes.set(new Uint8Array(int16Array.buffer), 44);

// Create blob for download
const blob = new Blob([wavBytes], { type: 'audio/wav' });
setAudioBlob(blob);
}

console.log(`Finished playing stream (${chunksProcessed} total chunks)`);
setStatus('Finished playing stream');
setIsLoading(false);
isPlayingRef.current = false;
}
}, finishingDelay);

} catch (error) {
console.error('Error in speak:', error);
setStatus('Error generating speech: ' + (error as Error).message);
console.error('Error in speech stream:', error);
setStatus('Error generating or playing stream: ' + (error as Error).message);
setIsLoading(false);
isPlayingRef.current = false;
}
};

const stopSpeak = () => {
isPlayingRef.current = false;
setIsLoading(false);
setStatus('Playback stopped.');

// Reset audio context time tracking
if (audioContextRef.current) {
nextPlayTimeRef.current = audioContextRef.current.currentTime;
}
};

const downloadAudio = () => {
if (audioBlob) {
const url = URL.createObjectURL(audioBlob);
const a = document.createElement('a');
a.href = url;
a.download = 'generated-speech.wav';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
if (!audioBlob) {
setStatus('No audio data available to download');
return;
}

const url = URL.createObjectURL(audioBlob);
const a = document.createElement('a');
a.href = url;
a.download = 'generated-speech.wav';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);

setStatus('Audio downloaded successfully');
};

return (
Expand All @@ -268,7 +460,7 @@ function App() {
<Container>
<div>
<Title>Kokoro TTS Demo</Title>
<Subtitle>A lightweight, browser-based text-to-speech engine</Subtitle>
<Subtitle>A lightweight, browser-based text-to-speech engine with streaming</Subtitle>
</div>

<Button
Expand Down Expand Up @@ -322,15 +514,23 @@ function App() {
<Button
onClick={speak}
disabled={!isModelLoaded || isLoading || !text.trim()}
isLoading={isLoading && isModelLoaded}
isLoading={isLoading}
>
<ButtonContent>
{(isLoading && isModelLoaded) && <Spinner />}
{isLoading ? 'Processing...' : 'Speak'}
{isLoading && <Spinner />}
{isLoading ? 'Streaming...' : 'Speak'}
</ButtonContent>
</Button>

{audioBlob && (

{isLoading && (
<Button onClick={stopSpeak}>
<ButtonContent>
Stop
</ButtonContent>
</Button>
)}

{audioBlob && !isLoading && (
<Button onClick={downloadAudio}>
<ButtonContent>
Download Audio
Expand Down
22 changes: 13 additions & 9 deletions src/core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { TransformersEngineWrapper } from '../../engines/transformer-engine-wrap
import { ModelConfig, MLCConfig, TransformersConfig } from '../../config/models/types';
import mlcModels from '../../config/models/mlc-models.json';
import transformersModels from '../../config/models/transformers-models.json';
import { TTSEngine } from '@/engines/tts-engine';

// Combine model configurations
const MODEL_CONFIG: Record<string, ModelConfig> = {
Expand All @@ -19,7 +18,6 @@ export class BrowserAI {
private mediaRecorder: MediaRecorder | null = null;
private audioChunks: Blob[] = [];
private modelIdentifier: string | null = null;
private ttsEngine: TTSEngine | null = null;
private customModels: Record<string, ModelConfig> = {};

constructor() {
Expand Down Expand Up @@ -170,21 +168,27 @@ export class BrowserAI {
return response as string;
}

async textToSpeech(text: string, options: Record<string, unknown> = {}): Promise<ArrayBuffer> {
if (!this.ttsEngine) {
this.ttsEngine = new TTSEngine();
await this.ttsEngine.loadModel(MODEL_CONFIG['kokoro-tts'], {
async textToSpeech(text: string, options: Record<string, unknown> = {}): Promise<any> {
// Check if engine is already loaded
if (!this.engine) {
// Load the transformers engine if not already loaded
this.engine = new TransformersEngineWrapper();
await this.engine.loadModel(MODEL_CONFIG['kokoro-tts'], {
quantized: true,
device: 'webgpu',
...options,
});
}

try {
const audioData = await this.ttsEngine.generateSpeech(text, options);
return audioData;
if (this.engine instanceof TransformersEngineWrapper) {
// Use the streaming method
return await this.engine.textToSpeechStream(text, options);
} else {
throw new Error('Current engine does not support text-to-speech streaming');
}
} catch (error) {
console.error('Error generating speech:', error);
console.error('Error generating speech stream:', error);
throw error;
}
}
Expand Down
Loading