diff --git a/youte_talk/cli.py b/youte_talk/cli.py index 1f66b27..e6106c4 100644 --- a/youte_talk/cli.py +++ b/youte_talk/cli.py @@ -2,55 +2,108 @@ Copyright: Digital Observatory 2023 Author: Mat Bettinson """ -import click -from sys import exit -import subprocess +from __future__ import annotations + import os +import subprocess import warnings +from sys import exit + +import click + warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") warnings.filterwarnings("ignore", message=".*FP16 is not supported on CPU.*") -warnings.filterwarnings("ignore", message=".*Performing inference on CPU when CUDA is available.*") +warnings.filterwarnings( + "ignore", message=".*Performing inference on CPU when CUDA is available.*" +) + def transcribe(*args): from youte_talk.transcriber import Transcriber + trans = Transcriber(*args) trans.transcribe() - print("Done!") + click.echo("Done!") + def is_ffmpeg_available(): try: # The "which" command is used to find out if a system command exists. # It's used here to check if ffmpeg is available. # This works on Unix/Linux/Mac systems. On Windows, it's "where". - command = 'where' if os.name == 'nt' else 'which' - subprocess.check_output([command, 'ffmpeg']) + command = "where" if os.name == "nt" else "which" + subprocess.check_output([command, "ffmpeg"]) return True except (subprocess.CalledProcessError, FileNotFoundError): return False -whispermodels = ['tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large'] + +whispermodels = [ + "tiny", + "tiny.en", + "base", + "base.en", + "small", + "small.en", + "medium", + "medium.en", + "large", +] + @click.command() -@click.argument('url') -@click.option('--output', default=None, help="File name without an extension (extension determined by format), defaults to videoId") -@click.option("--model", type=click.Choice(whispermodels, case_sensitive=True), default="small", help="Whisper model to use (defaults to small)") -@click.option('--format', type=click.Choice(['text', 'srt', 'csv', 'json'], case_sensitive=False), default="text", help="Format of output file: text, srt or json") -@click.option("--saveaudio", is_flag=True, show_default=True, default=False, help="Save the audio stream to a webm file") -@click.option("--cpu", is_flag=True, show_default=True, default=False, help="Use the CPU instead of the GPU") +@click.argument("url") +@click.option( + "--output", + default=None, + help="File name without an extension (extension determined by format), defaults to videoId", +) +@click.option( + "--model", + type=click.Choice(whispermodels, case_sensitive=True), + default="small", + help="Whisper model to use", + show_default=True, +) +@click.option( + "--format", + type=click.Choice(["text", "srt", "csv", "json"], case_sensitive=False), + default="text", + help="Format of output file: text, srt or json", + show_default=True, +) +@click.option( + "--saveaudio", + is_flag=True, + default=False, + help="Save the audio stream to a webm file", +) +@click.option( + "--cpu", + is_flag=True, + default=False, + help="Use the CPU instead of the GPU", +) def main(url: str, output: str, model: str, format: str, saveaudio: bool, cpu: bool): """YouTalk: A command-line YouTube video transcriber using Whisper""" - if not url.startswith('https://youtube.com/watch?v=') and not url.startswith('https://www.youtube.com/watch?v='): - print("Youtube video URLs look like https://youtube.com/watch?v=...") - exit(1) + if not url.startswith("https://youtube.com/watch?v=") and not url.startswith( + "https://www.youtube.com/watch?v=" + ): + raise click.BadArgumentUsage( + "YouTube video URLs must look like https://youtube.com/watch?v=..." + ) if not is_ffmpeg_available(): - print("\nYoute_Talk requires ffmpeg but it is not available!\n") - print(" - See https://ffmpeg.org/download.html for installation instructions") - print(" - On Windows, you may need to add ffmpeg to your PATH") - print(" - On Linux, you may need to install ffmpeg with your package manager") - exit(1) + click.echo( + "\nYoute_Talk requires ffmpeg but it is not available!\n" + " - See https://ffmpeg.org/download.html for installation instructions\n" + " - On Windows, you may need to add ffmpeg to your PATH\n" + " - On Linux, you may need to install ffmpeg with your package manager\n" + ) + raise click.Abort transcribe(url, output, model, format, saveaudio, cpu) + if __name__ == "__main__": main() diff --git a/youte_talk/scraper.py b/youte_talk/scraper.py index a2b0b58..06dc977 100644 --- a/youte_talk/scraper.py +++ b/youte_talk/scraper.py @@ -2,12 +2,16 @@ Copyright: Digital Observatory 2023 Author: Mat Bettinson """ -import requests -from bs4 import BeautifulSoup +from __future__ import annotations + import json import re from dataclasses import dataclass +import requests +from bs4 import BeautifulSoup + + @dataclass class VideoDetails: videoId: str @@ -19,6 +23,7 @@ class VideoDetails: author: str viewcount: int + def get_VideoDetails(url: str) -> VideoDetails | None: # Send a HTTP request to the given URL res = requests.get(url) @@ -26,37 +31,36 @@ def get_VideoDetails(url: str) -> VideoDetails | None: res.raise_for_status() # Parse HTML response - soup = BeautifulSoup(res.text, 'html.parser') - + soup = BeautifulSoup(res.text, "html.parser") + # Find body tag and then find next sibling script tag body = soup.body assert body is not None - script = body.find('script') + script = body.find("script") # If script tag is not found, return None if script is None: - print('Script tag not found') + print("Script tag not found") return None # Extract JavaScript object using regex - match = re.search(r'\{.*\}', script.text) - + match = re.search(r"\{.*\}", script.text) + # If object is not found, return None if match is None: - print('JavaScript object not found') + print("JavaScript object not found") return None # Decode the JavaScript object into a Python dictionary obj = json.loads(match.group(0)) - vd = obj['videoDetails'] + vd = obj["videoDetails"] return VideoDetails( - videoId=vd['videoId'], - title=vd['title'], - shortDescription=vd['shortDescription'], - lengthSeconds=int(vd['lengthSeconds']), - keywords=vd['keywords'] if 'keywords' in vd else None, - channelId=vd['channelId'], - author=vd['author'], - viewcount=int(vd['viewCount']) + videoId=vd["videoId"], + title=vd["title"], + shortDescription=vd["shortDescription"], + lengthSeconds=int(vd["lengthSeconds"]), + keywords=vd["keywords"] if "keywords" in vd else None, + channelId=vd["channelId"], + author=vd["author"], + viewcount=int(vd["viewCount"]), ) - diff --git a/youte_talk/transcriber.py b/youte_talk/transcriber.py index d819aed..c1ec29c 100644 --- a/youte_talk/transcriber.py +++ b/youte_talk/transcriber.py @@ -2,27 +2,40 @@ Copyright: Digital Observatory 2023 Author: Mat Bettinson """ +from __future__ import annotations + +from datetime import timedelta +from os import path, remove, rename from time import time + +import youtube_dl + from youte_talk.scraper import get_VideoDetails -from os import remove, rename, path from youte_talk.whisper import WhisperTranscribe -import youtube_dl -from datetime import timedelta + class Transcriber: - def __init__(self, url: str, output: str | None = None, model: str = 'small', fformat: str = 'small', saveaudio: bool = False, cpu: bool = False): + def __init__( + self, + url: str, + output: str | None = None, + model: str = "small", + fformat: str = "small", + saveaudio: bool = False, + cpu: bool = False, + ): """Set up a Transcription job""" - print('url:', url) + print("url:", url) self.url = url - self.output = url.split('=')[-1] if output is None else output + self.output = url.split("=")[-1] if output is None else output self.model = model self.format = fformat self.saveaudio = saveaudio - self.fileext = 'txt' if fformat == 'text' else fformat + self.fileext = "txt" if fformat == "text" else fformat self.cpu = cpu def get_elapsed(self, start: float, end: float) -> str: - return str(timedelta(seconds=round(end-start))) + return str(timedelta(seconds=round(end - start))) def transcribe(self): """Execute the transcription job, saving output as appropriate""" @@ -32,12 +45,11 @@ def transcribe(self): print("Failed to get video details") return print("Got metadata in", self.get_elapsed(starttime, time())) - prompt = videodetails.title + ': ' + videodetails.shortDescription # Will feed to Whisper API - ydl_opts = { - 'format': 'worstaudio[ext=webm]', - 'outtmpl': '%(id)s.%(ext)s' - } - sourcefilename = videodetails.videoId + '.webm' + prompt = ( + videodetails.title + ": " + videodetails.shortDescription + ) # Will feed to Whisper API + ydl_opts = {"format": "worstaudio[ext=webm]", "outtmpl": "%(id)s.%(ext)s"} + sourcefilename = videodetails.videoId + ".webm" if not path.isfile(sourcefilename): starttime = time() with youtube_dl.YoutubeDL(ydl_opts) as ydl: @@ -48,25 +60,24 @@ def transcribe(self): print("Transcribing...") transcription = WhisperTranscribe(sourcefilename, prompt, self.model, self.cpu) print("Whisper transcribed in", self.get_elapsed(starttime, time())) - outputfilename = self.output + '.' + self.fileext + outputfilename = self.output + "." + self.fileext outputdata = None match self.format: - case 'text': + case "text": outputdata = transcription.text - case 'srt': + case "srt": outputdata = transcription.srt - case 'json': + case "json": outputdata = transcription.json - case 'csv': + case "csv": outputdata = transcription.csv assert outputdata is not None - with open(outputfilename, 'w') as f: + with open(outputfilename, "w") as f: f.write(outputdata) # clean up if self.saveaudio: - outfile = self.output + '.webm' + outfile = self.output + ".webm" if sourcefilename != outfile: rename(sourcefilename, outfile) else: remove(sourcefilename) - \ No newline at end of file diff --git a/youte_talk/whisper.py b/youte_talk/whisper.py index b0e5c86..b26aeb8 100644 --- a/youte_talk/whisper.py +++ b/youte_talk/whisper.py @@ -2,15 +2,19 @@ Copyright: Digital Observatory 2023 Author: Mat Bettinson """ -import whisper -from dataclasses import dataclass -from typing import TypedDict -import json +from __future__ import annotations + import csv import io +import json +from dataclasses import dataclass from datetime import timedelta +from typing import TypedDict + +import whisper from torch.cuda import is_available + # How hard was that OpenAI? Types. It's not hard. class WhisperSegment(TypedDict): id: int @@ -24,6 +28,7 @@ class WhisperSegment(TypedDict): compression_ratio: float no_speech_prob: float + @dataclass class Segment: id: int @@ -31,100 +36,112 @@ class Segment: end: int text: str + class WhisperTranscribe: - def __init__(self, sourcefile: str, prompt: str | None = None, modelname: str = "small", cpu: bool = False): + def __init__( + self, + sourcefile: str, + prompt: str | None = None, + modelname: str = "small", + cpu: bool = False, + ): self.segmentlist = [] if cpu == False and is_available() == False: print("CUDA is not available. Falling back to CPU.") cpu = True - model = whisper.load_model(modelname, device='cpu' if cpu else 'cuda') + model = whisper.load_model(modelname, device="cpu" if cpu else "cuda") # This doesn't return what OpenAI says it does result = model.transcribe(sourcefile, initial_prompt=prompt) # It returns a dict, and we will store the verbose segments structure - self.segmentlist: list[WhisperSegment] = result['segments'] # type: ignore + self.segmentlist: list[WhisperSegment] = result["segments"] # type: ignore @property def text(self) -> str: - return ''.join([segment['text'] for segment in self.segmentlist]) - + return "".join([segment["text"] for segment in self.segmentlist]) + @property def json(self) -> str: objlist = [] for segment in self.segmentlist: - objlist.append({ - 'id': segment['id'], - 'start': round(segment['start'], 2), - 'end': round(segment['end'], 2), - 'text': segment['text'].strip() - }) + objlist.append( + { + "id": segment["id"], + "start": round(segment["start"], 2), + "end": round(segment["end"], 2), + "text": segment["text"].strip(), + } + ) return json.dumps(objlist, indent=2) - + @property def segments(self) -> list[Segment]: short_segments = [] for segment in self.segmentlist: - short_segments.append(Segment( - id=segment['id'], - start=round(segment['start'], 2), - end=round(segment['end'], 2), - text=segment['text'].strip() - )) + short_segments.append( + Segment( + id=segment["id"], + start=round(segment["start"], 2), + end=round(segment["end"], 2), + text=segment["text"].strip(), + ) + ) return short_segments - + @property def csv(self) -> str | None: try: output = io.StringIO() - writer = csv.DictWriter(output, fieldnames=['id', 'start', 'end', 'text'], quoting=csv.QUOTE_ALL) + writer = csv.DictWriter( + output, fieldnames=["id", "start", "end", "text"], quoting=csv.QUOTE_ALL + ) writer.writeheader() for segment in self.segmentlist: seg = { - 'id': segment['id'], - 'start': round(segment['start'], 2), - 'end': round(segment['end'], 2), - 'text': segment['text'].strip() + "id": segment["id"], + "start": round(segment["start"], 2), + "end": round(segment["end"], 2), + "text": segment["text"].strip(), } # encode special characters in the 'text' field - seg['text'] = seg['text'].encode('unicode_escape').decode('utf-8') + seg["text"] = seg["text"].encode("unicode_escape").decode("utf-8") writer.writerow(seg) return output.getvalue() except Exception as e: print("An error occurred while creating the CSV string.") print(str(e)) return None - + @property def srt(self) -> str: srt_str = "" for i, segment in enumerate(self.segmentlist): # Convert start and end times to hh:mm:ss,sss format - start_time = timedelta(seconds=segment['start']) - end_time = timedelta(seconds=segment['end']) + start_time = timedelta(seconds=segment["start"]) + end_time = timedelta(seconds=segment["end"]) # Convert timedelta to appropriate string format str_start_time = str(start_time).zfill(8) - if '.' in str_start_time: - str_start_time = str_start_time.replace('.', ',') - if len(str_start_time.split(',')[1]) < 3: - str_start_time = str_start_time + '0'*(3-len(str_start_time.split(',')[1])) + if "." in str_start_time: + str_start_time = str_start_time.replace(".", ",") + if len(str_start_time.split(",")[1]) < 3: + str_start_time = str_start_time + "0" * ( + 3 - len(str_start_time.split(",")[1]) + ) else: - str_start_time = str_start_time + ',000' + str_start_time = str_start_time + ",000" str_end_time = str(end_time).zfill(8) - if '.' in str_end_time: - str_end_time = str_end_time.replace('.', ',') - if len(str_end_time.split(',')[1]) < 3: - str_end_time = str_end_time + '0'*(3-len(str_end_time.split(',')[1])) + if "." in str_end_time: + str_end_time = str_end_time.replace(".", ",") + if len(str_end_time.split(",")[1]) < 3: + str_end_time = str_end_time + "0" * ( + 3 - len(str_end_time.split(",")[1]) + ) else: - str_end_time = str_end_time + ',000' + str_end_time = str_end_time + ",000" # Add index, time range, and text to the string srt_str += f"{i+1}\n{str_start_time} --> {str_end_time}\n{segment['text'].strip()}\n\n" return srt_str - - - - - \ No newline at end of file