Skip to content

Commit 0be9263

Browse files
committed
feat: add support for MMS TTS models via Modal
1 parent e3b9171 commit 0be9263

File tree

4 files changed

+198
-0
lines changed

4 files changed

+198
-0
lines changed

daras_ai_v2/mms_tts.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
To deploy changes to remote functions, run this file directly as a script:
3+
4+
```bash
5+
poetry run python daras_ai_v2/mms_tts.py
6+
```
7+
"""
8+
9+
import modal
10+
from decouple import config
11+
12+
13+
MMS_TTS_SUPPORTED_LANGUAGES = {
14+
"abi", "abp", "aca", "acd", "ace", "acf", "ach", "acn", "acr", "acu", "ade", "adh", "adj", "adx", "aeu", "agd",
15+
"agg", "agn", "agr", "agu", "agx", "aha", "ahk", "aia", "aka", "akb", "ake", "akp", "alj", "alp", "alt", "alz",
16+
"ame", "amf", "amh", "ami", "amk", "ann", "any", "aoz", "apb", "apr", "ara", "arl", "asa", "asg", "asm", "ata",
17+
"atb", "atg", "ati", "atq", "ava", "avn", "avu", "awa", "awb", "ayo", "ayr", "ayz", "azb", "azg", "azj", "azz",
18+
"bak", "bam", "ban", "bao", "bav", "bba", "bbb", "bbc", "bbo", "bcc", "bcl", "bcw", "bdg", "bdh", "bdq", "bdu",
19+
"bdv", "beh", "bem", "ben", "bep", "bex", "bfa", "bfo", "bfy", "bfz", "bgc", "bgq", "bgr", "bgt", "bgw", "bha",
20+
"bht", "bhz", "bib", "bim", "bis", "biv", "bjr", "bjv", "bjw", "bjz", "bkd", "bkv", "blh", "blt", "blx", "blz",
21+
"bmq", "bmr", "bmu", "bmv", "bng", "bno", "bnp", "boa", "bod", "boj", "bom", "bor", "bov", "box", "bpr", "bps",
22+
"bqc", "bqi", "bqj", "bqp", "bru", "bsc", "bsq", "bss", "btd", "bts", "btt", "btx", "bud", "bul", "bus", "bvc",
23+
"bvz", "bwq", "bwu", "byr", "bzh", "bzi", "bzj", "caa", "cab", "cac", "cak", "cap", "car", "cas", "cat", "cax",
24+
"cbc", "cbi", "cbr", "cbs", "cbt", "cbu", "cbv", "cce", "cco", "cdj", "ceb", "ceg", "cek", "cfm", "cgc", "che",
25+
"chf", "chv", "chz", "cjo", "cjp", "cjs", "cko", "ckt", "cla", "cle", "cly", "cme", "cmo", "cmr", "cnh", "cni",
26+
"cnl", "cnt", "coe", "cof", "cok", "con", "cot", "cou", "cpa", "cpb", "cpu", "crh", "crk", "crn", "crq", "crs",
27+
"crt", "csk", "cso", "ctd", "ctg", "cto", "ctu", "cuc", "cui", "cuk", "cul", "cwa", "cwe", "cwt", "cya", "cym",
28+
"daa", "dah", "dar", "dbj", "dbq", "ddn", "ded", "des", "deu", "dga", "dgi", "dgk", "dgo", "dgr", "dhi", "did",
29+
"dig", "dik", "dip", "div", "djk", "dnj", "dnt", "dnw", "dop", "dos", "dsh", "dso", "dtp", "dts", "dug", "dwr",
30+
"dyi", "dyo", "dyu", "dzo", "eip", "eka", "ell", "emp", "enb", "eng", "enx", "ese", "ess", "eus", "evn", "ewe",
31+
"eza", "fal", "fao", "far", "fas", "fij", "fin", "flr", "fmu", "fon", "fra", "frd", "ful", "gag", "gai", "gam",
32+
"gau", "gbi", "gbk", "gbm", "gbo", "gde", "geb", "gej", "gil", "gjn", "gkn", "gld", "glk", "gmv", "gna", "gnd",
33+
"gng", "gof", "gog", "gor", "gqr", "grc", "gri", "grn", "grt", "gso", "gub", "guc", "gud", "guh", "guj", "guk",
34+
"gum", "guo", "guq", "guu", "gux", "gvc", "gvl", "gwi", "gwr", "gym", "gyr", "had", "hag", "hak", "hap", "hat",
35+
"hau", "hay", "heb", "heh", "hif", "hig", "hil", "hin", "hlb", "hlt", "hne", "hnn", "hns", "hoc", "hoy", "hto",
36+
"hub", "hui", "hun", "hus", "huu", "huv", "hvn", "hwc", "hyw", "iba", "icr", "idd", "ifa", "ifb", "ife", "ifk",
37+
"ifu", "ify", "ign", "ikk", "ilb", "ilo", "imo", "inb", "ind", "iou", "ipi", "iqw", "iri", "irk", "isl", "itl",
38+
"itv", "ixl", "izr", "izz", "jac", "jam", "jav", "jbu", "jen", "jic", "jiv", "jmc", "jmd", "jun", "juy", "jvn",
39+
"kaa", "kab", "kac", "kak", "kan", "kao", "kaq", "kay", "kaz", "kbo", "kbp", "kbq", "kbr", "kby", "kca", "kcg",
40+
"kdc", "kde", "kdh", "kdi", "kdj", "kdl", "kdn", "kdt", "kek", "ken", "keo", "ker", "key", "kez", "kfb", "kff",
41+
"kfw", "kfx", "khg", "khm", "khq", "kia", "kij", "kik", "kin", "kir", "kjb", "kje", "kjg", "kjh", "kki", "kkj",
42+
"kle", "klu", "klv", "klw", "kma", "kmd", "kml", "kmr", "kmu", "knb", "kne", "knf", "knj", "knk", "kno", "kog",
43+
"kor", "kpq", "kps", "kpv", "kpy", "kpz", "kqe", "kqp", "kqr", "kqy", "krc", "kri", "krj", "krl", "krr", "krs",
44+
"kru", "ksb", "ksr", "kss", "ktb", "ktj", "kub", "kue", "kum", "kus", "kvn", "kvw", "kwd", "kwf", "kwi", "kxc",
45+
"kxf", "kxm", "kxv", "kyb", "kyc", "kyf", "kyg", "kyo", "kyq", "kyu", "kyz", "kzf", "lac", "laj", "lam", "lao",
46+
"las", "lat", "lav", "law", "lbj", "lbw", "lcp", "lee", "lef", "lem", "lew", "lex", "lgg", "lgl", "lhu", "lia",
47+
"lid", "lif", "lip", "lis", "lje", "ljp", "llg", "lln", "lme", "lnd", "lns", "lob", "lok", "lom", "lon", "loq",
48+
"lsi", "lsm", "luc", "lug", "lwo", "lww", "lzz", "maa", "mad", "mag", "mah", "mai", "maj", "mak", "mal", "mam",
49+
"maq", "mar", "maw", "maz", "mbb", "mbc", "mbh", "mbj", "mbt", "mbu", "mbz", "mca", "mcb", "mcd", "mco", "mcp",
50+
"mcq", "mcu", "mda", "mdv", "mdy", "med", "mee", "mej", "men", "meq", "met", "mev", "mfe", "mfh", "mfi", "mfk",
51+
"mfq", "mfy", "mfz", "mgd", "mge", "mgh", "mgo", "mhi", "mhr", "mhu", "mhx", "mhy", "mib", "mie", "mif", "mih",
52+
"mil", "mim", "min", "mio", "mip", "miq", "mit", "miy", "miz", "mjl", "mjv", "mkl", "mkn", "mlg", "mmg", "mnb",
53+
"mnf", "mnk", "mnw", "mnx", "moa", "mog", "mon", "mop", "mor", "mos", "mox", "moz", "mpg", "mpm", "mpp", "mpx",
54+
"mqb", "mqf", "mqj", "mqn", "mrw", "msy", "mtd", "mtj", "mto", "muh", "mup", "mur", "muv", "muy", "mvp", "mwq",
55+
"mwv", "mxb", "mxq", "mxt", "mxv", "mya", "myb", "myk", "myl", "myv", "myx", "myy", "mza", "mzi", "mzj", "mzk",
56+
"mzm", "mzw", "nab", "nag", "nan", "nas", "naw", "nca", "nch", "ncj", "ncl", "ncu", "ndj", "ndp", "ndv", "ndy",
57+
"ndz", "neb", "new", "nfa", "nfr", "nga", "ngl", "ngp", "ngu", "nhe", "nhi", "nhu", "nhw", "nhx", "nhy", "nia",
58+
"nij", "nim", "nin", "nko", "nlc", "nld", "nlg", "nlk", "nmz", "nnb", "nnq", "nnw", "noa", "nod", "nog", "not",
59+
"npl", "npy", "nst", "nsu", "ntm", "ntr", "nuj", "nus", "nuz", "nwb", "nxq", "nya", "nyf", "nyn", "nyo", "nyy",
60+
"nzi", "obo", "ojb", "oku", "old", "omw", "onb", "ood", "orm", "ory", "oss", "ote", "otq", "ozm", "pab", "pad",
61+
"pag", "pam", "pan", "pao", "pap", "pau", "pbb", "pbc", "pbi", "pce", "pcm", "peg", "pez", "pib", "pil", "pir",
62+
"pis", "pjt", "pkb", "pls", "plw", "pmf", "pny", "poh", "poi", "pol", "por", "poy", "ppk", "pps", "prf", "prk",
63+
"prt", "pse", "pss", "ptu", "pui", "pwg", "pww", "pxm", "qub", "quc", "quf", "quh", "qul", "quw", "quy", "quz",
64+
"qvc", "qve", "qvh", "qvm", "qvn", "qvo", "qvs", "qvw", "qvz", "qwh", "qxh", "qxl", "qxn", "qxo", "qxr", "rah",
65+
"rai", "rap", "rav", "raw", "rej", "rel", "rgu", "rhg", "rif", "ril", "rim", "rjs", "rkt", "rmc", "rmo", "rmy",
66+
"rng", "rnl", "rol", "ron", "rop", "rro", "rub", "ruf", "rug", "run", "rus", "sab", "sag", "sah", "saj", "saq",
67+
"sas", "sba", "sbd", "sbl", "sbp", "sch", "sck", "sda", "sea", "seh", "ses", "sey", "sgb", "sgj", "sgw", "shi",
68+
"shk", "shn", "sho", "shp", "sid", "sig", "sil", "sja", "sjm", "sld", "slu", "sml", "smo", "sna", "sne", "snn",
69+
"snp", "snw", "som", "soy", "spa", "spp", "spy", "sqi", "sri", "srm", "srn", "srx", "stn", "stp", "suc", "suk",
70+
"sun", "sur", "sus", "suv", "suz", "swe", "swh", "sxb", "sxn", "sya", "syl", "sza", "tac", "taj", "tam", "tao",
71+
"tap", "taq", "tat", "tav", "tbc", "tbg", "tbk", "tbl", "tby", "tbz", "tca", "tcc", "tcs", "tcz", "tdj", "ted",
72+
"tee", "tel", "tem", "teo", "ter", "tes", "tew", "tex", "tfr", "tgj", "tgk", "tgl", "tgo", "tgp", "tha", "thk",
73+
"thl", "tih", "tik", "tir", "tkr", "tlb", "tlj", "tly", "tmc", "tmf", "tna", "tng", "tnk", "tnn", "tnp", "tnr",
74+
"tnt", "tob", "toc", "toh", "tom", "tos", "tpi", "tpm", "tpp", "tpt", "trc", "tri", "trn", "trs", "tso", "tsz",
75+
"ttc", "tte", "ttq", "tue", "tuf", "tuk", "tuo", "tur", "tvw", "twb", "twe", "twu", "txa", "txq", "txu", "tye",
76+
"tzh", "tzj", "tzo", "ubl", "ubu", "udm", "udu", "uig", "ukr", "unr", "upv", "ura", "urb", "urd", "urk", "urt",
77+
"ury", "usp", "uzb", "vag", "vid", "vie", "vif", "vmw", "vmy", "vun", "vut", "wal", "wap", "war", "waw", "way",
78+
"wba", "wlo", "wlx", "wmw", "wob", "wsg", "wwa", "xal", "xdy", "xed", "xer", "xmm", "xnj", "xnr", "xog", "xon",
79+
"xrb", "xsb", "xsm", "xsr", "xsu", "xta", "xtd", "xte", "xtm", "xtn", "xua", "xuo", "yaa", "yad", "yal", "yam",
80+
"yao", "yas", "yat", "yaz", "yba", "ybb", "ycl", "ycn", "yea", "yka", "yli", "yor", "yre", "yua", "yuz", "yva",
81+
"zaa", "zab", "zac", "zad", "zae", "zai", "zam", "zao", "zaq", "zar", "zas", "zav", "zaw", "zca", "zga", "zim",
82+
"ziw", "zlm", "zmz", "zne", "zos", "zpc", "zpg", "zpi", "zpl", "zpm", "zpo", "zpt", "zpu", "zpz", "ztq", "zty",
83+
"zyb", "zyp", "zza"
84+
} # fmt: skip
85+
86+
87+
app = modal.App("gooey-mms-tts-runner")
88+
89+
cache_dir = "/cache"
90+
model_cache = modal.Volume.from_name("hf-model-cache", create_if_missing=True)
91+
image = (
92+
modal.Image.debian_slim()
93+
.pip_install(
94+
"transformers~=4.44",
95+
"huggingface_hub[hf_transfer]~=0.34.4",
96+
"torch~=2.5.1",
97+
"scipy~=1.11",
98+
"python-decouple~=3.6",
99+
)
100+
.env({"HF_HUB_CACHE": cache_dir, "HF_TOKEN": config("HF_TOKEN", cast=str)})
101+
)
102+
103+
104+
def load_pipe(language: str):
105+
import torch
106+
from transformers import pipeline
107+
108+
has_cuda = torch.cuda.is_available()
109+
if has_cuda:
110+
print("Using GPU")
111+
else:
112+
print("GPU not available, using CPU")
113+
114+
pipe = pipeline(
115+
"text-to-speech",
116+
model=f"facebook/mms-tts-{language}",
117+
tokenizer=f"facebook/mms-tts-{language}",
118+
device=0 if torch.cuda.is_available() else -1,
119+
)
120+
121+
return pipe
122+
123+
124+
@app.function(
125+
image=image,
126+
gpu="a10g",
127+
timeout=30 * 60,
128+
volumes={"/cache": model_cache},
129+
enable_memory_snapshot=True,
130+
)
131+
def run_mms_tts(language: str, text: str) -> bytes:
132+
import io
133+
import torch
134+
import scipy
135+
136+
pipe = load_pipe(language)
137+
138+
print("Running inference")
139+
with torch.no_grad():
140+
output = pipe(text)
141+
142+
b = io.BytesIO()
143+
scipy.io.wavfile.write(b, rate=output["sampling_rate"], data=output["audio"][0])
144+
return b.getvalue()
145+
146+
147+
if __name__ == "__main__":
148+
with modal.enable_output():
149+
app.deploy()

daras_ai_v2/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,5 @@
517517
if MODAL_TOKEN_ID and MODAL_TOKEN_SECRET:
518518
os.environ["MODAL_TOKEN_ID"] = MODAL_TOKEN_ID
519519
os.environ["MODAL_TOKEN_SECRET"] = MODAL_TOKEN_SECRET
520+
521+
HF_TOKEN = config("HF_TOKEN", "")

daras_ai_v2/text_to_speech_settings_widgets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class TextToSpeechProviders(Enum):
7474
AZURE_TTS = "Azure Text-to-Speech"
7575
OPEN_AI = "OpenAI"
7676
GHANA_NLP = "GhanaNLP Text-To-Speech"
77+
MMS_TTS = "MMS TTS (Meta)"
7778

7879

7980
# This exists only for backwards compatiblity
@@ -170,6 +171,8 @@ def text_to_speech_provider_selector(page):
170171
openai_tts_selector()
171172
case TextToSpeechProviders.GHANA_NLP.name:
172173
ghana_nlp_tts_selector()
174+
case TextToSpeechProviders.MMS_TTS.name:
175+
mms_tts_selector()
173176
return tts_provider
174177

175178

@@ -198,6 +201,29 @@ def ghana_nlp_tts_selector():
198201
)
199202

200203

204+
def mms_tts_selector():
205+
options = mms_tts_language_options()
206+
gui.selectbox(
207+
label="""
208+
###### MMS TTS Language
209+
""",
210+
key="mms_tts_language",
211+
format_func=lambda lang: options[lang],
212+
options=options,
213+
)
214+
215+
216+
@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
217+
def mms_tts_language_options():
218+
import langcodes
219+
from daras_ai_v2.mms_tts import MMS_TTS_SUPPORTED_LANGUAGES
220+
221+
result = {}
222+
for lang in MMS_TTS_SUPPORTED_LANGUAGES:
223+
result[lang] = langcodes.Language.get(lang).display_name()
224+
return result
225+
226+
201227
def openai_tts_selector():
202228
enum_selector(
203229
OpenAI_TTS_Voices,

recipes/TextToSpeech.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import gooey_gui as gui
66
import requests
7+
import modal
78
from pydantic import BaseModel, Field
89

910
from bots.models import Workflow
@@ -64,6 +65,8 @@ class TextToSpeechSettings(BaseModel):
6465
openai_tts_model: OpenAI_TTS_Models.api_choices | None = None
6566
ghana_nlp_tts_language: GHANA_NLP_TTS_LANGUAGES.api_choices | None = None
6667

68+
mms_tts_language: str = "eng"
69+
6770

6871
class TextToSpeechPage(BasePage):
6972
title = "Compare AI Voice Generators"
@@ -408,6 +411,24 @@ def run(self, state: dict):
408411
audio_url = upload_file_from_bytes("ghana_gen.wav", response.content)
409412
state["audio_url"] = audio_url
410413

414+
case TextToSpeechProviders.MMS_TTS:
415+
from daras_ai_v2.mms_tts import (
416+
MMS_TTS_SUPPORTED_LANGUAGES,
417+
app as modal_app,
418+
)
419+
420+
language = state.get("mms_tts_language", "eng")
421+
if language not in MMS_TTS_SUPPORTED_LANGUAGES:
422+
raise UserError(f"Unsupported language: {language}")
423+
424+
run_mms_tts = modal.Function.lookup(modal_app.name, "run_mms_tts")
425+
with modal.enable_output():
426+
audio = run_mms_tts.remote(language=language, text=text)
427+
428+
state["audio_url"] = upload_file_from_bytes(
429+
filename="output.wav", data=audio, content_type="audio/wav"
430+
)
431+
411432
def _get_elevenlabs_voice_model(self, state: dict[str, str]):
412433
default_voice_model = next(iter(ELEVEN_LABS_MODELS))
413434
voice_model = state.get("elevenlabs_model", default_voice_model)

0 commit comments

Comments
 (0)