-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Rev AI support #1
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,9 @@ | |
from dotenv import load_dotenv | ||
from pathlib import Path | ||
import click | ||
from openai import AzureOpenAI | ||
from openai import OpenAI, AzureOpenAI | ||
from groq import Groq | ||
from rev_ai import apiclient | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
@@ -17,6 +19,21 @@ | |
AZURE_OPENAI_DEPLOYMENT_NAME_CHAT = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME_CHAT') | ||
AZURE_OPENAI_API_VERSION_CHAT = os.getenv("AZURE_OPENAI_API_VERSION_CHAT") | ||
|
||
# OpenAI configuration | ||
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') | ||
OPENAI_ENDPOINT = os.getenv('OPENAI_ENDPOINT') | ||
OPENAI_MODEL = os.getenv('OPENAI_MODEL') | ||
OPENAI_MODEL_CHAT = os.getenv('OPENAI_MODEL_CHAT') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. openai_model is whisper-1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are the values for endpoint, could you please enter examples in .env.example |
||
|
||
# Groq OpenAI configuration | ||
GROQ_OPENAI_API_KEY = os.getenv('GROQ_OPENAI_API_KEY') | ||
GROQ_OPENAI_ENDPOINT = os.getenv('GROQ_OPENAI_ENDPOINT') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are the values for endpoint, could you please enter examples in .env.example |
||
GROQ_MODEL = os.getenv('GROQ_MODEL') | ||
GROQ_MODEL_CHAT = os.getenv('GROQ_MODEL_CHAT') | ||
|
||
# Rev AI configuration | ||
REVAI_ACCESS_TOKEN = os.getenv('REVAI_ACCESS_TOKEN') | ||
|
||
def convert_to_mp3(input_file, output_file, quality): | ||
"""Convert video to MP3 using ffmpeg with specified quality""" | ||
if quality == 'L': | ||
|
@@ -31,28 +48,14 @@ def convert_to_mp3(input_file, output_file, quality): | |
command = ['ffmpeg', '-i', input_file, '-vn'] + ffmpeg_options + [output_file] | ||
subprocess.run(command, check=True) | ||
|
||
def transcribe_audio( | ||
def transcribe_audio_azure( | ||
audio_file, | ||
language="en", | ||
prompt=None, | ||
response_format="json", | ||
temperature=0, | ||
timestamp_granularities=None | ||
language, | ||
prompt, | ||
response_format, | ||
temperature, | ||
timestamp_granularities | ||
): | ||
""" | ||
Transcribe audio using Azure OpenAI Whisper API with expanded options. | ||
|
||
Parameters: | ||
- audio_file (str): Path to the audio file to transcribe. | ||
- language (str, optional): The language of the input audio. Defaults to "en". | ||
- prompt (str, optional): An optional text to guide the model's style or continue a previous audio segment. | ||
- response_format (str, optional): The format of the transcript output. Options are: "json", "text", "srt", "verbose_json", or "vtt". Defaults to "json". | ||
- temperature (float, optional): The sampling temperature, between 0 and 1. Defaults to 0. | ||
- timestamp_granularities (list, optional): List of timestamp granularities to include. Options are: "word", "segment". Defaults to None. | ||
|
||
Returns: | ||
- If response_format is "json" or "verbose_json", returns a dictionary. Otherwise, returns a string. | ||
""" | ||
url = f"{AZURE_OPENAI_ENDPOINT}/openai/deployments/{AZURE_OPENAI_DEPLOYMENT_NAME_WHISPER}/audio/translations?api-version={AZURE_OPENAI_API_VERSION_WHISPER}" | ||
|
||
headers = { | ||
|
@@ -83,7 +86,78 @@ def transcribe_audio( | |
else: | ||
raise Exception(f"Transcription failed: {response.text}") | ||
|
||
def process_file(input_file, language, prompt, temperature, quality, correct): | ||
def transcribe_audio_openai( | ||
client, | ||
model, | ||
audio_file, | ||
language, | ||
prompt, | ||
response_format, | ||
temperature | ||
): | ||
with open(audio_file, "rb") as file: | ||
transcription = client.audio.transcriptions.create( | ||
file=(audio_file, file.read()), | ||
model=model, | ||
prompt=prompt, | ||
response_format=response_format, | ||
language=language, | ||
temperature=temperature | ||
) | ||
return transcription.text | ||
|
||
def transcribe_audio_rev(audio_file, response_format): | ||
client = apiclient.RevAiAPIClient(REVAI_ACCESS_TOKEN) | ||
job = client.submit_job_local_file(audio_file) | ||
|
||
if response_format in ["json", "verbose_json"]: | ||
return client.get_transcript_json(job.id) | ||
else: | ||
return client.get_transcript_text(job.id) | ||
|
||
def transcribe_audio( | ||
audio_file, | ||
api="azure", | ||
language="en", | ||
prompt=None, | ||
response_format="json", | ||
temperature=0, | ||
timestamp_granularities=None | ||
): | ||
""" | ||
Transcribe audio using Azure OpenAI Whisper API with expanded options. | ||
|
||
Parameters: | ||
- audio_file (str): Path to the audio file to transcribe. | ||
- api (str, optional): The API to use. Defaults to "azure". | ||
- language (str, optional): The language of the input audio. Defaults to "en". | ||
- prompt (str, optional): An optional text to guide the model's style or continue a previous audio segment. | ||
- response_format (str, optional): The format of the transcript output. Options are: "json", "text", "srt", "verbose_json", or "vtt". Defaults to "json". | ||
- temperature (float, optional): The sampling temperature, between 0 and 1. Defaults to 0. | ||
- timestamp_granularities (list, optional): List of timestamp granularities to include. Options are: "word", "segment". Defaults to None. | ||
|
||
Returns: | ||
- If response_format is "json" or "verbose_json", returns a dictionary. Otherwise, returns a string. | ||
""" | ||
if api == "azure": | ||
transcribe_audio_azure(audio_file, language, prompt, response_format, temperature, timestamp_granularities) | ||
elif api == "openai": | ||
client = OpenAI( | ||
base_url=OPENAI_ENDPOINT, | ||
api_key=OPENAI_API_KEY | ||
) | ||
model = OPENAI_MODEL | ||
transcribe_audio_openai(client, model, audio_file, language, prompt, response_format, temperature) | ||
elif api == "groq": | ||
client = Groq() | ||
model = GROQ_MODEL | ||
transcribe_audio_openai(client, model, audio_file, language, prompt, response_format, temperature) | ||
elif api == "rev": | ||
transcribe_audio_rev(audio_file, response_format) | ||
else: | ||
raise ValueError("Invalid api. Choose from 'azure', 'openai', 'groq', 'rev'.") | ||
|
||
def process_file(input_file, api, language, prompt, temperature, quality, correct): | ||
input_path = Path(input_file) | ||
mp3_file = input_path.with_suffix('.mp3') | ||
txt_file = input_path.with_suffix('.txt') | ||
|
@@ -96,12 +170,12 @@ def process_file(input_file, language, prompt, temperature, quality, correct): | |
else: | ||
click.echo("MP3 file already exists, skipping conversion") | ||
|
||
transcription_result = transcribe_audio(str(mp3_file), language=language, prompt=prompt, temperature=temperature) | ||
transcription_result = transcribe_audio(str(mp3_file), api=api, language=language, prompt=prompt, temperature=temperature) | ||
click.echo("Transcription completed") | ||
|
||
if correct: | ||
system_prompt = "You are a helpful assistant. Your task is to correct any spelling discrepancies in the transcribed text. Make sure that the names of the following products are spelled correctly: {user provided prompt} Only add necessary punctuation such as periods, commas, and capitalization, and use only the context provided." | ||
corrected_text = generate_corrected_transcript(0.7, system_prompt, str(mp3_file)) | ||
corrected_text = generate_corrected_transcript(api, 0.7, system_prompt, str(mp3_file)) | ||
click.echo("Correction completed") | ||
else: | ||
corrected_text = transcription_result | ||
|
@@ -115,18 +189,9 @@ def process_file(input_file, language, prompt, temperature, quality, correct): | |
|
||
mp3_file.unlink() | ||
|
||
def generate_corrected_transcript(temperature, system_prompt, audio_file): | ||
client = AzureOpenAI( | ||
api_key=AZURE_OPENAI_API_KEY, | ||
azure_endpoint=AZURE_OPENAI_ENDPOINT, | ||
api_version=AZURE_OPENAI_API_VERSION_CHAT | ||
) | ||
|
||
transcription = transcribe_audio(audio_file, "") | ||
transcription_text = transcription.get('text', '') if isinstance(transcription, dict) else transcription | ||
|
||
def generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file): | ||
response = client.chat.completions.create( | ||
model=AZURE_OPENAI_DEPLOYMENT_NAME_CHAT, | ||
model=model, | ||
temperature=temperature, | ||
messages=[ | ||
{ | ||
|
@@ -142,26 +207,53 @@ def generate_corrected_transcript(temperature, system_prompt, audio_file): | |
return response.choices[0].message.content | ||
|
||
|
||
def generate_corrected_transcript(api, temperature, system_prompt, audio_file): | ||
transcription = transcribe_audio(audio_file, api, "") | ||
transcription_text = transcription.get('text', '') if isinstance(transcription, dict) else transcription | ||
|
||
if api == "azure" or api == "rev": | ||
client = AzureOpenAI( | ||
api_key=AZURE_OPENAI_API_KEY, | ||
azure_endpoint=AZURE_OPENAI_ENDPOINT, | ||
api_version=AZURE_OPENAI_API_VERSION_CHAT | ||
) | ||
model = AZURE_OPENAI_DEPLOYMENT_NAME_CHAT | ||
generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file) | ||
elif api == "openai": | ||
client = OpenAI( | ||
base_url=OPENAI_ENDPOINT, | ||
api_key=OPENAI_API_KEY | ||
) | ||
model = OPENAI_MODEL_CHAT | ||
generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file) | ||
elif api == "groq": | ||
client = Groq() | ||
model = GROQ_MODEL_CHAT | ||
generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file) | ||
else: | ||
raise ValueError("Invalid api. Choose from 'azure', 'openai', 'groq', 'rev'.") | ||
|
||
@click.command() | ||
@click.argument("input_path", type=click.Path(exists=True)) | ||
@click.option("--api", default="azure", help="The API (default: azure)") | ||
@click.option("--language", "-l", default="en", help="Language of the audio (default: en)") | ||
@click.option("--prompt", "-p", help="Optional prompt to guide the model") | ||
@click.option("--temperature", "-t", type=float, default=0, help="Sampling temperature (default: 0)") | ||
@click.option("--quality", "-q", type=click.Choice(['L', 'M', 'H'], case_sensitive=False), default='M', help="Quality of the MP3 audio: 'L' for low, 'M' for medium, and 'H' for high (default: 'M')") | ||
@click.option("--correct", is_flag=True, help="Use LLM to correct the transcript") | ||
def main(input_path, language, prompt, temperature, quality, correct): | ||
def main(input_path, api, language, prompt, temperature, quality, correct): | ||
""" | ||
Transcribe video files using Azure OpenAI Whisper API. | ||
Transcribe video files using the OpenAI Whisper API. | ||
|
||
INPUT_PATH is the path to the video file or directory containing video files. | ||
""" | ||
input_path = Path(input_path) | ||
|
||
if input_path.is_file(): | ||
process_file(input_path, language, prompt, temperature, quality, correct) | ||
process_file(input_path, api, language, prompt, temperature, quality, correct) | ||
elif input_path.is_dir(): | ||
for file in input_path.glob('*.mp4'): | ||
process_file(file, language, prompt, temperature, quality, correct) | ||
process_file(file, api, language, prompt, temperature, quality, correct) | ||
else: | ||
click.echo(f"{input_path} is not a valid file or directory.") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need this? Have you tested it?