diff --git a/README.md b/README.md index 81bec5b..22f5171 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ This collection of scripts is the culmination of my efforts to contributes the A ## Tools -### Illustrator +### Illustrator Creates custom mnemonic images for your cards using AI image generation. It: - Analyzes card content to identify key concepts - Generates creative visual memory hooks @@ -387,14 +387,21 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a Click to read more </summary> +First, ensure that you API keys are set in you env variables. + +Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it. + + #### Reformulator + +The reformulator expects the notes you modify to have a specific field present so that it can save the old versions and add logging. Modify the note type you want to reformulate by adding a `AnkiReformulator` field to it. The Reformulator can be run from the command line: ```bash python reformulator.py \ - --query "(rated:2:1 OR rated:2:2) -is:suspended" \ - --dataset_path "data/reformulator_dataset.txt" \ - --string_formatting "data/string_formatting.py" \ + --query "note:Cloze (rated:2:1 OR rated:2:2) -is:suspended" \ + --dataset_path "examples/reformulator_dataset.txt" \ + --string_formatting "examples/string_formatting.py" \ --ntfy_url "ntfy.sh/YOUR_TOPIC" \ --main_field_index 0 \ --llm "openai/gpt-4" \ diff --git a/reformulator.py b/reformulator.py index 6bd554c..25ae566 100644 --- a/reformulator.py +++ b/reformulator.py @@ -31,7 +31,7 @@ import litellm from utils.misc import load_formatting_funcs, replace_media -from utils.llm import load_api_keys, llm_price, tkn_len, chat, model_name_matcher +from utils.llm import llm_price, tkn_len, chat, model_name_matcher from utils.anki import anki, sync_anki, addtags, removetags, updatenote from utils.logger import create_loggers from utils.datasets import load_dataset, semantic_prompt_filtering @@ -51,8 +51,6 @@ d = datetime.datetime.today() today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}" -load_api_keys() - # status string STAT_CHANGED_CONT = "Content has been changed" @@ -184,7 +182,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): [print(line) for line in traceback.format_tb(exc_traceback)] print(str(exc_value)) print(str(exc_type)) - print("\n--verbose was used so opening debug console at the " + print("\n--debug was used so opening debug console at the " "appropriate frame. Press 'c' to continue to the frame " "of this print.") pdb.post_mortem(exc_traceback) @@ -209,7 +207,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): litellm.set_verbose = verbose # arg sanity check and storing - assert "note:" in query, "You have to specify a notetype in the query" + assert "note:" in query, f"You have to specify a notetype in the query ({query})" assert mode in ["reformulate", "reset"], "Invalid value for 'mode'" assert isinstance(exclude_done, bool), "exclude_done must be a boolean" assert isinstance(exclude_version, bool), "exclude_version must be a boolean" @@ -224,8 +222,13 @@ def handle_exception(exc_type, exc_value, exc_traceback): parallel = int(parallel) main_field_index = int(main_field_index) assert main_field_index >= 0, "invalid field_index" + self.base_query = query + self.dataset_path = dataset_path self.mode = mode - if string_formatting is not None: + self.exclude_done = exclude_done + self.exclude_version = exclude_version + + if string_formatting: red(f"Loading specific string formatting from {string_formatting}") cloze_input_parser, cloze_output_parser = load_formatting_funcs( path=string_formatting, @@ -256,29 +259,36 @@ def handle_exception(exc_type, exc_value, exc_traceback): else: raise Exception(f"{llm} not found in llm_price") self.verbose = verbose - if mode == "reformulate": - if exclude_done: + + def reformulate(self): + query = self.base_query + if self.mode == "reformulate": + if self.exclude_done: query += " -AnkiReformulator::Done::*" - if exclude_version: + if self.exclude_version: query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\"" - # load db just in case + # load db just in case, and create one if it doesn't already exist self.db_content = self.load_db() if not self.db_content: - red( - "Empty database. If you have already ran anki_reformulator " - "before then something went wrong!" - ) - else: - self.compute_cost(self.db_content) + red("Empty database. If you have already ran anki_reformulator " + "before then something went wrong!") + whi("Creating a empty database") + self.save_to_db({}) + self.db_content = self.load_db() + assert self.db_content, "Could not create database" + + whi("Computing estimated costs") + self.compute_cost(self.db_content) # load dataset - dataset = load_dataset(dataset_path) - # check that each note is valid but exclude the system prompt - for id, d in enumerate(dataset): - if id != 0: - dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"] + whi("Loading dataset") + dataset = load_dataset(self.dataset_path) + # check that each note is valid but exclude the system prompt, which is + # the first entry + for id, d in enumerate(dataset[1:]): + dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"] assert len(dataset) % 2 == 1, "Even number of examples in dataset" self.dataset = dataset @@ -286,26 +296,25 @@ def handle_exception(exc_type, exc_value, exc_traceback): nids = anki(action="findNotes", query="tag:AnkiReformulator::RESETTING") if nids: - red( - f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}" - ) + red(f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}") nids = anki(action="findNotes", query="tag:AnkiReformulator::DOING") if nids: red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}") - # find notes ids for the first time + # find notes ids for the specific note type nids = anki(action="findNotes", query=query) assert nids, f"No notes found for the query '{query}'" - # find the model field names - fields = anki( - action="notesInfo", - notes=[int(nids[0])] - )[0]["fields"] - assert ( - "AnkiReformulator" in fields.keys() - ), "The notetype to edit must have a field called 'AnkiReformulator'" - self.field_name = list(fields.keys())[0] + # find the field names for this note type + fields = anki(action="notesInfo", + notes=[int(nids[0])])[0]["fields"] + assert "AnkiReformulator" in fields.keys(), \ + "The notetype to edit must have a field called 'AnkiReformulator'" + try: + self.field_name = list(fields.keys())[self.field_index] + except IndexError: + raise AssertionError(f"main_field_index {self.field_index} is invalid. " + f"Note only has {len(fields.keys())} fields!") if self.exclude_media: # now find notes ids after excluding the img in the important field @@ -316,7 +325,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): query += f' -{self.field_name}:"*http://*"' query += f' -{self.field_name}:"*https://*"' - whi(f"Query to find note: {query}") + whi(f"Query to find note: '{query}'") nids = anki(action="findNotes", query=query) assert nids, f"No notes found for the query '{query}'" whi(f"Found {len(nids)} notes") @@ -326,14 +335,13 @@ def handle_exception(exc_type, exc_value, exc_traceback): anki(action="notesInfo", notes=nids) ).set_index("noteId") self.notes = self.notes.loc[nids] - assert not self.notes.empty, "Empty notes df" + assert not self.notes.empty, "Empty notes" - assert ( - len(set(self.notes["modelName"].tolist())) == 1 - ), "Contains more than 1 note type" + assert len(set(self.notes["modelName"].tolist())) == 1, \ + "Contains more than 1 note type" # check absence of image and sounds in the main field - # as well incorrect tags + # as well as incorrect tags for nid, note in self.notes.iterrows(): if self.exclude_media: _, media = replace_media( @@ -358,38 +366,24 @@ def handle_exception(exc_type, exc_value, exc_traceback): else: assert not tag.lower().startswith("ankireformulator") + # check if required tokens are higher than our limits + tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset) + tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"], + media=None, + mode="remove_media")[0]) + for _, note in self.notes.iterrows()) + assert tkn_sum <= tkn_warn_limit, (f"Found {tkn_sum} tokens to process, which is " + f"higher than the limit of {tkn_warn_limit}") - # check if too many tokens - tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset]) - tkn_sum += sum( - [ - tkn_len( - replace_media( - content=note["fields"][self.field_name]["value"], - media=None, - mode="remove_media", - )[0] - ) - for _, note in self.notes.iterrows() - ]) - if tkn_sum > tkn_warn_limit: - raise Exception( - f"Found {tkn_sum} tokens to process, which is " - f"higher than the limit of {tkn_warn_limit}" - ) - - if len(self.notes) > n_note_limit: - raise Exception( - f"Found {len(self.notes)} notes to process " - f"which is higher than the limit of {n_note_limit}" - ) + assert len(self.notes) <= n_note_limit, (f"Found {len(self.notes)} notes to process " + f"which is higher than the limit of {n_note_limit}") if self.mode == "reformulate": - func = self.reformulate + func = self.reformulate_note elif self.mode == "reset": - func = self.reset + func = self.reset_note else: - raise ValueError(self.mode) + raise ValueError(f"Unknown mode {self.mode}") def error_wrapped_func(*args, **kwargs): """Wrapper that catches exceptions and marks failed notes with appropriate tags.""" @@ -397,7 +391,7 @@ def error_wrapped_func(*args, **kwargs): return func(*args, **kwargs) except Exception as err: addtags(nid=note.name, tags="AnkiReformulator::FAILED") - red(f"Error when running self.{self.mode}: '{err}'") + red(f"Error when running self.{func.__name__}: '{err}'") return str(err) # getting all the new values in parallel and using caching @@ -413,11 +407,9 @@ def error_wrapped_func(*args, **kwargs): ) ) - failed_runs = [ - self.notes.iloc[i_nv] - for i_nv in range(len(new_values)) - if isinstance(new_values[i_nv], str) - ] + failed_runs = [self.notes.iloc[i_nv] + for i_nv in range(len(new_values)) + if isinstance(new_values[i_nv], str)] if failed_runs: red(f"Found {len(failed_runs)} failed notes") failed_run_index = pd.DataFrame(failed_runs).index @@ -427,6 +419,7 @@ def error_wrapped_func(*args, **kwargs): assert len(new_values) == len(self.notes) # applying the changes + whi("Applying changes") for values in tqdm(new_values, desc="Applying changes to anki"): if self.mode == "reformulate": self.apply_reformulate(values) @@ -435,8 +428,10 @@ def error_wrapped_func(*args, **kwargs): else: raise ValueError(self.mode) + whi("Clearing unused tags") anki(action="clearUnusedTags") + # TODO: Why add and them remove them? # add and remove the tag TODO to make it easier to re add by the user # as it was cleared by calling 'clearUnusedTags' nid, note = next(self.notes.iterrows()) @@ -445,7 +440,7 @@ def error_wrapped_func(*args, **kwargs): sync_anki() - # display again the total cost at the end + # display the total cost again at the end db = self.load_db() assert db, "Empty database at the end of the run. Something went wrong?" self.compute_cost(db) @@ -456,11 +451,11 @@ def compute_cost(self, db_content: List[Dict]) -> None: This is used to know if something went wrong. """ n_db = len(db_content) - red(f"Number of entries in databases/reformulator.db: {n_db}") + red(f"Number of entries in databases/reformulator/reformulator.db: {n_db}") dol_costs = [] dol_missing = 0 for dic in db_content: - if dic["mode"] != "reformulate": + if self.mode != "reformulate": continue try: dol = float(dic["dollar_price"]) @@ -482,16 +477,16 @@ def compute_cost(self, db_content: List[Dict]) -> None: elif dol_costs: self._cost_so_far = dol_total - def reformulate(self, nid: int, note: pd.Series) -> Dict: + def reformulate_note(self, nid: int, note: pd.Series) -> Dict: """Generate a reformulated version of a note's content using an LLM. - + Parameters ---------- nid : int Note ID from Anki note : pd.Series Row from the notes DataFrame containing the note data - + Returns ------- Dict @@ -512,7 +507,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: # reformulate the content content = note["fields"][self.field_name]["value"] log["note_field_content"] = content - formattedcontent = self.cloze_input_parser(content) if iscloze(content) else content + formattedcontent = self.cloze_input_parser(content) log["note_field_formattedcontent"] = formattedcontent # if the card is in the dataset, just take the dataset value directly @@ -535,19 +530,20 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: elif d["role"] == "user": newcontent = self.dataset[i + 1]["content"] else: - raise ValueError( - f"Unexpected role of message in dataset: {d}") + raise ValueError(f"Unexpected role of message in dataset: {d}") skip_llm = True break fc, media = replace_media( content=formattedcontent, media=None, - mode="remove_media", - ) + mode="remove_media") log["media"] = media - if not skip_llm: + if skip_llm: + log["llm_answer"] = {"Skipped": True} + log["dollar_price"] = 0 + else: dataset = copy.deepcopy(self.dataset) curr_mess = [{"role": "user", "content": fc}] dataset = semantic_prompt_filtering( @@ -559,8 +555,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: embedding_model=self.embedding_model, whi=whi, yel=yel, - red=red, - ) + red=red) dataset += curr_mess assert dataset[0]["role"] == "system", "First message is not from system!" @@ -603,16 +598,13 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: ) else: log["dollar_price"] = "?" - else: - log["llm_answer"] = {"Skipped": True} - log["dollar_price"] = 0 log["note_field_newcontent"] = newcontent - formattednewcontent = self.cloze_output_parser(newcontent) if iscloze(newcontent) else newcontent + formattednewcontent = self.cloze_output_parser(newcontent) log["note_field_formattednewcontent"] = formattednewcontent log["status"] = STAT_OK_REFORM - if iscloze(content + newcontent + formattednewcontent): + if iscloze(content) and iscloze( newcontent + formattednewcontent): # check that no cloze were lost for cl in getclozes(content): cl = cl.split("::")[0] + "::" @@ -628,7 +620,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: def apply_reformulate(self, log: Dict) -> None: """Apply reformulation changes to an Anki note and update its metadata. - + Parameters ---------- log : Dict @@ -651,7 +643,7 @@ def apply_reformulate(self, log: Dict) -> None: new_minilog = rtoml.dumps(minilog, pretty=True) new_minilog = new_minilog.strip().replace("\n", "<br>") - previous_minilog = note["fields"]["AnkiReformulator"]["value"].strip() + previous_minilog = note["fields"].get("AnkiReformulator", {}).get("value", "").strip() if previous_minilog: new_minilog += "<!--SEPARATOR-->" new_minilog += "<br><br><details><summary>Older minilog</summary>" @@ -681,6 +673,7 @@ def apply_reformulate(self, log: Dict) -> None: nid, fields={ self.field_name: log["note_field_formattednewcontent"], + # TODO: Might be nice to not require this "AnkiReformulator": new_minilog, }, ) @@ -696,16 +689,16 @@ def apply_reformulate(self, log: Dict) -> None: # remove DOING tag removetags(nid, "AnkiReformulator::DOING") - def reset(self, nid: int, note: pd.Series) -> Dict: + def reset_note(self, nid: int, note: pd.Series) -> Dict: """Reset a note back to its state before reformulation. - + Parameters ---------- nid : int Note ID from Anki note : pd.Series Row from the notes DataFrame containing the note data - + Returns ------- Dict @@ -736,18 +729,14 @@ def reset(self, nid: int, note: pd.Series) -> Dict: ] if not entries: - red( - f"Entry not found for note {nid}. Looking for the content of " - "the field AnkiReformulator" - ) + red(f"Entry not found for note {nid}. Looking for the content of " + "the field AnkiReformulator") logfield = note["fields"]["AnkiReformulator"]["value"] logfield = logfield.split( "<!--SEPARATOR-->")[0] # keep most recent if not logfield.strip(): - raise Exception( - f"Note {nid} was not found in the db and its " - "AnkiReformulator field was empty." - ) + raise Exception(f"Note {nid} was not found in the db and its " + "AnkiReformulator field was empty.") # replace the [[c1::cloze]] by {{c1::cloze}} logfield = logfield.replace("]]", "}}") @@ -757,7 +746,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict: # parse old content buffer = [] - for i, line in enumerate(logfield.split("<br>")): + for line in logfield.split("<br>"): if buffer: try: _ = rtoml.loads("".join(buffer + [line])) @@ -776,10 +765,12 @@ def reset(self, nid: int, note: pd.Series) -> Dict: # parse new content at the time buffer = [] - for i, line in enumerate(logfield.split("<br>")): + for line in logfield.split("<br>"): if buffer: try: - _ = rtoml.loads("".join(buffer + [line])) + # TODO: What are you trying to do here? Just check that adding the line keeps valid toml? + # If so, you should catch the specific exception that the load function raises on error + rtoml.loads("".join(buffer + [line])) buffer.append(line) continue except Exception: @@ -879,7 +870,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict: def apply_reset(self, log: Dict) -> None: """Apply reset changes to an Anki note and update its metadata. - + Parameters ---------- log : Dict @@ -933,10 +924,8 @@ def apply_reset(self, log: Dict) -> None: # remove TO_RESET tag if present removetags(nid, "AnkiReformulator::TO_RESET") - # remove Done tag removetags(nid, "AnkiReformulator::Done") - # remove DOING tag removetags(nid, "AnkiReformulator::RESETTING") @@ -946,12 +935,12 @@ def apply_reset(self, log: Dict) -> None: def save_to_db(self, dictionnary: Dict) -> bool: """Save a log dictionary to the SQLite database. - + Parameters ---------- dictionnary : Dict Log dictionary to save - + Returns ------- bool @@ -976,34 +965,35 @@ def save_to_db(self, dictionnary: Dict) -> bool: def load_db(self) -> Dict: """Load all log dictionaries from the SQLite database. - + Returns ------- Dict All log dictionaries from the database, or False if database not found """ if not (REFORMULATOR_DIR / "reformulator.db").exists(): - red("db not found: '$REFORMULATOR_DIR/reformulator.db'") + red(f"db not found: '{REFORMULATOR_DIR}/reformulator.db'") return False conn = sqlite3.connect(str((REFORMULATOR_DIR / "reformulator.db").absolute())) cursor = conn.cursor() cursor.execute("SELECT data FROM dictionaries") rows = cursor.fetchall() - dictionaries = [] - for row in rows: - dictionary = json.loads(zlib.decompress(row[0])) - dictionaries.append(dictionary) - return dictionaries + # TODO: Why do you compress? This just makes it more difficult to debug + return [json.loads(zlib.decompress(row[0])) for row in rows] if __name__ == "__main__": try: args, kwargs = fire.Fire(lambda *args, **kwargs: [args, kwargs]) if "help" in kwargs: - print(help(AnkiReformulator)) + print(help(AnkiReformulator), file=sys.stderr) else: whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'") - AnkiReformulator(*args, **kwargs) - except Exception: - sync_anki() + r = AnkiReformulator(*args, **kwargs) + r.reformulate() + sync_anki() + except AssertionError as e: + red(e) + except Exception as e: + red(e) raise diff --git a/utils/cloze_utils.py b/utils/cloze_utils.py index e9a1a6c..52782c3 100644 --- a/utils/cloze_utils.py +++ b/utils/cloze_utils.py @@ -18,14 +18,16 @@ def iscloze(text: str) -> bool: def getclozes(text: str) -> List[str]: "return the cloze found in the text. Should only be called on cloze notes" - assert iscloze(text) + assert iscloze(text), f"Text '{text}' does not contain a cloze" return re.findall(CLOZE_REGEX, text) def cloze_input_parser(cloze: str) -> str: - """edits the cloze from anki before sending it to the LLM. This is useful - if you use weird formatting that mess with LLMs""" - assert iscloze(cloze), f"Invalid cloze: {cloze}" + """edit the cloze from anki before sending it to the LLM. This is useful + if you use weird formatting that mess with LLMs. + If the note content is not a cloze, then return it unmodified.""" + if not iscloze(cloze): + return cloze cloze = cloze.replace("\xa0", " ") @@ -37,7 +39,6 @@ def cloze_input_parser(cloze: str) -> str: # make spaces consitent cloze = cloze.replace(" ", " ") - # misc cloze = cloze.replace(">", ">") cloze = cloze.replace("≥", ">=") @@ -57,9 +58,12 @@ def cloze_output_parser(cloze: str) -> str: cloze = cloze.strip() # make sure all newlines are consistent for now + # TODO: You mean <br/>? cloze = cloze.replace("</br>", "<br>") + cloze = cloze.replace("<br/>", "<br>") cloze = cloze.replace("\r", "<br>") - cloze = cloze.replace("<br>", "\n") + # TODO: Not needed + # cloze = cloze.replace("<br>", "\n") # make sure all spaces are consistent cloze = cloze.replace(" ", " ") @@ -68,4 +72,3 @@ def cloze_output_parser(cloze: str) -> str: cloze = cloze.replace("\n", "<br>") return cloze - diff --git a/utils/datasets.py b/utils/datasets.py index df46e84..6ba4278 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -1,3 +1,4 @@ +import collections import json import pandas as pd import litellm @@ -51,9 +52,10 @@ def load_dataset( Returns ------- - Dict + List List of message dictionaries with 'role' and 'content' keys, - validated according to check_dataset() rules + validated according to check_dataset() rules. + First message is the system message. Raises ------ @@ -366,25 +368,21 @@ def semantic_prompt_filtering( if len(output_pr) != len(prompt_messages): red(f"Tokens of the kept prompts after {cnt} iterations: {tkns} (of all prompts: {all_tkns} tokens) Number of prompts: {len(output_pr)}/{len(prompt_messages)}") - # check no duplicate prompts + # TODO: This complains about duplicates. It looks like for some reason the + # last one is added as assistant AND user, but we only compare the content. contents = [pm["content"] for pm in output_pr] - dupli = [dp for dp in contents if contents.count(dp) > 1] + dupli = [k for k,v in collections.Counter(contents).items() if v>1] if dupli: raise Exception(f"{len(dupli)} duplicate prompts found in memory.py: {dupli}") - # remove unwanted keys - for i, d in enumerate(output_pr): - keys = [k for k in d.keys()] - for k in keys: - if k not in ["content", "role"]: - del d[k] - output_pr[i] = d + # Keep only the content and the role keys for each prompt + new_output = [{k: v for k, v in pk.items() if k in {"content", "role"}} for pk in output_pr] - assert curr_mess not in output_pr - assert output_pr, "No prompt were selected!" - check_dataset(output_pr, **check_args) + assert curr_mess not in new_output + assert new_output, "No prompt were selected!" + check_dataset(new_output, **check_args) - return output_pr + return new_output def format_anchor_key(key: str) -> str: diff --git a/utils/llm.py b/utils/llm.py index 5188fe3..b399a4c 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -13,48 +13,25 @@ litellm.drop_params = True -def load_api_keys() -> Dict: - """Load API keys from files in the API_KEYS directory. - - Creates API_KEYS directory if it doesn't exist. - Each file in API_KEYS/ should contain a single API key. - The filename (without extension) becomes part of the environment variable name. - - Returns - ------- - Dict - Dictionary mapping environment variable names to API key values - """ - Path("API_KEYS").mkdir(exist_ok=True) - if not list(Path("API_KEYS").iterdir()): - shared.red("## No API_KEYS found in API_KEYS") - api_keys = {} - for apifile in Path("API_KEYS").iterdir(): - keyname = f"{apifile.stem.upper()}_API_KEY" - key = apifile.read_text().strip() - os.environ[keyname] = key - api_keys[keyname] = key - return api_keys - llm_price = {} for k, v in litellm.model_cost.items(): llm_price[k] = v -embedding_models = [ - "openai/text-embedding-3-large", - "openai/text-embedding-3-small", - "mistral/mistral-embed", - ] +embedding_models = ["openai/text-embedding-3-large", + "openai/text-embedding-3-small", + "mistral/mistral-embed"] # steps : price -sd_price = { - "15": 0.001, - "30": 0.002, - "50": 0.004, - "100": 0.007, - "150": "0.01", -} +sd_price = {"15": 0.001, + "30": 0.002, + "50": 0.004, + "100": 0.007, + "150": 0.01} + +tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") +llm_cache = Memory(".cache", verbose=0) + def llm_cost_compute( input_cost: int, @@ -79,9 +56,6 @@ def llm_cost_compute( return input_cost * price["input_cost_per_token"] + output_cost * price["output_cost_per_token"] -tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") - - def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]): if isinstance(message, str): return len(tokenizer.encode(dedent(message))) @@ -90,7 +64,6 @@ def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]): elif isinstance(message, list): return sum([tkn_len(subel) for subel in message]) -llm_cache = Memory(".cache", verbose=0) @llm_cache.cache def chat( @@ -112,6 +85,7 @@ def chat( assert all(a["finish_reason"] == "stop" for a in answer["choices"]), f"Found bad finish_reason: '{answer}'" return answer + def wrapped_model_name_matcher(model: str) -> str: "find the best match for a modelname (wrapped to make some check)" # find the currently set api keys to avoid matching models from @@ -147,10 +121,10 @@ def wrapped_model_name_matcher(model: str) -> str: return match[0] else: print(f"Couldn't match the modelname {model} to any known model. " - "Continuing but this will probably crash DocToolsLLM further " - "down the code.") + "Continuing but this will probably crash further down the code.") return model + def model_name_matcher(model: str) -> str: """find the best match for a modelname (wrapper that checks if the matched model has a known cost and print the matched name)""" diff --git a/utils/logger.py b/utils/logger.py index d2f3db3..c6ff952 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -85,6 +85,6 @@ def create_loggers(local_file: Union[str, PosixPath], colors: List[str]): out = [] for col in colors: log = coloured_logger(col) - setattr(shared, "col", log) + setattr(shared, col, log) out.append(log) return out diff --git a/utils/shared.py b/utils/shared.py index 757dd9a..6fafacd 100644 --- a/utils/shared.py +++ b/utils/shared.py @@ -39,4 +39,3 @@ def __setattr__(self, name: str, value) -> None: raise TypeError(f'SharedModule forbids the creation of unexpected attribute "{name}"') shared = SharedModule() -