Skip to content

Commit

Permalink
Minor updates for timestamp accuracy. (#42)
Browse files Browse the repository at this point in the history
* Update default temperature value for agents.

* Add test-suites for temperature clamp.

* Improve timestamp accuracy.

* Increase default fee-limit.
  • Loading branch information
zh-plus authored Jun 13, 2024
1 parent 0c21adc commit 988fd07
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: CI

on:
push:
branches: [ "master" ]
branches: [ "master", "dev" ]
pull_request:
branches: [ "master" ]

Expand Down
15 changes: 8 additions & 7 deletions openlrc/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
# All rights reserved.
import abc
import re
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Type, Union

from openlrc.chatbot import route_chatbot
from openlrc.chatbot import route_chatbot, GPTBot, ClaudeBot
from openlrc.context import TranslationContext, TranslateInfo
from openlrc.logger import logger
from openlrc.prompter import BaseTranslatePrompter, ContextReviewPrompter, POTENTIAL_PREFIX_COMBOS, \
ProofreaderPrompter, PROOFREAD_PREFIX


class Agent(abc.ABC):
TEMPERATURE = 0.5
TEMPERATURE = 1
"""
Base class for all agents.
"""

def _initialize_chatbot(self, chatbot_model: str, fee_limit: float, proxy: str, base_url_config: Optional[dict]):
chatbot_cls: Union[Type[ClaudeBot], Type[GPTBot]]
chatbot_cls, model_name = route_chatbot(chatbot_model)
return chatbot_cls(model=model_name, fee_limit=fee_limit, proxy=proxy, retry=3,
temperature=self.TEMPERATURE, base_url_config=base_url_config)
Expand All @@ -28,10 +29,10 @@ class ChunkedTranslatorAgent(Agent):
Translate the well-defined chunked text to the target language and send it to the chatbot for further processing.
"""

TEMPERATURE = 0.9
TEMPERATURE = 1.0

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.chatbot_model = chatbot_model
Expand Down Expand Up @@ -108,7 +109,7 @@ class ContextReviewerAgent(Agent):
TEMPERATURE = 0.8

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.src_lang = src_lang
Expand Down Expand Up @@ -139,7 +140,7 @@ class ProofreaderAgent(Agent):
TEMPERATURE = 0.8

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.src_lang = src_lang
Expand Down
18 changes: 8 additions & 10 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def route_chatbot(model):
class ChatBot:
pricing = None

def __init__(self, pricing, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.2):
def __init__(self, pricing, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.25):
self.pricing = pricing
self._model = None

Expand Down Expand Up @@ -172,6 +172,9 @@ class GPTBot(ChatBot):
def __init__(self, model='gpt-3.5-turbo-0125', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False,
fee_limit=0.05, proxy=None, base_url_config=None):

# clamp temperature to 0-2
temperature = max(0, min(2, temperature))

super().__init__(self.pricing, temperature, top_p, retry, max_async, fee_limit)

self.async_client = AsyncGPTClient(
Expand All @@ -181,12 +184,7 @@ def __init__(self, model='gpt-3.5-turbo-0125', temperature=1, top_p=1, retry=8,
)

self.model = model
self.temperature = temperature
self.top_p = top_p
self.retry = retry
self.max_async = max_async
self.json_mode = json_mode
self.fee_limit = fee_limit

def __exit__(self, exc_type, exc_val, exc_tb):
self.async_client.close()
Expand Down Expand Up @@ -252,9 +250,12 @@ class ClaudeBot(ChatBot):
'claude-3-haiku-20240307': (0.25, 1.25)
}

def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.2,
def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.25,
proxy=None, base_url_config=None):

# clamp temperature to 0-1
temperature = max(0, min(1, temperature))

super().__init__(self.pricing, temperature, top_p, retry, max_async, fee_limit)

self.async_client = AsyncAnthropic(
Expand All @@ -266,9 +267,6 @@ def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, ret
)

self.model = model
self.retry = retry
self.max_async = max_async
self.fee_limit = fee_limit

def update_fee(self, response: Message):
prompt_price, completion_price = all_pricing[self.model]
Expand Down
2 changes: 1 addition & 1 deletion openlrc/openlrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class LRCer:
"""

def __init__(self, whisper_model='large-v3', compute_type='float16', chatbot_model: str = 'gpt-3.5-turbo',
fee_limit=0.2, consumer_thread=4, asr_options=None, vad_options=None, preprocess_options=None,
fee_limit=0.25, consumer_thread=4, asr_options=None, vad_options=None, preprocess_options=None,
proxy=None, base_url_config=None, glossary: Union[dict, str, Path] = None, retry_model=None):
self.chatbot_model = chatbot_model
self.fee_limit = fee_limit
Expand Down
42 changes: 37 additions & 5 deletions openlrc/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from openlrc.logger import logger
from openlrc.subtitle import Subtitle
from openlrc.utils import extend_filename
from openlrc.utils import extend_filename, format_timestamp

# Thresholds for different languages
CUT_LONG_THRESHOLD = {
Expand Down Expand Up @@ -81,18 +81,42 @@ def merge_short(self, threshold=1.2):
merged_element = None
new_elements.append(element)
else:
merged_element = self._merge_elements(merged_element, element)
if not merged_element:
merged_element = element
continue

# Merge to previous element if closer to pre-element and gap > 3s
previous_gap = merged_element.start - new_elements[-1].start
next_gap = element.start - merged_element.end
if previous_gap <= next_gap and previous_gap <= 3:
previous_element = new_elements.pop()
merged_element.text = previous_element.text + merged_element.text
merged_element.start = previous_element.start
new_elements.append(merged_element)
merged_element = element
elif next_gap <= previous_gap and next_gap <= 3:
merged_element.text += element.text
merged_element.end = element.end
new_elements.append(merged_element)
merged_element = None
else:
new_elements.append(merged_element)
merged_element = element

self.subtitle.segments = new_elements

def _finalize_merge(self, new_elements, merged_element, element):
if merged_element.duration < 1.5:
if element.start - merged_element.end < merged_element.start - new_elements[-1].end:
previous_gap = merged_element.start - new_elements[-1].end
next_gap = element.start - merged_element.end
if previous_gap <= next_gap and previous_gap <= 3:
new_elements[-1].text += merged_element.text
new_elements[-1].end = merged_element.end
elif next_gap <= previous_gap and next_gap <= 3:
element.text = merged_element.text + element.text
element.start = merged_element.start
else:
new_elements[-1].text += merged_element.text
new_elements[-1].end = merged_element.end
new_elements.append(merged_element)
else:
new_elements.append(merged_element)

Expand Down Expand Up @@ -204,10 +228,18 @@ def perform_all(self, steps: Optional[List[str]] = None, extend_time=False):
if extend_time:
self.extend_time()

# Finally check to notify users
self.check()

def save(self, output_name: Optional[str] = None, update_name=False):
"""
Save the optimized subtitle to a file.
"""
optimized_name = extend_filename(self.filename, '_optimized') if not output_name else output_name
self.subtitle.save(optimized_name, update_name=update_name)
logger.info(f'Optimized json file saved to {optimized_name}')

def check(self):
for element in self.subtitle.segments:
if element.duration >= 10:
logger.warning(f'Duration of text "{element.text}" at {format_timestamp(element.start)} exceeds 10')
45 changes: 33 additions & 12 deletions openlrc/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,19 @@ def is_punct(char):
latter_words = seg_entry.words[len(former_words):]

if not latter_words:
# Directly split using the hard-mid
former_words = seg_entry.words[:len(seg_entry.words) // 2]
latter_words = seg_entry.words[len(seg_entry.words) // 2:]
# Directly split using the largest word-word gap
gaps = [-1]
for k in range(len(seg_entry.words) - 1):
gaps.append(seg_entry.words[k + 1].start - seg_entry.words[k].end)
max_gap = max(gaps)
split_idx = gaps.index(max_gap) # TODO: Multiple largest or Multiple long gap

if max_gap >= 2: # Split using the max gap
former_words = seg_entry.words[:split_idx]
latter_words = seg_entry.words[split_idx:]
else: # Split using hard-mid
former_words = seg_entry.words[:len(seg_entry.words) // 2]
latter_words = seg_entry.words[len(seg_entry.words) // 2:]

former = seg_from_words(seg_entry, seg_entry.id, former_words, seg_entry.tokens[:len(former_words)])
latter = seg_from_words(seg_entry, seg_entry.id + 1, latter_words, seg_entry.tokens[len(former_words):])
Expand Down Expand Up @@ -162,16 +172,27 @@ def is_punct(char):
entry = seg_from_words(segment, id_cnt, split_words,
segment.tokens[word_start: word_start + len(split_words)])

# Check if the sentence is too long in words
if len(split) < (45 if lang in self.continuous_scripted else 90) or len(entry.words) == 1:
# split if duration > 10s
if entry.end - entry.start > 10:
segmented_entries = mid_split(entry)
def recursive_segment(entry):
if len(entry.text) < (45 if lang in self.continuous_scripted else 90) or len(entry.words) == 1:
if entry.end - entry.start > 10:
# split if duration > 10s
segmented_entries = mid_split(entry)
further_segmented = []
for segment in segmented_entries:
further_segmented.extend(recursive_segment(segment))
else:
return [entry]
else:
segmented_entries = [entry]
else:
# Split them in the middle
segmented_entries = mid_split(entry)
# Split them in the middle
segmented_entries = mid_split(entry)
further_segmented = []
for segment in segmented_entries:
further_segmented.extend(recursive_segment(segment))

return further_segmented

# Check if the sentence is too long in words
segmented_entries = recursive_segment(entry)

sentences.extend(segmented_entries)
id_cnt += len(segmented_entries)
Expand Down
2 changes: 1 addition & 1 deletion openlrc/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def translate(self, texts: Union[str, List[str]], src_lang: str, target_lang: st
class LLMTranslator(Translator):
CHUNK_SIZE = 30

def __init__(self, chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, chunk_size: int = CHUNK_SIZE,
def __init__(self, chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, chunk_size: int = CHUNK_SIZE,
intercept_line: Optional[int] = None, proxy: Optional[str] = None,
base_url_config: Optional[dict] = None,
retry_model: Optional[str] = None):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,16 @@ def test_route_chatbot_error(self):
chatbot_model = 'openai: invalid_model_name'
with self.assertRaises(ValueError):
route_chatbot(chatbot_model + 'error')

def test_temperature_clamp(self):
chatbot1 = GPTBot(temperature=10, top_p=1, retry=8, max_async=16)
chatbot2 = GPTBot(temperature=-1, top_p=1, retry=8, max_async=16)
chatbot3 = ClaudeBot(temperature=2, top_p=1, retry=8, max_async=16)
chatbot4 = ClaudeBot(temperature=-1, top_p=1, retry=8, max_async=16)

self.assertEqual(chatbot1.temperature, 2)
self.assertEqual(chatbot2.temperature, 0)
self.assertEqual(chatbot3.temperature, 1)
self.assertEqual(chatbot4.temperature, 0)

# TODO: Retry_bot testing

0 comments on commit 988fd07

Please sign in to comment.