Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rev AI support #1

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ dependencies = [
"click",
"requests",
"python-dotenv",
"openai"
"openai",
"groq",
"rev_ai"
]

[project.scripts]
Expand Down
170 changes: 131 additions & 39 deletions script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from dotenv import load_dotenv
from pathlib import Path
import click
from openai import AzureOpenAI
from openai import OpenAI, AzureOpenAI
from groq import Groq
from rev_ai import apiclient

# Load environment variables
load_dotenv()
Expand All @@ -17,6 +19,21 @@
AZURE_OPENAI_DEPLOYMENT_NAME_CHAT = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME_CHAT')
AZURE_OPENAI_API_VERSION_CHAT = os.getenv("AZURE_OPENAI_API_VERSION_CHAT")

# OpenAI configuration
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_ENDPOINT = os.getenv('OPENAI_ENDPOINT')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this? Have you tested it?

OPENAI_MODEL = os.getenv('OPENAI_MODEL')
OPENAI_MODEL_CHAT = os.getenv('OPENAI_MODEL_CHAT')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openai_model is whisper-1?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the values for endpoint, could you please enter examples in .env.example


# Groq OpenAI configuration
GROQ_OPENAI_API_KEY = os.getenv('GROQ_OPENAI_API_KEY')
GROQ_OPENAI_ENDPOINT = os.getenv('GROQ_OPENAI_ENDPOINT')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the values for endpoint, could you please enter examples in .env.example

GROQ_MODEL = os.getenv('GROQ_MODEL')
GROQ_MODEL_CHAT = os.getenv('GROQ_MODEL_CHAT')

# Rev AI configuration
REVAI_ACCESS_TOKEN = os.getenv('REVAI_ACCESS_TOKEN')

def convert_to_mp3(input_file, output_file, quality):
"""Convert video to MP3 using ffmpeg with specified quality"""
if quality == 'L':
Expand All @@ -31,28 +48,14 @@ def convert_to_mp3(input_file, output_file, quality):
command = ['ffmpeg', '-i', input_file, '-vn'] + ffmpeg_options + [output_file]
subprocess.run(command, check=True)

def transcribe_audio(
def transcribe_audio_azure(
audio_file,
language="en",
prompt=None,
response_format="json",
temperature=0,
timestamp_granularities=None
language,
prompt,
response_format,
temperature,
timestamp_granularities
):
"""
Transcribe audio using Azure OpenAI Whisper API with expanded options.

Parameters:
- audio_file (str): Path to the audio file to transcribe.
- language (str, optional): The language of the input audio. Defaults to "en".
- prompt (str, optional): An optional text to guide the model's style or continue a previous audio segment.
- response_format (str, optional): The format of the transcript output. Options are: "json", "text", "srt", "verbose_json", or "vtt". Defaults to "json".
- temperature (float, optional): The sampling temperature, between 0 and 1. Defaults to 0.
- timestamp_granularities (list, optional): List of timestamp granularities to include. Options are: "word", "segment". Defaults to None.

Returns:
- If response_format is "json" or "verbose_json", returns a dictionary. Otherwise, returns a string.
"""
url = f"{AZURE_OPENAI_ENDPOINT}/openai/deployments/{AZURE_OPENAI_DEPLOYMENT_NAME_WHISPER}/audio/translations?api-version={AZURE_OPENAI_API_VERSION_WHISPER}"

headers = {
Expand Down Expand Up @@ -83,7 +86,78 @@ def transcribe_audio(
else:
raise Exception(f"Transcription failed: {response.text}")

def process_file(input_file, language, prompt, temperature, quality, correct):
def transcribe_audio_openai(
client,
model,
audio_file,
language,
prompt,
response_format,
temperature
):
with open(audio_file, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(audio_file, file.read()),
model=model,
prompt=prompt,
response_format=response_format,
language=language,
temperature=temperature
)
return transcription.text

def transcribe_audio_rev(audio_file, response_format):
client = apiclient.RevAiAPIClient(REVAI_ACCESS_TOKEN)
job = client.submit_job_local_file(audio_file)

if response_format in ["json", "verbose_json"]:
return client.get_transcript_json(job.id)
else:
return client.get_transcript_text(job.id)

def transcribe_audio(
audio_file,
api="azure",
language="en",
prompt=None,
response_format="json",
temperature=0,
timestamp_granularities=None
):
"""
Transcribe audio using Azure OpenAI Whisper API with expanded options.

Parameters:
- audio_file (str): Path to the audio file to transcribe.
- api (str, optional): The API to use. Defaults to "azure".
- language (str, optional): The language of the input audio. Defaults to "en".
- prompt (str, optional): An optional text to guide the model's style or continue a previous audio segment.
- response_format (str, optional): The format of the transcript output. Options are: "json", "text", "srt", "verbose_json", or "vtt". Defaults to "json".
- temperature (float, optional): The sampling temperature, between 0 and 1. Defaults to 0.
- timestamp_granularities (list, optional): List of timestamp granularities to include. Options are: "word", "segment". Defaults to None.

Returns:
- If response_format is "json" or "verbose_json", returns a dictionary. Otherwise, returns a string.
"""
if api == "azure":
transcribe_audio_azure(audio_file, language, prompt, response_format, temperature, timestamp_granularities)
elif api == "openai":
client = OpenAI(
base_url=OPENAI_ENDPOINT,
api_key=OPENAI_API_KEY
)
model = OPENAI_MODEL
transcribe_audio_openai(client, model, audio_file, language, prompt, response_format, temperature)
elif api == "groq":
client = Groq()
model = GROQ_MODEL
transcribe_audio_openai(client, model, audio_file, language, prompt, response_format, temperature)
elif api == "rev":
transcribe_audio_rev(audio_file, response_format)
else:
raise ValueError("Invalid api. Choose from 'azure', 'openai', 'groq', 'rev'.")

def process_file(input_file, api, language, prompt, temperature, quality, correct):
input_path = Path(input_file)
mp3_file = input_path.with_suffix('.mp3')
txt_file = input_path.with_suffix('.txt')
Expand All @@ -96,12 +170,12 @@ def process_file(input_file, language, prompt, temperature, quality, correct):
else:
click.echo("MP3 file already exists, skipping conversion")

transcription_result = transcribe_audio(str(mp3_file), language=language, prompt=prompt, temperature=temperature)
transcription_result = transcribe_audio(str(mp3_file), api=api, language=language, prompt=prompt, temperature=temperature)
click.echo("Transcription completed")

if correct:
system_prompt = "You are a helpful assistant. Your task is to correct any spelling discrepancies in the transcribed text. Make sure that the names of the following products are spelled correctly: {user provided prompt} Only add necessary punctuation such as periods, commas, and capitalization, and use only the context provided."
corrected_text = generate_corrected_transcript(0.7, system_prompt, str(mp3_file))
corrected_text = generate_corrected_transcript(api, 0.7, system_prompt, str(mp3_file))
click.echo("Correction completed")
else:
corrected_text = transcription_result
Expand All @@ -115,18 +189,9 @@ def process_file(input_file, language, prompt, temperature, quality, correct):

mp3_file.unlink()

def generate_corrected_transcript(temperature, system_prompt, audio_file):
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
api_version=AZURE_OPENAI_API_VERSION_CHAT
)

transcription = transcribe_audio(audio_file, "")
transcription_text = transcription.get('text', '') if isinstance(transcription, dict) else transcription

def generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file):
response = client.chat.completions.create(
model=AZURE_OPENAI_DEPLOYMENT_NAME_CHAT,
model=model,
temperature=temperature,
messages=[
{
Expand All @@ -142,26 +207,53 @@ def generate_corrected_transcript(temperature, system_prompt, audio_file):
return response.choices[0].message.content


def generate_corrected_transcript(api, temperature, system_prompt, audio_file):
transcription = transcribe_audio(audio_file, api, "")
transcription_text = transcription.get('text', '') if isinstance(transcription, dict) else transcription

if api == "azure" or api == "rev":
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
api_version=AZURE_OPENAI_API_VERSION_CHAT
)
model = AZURE_OPENAI_DEPLOYMENT_NAME_CHAT
generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file)
elif api == "openai":
client = OpenAI(
base_url=OPENAI_ENDPOINT,
api_key=OPENAI_API_KEY
)
model = OPENAI_MODEL_CHAT
generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file)
elif api == "groq":
client = Groq()
model = GROQ_MODEL_CHAT
generate_corrected_transcript_openai(client, model, temperature, system_prompt, audio_file)
else:
raise ValueError("Invalid api. Choose from 'azure', 'openai', 'groq', 'rev'.")

@click.command()
@click.argument("input_path", type=click.Path(exists=True))
@click.option("--api", default="azure", help="The API (default: azure)")
@click.option("--language", "-l", default="en", help="Language of the audio (default: en)")
@click.option("--prompt", "-p", help="Optional prompt to guide the model")
@click.option("--temperature", "-t", type=float, default=0, help="Sampling temperature (default: 0)")
@click.option("--quality", "-q", type=click.Choice(['L', 'M', 'H'], case_sensitive=False), default='M', help="Quality of the MP3 audio: 'L' for low, 'M' for medium, and 'H' for high (default: 'M')")
@click.option("--correct", is_flag=True, help="Use LLM to correct the transcript")
def main(input_path, language, prompt, temperature, quality, correct):
def main(input_path, api, language, prompt, temperature, quality, correct):
"""
Transcribe video files using Azure OpenAI Whisper API.
Transcribe video files using the OpenAI Whisper API.

INPUT_PATH is the path to the video file or directory containing video files.
"""
input_path = Path(input_path)

if input_path.is_file():
process_file(input_path, language, prompt, temperature, quality, correct)
process_file(input_path, api, language, prompt, temperature, quality, correct)
elif input_path.is_dir():
for file in input_path.glob('*.mp4'):
process_file(file, language, prompt, temperature, quality, correct)
process_file(file, api, language, prompt, temperature, quality, correct)
else:
click.echo(f"{input_path} is not a valid file or directory.")

Expand Down