Skip to content

Commit

Permalink
Merge pull request #190 from jhj0517/fix/translation-long-input
Browse files Browse the repository at this point in the history
Add `max_length` parameter
  • Loading branch information
jhj0517 authored Jul 3, 2024
2 parents 8da8748 + 184dab0 commit 34da350
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
6 changes: 4 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, args):
print(f"Device \"{self.whisper_inf.device}\" is detected")
self.nllb_inf = NLLBInference(
model_dir=self.args.nllb_model_dir,
output_dir=self.args.output_dir
output_dir=os.path.join(self.args.output_dir, "translations")
)
self.deepl_api = DeepLAPI(
output_dir=self.args.output_dir
Expand Down Expand Up @@ -375,6 +375,8 @@ def launch(self):
choices=self.nllb_inf.available_source_langs)
dd_nllb_targetlang = gr.Dropdown(label="Target Language",
choices=self.nllb_inf.available_target_langs)
with gr.Row():
nb_max_length = gr.Number(label="Max Length Per Line", value=200, precision=0)
with gr.Row():
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
interactive=True)
Expand All @@ -388,7 +390,7 @@ def launch(self):
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")

btn_run.click(fn=self.nllb_inf.translate_file,
inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, cb_timestamp],
inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, nb_max_length, cb_timestamp],
outputs=[tb_indicator, files_subtitles])

btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
Expand Down
8 changes: 6 additions & 2 deletions modules/translation/nllb_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def __init__(self,
self.pipeline = None

def translate(self,
text: str
text: str,
max_length: int
):
result = self.pipeline(text)
result = self.pipeline(
text,
max_length=max_length
)
return result[0]['translation_text']

def update_model(self,
Expand Down
11 changes: 7 additions & 4 deletions modules/translation/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self,

@abstractmethod
def translate(self,
text: str
text: str,
max_length: int
):
pass

Expand All @@ -42,6 +43,7 @@ def translate_file(self,
model_size: str,
src_lang: str,
tgt_lang: str,
max_length: int,
add_timestamp: bool,
progress=gr.Progress()) -> list:
"""
Expand All @@ -57,6 +59,8 @@ def translate_file(self,
Source language of the file to translate from gr.Dropdown()
tgt_lang: str
Target language of the file to translate from gr.Dropdown()
max_length: int
Max length per line to translate
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Expand Down Expand Up @@ -84,7 +88,7 @@ def translate_file(self,
total_progress = len(parsed_dicts)
for index, dic in enumerate(parsed_dicts):
progress(index / total_progress, desc="Translating..")
translated_text = self.translate(dic["sentence"])
translated_text = self.translate(dic["sentence"], max_length=max_length)
dic["sentence"] = translated_text
subtitle = get_serialized_srt(parsed_dicts)

Expand All @@ -99,7 +103,7 @@ def translate_file(self,
total_progress = len(parsed_dicts)
for index, dic in enumerate(parsed_dicts):
progress(index / total_progress, desc="Translating..")
translated_text = self.translate(dic["sentence"])
translated_text = self.translate(dic["sentence"], max_length=max_length)
dic["sentence"] = translated_text
subtitle = get_serialized_vtt(parsed_dicts)

Expand All @@ -124,7 +128,6 @@ def translate_file(self,
print(f"Error: {str(e)}")
finally:
self.release_cuda_memory()
self.remove_input_files([fileobj.name for fileobj in fileobjs])

@staticmethod
def get_device():
Expand Down

0 comments on commit 34da350

Please sign in to comment.