-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmal.py
More file actions
91 lines (76 loc) · 3.53 KB
/
mal.py
File metadata and controls
91 lines (76 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# mal.py
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor
import re, time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Domain-specific pre/post maps
PRE_TRANSLATE_MAP = {
"ബി.ടെക്ക്": "B.Tech",
"എം.സിഎ": "MCA",
"എം.ടെക്": "M.Tech"
}
POST_TRANSLATE_MAP = {v: k for k, v in PRE_TRANSLATE_MAP.items()}
class MalTranslator:
def __init__(self, direction="ml-en"):
if direction == "ml-en":
self.model_name = "ai4bharat/indictrans2-indic-en-dist-200M"
self.src_lang, self.tgt_lang = "mal_Mlym", "eng_Latn"
elif direction == "en-ml":
self.model_name = "ai4bharat/indictrans2-en-indic-dist-200M"
self.src_lang, self.tgt_lang = "eng_Latn", "mal_Mlym"
else:
raise ValueError("Direction must be 'ml-en' or 'en-ml'.")
print(f"🔄 Loading model: {self.model_name} on {DEVICE} ...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
# Compile for faster inference
if DEVICE == "cuda":
self.model = torch.compile(self.model)
self.ip = IndicProcessor(inference=True)
print("✅ Model & tokenizer loaded successfully.\n")
def _apply_pre_map(self, text):
for k, v in PRE_TRANSLATE_MAP.items():
text = re.sub(rf"{re.escape(k)}[\u0D00-\u0D7F]*", v, text)
return text
def _apply_post_map(self, text):
text = re.sub(r"<.*?>", "", text) # remove artifacts
for k, v in POST_TRANSLATE_MAP.items():
text = re.sub(re.escape(k), v, text)
return text.strip()
def translate(self, sentences, debug=False):
sentences_pre = [self._apply_pre_map(s) for s in sentences]
batch = self.ip.preprocess_batch(sentences_pre, src_lang=self.src_lang, tgt_lang=self.tgt_lang)
inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
t0 = time.time()
with torch.no_grad():
gen_output = self.model.generate(
**inputs,
max_new_tokens=30,
num_beams=1,
do_sample=False,
use_cache=True
)
t_gen = (time.time() - t0) * 1000
decoded = self.tokenizer.batch_decode(gen_output, skip_special_tokens=True)
out_final = self.ip.postprocess_batch(decoded, lang=self.tgt_lang)
out_final = [self._apply_post_map(s) for s in out_final]
if debug:
for inp, out in zip(sentences, out_final):
print(f"\n🔹 Input: {inp}")
print(f"🔹 Output: {out}")
print(f"⚡ Translation took {t_gen:.0f} ms for {len(sentences)} sentence(s)")
return out_final
if __name__ == "__main__":
translator = MalTranslator("ml-en")
test_sentences = [
"എനിക്ക് ബി.ടെക്കിന്റെ അഡ്മിഷൻ തീയതികൾ അറിയണം.",
"എം.സിഎ ഫീസ് എന്താണ്?",
"എം.ടെക് പ്രവേശനം എപ്പോഴാണ് തുടങ്ങുന്നത്?"
]
out = translator.translate(test_sentences, debug=True)
print("\nML → EN:", out)