1414import torch
1515import torchaudio as ta
1616from chatterbox .tts import ChatterboxTTS
17-
17+ from chatterbox . mtl_tts import ChatterboxMultilingualTTS
1818import grpc
1919
20+ def is_float (s ):
21+ """Check if a string can be converted to float."""
22+ try :
23+ float (s )
24+ return True
25+ except ValueError :
26+ return False
27+ def is_int (s ):
28+ """Check if a string can be converted to int."""
29+ try :
30+ int (s )
31+ return True
32+ except ValueError :
33+ return False
2034
2135_ONE_DAY_IN_SECONDS = 60 * 60 * 24
2236
@@ -47,6 +61,28 @@ def LoadModel(self, request, context):
4761 if not torch .cuda .is_available () and request .CUDA :
4862 return backend_pb2 .Result (success = False , message = "CUDA is not available" )
4963
64+
65+ options = request .Options
66+
67+ # empty dict
68+ self .options = {}
69+
70+ # The options are a list of strings in this form optname:optvalue
71+ # We are storing all the options in a dict so we can use it later when
72+ # generating the images
73+ for opt in options :
74+ if ":" not in opt :
75+ continue
76+ key , value = opt .split (":" )
77+ # if value is a number, convert it to the appropriate type
78+ if is_float (value ):
79+ value = float (value )
80+ elif is_int (value ):
81+ value = int (value )
82+ elif value .lower () in ["true" , "false" ]:
83+ value = value .lower () == "true"
84+ self .options [key ] = value
85+
5086 self .AudioPath = None
5187
5288 if os .path .isabs (request .AudioPath ):
@@ -56,10 +92,14 @@ def LoadModel(self, request, context):
5692 modelFileBase = os .path .dirname (request .ModelFile )
5793 # modify LoraAdapter to be relative to modelFileBase
5894 self .AudioPath = os .path .join (modelFileBase , request .AudioPath )
59-
6095 try :
6196 print ("Preparing models, please wait" , file = sys .stderr )
62- self .model = ChatterboxTTS .from_pretrained (device = device )
97+ if "multilingual" in self .options :
98+ # remove key from options
99+ del self .options ["multilingual" ]
100+ self .model = ChatterboxMultilingualTTS .from_pretrained (device = device )
101+ else :
102+ self .model = ChatterboxTTS .from_pretrained (device = device )
63103 except Exception as err :
64104 return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
65105 # Implement your logic here for the LoadModel service
@@ -68,12 +108,18 @@ def LoadModel(self, request, context):
68108
69109 def TTS (self , request , context ):
70110 try :
71- # Generate audio using ChatterboxTTS
111+ kwargs = {}
112+
113+ if "language" in self .options :
114+ kwargs ["language_id" ] = self .options ["language" ]
72115 if self .AudioPath is not None :
73- wav = self .model .generate (request .text , audio_prompt_path = self .AudioPath )
74- else :
75- wav = self .model .generate (request .text )
76-
116+ kwargs ["audio_prompt_path" ] = self .AudioPath
117+
118+ # add options to kwargs
119+ kwargs .update (self .options )
120+
121+ # Generate audio using ChatterboxTTS
122+ wav = self .model .generate (request .text , ** kwargs )
77123 # Save the generated audio
78124 ta .save (request .dst , wav , self .model .sr )
79125
0 commit comments