diff --git a/README.md b/README.md index 12cf6d0..bd2b48b 100644 --- a/README.md +++ b/README.md @@ -1,144 +1,9 @@ -# LuxTTS -

- - Hugging Face Model - -   - - Hugging Face Space - -   - - Colab Notebook - -

+Project quick start :https://github.com/ysharma3501/LuxTTS -LuxTTS is an lightweight zipvoice based text-to-speech model designed for high quality voice cloning and realistic generation at speeds exceeding 150x realtime. +update note: -https://github.com/user-attachments/assets/a3b57152-8d97-43ce-bd99-26dc9a145c29 - - -### The main features are -- Voice cloning: SOTA voice cloning on par with models 10x larger. -- Clarity: Clear 48khz speech generation unlike most TTS models which are limited to 24khz. -- Speed: Reaches speeds of 150x realtime on a single GPU and faster then realtime on CPU's as well. -- Efficiency: Fits within 1gb vram meaning it can fit in any local gpu. - -## Usage -You can try it locally, colab, or spaces. - -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1cDaxtbSDLRmu6tRV_781Of_GSjHSo1Cu?usp=sharing) -[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/YatharthS/LuxTTS) - -#### Simple installation: -``` -git clone https://github.com/ysharma3501/LuxTTS.git -cd LuxTTS -pip install -r requirements.txt -``` - -#### Load model: -```python -from zipvoice.luxvoice import LuxTTS - -# load model on GPU -lux_tts = LuxTTS('YatharthS/LuxTTS', device='cuda') - -# load model on CPU -# lux_tts = LuxTTS('YatharthS/LuxTTS', device='cpu', threads=2) - -# load model on MPS for macs -# lux_tts = LuxTTS('YatharthS/LuxTTS', device='mps') -``` - -#### Simple inference -```python -import soundfile as sf -from IPython.display import Audio - -text = "Hey, what's up? I'm feeling really great if you ask me honestly!" - -## change this to your reference file path, can be wav/mp3 -prompt_audio = 'audio_file.wav' - -## encode audio(takes 10s to init because of librosa first time) -encoded_prompt = lux_tts.encode_prompt(prompt_audio, rms=0.01) - -## generate speech -final_wav = lux_tts.generate_speech(text, encoded_prompt, num_steps=4) - -## save audio -final_wav = final_wav.numpy().squeeze() -sf.write('output.wav', final_wav, 48000) - -## display speech -if display is not None: - display(Audio(final_wav, rate=48000)) -``` - -#### Inference with sampling params: -```python -import soundfile as sf -from IPython.display import Audio - -text = "Hey, what's up? I'm feeling really great if you ask me honestly!" - -## change this to your reference file path, can be wav/mp3 -prompt_audio = 'audio_file.wav' - -rms = 0.01 ## higher makes it sound louder(0.01 or so recommended) -t_shift = 0.9 ## sampling param, higher can sound better but worse WER -num_steps = 4 ## sampling param, higher sounds better but takes longer(3-4 is best for efficiency) -speed = 1.0 ## sampling param, controls speed of audio(lower=slower) -return_smooth = False ## sampling param, makes it sound smoother possibly but less cleaner -ref_duration = 5 ## Setting it lower can speedup inference, set to 1000 if you find artifacts. - -## encode audio(takes 10s to init because of librosa first time) -encoded_prompt = lux_tts.encode_prompt(prompt_audio, duration=ref_duration, rms=rms) - -## generate speech -final_wav = lux_tts.generate_speech(text, encoded_prompt, num_steps=num_steps, t_shift=t_shift, speed=speed, return_smooth=return_smooth) - -## save audio -final_wav = final_wav.numpy().squeeze() -sf.write('output.wav', final_wav, 48000) - -## display speech -if display is not None: - display(Audio(final_wav, rate=48000)) -``` -## Tips -- Please use at minimum a 3 second audio file for voice cloning. -- You can use return_smooth = True if you hear metallic sounds. -- Lower t_shift for less possible pronunciation errors but worse quality and vice versa. - - -## Info - -Q: How is this different from ZipVoice? - -A: LuxTTS uses the same architecture but distilled to 4 steps with an improved sampling technique. It also uses a custom 48khz vocoder instead of the default 24khz version. - -Q: Can it be even faster? - -A: Yes, currently it uses float32. Float16 should be significantly faster(almost 2x). - -## Roadmap - -- [x] Release model and code -- [x] Huggingface spaces demo -- [x] Release MPS support (thanks to @builtbybasit) -- [ ] Release code for float16 inference - -## Acknowledgments - -- [ZipVoice](https://github.com/k2-fsa/ZipVoice) for their excellent code and model. -- [Vocos](https://github.com/gemelo-ai/vocos.git) for their great vocoder. - -## Final Notes - -The model and code are licensed under the Apache-2.0 license. See LICENSE for details. - -Stars/Likes would be appreciated, thank you. - -Email: yatharthsharma350@gmail.com +2026-01-29 +1. Skip Whisper recognition if text exists in speaker.yml; +2. Add 50ms/80ms silence via NumPy in post-processing; +3. Language-specific t_shift/guidance_scale by text proportion; +4. Independent Chinese speech rate with token padding coefficient. diff --git a/lux_server.py b/lux_server.py new file mode 100644 index 0000000..2d75999 --- /dev/null +++ b/lux_server.py @@ -0,0 +1,672 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Lux-TTS 服务端实现 +支持HTTP接口和SSE接口 +参数包括speaker、text +""" + +import os +import yaml +import base64 +import torch +import torchaudio +import numpy as np +from fastapi import FastAPI, Query, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +from pydantic import BaseModel +import time +from typing import Generator, Dict, Any, Tuple, List +from contextlib import asynccontextmanager +import soundfile as sf +from io import BytesIO +import logging +from zipvoice.luxvoice import LuxTTS +import asyncio +from dataclasses import dataclass +from datetime import datetime + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# ==================== 配置参数区域 ==================== +# 服务配置 +NUM_MODEL_INSTANCES = int(os.getenv('NUM_MODEL_INSTANCES', '24')) # 默认24个实例(适合40GB显存) +MODEL_ACQUIRE_TIMEOUT = int(os.getenv('MODEL_ACQUIRE_TIMEOUT', '10')) # 获取模型超时时间(秒) + +# 线程池配置 +import concurrent.futures +CPU_CORES = 24 # vCPU核心数 +MAX_WORKERS = min(int(CPU_CORES * 1.2), 36) # 线程池大小 +executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) + +# 音频配置 +ORIGINAL_SAMPLE_RATE = 48000 # 原始采样率 +TARGET_SAMPLE_RATE = 16000 # 目标采样率 + +# 模型参数 +rms = 0.01 ## higher makes it sound louder(0.01 or so recommended) +t_shift = 0.65 ## sampling param, higher can sound better but worse WER +num_steps = 4 ## sampling param, higher sounds better but takes longer(3-4 is best for efficiency) +return_smooth = True ## sampling param, makes it sound smoother possibly but less cleaner + +# 静音配置 +PRE_SILENCE_MS = 50 # 前静音(毫秒) +POST_SILENCE_MS = 80 # 后静音(毫秒) + +# 语言相关配置 +CHINESE_T_SHIFT = 0.65 # 中文t_shift值 +CHINESE_GUIDANCE_SCALE = 3.0 # 中文guidance_scale值 +ENGLISH_T_SHIFT = 0.85 # 英文t_shift值 +ENGLISH_GUIDANCE_SCALE = 3.0 # 英文guidance_scale值 +CHINESE_SPEED_ADJUSTMENT = 0.8 # 中文速度调整系数 +# ===================================================== + +# 存储发音人特征数据 +speaker_embeddings = {} + +# 全局重采样器(在启动时初始化) +gpu_resampler = None +cpu_resampler = None + + +class TTSRequest(BaseModel): + """TTS请求体模型""" + text: str + speaker: str + speed: float = 1.0 # 语速控制参数,默认1.0,范围建议0.5-2.0 + + +class SpeakerInfo: + """发音人信息类""" + + def __init__(self, key: str, wav_path: str, text: str = None): + self.key = key + self.wav_path = wav_path + self.text = text + self.embedding = None + + +@dataclass +class ModelInstanceStats: + """模型实例统计信息""" + instance_id: int + total_requests: int = 0 + total_inference_time: float = 0.0 + last_used_time: datetime = None + is_busy: bool = False + + +class ModelInstance: + """单个模型实例封装""" + + def __init__(self, instance_id: int, model: Any): + self.instance_id = instance_id + self.model = model + self.stats = ModelInstanceStats(instance_id=instance_id) + # 不再需要锁,由队列管理并发 + + def update_stats(self, inference_time: float): + """更新统计信息""" + self.stats.total_requests += 1 + self.stats.total_inference_time += inference_time + self.stats.last_used_time = datetime.now() + self.stats.is_busy = False + + +class ModelPool: + """模型实例池管理器(使用 asyncio.Queue 实现真正的异步)""" + + def __init__(self, num_instances: int): + self.num_instances = num_instances + self.instances: List[ModelInstance] = [] + self.queue: asyncio.Queue = None # 在 lifespan 中初始化 + self._initialized = False + self._lock = asyncio.Lock() + + async def initialize(self, model_path: str, device: str = 'cuda', threads: int = 4): + """初始化所有模型实例""" + async with self._lock: + if self._initialized: + logger.warning("ModelPool already initialized") + return + + # 创建异步队列 + self.queue = asyncio.Queue(maxsize=self.num_instances) + + logger.info(f"Initializing {self.num_instances} model instances...") + logger.info(f"Using device: {device}, threads per instance: {threads}") + + for i in range(self.num_instances): + try: + logger.info(f"Loading model instance {i+1}/{self.num_instances}...") + model = LuxTTS(model_path, device=device, threads=threads) + instance = ModelInstance(instance_id=i, model=model) + self.instances.append(instance) + await self.queue.put(instance) # 使用 await + logger.info(f"Model instance {i} loaded successfully") + except Exception as e: + logger.error(f"Failed to load model instance {i}: {e}") + raise + + self._initialized = True + logger.info(f"Successfully initialized {len(self.instances)} model instances") + + async def acquire(self, timeout: int = MODEL_ACQUIRE_TIMEOUT) -> ModelInstance: + """获取一个空闲的模型实例(真正的异步等待)""" + if not self._initialized: + raise RuntimeError("ModelPool not initialized") + + try: + # 使用 asyncio.wait_for 实现超时控制 + instance = await asyncio.wait_for( + self.queue.get(), + timeout=timeout + ) + instance.stats.is_busy = True + logger.debug(f"Acquired model instance {instance.instance_id}") + return instance + except asyncio.TimeoutError: + raise TimeoutError(f"Failed to acquire model instance within {timeout} seconds. All instances are busy.") + + def release(self, instance: ModelInstance): + """释放模型实例回池中(同步方法,立即返回)""" + instance.stats.is_busy = False + # 使用 put_nowait 立即放回,不会阻塞 + self.queue.put_nowait(instance) + logger.debug(f"Released model instance {instance.instance_id}") + + def get_stats(self) -> Dict[str, Any]: + """获取实例池统计信息""" + total_requests = sum(inst.stats.total_requests for inst in self.instances) + total_inference_time = sum(inst.stats.total_inference_time for inst in self.instances) + busy_count = sum(1 for inst in self.instances if inst.stats.is_busy) + + instance_details = [] + for inst in self.instances: + avg_time = (inst.stats.total_inference_time / inst.stats.total_requests + if inst.stats.total_requests > 0 else 0) + instance_details.append({ + "instance_id": inst.instance_id, + "total_requests": inst.stats.total_requests, + "total_inference_time": round(inst.stats.total_inference_time, 2), + "avg_inference_time": round(avg_time, 3), + "last_used": inst.stats.last_used_time.isoformat() if inst.stats.last_used_time else None, + "is_busy": inst.stats.is_busy + }) + + return { + "total_instances": self.num_instances, + "busy_instances": busy_count, + "available_instances": self.num_instances - busy_count, + "total_requests": total_requests, + "total_inference_time": round(total_inference_time, 2), + "avg_inference_time": round(total_inference_time / total_requests, 3) if total_requests > 0 else 0, + "instances": instance_details + } + + +def load_speakers_config(config_path: str) -> Dict[str, SpeakerInfo]: + """加载发音人配置""" + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + speakers_info = {} + for speaker_data in config['speaker']: + key = speaker_data['key'] + wav_path = os.path.join(os.path.dirname(config_path), speaker_data['wav']) + # 读取配置文件中的参考文本 + text = speaker_data.get('text', None) + + # 创建SpeakerInfo对象 + speaker_info = SpeakerInfo(key, wav_path, text) + speakers_info[key] = speaker_info + + logger.info(f"Loaded speaker: {key}, wav: {wav_path}, text: {text}") + + return speakers_info + + +def load_speaker_embeddings(model: Any, speakers_info: Dict[str, SpeakerInfo]): + """加载发音人特征嵌入""" + global speaker_embeddings + + for key, speaker_info in speakers_info.items(): + try: + # 使用模型提取发音人特征 + # 从参考音频和文本生成特征 + data, sample_rate = sf.read(speaker_info.wav_path) + # 计算时长(以秒为单位) + ref_duration = len(data) / float(sample_rate) + # 传递配置文件中指定的参考文本 + ref_embedding = model.encode_prompt(speaker_info.wav_path, duration=ref_duration, rms=rms, text=speaker_info.text) + speaker_embeddings[key] = ref_embedding + logger.info(f"Generated embedding for speaker: {key},duration: {ref_duration}, text: {speaker_info.text}") + + except Exception as e: + logger.error(f"Error loading embedding for speaker {key}: {str(e)}") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + global speaker_embeddings + + logger.info(f"Initializing Lux-TTS with {NUM_MODEL_INSTANCES} model instances...") + + try: + # 加载发音人配置 + config_path = os.path.join(os.path.dirname(__file__), 'audios', 'speaker.yml') + speakers_info = load_speakers_config(config_path) + + # 初始化全局重采样器 + logger.info("Initializing global audio resamplers...") + initialize_resamplers() + + # 初始化模型实例池 + model_path = '/opt/LuxTTS/models' + model_pool = ModelPool(num_instances=NUM_MODEL_INSTANCES) + + try: + # 初始化所有模型实例 + await model_pool.initialize(model_path, device='cuda', threads=2) + + # 使用第一个实例加载发音人特征(所有实例共享) + first_instance = await model_pool.acquire() + try: + load_speaker_embeddings(first_instance.model, speakers_info) + finally: + model_pool.release(first_instance) + + except Exception as e: + logger.error(f"Failed to initialize model pool: {e}") + raise + + # 将模型池和配置存储在应用状态中 + app.state.model_pool = model_pool + app.state.speakers_info = speakers_info + + logger.info(f"Successfully initialized {NUM_MODEL_INSTANCES} model instances and loaded speaker configurations") + + except Exception as e: + logger.error(f"Failed to initialize Lux-TTS: {str(e)}") + raise + + yield + + # 清理工作 + logger.info("Shutting down model pool...") + + +app = FastAPI( + title="Lux-TTS Server", + description="Lux-TTS API Service", + lifespan=lifespan +) + + +def initialize_resamplers(): + """初始化全局重采样器""" + global gpu_resampler, cpu_resampler + + try: + # 创建 GPU 重采样器 + gpu_resampler = torchaudio.transforms.Resample( + orig_freq=ORIGINAL_SAMPLE_RATE, + new_freq=TARGET_SAMPLE_RATE, + dtype=torch.float32 + ).to('cuda') + logger.info(f"GPU resampler initialized: {ORIGINAL_SAMPLE_RATE}Hz -> {TARGET_SAMPLE_RATE}Hz") + except Exception as e: + logger.warning(f"Failed to initialize GPU resampler: {e}") + gpu_resampler = None + + # 创建 CPU 重采样器(作为后备) + cpu_resampler = torchaudio.transforms.Resample( + orig_freq=ORIGINAL_SAMPLE_RATE, + new_freq=TARGET_SAMPLE_RATE, + dtype=torch.float32 + ) + logger.info("CPU resampler initialized as fallback") + + +def resample_audio_gpu(audio_tensor: torch.Tensor, device: str = 'cuda') -> torch.Tensor: + """使用预初始化的GPU重采样器进行音频重采样 + + Args: + audio_tensor: 输入音频张量 (channels, samples) 或 (samples,) + device: 设备 ('cuda' 或 'cpu') + + Returns: + 重采样后的音频张量 + """ + global gpu_resampler, cpu_resampler + + try: + # 确保数据类型为 float32 + if audio_tensor.dtype != torch.float32: + audio_tensor = audio_tensor.float() + + # 确保音频在 GPU 上 + if not audio_tensor.is_cuda and device == 'cuda': + audio_tensor = audio_tensor.to(device) + + # 确保是 2D 张量 (channels, samples) + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + + # 使用预初始化的 GPU 重采样器 + if gpu_resampler is not None: + resampled = gpu_resampler(audio_tensor) + return resampled + else: + raise RuntimeError("GPU resampler not initialized") + + except Exception as e: + logger.warning(f"GPU resampling failed: {e}, falling back to CPU") + # 降级到 CPU 重采样 + if audio_tensor.is_cuda: + audio_tensor = audio_tensor.cpu() + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + if audio_tensor.dtype != torch.float32: + audio_tensor = audio_tensor.float() + + if cpu_resampler is not None: + return cpu_resampler(audio_tensor) + else: + # 最后的兜底方案 + fallback_resampler = torchaudio.transforms.Resample( + orig_freq=ORIGINAL_SAMPLE_RATE, + new_freq=TARGET_SAMPLE_RATE, + dtype=torch.float32 + ) + return fallback_resampler(audio_tensor) + + +def detect_language(text: str) -> str: + """检测文本语言类型""" + # 统计中文字符数量 + chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') + # 如果中文字符占比超过30%,认为是中文 + if chinese_chars / len(text) > 0.3: + return 'chinese' + else: + return 'english' + +def synthesize_speech(text: str, speaker: str, model: Any, speed: float = 1.0) -> Tuple[torch.Tensor, float, float]: + """语音合成(仅模型推理,不包含后处理) + + Returns: + Tuple[torch.Tensor, float, float]: (final_wav, inference_time, audio_duration) + """ + try: + # 检查发音人是否存在 + if speaker not in app.state.speakers_info: + raise ValueError(f"Speaker '{speaker}' not found in configuration") + + # 检测语言类型并设置相应参数 + language = detect_language(text) + if language == 'chinese': + current_t_shift = CHINESE_T_SHIFT + current_guidance_scale = CHINESE_GUIDANCE_SCALE + # 中文需要调整速度,因为EmiliaTokenizer对中文的token化产生较少token + # 降低速度以补偿token数量较少的问题 + adjusted_speed = speed * CHINESE_SPEED_ADJUSTMENT + else: + current_t_shift = ENGLISH_T_SHIFT + current_guidance_scale = ENGLISH_GUIDANCE_SCALE + # 英文保持原始速度 + adjusted_speed = speed + + start_time = time.time() + # 使用预生成的发音人特征进行语音合成 + final_wav = model.generate_speech(text, speaker_embeddings[speaker], num_steps=num_steps, t_shift=current_t_shift, + guidance_scale=current_guidance_scale, speed=adjusted_speed, return_smooth=return_smooth) + inference_time = time.time() - start_time + + # 估算音频时长(基于原始采样率) + audio_duration = final_wav.shape[-1] / ORIGINAL_SAMPLE_RATE + + # 返回 GPU 上的 tensor,推理时间和音频时长 + return final_wav, inference_time, audio_duration + + except Exception as e: + logger.error(f"Error synthesizing speech: {str(e)}") + raise + + +def post_process_audio(final_wav: torch.Tensor, inference_time: float, audio_duration: float) -> Tuple[bytes, float]: + """音频后处理(重采样、转换,不占用模型) + + Args: + final_wav: GPU 上的音频 tensor + inference_time: 推理时间 + audio_duration: 音频时长 + + Returns: + Tuple[bytes, float]: (pcm_bytes, rtf) + """ + try: + # GPU 重采样:48kHz -> 16kHz + resampled_wav = resample_audio_gpu(final_wav, device='cuda') + + # 转换为 numpy 数组 + resampled_wav = resampled_wav.cpu().numpy().squeeze() + + # 添加前后静音 + sample_rate = TARGET_SAMPLE_RATE + pre_silence_ms = PRE_SILENCE_MS + post_silence_ms = POST_SILENCE_MS + + # 计算静音的样本数 + pre_silence_samples = int(sample_rate * pre_silence_ms / 1000) + post_silence_samples = int(sample_rate * post_silence_ms / 1000) + + # 创建静音数组并拼接 + if pre_silence_samples > 0 or post_silence_samples > 0: + pre_silence = np.zeros(pre_silence_samples, dtype=np.float32) + post_silence = np.zeros(post_silence_samples, dtype=np.float32) + resampled_wav = np.concatenate([pre_silence, resampled_wav, post_silence]) + + # 重新计算采样后的音频时长 + resampled_duration = len(resampled_wav) / TARGET_SAMPLE_RATE + rtf = inference_time / resampled_duration if resampled_duration > 0 else 0 + + # 转换为 PCM 格式(16-bit signed integer) + resampled_wav = np.clip(resampled_wav, -1.0, 1.0) + pcm_data = (resampled_wav * 32767.0).astype(np.int16) + + # 直接返回 PCM 字节流(无 WAV 头) + pcm_bytes = pcm_data.tobytes() + + return pcm_bytes, rtf + + except Exception as e: + logger.error(f"Error in post-processing audio: {str(e)}") + raise + + +@app.get("/") +async def root(): + """根路径,返回API信息""" + return { + "message": "Welcome to Lux-TTS Server", + "endpoints": { + "http": "/tts/http" + } + } + + +@app.post("/tts/http") +async def tts_http(request: TTSRequest): + """HTTP接口:接收text和speaker参数,返回base64编码的音频数据""" + try: + if not request.text.strip(): + raise HTTPException(status_code=400, detail="Text parameter is required and cannot be empty") + + if not request.speaker.strip(): + raise HTTPException(status_code=400, detail="Speaker parameter is required and cannot be empty") + + # 阶段 1:获取模型实例并进行推理 + model_instance = None + final_wav = None + inference_time = 0 + audio_duration = 0 + instance_id = -1 + + try: + model_instance = await app.state.model_pool.acquire() + instance_id = model_instance.instance_id + + logger.debug(f"Instance {instance_id} acquired for inference") + + # 在线程池中执行模型推理(仅推理) + loop = asyncio.get_event_loop() + final_wav, inference_time, audio_duration = await loop.run_in_executor( + executor, # 使用自定义线程池 + synthesize_speech, + request.text, + request.speaker, + model_instance.model, + request.speed + ) + + # 更新实例统计信息 + model_instance.update_stats(inference_time) + + logger.debug(f"Instance {instance_id} inference completed, releasing...") + + except TimeoutError as te: + logger.error(f"Timeout acquiring model instance: {str(te)}") + raise HTTPException(status_code=503, detail="Server is busy. Please try again later.") + + finally: + # 关键:推理完成后立即释放模型实例 + if model_instance is not None: + app.state.model_pool.release(model_instance) + logger.debug(f"Instance {instance_id} released") + + # 阶段 2:后处理(重采样、转换、编码),不占用模型 + if final_wav is None: + raise HTTPException(status_code=500, detail="TTS synthesis failed") + + # 在线程池中执行后处理 + loop = asyncio.get_event_loop() + audio_bytes, rtf = await loop.run_in_executor( + executor, # 使用自定义线程池 + post_process_audio, + final_wav, + inference_time, + audio_duration + ) + + # 编码为base64 + base64_audio = base64.b64encode(audio_bytes).decode('utf-8') + + return JSONResponse(content={ + "audio": base64_audio, + "format": "pcm", + "sample_rate": TARGET_SAMPLE_RATE, + "bit_depth": 16, + "channels": 1, + "speaker": request.speaker, + "text": request.text, + "speed": request.speed, + "rtf": rtf, + "instance_id": instance_id + }) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in HTTP TTS endpoint: {str(e)}") + raise HTTPException(status_code=500, detail=f"TTS synthesis failed: {str(e)}") + + +@app.get("/speakers") +async def get_speakers(): + """获取可用的发音人列表""" + if hasattr(app.state, 'speakers_info'): + speakers = list(app.state.speakers_info.keys()) + else: + speakers = [] + + return {"speakers": speakers} + + +@app.get("/health") +async def health_check(): + """健康检查接口""" + try: + if not hasattr(app.state, 'model_pool'): + return JSONResponse( + status_code=503, + content={"status": "unhealthy", "message": "Model pool not initialized"} + ) + + stats = app.state.model_pool.get_stats() + + return { + "status": "healthy", + "model_instances": stats["total_instances"], + "available_instances": stats["available_instances"], + "busy_instances": stats["busy_instances"] + } + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + return JSONResponse( + status_code=503, + content={"status": "unhealthy", "message": str(e)} + ) + + +@app.get("/stats") +async def get_statistics(): + """获取详细的实例池统计信息""" + try: + if not hasattr(app.state, 'model_pool'): + raise HTTPException(status_code=503, detail="Model pool not initialized") + + stats = app.state.model_pool.get_stats() + return stats + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get statistics: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + import os + + # 从环境变量获取 workers 数量,默认 1 + workers = int(os.getenv('UVICORN_WORKERS', '1')) + + # 多 worker 模式仅适用于 Linux/Unix,Windows 不支持 + if workers > 1 and os.name == 'nt': + logger.warning("Multi-worker mode is not supported on Windows, falling back to single worker") + workers = 1 + + logger.info(f"Starting server with {workers} worker(s)...") + + # 多 worker 模式需要使用导入字符串 + if workers > 1: + uvicorn.run( + "lux_server:app", # 使用导入字符串 + host="0.0.0.0", + port=7000, + workers=workers, + log_level="info" + ) + else: + # 单 worker 模式可以直接使用 app 对象 + uvicorn.run( + app, + host="0.0.0.0", + port=7000, + log_level="info" + ) + diff --git a/zipvoice/luxvoice.py b/zipvoice/luxvoice.py index 1ead207..e627d67 100644 --- a/zipvoice/luxvoice.py +++ b/zipvoice/luxvoice.py @@ -1,62 +1,62 @@ -import torch -from zipvoice.modeling_utils import process_audio, generate, load_models_gpu, load_models_cpu -from zipvoice.onnx_modeling import generate_cpu - -class LuxTTS: - """ - LuxTTS class for encoding prompt and generating speech on cpu/cuda/mps. - """ - - def __init__(self, model_path='YatharthS/LuxTTS', device='cuda', threads=4): - if model_path == 'YatharthS/LuxTTS': - model_path = None - - # Auto-detect better device if cuda is requested but not available - if device == 'cuda' and not torch.cuda.is_available(): - if torch.backends.mps.is_available(): - print("CUDA not available, switching to MPS") - device = 'mps' - else: - print("CUDA not available, switching to CPU") - device = 'cpu' - - if device == 'cpu': - model, feature_extractor, vocos, tokenizer, transcriber = load_models_cpu(model_path, threads) - print("Loading model on CPU") - else: - model, feature_extractor, vocos, tokenizer, transcriber = load_models_gpu(model_path, device=device) - print("Loading model on GPU") - - self.model = model - self.feature_extractor = feature_extractor - self.vocos = vocos - self.tokenizer = tokenizer - self.transcriber = transcriber - self.device = device - self.vocos.freq_range = 12000 - - - - def encode_prompt(self, prompt_audio, duration=5, rms=0.001): - """encodes audio prompt according to duration and rms(volume control)""" - prompt_tokens, prompt_features_lens, prompt_features, prompt_rms = process_audio(prompt_audio, self.transcriber, self.tokenizer, self.feature_extractor, self.device, target_rms=rms, duration=duration) - encode_dict = {"prompt_tokens": prompt_tokens, 'prompt_features_lens': prompt_features_lens, 'prompt_features': prompt_features, 'prompt_rms': prompt_rms} - - return encode_dict - - def generate_speech(self, text, encode_dict, num_steps=4, guidance_scale=3.0, t_shift=0.5, speed=1.0, return_smooth=False): - """encodes text and generates speech using flow matching model according to steps, guidance scale, and t_shift(like temp)""" - - prompt_tokens, prompt_features_lens, prompt_features, prompt_rms = encode_dict.values() - - if return_smooth == True: - self.vocos.return_48k = False - else: - self.vocos.return_48k = True - - if self.device == 'cpu': - final_wav = generate_cpu(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed) - else: - final_wav = generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed) - - return final_wav.cpu() +import torch +from zipvoice.modeling_utils import process_audio, generate, load_models_gpu, load_models_cpu +from zipvoice.onnx_modeling import generate_cpu + +class LuxTTS: + """ + LuxTTS class for encoding prompt and generating speech on cpu/cuda/mps. + """ + + def __init__(self, model_path='YatharthS/LuxTTS', device='cuda', threads=4): + if model_path == 'YatharthS/LuxTTS': + model_path = None + + # Auto-detect better device if cuda is requested but not available + if device == 'cuda' and not torch.cuda.is_available(): + if torch.backends.mps.is_available(): + print("CUDA not available, switching to MPS") + device = 'mps' + else: + print("CUDA not available, switching to CPU") + device = 'cpu' + + if device == 'cpu': + model, feature_extractor, vocos, tokenizer, transcriber = load_models_cpu(model_path, threads) + print("Loading model on CPU") + else: + model, feature_extractor, vocos, tokenizer, transcriber = load_models_gpu(model_path, device=device) + print("Loading model on GPU") + + self.model = model + self.feature_extractor = feature_extractor + self.vocos = vocos + self.tokenizer = tokenizer + self.transcriber = transcriber + self.device = device + self.vocos.freq_range = 12000 + + + + def encode_prompt(self, prompt_audio, duration=5, rms=0.001, text=None): + """encodes audio prompt according to duration and rms(volume control)""" + prompt_tokens, prompt_features_lens, prompt_features, prompt_rms = process_audio(prompt_audio, self.transcriber, self.tokenizer, self.feature_extractor, self.device, target_rms=rms, duration=duration, text=text) + encode_dict = {"prompt_tokens": prompt_tokens, 'prompt_features_lens': prompt_features_lens, 'prompt_features': prompt_features, 'prompt_rms': prompt_rms} + + return encode_dict + + def generate_speech(self, text, encode_dict, num_steps=4, guidance_scale=3.0, t_shift=0.5, speed=1.0, return_smooth=False): + """encodes text and generates speech using flow matching model according to steps, guidance scale, and t_shift(like temp)""" + + prompt_tokens, prompt_features_lens, prompt_features, prompt_rms = encode_dict.values() + + if return_smooth == True: + self.vocos.return_48k = False + else: + self.vocos.return_48k = True + + if self.device == 'cpu': + final_wav = generate_cpu(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed) + else: + final_wav = generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed) + + return final_wav.cpu() diff --git a/zipvoice/modeling_utils.py b/zipvoice/modeling_utils.py index f4e621d..ac34053 100644 --- a/zipvoice/modeling_utils.py +++ b/zipvoice/modeling_utils.py @@ -1,157 +1,164 @@ -import argparse -import datetime as dt -import json -import logging -import os -from pathlib import Path -from typing import Optional - -import numpy as np -import safetensors.torch -import torch -import librosa -import torchaudio -from transformers import pipeline -from huggingface_hub import snapshot_download -from lhotse.utils import fix_random_seed - -from zipvoice.models.zipvoice_distill import ZipVoiceDistill -from zipvoice.tokenizer.tokenizer import EmiliaTokenizer -from zipvoice.utils.checkpoint import load_checkpoint -from zipvoice.utils.common import AttributeDict, str2bool -from zipvoice.utils.feature import VocosFbank -from zipvoice.utils.infer import rms_norm - -from dataclasses import dataclass, field -from typing import Optional, List - -from linacodec.vocoder.vocos import Vocos -from zipvoice.onnx_modeling import OnnxModel -from torch.nn.utils import parametrize - - -@dataclass -class LuxTTSConfig: - # Model Setup - model_dir: Optional[str] = None - checkpoint_name: str = "model.pt" - vocoder_path: Optional[str] = None - trt_engine_path: Optional[str] = None - - # Tokenizer & Language - tokenizer: str = "emilia" # choices: ["emilia", "libritts", "espeak", "simple"] - lang: str = "en-us" - - -@torch.inference_mode -def process_audio(audio, transcriber, tokenizer, feature_extractor, device, target_rms=0.1, duration=4, feat_scale=0.1): - prompt_wav, sr = librosa.load(audio, sr=24000, duration=duration) - prompt_wav2, sr = librosa.load(audio, sr=16000, duration=duration) - prompt_text = transcriber(prompt_wav2)["text"] - print(prompt_text) - - prompt_wav = torch.from_numpy(prompt_wav).unsqueeze(0) - prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms) - - prompt_features = feature_extractor.extract( - prompt_wav, sampling_rate=24000 - ).to(device) - prompt_features = prompt_features.unsqueeze(0) * feat_scale - prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) - prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) - return prompt_tokens, prompt_features_lens, prompt_features, prompt_rms - -def generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, model, vocoder, tokenizer, num_step=4, guidance_scale=3.0, speed=1.0, t_shift=0.5, target_rms=0.1): - tokens = tokenizer.texts_to_token_ids([text]) - device = next(model.parameters()).device # Auto-detect device - - speed = speed * 1.3 - - with torch.inference_mode(): - (pred_features, _, _, _) = model.sample( - tokens=tokens, - prompt_tokens=prompt_tokens, - prompt_features=prompt_features, - prompt_features_lens=prompt_features_lens, - speed=speed, - t_shift=t_shift, - duration='predict', - num_step=num_step, - guidance_scale=guidance_scale, - ) - - # Convert to waveform - pred_features = pred_features.permute(0, 2, 1) / 0.1 - wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) - - # Volume matching - if prompt_rms < target_rms: - wav = wav * (prompt_rms / target_rms) - - return wav - -def load_models_gpu(model_path=None, device="cuda"): - params = LuxTTSConfig() - if model_path is None: - model_path = snapshot_download("YatharthS/LuxTTS") - - token_file = f"{model_path}/tokens.txt" - model_ckpt = f"{model_path}/model.pt" - model_config = f"{model_path}/config.json" - - transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device) - tokenizer = EmiliaTokenizer(token_file=token_file) - tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} - - with open(model_config, "r") as f: - model_config = json.load(f) - - model = ZipVoiceDistill( - **model_config["model"], - **tokenizer_config, - ) - load_checkpoint(filename=model_ckpt, model=model, strict=True) - params.device = torch.device(device, 0) - - model = model.to(params.device).eval() - feature_extractor = VocosFbank() - - vocos = Vocos.from_hparams(f'{model_path}/vocoder/config.yaml').to(device) - parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[0], "weight") - parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[1], "weight") - vocos.load_state_dict(torch.load(f'{model_path}/vocoder/vocos.bin', map_location=params.device)) - - params.sampling_rate = model_config["feature"]["sampling_rate"] - return model, feature_extractor, vocos, tokenizer, transcriber - -def load_models_cpu(model_path = None, num_thread=2): - params = LuxTTSConfig() - params.seed = 42 - - model_path = snapshot_download('YatharthS/LuxTTS') - - token_file = f"{model_path}/tokens.txt" - text_encoder_path = f"{model_path}/text_encoder.onnx" - fm_decoder_path = f"{model_path}/fm_decoder.onnx" - model_config = f"{model_path}/config.json" - - transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny", device='cpu') - - tokenizer = EmiliaTokenizer(token_file=token_file) - tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} - - with open(model_config, "r") as f: - model_config = json.load(f) - - model = OnnxModel(text_encoder_path, fm_decoder_path, num_thread=num_thread) - - vocos = Vocos.from_hparams(f'{model_path}/vocoder/config.yaml').eval() - parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[0], "weight") - parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[1], "weight") - vocos.load_state_dict(torch.load(f'{model_path}/vocoder/vocos.bin', map_location=torch.device('cpu'))) - - feature_extractor = VocosFbank() - - params.sampling_rate = model_config["feature"]["sampling_rate"] - params.onnx_int8 = True - return model, feature_extractor, vocos, tokenizer, transcriber +import argparse +import datetime as dt +import json +import logging +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import safetensors.torch +import torch +import librosa +import torchaudio +from transformers import pipeline +from huggingface_hub import snapshot_download +from lhotse.utils import fix_random_seed + +from zipvoice.models.zipvoice_distill import ZipVoiceDistill +from zipvoice.tokenizer.tokenizer import EmiliaTokenizer +from zipvoice.utils.checkpoint import load_checkpoint +from zipvoice.utils.common import AttributeDict, str2bool +from zipvoice.utils.feature import VocosFbank +from zipvoice.utils.infer import rms_norm + +from dataclasses import dataclass, field +from typing import Optional, List + +from linacodec.vocoder.vocos import Vocos +from zipvoice.onnx_modeling import OnnxModel +from torch.nn.utils import parametrize + + +@dataclass +class LuxTTSConfig: + # Model Setup + model_dir: Optional[str] = None + checkpoint_name: str = "model.pt" + vocoder_path: Optional[str] = None + trt_engine_path: Optional[str] = None + + # Tokenizer & Language + tokenizer: str = "emilia" # choices: ["emilia", "libritts", "espeak", "simple"] + lang: str = "en-us" + + +@torch.inference_mode +def process_audio(audio, transcriber, tokenizer, feature_extractor, device, target_rms=0.1, duration=4, feat_scale=0.1, text=None): + prompt_wav, sr = librosa.load(audio, sr=24000, duration=duration) + + # 如果提供了文本,使用它作为参考文本 + if text is not None: + prompt_text = text + print(f"Using provided text: {prompt_text}") + else: + # 否则使用Whisper自动识别 + prompt_wav2, sr = librosa.load(audio, sr=16000, duration=duration) + prompt_text = transcriber(prompt_wav2)["text"] + print(f"Using Whisper-recognized text: {prompt_text}") + + prompt_wav = torch.from_numpy(prompt_wav).unsqueeze(0) + prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms) + + prompt_features = feature_extractor.extract( + prompt_wav, sampling_rate=24000 + ).to(device) + prompt_features = prompt_features.unsqueeze(0) * feat_scale + prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) + prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) + return prompt_tokens, prompt_features_lens, prompt_features, prompt_rms + +def generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, model, vocoder, tokenizer, num_step=4, guidance_scale=3.0, speed=1.0, t_shift=0.5, target_rms=0.1): + tokens = tokenizer.texts_to_token_ids([text]) + device = next(model.parameters()).device # Auto-detect device + + speed = speed * 1.3 + + with torch.inference_mode(): + (pred_features, _, _, _) = model.sample( + tokens=tokens, + prompt_tokens=prompt_tokens, + prompt_features=prompt_features, + prompt_features_lens=prompt_features_lens, + speed=speed, + t_shift=t_shift, + duration='predict', + num_step=num_step, + guidance_scale=guidance_scale, + ) + + # Convert to waveform + pred_features = pred_features.permute(0, 2, 1) / 0.1 + wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) + + # Volume matching + if prompt_rms < target_rms: + wav = wav * (prompt_rms / target_rms) + + return wav + +def load_models_gpu(model_path=None, device="cuda"): + params = LuxTTSConfig() + if model_path is None: + model_path = snapshot_download("YatharthS/LuxTTS") + + token_file = f"{model_path}/tokens.txt" + model_ckpt = f"{model_path}/model.pt" + model_config = f"{model_path}/config.json" + + transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device) + tokenizer = EmiliaTokenizer(token_file=token_file) + tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} + + with open(model_config, "r") as f: + model_config = json.load(f) + + model = ZipVoiceDistill( + **model_config["model"], + **tokenizer_config, + ) + load_checkpoint(filename=model_ckpt, model=model, strict=True) + params.device = torch.device(device, 0) + + model = model.to(params.device).eval() + feature_extractor = VocosFbank() + + vocos = Vocos.from_hparams(f'{model_path}/vocoder/config.yaml').to(device) + parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[0], "weight") + parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[1], "weight") + vocos.load_state_dict(torch.load(f'{model_path}/vocoder/vocos.bin', map_location=params.device)) + + params.sampling_rate = model_config["feature"]["sampling_rate"] + return model, feature_extractor, vocos, tokenizer, transcriber + +def load_models_cpu(model_path = None, num_thread=2): + params = LuxTTSConfig() + params.seed = 42 + + model_path = snapshot_download('YatharthS/LuxTTS') + + token_file = f"{model_path}/tokens.txt" + text_encoder_path = f"{model_path}/text_encoder.onnx" + fm_decoder_path = f"{model_path}/fm_decoder.onnx" + model_config = f"{model_path}/config.json" + + transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny", device='cpu') + + tokenizer = EmiliaTokenizer(token_file=token_file) + tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} + + with open(model_config, "r") as f: + model_config = json.load(f) + + model = OnnxModel(text_encoder_path, fm_decoder_path, num_thread=num_thread) + + vocos = Vocos.from_hparams(f'{model_path}/vocoder/config.yaml').eval() + parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[0], "weight") + parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[1], "weight") + vocos.load_state_dict(torch.load(f'{model_path}/vocoder/vocos.bin', map_location=torch.device('cpu'))) + + feature_extractor = VocosFbank() + + params.sampling_rate = model_config["feature"]["sampling_rate"] + params.onnx_int8 = True + return model, feature_extractor, vocos, tokenizer, transcriber diff --git a/zipvoice/models/modules/zipformer.py b/zipvoice/models/modules/zipformer.py index aed6cd3..a975759 100644 --- a/zipvoice/models/modules/zipformer.py +++ b/zipvoice/models/modules/zipformer.py @@ -1,1687 +1,1687 @@ -#!/usr/bin/env python3 -# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao, -# Wei Kang -# Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import random -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor, nn - -if torch.cuda.is_available(): - DEVICE_TYPE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE_TYPE = "mps" -else: - DEVICE_TYPE = "cpu" - -from zipvoice.models.modules.scaling import ( - ActivationDropoutAndLinear, - Balancer, - BiasNorm, - Dropout2, - FloatLike, - Identity, - ScaledLinear, - ScheduledFloat, - SwooshR, - Whiten, - limit_param_value, - penalize_abs_values_gt, - softmax, -) - - -def timestep_embedding(timesteps, dim, max_period=10000): - """Create sinusoidal timestep embeddings. - - :param timesteps: shape of (N) or (N, T) - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim) - """ - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) - / half - ) - - if timesteps.dim() == 2: - timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N) - - args = timesteps[..., None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1) - return embedding - - -class TTSZipformer(nn.Module): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same - length as downsampling_factor if they are single ints or one-element tuples. - The length of downsampling_factor defines the number of stacks. - - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, - one per encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection - per attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to - projection, e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - use_time_embed: (bool): if True, take time embedding as an additional input. - time_embed_dim: (int): the dimension of the time embedding. - use_guidance_scale_embed (bool): if True, take guidance scale embedding as - an additional input. - guidance_scale_embed_dim: (int): the dimension of the guidance scale embedding. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - downsampling_factor: Union[int, Tuple[int]] = (2, 4), - num_encoder_layers: Union[int, Tuple[int]] = 4, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - encoder_dim: int = 384, - query_head_dim: int = 24, - pos_head_dim: int = 4, - value_head_dim: int = 12, - num_heads: int = 8, - feedforward_dim: int = 1536, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - use_time_embed: bool = True, - time_embed_dim: int = 192, - use_guidance_scale_embed: bool = False, - guidance_scale_embed_dim: int = 192, - use_conv: bool = True, - ) -> None: - super(TTSZipformer, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - if isinstance(downsampling_factor, int): - downsampling_factor = (downsampling_factor,) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same - length as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - def _assert_downsampling_factor(factors): - """assert downsampling_factor follows u-net style""" - assert factors[0] == 1 and factors[-1] == 1 - - for i in range(1, len(factors) // 2 + 1): - assert factors[i] == factors[i - 1] * 2 - - for i in range(len(factors) // 2 + 1, len(factors)): - assert factors[i] * 2 == factors[i - 1] - - _assert_downsampling_factor(downsampling_factor) - self.downsampling_factor = downsampling_factor # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - self.encoder_dim = encoder_dim - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim - self.value_head_dim = value_head_dim - self.num_heads = num_heads - - self.use_time_embed = use_time_embed - self.use_guidance_scale_embed = use_guidance_scale_embed - - self.time_embed_dim = time_embed_dim - if self.use_time_embed: - assert time_embed_dim != -1 - else: - time_embed_dim = -1 - self.guidance_scale_embed_dim = guidance_scale_embed_dim - - self.in_proj = nn.Linear(in_dim, encoder_dim) - self.out_proj = nn.Linear(encoder_dim, out_dim) - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - value_head_dim=value_head_dim, - feedforward_dim=feedforward_dim, - use_conv=use_conv, - cnn_module_kernel=cnn_module_kernel[i], - dropout=dropout, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - embed_dim=encoder_dim, - time_embed_dim=time_embed_dim, - pos_dim=pos_dim, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim, - downsample=downsampling_factor[i], - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - if self.use_time_embed: - self.time_embed = nn.Sequential( - nn.Linear(time_embed_dim, time_embed_dim * 2), - SwooshR(), - nn.Linear(time_embed_dim * 2, time_embed_dim), - ) - else: - self.time_embed = None - - if self.use_guidance_scale_embed: - self.guidance_scale_embed = ScaledLinear( - guidance_scale_embed_dim, - time_embed_dim, - bias=False, - initial_scale=0.1, - ) - else: - self.guidance_scale_embed = None - - def forward( - self, - x: Tensor, - t: Optional[Tensor] = None, - padding_mask: Optional[Tensor] = None, - guidance_scale: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - t: - A t tensor of shape (batch_size,) or (batch_size, seq_len) - padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - guidance_scale: - The guidance scale in classifier-free guidance of distillation model. - Returns: - Return the output embeddings. its shape is - (batch_size, output_seq_len, encoder_dim) - """ - x = x.permute(1, 0, 2) - x = self.in_proj(x) - - if t is not None: - assert t.dim() == 1 or t.dim() == 2, t.shape - time_emb = timestep_embedding(t, self.time_embed_dim) - if guidance_scale is not None: - assert ( - guidance_scale.dim() == 1 or guidance_scale.dim() == 2 - ), guidance_scale.shape - guidance_scale_emb = self.guidance_scale_embed( - timestep_embedding(guidance_scale, self.guidance_scale_embed_dim) - ) - time_emb = time_emb + guidance_scale_emb - time_emb = self.time_embed(time_emb) - else: - time_emb = None - - attn_mask = None - - for i, module in enumerate(self.encoders): - x = module( - x, - time_emb=time_emb, - src_key_padding_mask=padding_mask, - attn_mask=attn_mask, - ) - x = self.out_proj(x) - x = x.permute(1, 0, 2) - return x - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (required). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module (default=31). - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - use_conv: bool = True, - attention_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - conv_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - const_attention_rate: FloatLike = ScheduledFloat( - (0.0, 0.25), (4000.0, 0.025), default=0 - ), - ff2_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - ff3_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - bypass_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.5), (4000.0, 0.02), default=0 - ), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass. - self.bypass = BypassModule( - embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 - ) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout - ) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule( - embed_dim, (feedforward_dim * 5) // 4, dropout - ) - - self.nonlin_attention = NonlinAttention( - embed_dim, hidden_channels=3 * embed_dim // 4 - ) - - self.use_conv = use_conv - - if self.use_conv: - self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.2, - max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.balancer2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.1, - max_abs=4.0, - ) - - def get_sequence_dropout_mask( - self, x: Tensor, dropout_rate: float - ) -> Optional[Tensor]: - if ( - dropout_rate == 0.0 - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - time_emb: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): - shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or - (batch_size, 2*seq_len-1, pos_emb_dim) - time_emb: the embedding representing the current timestep - shape (batch_size, embedding_dim) or (seq_len, batch_size, embedding_dim). - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) - or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) - or (tgt_seq_len, src_seq_len). True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); - True means masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting() or torch.jit.is_tracing(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = ( - float(self.attention_skip_rate) if self.training else 0.0 - ) - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - if time_emb is not None: - - src = src + time_emb - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask( - src, attention_skip_rate - ) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to( - selected_attn_weights.dtype - ) - selected_attn_weights = selected_attn_weights * ( - 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) - ) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + ( - na if self_attn_dropout_mask is None else na * self_attn_dropout_mask - ) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if self.use_conv: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - - if time_emb is not None: - src = src + time_emb - - src = src + self.sequence_dropout( - self.conv_module1( - src, - src_key_padding_mask=src_key_padding_mask, - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate - ) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if self.use_conv: - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - - if time_emb is not None: - src = src + time_emb - - src = src + self.sequence_dropout( - self.conv_module2( - src, - src_key_padding_mask=src_key_padding_mask, - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate - ) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - embed_dim: int, - time_embed_dim: int, - pos_dim: int, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.15, length_factor=1.0 - ) - if time_embed_dim != -1: - self.time_emb = nn.Sequential( - SwooshR(), - nn.Linear(time_embed_dim, embed_dim), - ) - else: - self.time_emb = None - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat( - (cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0, - ) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - time_emb: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): - shape (seq_len, batch_size, embedding_dim). - time_emb: the embedding representing the current timestep: - shape (batch_size, embedding_dim) - or (seq_len, batch_size, embedding_dim) . - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) - or (seq_len, seq_len), interpreted as - (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); - True means masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - if self.time_emb is not None: - assert time_emb is not None - time_emb = self.time_emb(time_emb) - else: - assert time_emb is None - - output = src - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - time_emb=time_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - return output - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized - per-sequence layer-skipping. The bypass is limited during early stages of training - to be close to "straight-through", i.e. to not do the bypass operation much - initially, in order to force all the modules to learn something. - """ - - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0, - ): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 corresponds to bypassing - # this module. - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value( - self.bypass_scale, - min=float(self.scale_min), - max=float(self.scale_max), - ) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for - # sequences on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = ( - torch.rand((batch_size, 1), device=ans.device) - < straight_through_rate - ) - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame - rate, after convolutional downsampling, and then upsampled again at the output, and - combined with the origin input, so that the output has the same shape as the input. - """ - - def __init__(self, encoder: nn.Module, dim: int, downsample: int): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(downsample) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - time_emb: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): - shape (seq_len, batch_size, embedding_dim). - time_emb: the embedding representing the current timestep: - shape (batch_size, embedding_dim) - or (seq_len, batch_size, embedding_dim) . - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape - (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) - or (seq_len, seq_len), interpreted as - (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); - True means masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if time_emb is not None and time_emb.dim() == 3: - time_emb = time_emb[::ds] - if attn_mask is not None: - attn_mask = attn_mask[::ds, ::ds] - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask[..., ::ds] - - src = self.encoder( - src, - time_emb=time_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum. - """ - - def __init__(self, downsample: int): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that just repeats the input. - """ - - def __init__(self, upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able - to encode the important information about the relative position in a relatively - small number of dimensions. The goal is to make it so that small differences between - large relative offsets (e.g. 1000 vs. 1001) make very little difference to the - embedding. Such differences were potentially important when encoding absolute - position, but not important when encoding relative position because there is now no - need to compare two large offsets with each other. - - Our embedding works by projecting the interval [-infinity,infinity] to a finite - interval using the atan() function, before doing the Fourier transform of that fixed - interval. The atan() function would compress the "long tails" too small, making it - hard to distinguish between different magnitudes of large offsets, so we use a - logarithmic function to compress large offsets to a smaller range before applying - atan(). Scalings are chosen in such a way that the embedding can clearly distinguish - individual offsets as long as they are quite close to the origin, e.g. abs(offset) - <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - - def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0, embed_dim - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0, length_factor - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more - # resolution for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity - # to infinity; but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which is - # important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , - # atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position - encoding. Various other modules consume the resulting attention weights: - see, for example, the SimpleAttention module which allows you to compute - conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language - Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, - in_proj_dim, - bias=True, - initial_scale=query_head_dim**-0.25, - ) - - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be sufficient to fix the problem. - self.balance_keys = Balancer( - key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnostics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). - Positions that are True in this mask will be ignored as sources in the - attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or - (batch_size, seq_len, seq_len), interpreted as - ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of - shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, ( - p.shape[-1], - num_heads, - pos_head_dim, - ) - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, - # batch, time1, seq_len2) [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of - # pos_scores from relative to absolute position. I don't know whether I - # might have got the time-offsets backwards or not, but let this code define - # which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.amp.autocast(DEVICE_TYPE, enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.debug( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed - attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, - embed_dim, - bias=True, - initial_scale=0.05, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - -class FeedforwardModule(nn.Module): - """Feedforward module in TTSZipformer model.""" - - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer( - feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwooshL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.1, - ) - - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication - by attention weights (borrowed from the attention module) in place of actual - convolution. We also took out the second nonlinearity, the one after the - attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at - # 2.0, because we noticed that well-trained instances of this module have - # abs-value before the sigmoid starting from about 3, and poorly-trained - # instances of the module have smaller abs values before the sigmoid. - self.balancer = Balancer( - hidden_channels, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). For - # most layers the normal rms value of channels of x seems to be in the range 1 - # to 4, but sometimes, for some reason, for layer 0 the rms ends up being very - # large, between 50 and 100 for different channels. This will cause very peaky - # and sparse derivatives for the sigmoid gating function, which will tend to - # make the loss function not learn effectively. (for most layers the average - # absolute values are in the range 0.5..9.0, and the average p(x>0), i.e. - # positive proportion, at the output of pointwise_conv1.output is around 0.35 to - # 0.45 for different layers, which likely breaks down as 0.5 for the "linear" - # half and 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that - # if we constrain the rms values to a reasonable range via a constraint of - # max_abs=10.0, it will be in a better position to start learning something, - # i.e. to latch onto the correct range. - self.balancer1 = Balancer( - bottleneck_dim, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - - self.balancer2 = Balancer( - bottleneck_dim, - channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwooshR", - dropout_p=0.0, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x +#!/usr/bin/env python3 +# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao, +# Wei Kang +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +if torch.cuda.is_available(): + DEVICE_TYPE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE_TYPE = "mps" +else: + DEVICE_TYPE = "cpu" + +from zipvoice.models.modules.scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + Dropout2, + FloatLike, + Identity, + ScaledLinear, + ScheduledFloat, + SwooshR, + Whiten, + limit_param_value, + penalize_abs_values_gt, + softmax, +) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """Create sinusoidal timestep embeddings. + + :param timesteps: shape of (N) or (N, T) + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim) + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) + / half + ) + + if timesteps.dim() == 2: + timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N) + + args = timesteps[..., None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1) + return embedding + + +class TTSZipformer(nn.Module): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same + length as downsampling_factor if they are single ints or one-element tuples. + The length of downsampling_factor defines the number of stacks. + + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, + one per encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection + per attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to + projection, e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + use_time_embed: (bool): if True, take time embedding as an additional input. + time_embed_dim: (int): the dimension of the time embedding. + use_guidance_scale_embed (bool): if True, take guidance scale embedding as + an additional input. + guidance_scale_embed_dim: (int): the dimension of the guidance scale embedding. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + downsampling_factor: Union[int, Tuple[int]] = (2, 4), + num_encoder_layers: Union[int, Tuple[int]] = 4, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + encoder_dim: int = 384, + query_head_dim: int = 24, + pos_head_dim: int = 4, + value_head_dim: int = 12, + num_heads: int = 8, + feedforward_dim: int = 1536, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + use_time_embed: bool = True, + time_embed_dim: int = 192, + use_guidance_scale_embed: bool = False, + guidance_scale_embed_dim: int = 192, + use_conv: bool = True, + ) -> None: + super(TTSZipformer, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + if isinstance(downsampling_factor, int): + downsampling_factor = (downsampling_factor,) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same + length as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + def _assert_downsampling_factor(factors): + """assert downsampling_factor follows u-net style""" + assert factors[0] == 1 and factors[-1] == 1 + + for i in range(1, len(factors) // 2 + 1): + assert factors[i] == factors[i - 1] * 2 + + for i in range(len(factors) // 2 + 1, len(factors)): + assert factors[i] * 2 == factors[i - 1] + + _assert_downsampling_factor(downsampling_factor) + self.downsampling_factor = downsampling_factor # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + self.encoder_dim = encoder_dim + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim + self.value_head_dim = value_head_dim + self.num_heads = num_heads + + self.use_time_embed = use_time_embed + self.use_guidance_scale_embed = use_guidance_scale_embed + + self.time_embed_dim = time_embed_dim + if self.use_time_embed: + assert time_embed_dim != -1 + else: + time_embed_dim = -1 + self.guidance_scale_embed_dim = guidance_scale_embed_dim + + self.in_proj = nn.Linear(in_dim, encoder_dim) + self.out_proj = nn.Linear(encoder_dim, out_dim) + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + value_head_dim=value_head_dim, + feedforward_dim=feedforward_dim, + use_conv=use_conv, + cnn_module_kernel=cnn_module_kernel[i], + dropout=dropout, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + embed_dim=encoder_dim, + time_embed_dim=time_embed_dim, + pos_dim=pos_dim, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim, + downsample=downsampling_factor[i], + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + if self.use_time_embed: + self.time_embed = nn.Sequential( + nn.Linear(time_embed_dim, time_embed_dim * 2), + SwooshR(), + nn.Linear(time_embed_dim * 2, time_embed_dim), + ) + else: + self.time_embed = None + + if self.use_guidance_scale_embed: + self.guidance_scale_embed = ScaledLinear( + guidance_scale_embed_dim, + time_embed_dim, + bias=False, + initial_scale=0.1, + ) + else: + self.guidance_scale_embed = None + + def forward( + self, + x: Tensor, + t: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + guidance_scale: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + t: + A t tensor of shape (batch_size,) or (batch_size, seq_len) + padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + guidance_scale: + The guidance scale in classifier-free guidance of distillation model. + Returns: + Return the output embeddings. its shape is + (batch_size, output_seq_len, encoder_dim) + """ + x = x.permute(1, 0, 2) + x = self.in_proj(x) + + if t is not None: + assert t.dim() == 1 or t.dim() == 2, t.shape + time_emb = timestep_embedding(t, self.time_embed_dim) + if guidance_scale is not None: + assert ( + guidance_scale.dim() == 1 or guidance_scale.dim() == 2 + ), guidance_scale.shape + guidance_scale_emb = self.guidance_scale_embed( + timestep_embedding(guidance_scale, self.guidance_scale_embed_dim) + ) + time_emb = time_emb + guidance_scale_emb + time_emb = self.time_embed(time_emb) + else: + time_emb = None + + attn_mask = None + + for i, module in enumerate(self.encoders): + x = module( + x, + time_emb=time_emb, + src_key_padding_mask=padding_mask, + attn_mask=attn_mask, + ) + x = self.out_proj(x) + x = x.permute(1, 0, 2) + return x + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (required). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + use_conv: bool = True, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.use_conv = use_conv + + if self.use_conv: + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + time_emb: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): + shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or + (batch_size, 2*seq_len-1, pos_emb_dim) + time_emb: the embedding representing the current timestep + shape (batch_size, embedding_dim) or (seq_len, batch_size, embedding_dim). + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) + or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) + or (tgt_seq_len, src_seq_len). True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); + True means masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + if time_emb is not None: + + src = src + time_emb + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if self.use_conv: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + + if time_emb is not None: + src = src + time_emb + + src = src + self.sequence_dropout( + self.conv_module1( + src, + src_key_padding_mask=src_key_padding_mask, + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if self.use_conv: + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + + if time_emb is not None: + src = src + time_emb + + src = src + self.sequence_dropout( + self.conv_module2( + src, + src_key_padding_mask=src_key_padding_mask, + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + embed_dim: int, + time_embed_dim: int, + pos_dim: int, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + if time_embed_dim != -1: + self.time_emb = nn.Sequential( + SwooshR(), + nn.Linear(time_embed_dim, embed_dim), + ) + else: + self.time_emb = None + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + time_emb: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): + shape (seq_len, batch_size, embedding_dim). + time_emb: the embedding representing the current timestep: + shape (batch_size, embedding_dim) + or (seq_len, batch_size, embedding_dim) . + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) + or (seq_len, seq_len), interpreted as + (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); + True means masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + if self.time_emb is not None: + assert time_emb is not None + time_emb = self.time_emb(time_emb) + else: + assert time_emb is None + + output = src + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + time_emb=time_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + return output + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized + per-sequence layer-skipping. The bypass is limited during early stages of training + to be close to "straight-through", i.e. to not do the bypass operation much + initially, in order to force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 corresponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max), + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for + # sequences on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame + rate, after convolutional downsampling, and then upsampled again at the output, and + combined with the origin input, so that the output has the same shape as the input. + """ + + def __init__(self, encoder: nn.Module, dim: int, downsample: int): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(downsample) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + time_emb: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): + shape (seq_len, batch_size, embedding_dim). + time_emb: the embedding representing the current timestep: + shape (batch_size, embedding_dim) + or (seq_len, batch_size, embedding_dim) . + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape + (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) + or (seq_len, seq_len), interpreted as + (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); + True means masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if time_emb is not None and time_emb.dim() == 3: + time_emb = time_emb[::ds] + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[..., ::ds] + + src = self.encoder( + src, + time_emb=time_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum. + """ + + def __init__(self, downsample: int): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that just repeats the input. + """ + + def __init__(self, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able + to encode the important information about the relative position in a relatively + small number of dimensions. The goal is to make it so that small differences between + large relative offsets (e.g. 1000 vs. 1001) make very little difference to the + embedding. Such differences were potentially important when encoding absolute + position, but not important when encoding relative position because there is now no + need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite + interval using the atan() function, before doing the Fourier transform of that fixed + interval. The atan() function would compress the "long tails" too small, making it + hard to distinguish between different magnitudes of large offsets, so we use a + logarithmic function to compress large offsets to a smaller range before applying + atan(). Scalings are chosen in such a way that the embedding can clearly distinguish + individual offsets as long as they are quite close to the origin, e.g. abs(offset) + <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more + # resolution for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity + # to infinity; but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which is + # important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , + # atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position + encoding. Various other modules consume the resulting attention weights: + see, for example, the SimpleAttention module which allows you to compute + conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language + Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, + in_proj_dim, + bias=True, + initial_scale=query_head_dim**-0.25, + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be sufficient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). + Positions that are True in this mask will be ignored as sources in the + attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or + (batch_size, seq_len, seq_len), interpreted as + ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of + shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, + # batch, time1, seq_len2) [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of + # pos_scores from relative to absolute position. I don't know whether I + # might have got the time-offsets backwards or not, but let this code define + # which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.amp.autocast(DEVICE_TYPE, enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.debug( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed + attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, + embed_dim, + bias=True, + initial_scale=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + +class FeedforwardModule(nn.Module): + """Feedforward module in TTSZipformer model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication + by attention weights (borrowed from the attention module) in place of actual + convolution. We also took out the second nonlinearity, the one after the + attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at + # 2.0, because we noticed that well-trained instances of this module have + # abs-value before the sigmoid starting from about 3, and poorly-trained + # instances of the module have smaller abs values before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). For + # most layers the normal rms value of channels of x seems to be in the range 1 + # to 4, but sometimes, for some reason, for layer 0 the rms ends up being very + # large, between 50 and 100 for different channels. This will cause very peaky + # and sparse derivatives for the sigmoid gating function, which will tend to + # make the loss function not learn effectively. (for most layers the average + # absolute values are in the range 0.5..9.0, and the average p(x>0), i.e. + # positive proportion, at the output of pointwise_conv1.output is around 0.35 to + # 0.45 for different layers, which likely breaks down as 0.5 for the "linear" + # half and 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that + # if we constrain the rms values to a reasonable range via a constraint of + # max_abs=10.0, it will be in a better position to start learning something, + # i.e. to latch onto the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x