diff --git a/moshi/moshi/offline.py b/moshi/moshi/offline.py index f690620d..b67a206d 100644 --- a/moshi/moshi/offline.py +++ b/moshi/moshi/offline.py @@ -90,7 +90,7 @@ def wrap_with_system_tags(text: str) -> str: return f" {cleaned} " -def warmup(mimi: MimiModel, other_mimi: MimiModel, lm_gen: LMGen, device: str, frame_size: int): +def warmup(mimi: MimiModel, lm_gen: LMGen, device: str, frame_size: int): """Run a short warmup loop to initialize CUDA graphs and streaming state. Replicates the same warmup behavior as server.py: zeros → encode → LMGen.step → decode. @@ -98,26 +98,23 @@ def warmup(mimi: MimiModel, other_mimi: MimiModel, lm_gen: LMGen, device: str, f for _ in range(4): chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=device) codes = mimi.encode(chunk) - _ = other_mimi.encode(chunk) for c in range(codes.shape[-1]): tokens = lm_gen.step(codes[:, :, c : c + 1]) if tokens is None: continue # Decode agent audio channels to ensure decode graphs/states are primed _ = mimi.decode(tokens[:, 1:9]) - _ = other_mimi.decode(tokens[:, 1:9]) if torch.cuda.is_available(): torch.cuda.synchronize() -def decode_tokens_to_pcm(mimi: MimiModel, other_mimi: MimiModel, lm_gen: LMGen, tokens: torch.Tensor) -> np.ndarray: +def decode_tokens_to_pcm(mimi: MimiModel, lm_gen: LMGen, tokens: torch.Tensor) -> np.ndarray: """Decode a single step of model tokens to PCM using Mimi. tokens is shaped [B, dep_q+1, 1]; channels 1..dep_q are the agent audio codebooks. Returns a 1D float32 numpy array (mono) for the current frame. """ pcm = mimi.decode(tokens[:, 1:9]) - _ = other_mimi.decode(tokens[:, 1:9]) pcm = pcm.detach().cpu().numpy()[0, 0] return pcm @@ -191,7 +188,6 @@ def run_inference( if mimi_weight is None: mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) # type: ignore mimi = loaders.get_mimi(mimi_weight, device) - other_mimi = loaders.get_mimi(mimi_weight, device) log("info", "mimi loaded") # 2) Load tokenizer @@ -224,12 +220,11 @@ def run_inference( ) # Keep models in streaming mode similar to the server mimi.streaming_forever(1) - other_mimi.streaming_forever(1) lm_gen.streaming_forever(1) # 5) Warmup log("info", "warming up the model") - warmup(mimi, other_mimi, lm_gen, device, frame_size) + warmup(mimi, lm_gen, device, frame_size) # 6) Prompt configuration (text + voice) # System text tokens (k=0) and agent voice-prompt audio (k=1..dep_q) are forced @@ -248,7 +243,6 @@ def run_inference( # - Text prompt injection # - Final audio silence mimi.reset_streaming() - other_mimi.reset_streaming() lm_gen.reset_streaming() lm_gen.step_system_prompts(mimi) # Reset mimi streaming after voice prompt encoding @@ -280,7 +274,7 @@ def run_inference( if tokens is None: continue # Decode current sampled agent frame to PCM - pcm = decode_tokens_to_pcm(mimi, other_mimi, lm_gen, tokens) + pcm = decode_tokens_to_pcm(mimi, lm_gen, tokens) generated_frames.append(pcm) # Decode text token text_token = tokens[0, 0, 0].item() @@ -428,4 +422,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 771f491d..3224345e 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -89,16 +89,14 @@ def wrap_with_system_tags(text: str) -> str: @dataclass class ServerState: mimi: MimiModel - other_mimi: MimiModel text_tokenizer: sentencepiece.SentencePieceProcessor lm_gen: LMGen lock: asyncio.Lock - def __init__(self, mimi: MimiModel, other_mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, + def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, lm: LMModel, device: str | torch.device, voice_prompt_dir: str | None = None, save_voice_prompt_embeddings: bool = False): self.mimi = mimi - self.other_mimi = other_mimi self.text_tokenizer = text_tokenizer self.device = device self.voice_prompt_dir = voice_prompt_dir @@ -113,20 +111,17 @@ def __init__(self, mimi: MimiModel, other_mimi: MimiModel, text_tokenizer: sente self.lock = asyncio.Lock() self.mimi.streaming_forever(1) - self.other_mimi.streaming_forever(1) self.lm_gen.streaming_forever(1) def warmup(self): for _ in range(4): chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) codes = self.mimi.encode(chunk) - _ = self.other_mimi.encode(chunk) for c in range(codes.shape[-1]): tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue _ = self.mimi.decode(tokens[:, 1:9]) - _ = self.other_mimi.decode(tokens[:, 1:9]) if self.device.type == 'cuda': torch.cuda.synchronize() @@ -222,14 +217,12 @@ async def opus_loop(): chunk = torch.from_numpy(chunk) chunk = chunk.to(device=self.device)[None, None] codes = self.mimi.encode(chunk) - _ = self.other_mimi.encode(chunk) for c in range(codes.shape[-1]): tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 main_pcm = self.mimi.decode(tokens[:, 1:9]) - _ = self.other_mimi.decode(tokens[:, 1:9]) main_pcm = main_pcm.cpu() opus_writer.append_pcm(main_pcm[0, 0].numpy()) text_token = tokens[0, 0, 0].item() @@ -263,7 +256,6 @@ async def send_loop(): opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) self.mimi.reset_streaming() - self.other_mimi.reset_streaming() self.lm_gen.reset_streaming() async def is_alive(): if close or ws.closed: @@ -432,7 +424,6 @@ def main(): if args.mimi_weight is None: args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) mimi = loaders.get_mimi(args.mimi_weight, args.device) - other_mimi = loaders.get_mimi(args.mimi_weight, args.device) logger.info("mimi loaded") if args.tokenizer is None: @@ -447,7 +438,6 @@ def main(): logger.info("moshi loaded") state = ServerState( mimi=mimi, - other_mimi=other_mimi, text_tokenizer=text_tokenizer, lm=lm, device=args.device,