-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo_cli.py
More file actions
executable file
·273 lines (232 loc) · 8.77 KB
/
Copy pathdemo_cli.py
File metadata and controls
executable file
·273 lines (232 loc) · 8.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#!/usr/bin/env python3
"""
Simple CLI wrapper for F5-TTS inference using demo samples
Usage:
python demo_cli.py --sample 1 --gen-text "Sample text to generate"
python demo_cli.py --sample 2 --gen-sample 3 # Use another sample's text
python demo_cli.py --list-samples # List available samples
"""
import argparse
import json
import os
import sys
# Configuration
DEMO_DIR = "./demo_samples"
METADATA_FILE = os.path.join(DEMO_DIR, "samples_metadata.json")
OUTPUT_DIR = "./demo_outputs"
MODEL_NAME = "F5TTS_v1_Base"
VOCAB_PATH = "./data/cantonese_data_pinyin/vocab.txt"
CHECKPOINT_OPTIONS = [
"./ckpts/cantonese_data/model_last.pt",
"./ckpts/cantonese_data_pinyin/model_last.pt",
"./ckpts/cantonese_data_pinyin/pretrained_model_1250000.safetensors",
]
def load_samples_metadata():
"""Load demo samples metadata"""
with open(METADATA_FILE, 'r', encoding='utf-8') as f:
return json.load(f)
def find_checkpoint():
"""Find the best available checkpoint"""
for path in CHECKPOINT_OPTIONS:
if os.path.exists(path):
return path
return None
def list_samples():
"""List available demo samples"""
samples = load_samples_metadata()
print("\n" + "=" * 70)
print("Available Demo Samples")
print("=" * 70)
for sample in samples:
print(f"\nSample {sample['id']}: {sample['filename']}")
print(f" Duration: {sample['duration']}s")
print(f" Text: {sample['text']}")
print("\n" + "=" * 70)
def load_inference_modules():
"""Lazy load heavy inference modules"""
sys.path.append('/home/husrcf/Code/AIAA/AIAA2205-assignment2-F5-TTS/')
import torch
import soundfile as sf
from src.f5_tts.infer.utils_infer import load_model, load_vocoder, infer_process
return torch, sf, load_model, load_vocoder, infer_process
def main():
parser = argparse.ArgumentParser(
description="F5-TTS Demo CLI - Easy inference with demo samples",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# List available samples
python demo_cli.py --list-samples
# Generate using custom text
python demo_cli.py --sample 1 --gen-text "ni3 hao3 shi4 jie4"
# Generate using another sample's text
python demo_cli.py --sample 2 --gen-sample 3
# Custom output filename and quality settings
python demo_cli.py --sample 1 --gen-sample 2 --output my_test.wav --nfe 64 --cfg 2.5
"""
)
# Arguments
parser.add_argument('--list-samples', action='store_true',
help='List available demo samples and exit')
parser.add_argument('--sample', type=int, metavar='N',
help='Reference sample ID (1-5)')
parser.add_argument('--gen-text', type=str, metavar='TEXT',
help='Text to generate (in Pinyin format)')
parser.add_argument('--gen-sample', type=int, metavar='N',
help='Generate text from another sample ID (1-5)')
parser.add_argument('--output', type=str, metavar='FILE',
default=None,
help='Output filename (default: auto-generated in demo_outputs/)')
parser.add_argument('--nfe', type=int, metavar='N',
default=32,
help='NFE steps - quality (default: 32, range: 16-64)')
parser.add_argument('--cfg', type=float, metavar='F',
default=2.0,
help='CFG strength - text faithfulness (default: 2.0, range: 1.0-3.0)')
parser.add_argument('--speed', type=float, metavar='F',
default=1.0,
help='Speech speed multiplier (default: 1.0)')
args = parser.parse_args()
# List samples
if args.list_samples:
list_samples()
return
# Validate arguments
if not args.sample:
parser.error("--sample is required (or use --list-samples)")
if not args.gen_text and not args.gen_sample:
parser.error("Either --gen-text or --gen-sample is required")
if args.gen_text and args.gen_sample:
parser.error("Use only one of --gen-text or --gen-sample")
# Load samples
samples = load_samples_metadata()
# Validate sample IDs
if args.sample < 1 or args.sample > len(samples):
print(f"Error: --sample must be between 1 and {len(samples)}")
return
if args.gen_sample and (args.gen_sample < 1 or args.gen_sample > len(samples)):
print(f"Error: --gen-sample must be between 1 and {len(samples)}")
return
# Get reference sample (ID is 1-indexed, array is 0-indexed)
ref_sample = samples[args.sample - 1]
# Get generation text
if args.gen_text:
gen_text = args.gen_text
gen_desc = f'"{gen_text[:50]}..."' if len(gen_text) > 50 else f'"{gen_text}"'
else:
gen_sample = samples[args.gen_sample - 1]
gen_text = gen_sample['text']
gen_desc = f"Sample {args.gen_sample}'s text"
# Find checkpoint
checkpoint_path = find_checkpoint()
if not checkpoint_path:
print("Error: No checkpoint found!")
print("Searched for:")
for path in CHECKPOINT_OPTIONS:
print(f" - {path}")
return
# Determine output path
if args.output:
output_path = args.output
else:
os.makedirs(OUTPUT_DIR, exist_ok=True)
if args.gen_sample:
filename = f"sample{args.sample}_to_sample{args.gen_sample}.wav"
else:
filename = f"sample{args.sample}_custom.wav"
output_path = os.path.join(OUTPUT_DIR, filename)
# Print configuration
print("=" * 70)
print("F5-TTS Demo CLI - Inference")
print("=" * 70)
print(f"\n📦 Model: {MODEL_NAME}")
print(f"💾 Checkpoint: {checkpoint_path}")
print(f"\n🎤 Reference: Sample {args.sample} ({ref_sample['duration']}s)")
print(f" Text: {ref_sample['text']}")
print(f"\n✍️ Generating: {gen_desc}")
print(f"\n⚙️ Settings:")
print(f" NFE Steps: {args.nfe}")
print(f" CFG Strength: {args.cfg}")
print(f" Speed: {args.speed}x")
print(f"\n💾 Output: {output_path}")
# Load inference modules
print("\n" + "-" * 70)
print("Loading inference modules...")
print("-" * 70)
try:
torch, sf, load_model, load_vocoder, infer_process = load_inference_modules()
device = "cuda" if torch.cuda.is_available() else "cpu"
except Exception as e:
print(f"❌ Error loading modules: {e}")
print("\nMake sure all dependencies are installed:")
print(" pip install -r requirement.txt")
return
# Load model
print("Loading model...")
try:
# Import model configuration
from src.f5_tts.Models.backbones.dit import DiT
# Set model config based on model name
if MODEL_NAME == "F5TTS_v1_Base":
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
else:
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=4)
model_cls = DiT
model = load_model(
model_cls=model_cls,
model_cfg=model_cfg,
ckpt_path=checkpoint_path,
vocab_file=VOCAB_PATH,
device=device
)
print("✓ Model loaded")
vocoder = load_vocoder(
vocoder_name="vocos",
is_local=False,
local_path=None
)
print("✓ Vocoder loaded")
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# Generate
print("\n" + "-" * 70)
print("Generating speech...")
print("-" * 70)
try:
generated_audio, sample_rate, _ = infer_process(
ref_audio=ref_sample['audio_path'],
ref_text=ref_sample['text'],
gen_text=gen_text,
model_obj=model,
vocoder=vocoder,
mel_spec_type="vocos",
show_info=print,
progress=None,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=args.nfe,
cfg_strength=args.cfg,
sway_sampling_coef=-1.0,
speed=args.speed,
fix_duration=None,
device=device
)
# Save
sf.write(output_path, generated_audio, sample_rate)
duration = len(generated_audio) / sample_rate
print("\n" + "=" * 70)
print("✓ Success!")
print("=" * 70)
print(f"\n💾 Saved to: {output_path}")
print(f"⏱️ Duration: {duration:.2f}s")
print(f"🎵 Sample rate: {sample_rate}Hz")
except Exception as e:
print(f"\n❌ Error during generation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()