diff --git a/README.md b/README.md
index 81bec5b..22f5171 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@ This collection of scripts is the culmination of my efforts to contributes the A
## Tools
-### Illustrator
+### Illustrator
Creates custom mnemonic images for your cards using AI image generation. It:
- Analyzes card content to identify key concepts
- Generates creative visual memory hooks
@@ -387,14 +387,21 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a
Click to read more
+First, ensure that you API keys are set in you env variables.
+
+Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it.
+
+
#### Reformulator
+
+The reformulator expects the notes you modify to have a specific field present so that it can save the old versions and add logging. Modify the note type you want to reformulate by adding a `AnkiReformulator` field to it.
The Reformulator can be run from the command line:
```bash
python reformulator.py \
- --query "(rated:2:1 OR rated:2:2) -is:suspended" \
- --dataset_path "data/reformulator_dataset.txt" \
- --string_formatting "data/string_formatting.py" \
+ --query "note:Cloze (rated:2:1 OR rated:2:2) -is:suspended" \
+ --dataset_path "examples/reformulator_dataset.txt" \
+ --string_formatting "examples/string_formatting.py" \
--ntfy_url "ntfy.sh/YOUR_TOPIC" \
--main_field_index 0 \
--llm "openai/gpt-4" \
diff --git a/reformulator.py b/reformulator.py
index 6bd554c..25ae566 100644
--- a/reformulator.py
+++ b/reformulator.py
@@ -31,7 +31,7 @@
import litellm
from utils.misc import load_formatting_funcs, replace_media
-from utils.llm import load_api_keys, llm_price, tkn_len, chat, model_name_matcher
+from utils.llm import llm_price, tkn_len, chat, model_name_matcher
from utils.anki import anki, sync_anki, addtags, removetags, updatenote
from utils.logger import create_loggers
from utils.datasets import load_dataset, semantic_prompt_filtering
@@ -51,8 +51,6 @@
d = datetime.datetime.today()
today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}"
-load_api_keys()
-
# status string
STAT_CHANGED_CONT = "Content has been changed"
@@ -184,7 +182,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
[print(line) for line in traceback.format_tb(exc_traceback)]
print(str(exc_value))
print(str(exc_type))
- print("\n--verbose was used so opening debug console at the "
+ print("\n--debug was used so opening debug console at the "
"appropriate frame. Press 'c' to continue to the frame "
"of this print.")
pdb.post_mortem(exc_traceback)
@@ -209,7 +207,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
litellm.set_verbose = verbose
# arg sanity check and storing
- assert "note:" in query, "You have to specify a notetype in the query"
+ assert "note:" in query, f"You have to specify a notetype in the query ({query})"
assert mode in ["reformulate", "reset"], "Invalid value for 'mode'"
assert isinstance(exclude_done, bool), "exclude_done must be a boolean"
assert isinstance(exclude_version, bool), "exclude_version must be a boolean"
@@ -224,8 +222,13 @@ def handle_exception(exc_type, exc_value, exc_traceback):
parallel = int(parallel)
main_field_index = int(main_field_index)
assert main_field_index >= 0, "invalid field_index"
+ self.base_query = query
+ self.dataset_path = dataset_path
self.mode = mode
- if string_formatting is not None:
+ self.exclude_done = exclude_done
+ self.exclude_version = exclude_version
+
+ if string_formatting:
red(f"Loading specific string formatting from {string_formatting}")
cloze_input_parser, cloze_output_parser = load_formatting_funcs(
path=string_formatting,
@@ -256,29 +259,36 @@ def handle_exception(exc_type, exc_value, exc_traceback):
else:
raise Exception(f"{llm} not found in llm_price")
self.verbose = verbose
- if mode == "reformulate":
- if exclude_done:
+
+ def reformulate(self):
+ query = self.base_query
+ if self.mode == "reformulate":
+ if self.exclude_done:
query += " -AnkiReformulator::Done::*"
- if exclude_version:
+ if self.exclude_version:
query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\""
- # load db just in case
+ # load db just in case, and create one if it doesn't already exist
self.db_content = self.load_db()
if not self.db_content:
- red(
- "Empty database. If you have already ran anki_reformulator "
- "before then something went wrong!"
- )
- else:
- self.compute_cost(self.db_content)
+ red("Empty database. If you have already ran anki_reformulator "
+ "before then something went wrong!")
+ whi("Creating a empty database")
+ self.save_to_db({})
+ self.db_content = self.load_db()
+ assert self.db_content, "Could not create database"
+
+ whi("Computing estimated costs")
+ self.compute_cost(self.db_content)
# load dataset
- dataset = load_dataset(dataset_path)
- # check that each note is valid but exclude the system prompt
- for id, d in enumerate(dataset):
- if id != 0:
- dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"]
+ whi("Loading dataset")
+ dataset = load_dataset(self.dataset_path)
+ # check that each note is valid but exclude the system prompt, which is
+ # the first entry
+ for id, d in enumerate(dataset[1:]):
+ dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"]
assert len(dataset) % 2 == 1, "Even number of examples in dataset"
self.dataset = dataset
@@ -286,26 +296,25 @@ def handle_exception(exc_type, exc_value, exc_traceback):
nids = anki(action="findNotes",
query="tag:AnkiReformulator::RESETTING")
if nids:
- red(
- f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}"
- )
+ red(f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}")
nids = anki(action="findNotes", query="tag:AnkiReformulator::DOING")
if nids:
red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}")
- # find notes ids for the first time
+ # find notes ids for the specific note type
nids = anki(action="findNotes", query=query)
assert nids, f"No notes found for the query '{query}'"
- # find the model field names
- fields = anki(
- action="notesInfo",
- notes=[int(nids[0])]
- )[0]["fields"]
- assert (
- "AnkiReformulator" in fields.keys()
- ), "The notetype to edit must have a field called 'AnkiReformulator'"
- self.field_name = list(fields.keys())[0]
+ # find the field names for this note type
+ fields = anki(action="notesInfo",
+ notes=[int(nids[0])])[0]["fields"]
+ assert "AnkiReformulator" in fields.keys(), \
+ "The notetype to edit must have a field called 'AnkiReformulator'"
+ try:
+ self.field_name = list(fields.keys())[self.field_index]
+ except IndexError:
+ raise AssertionError(f"main_field_index {self.field_index} is invalid. "
+ f"Note only has {len(fields.keys())} fields!")
if self.exclude_media:
# now find notes ids after excluding the img in the important field
@@ -316,7 +325,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
query += f' -{self.field_name}:"*http://*"'
query += f' -{self.field_name}:"*https://*"'
- whi(f"Query to find note: {query}")
+ whi(f"Query to find note: '{query}'")
nids = anki(action="findNotes", query=query)
assert nids, f"No notes found for the query '{query}'"
whi(f"Found {len(nids)} notes")
@@ -326,14 +335,13 @@ def handle_exception(exc_type, exc_value, exc_traceback):
anki(action="notesInfo", notes=nids)
).set_index("noteId")
self.notes = self.notes.loc[nids]
- assert not self.notes.empty, "Empty notes df"
+ assert not self.notes.empty, "Empty notes"
- assert (
- len(set(self.notes["modelName"].tolist())) == 1
- ), "Contains more than 1 note type"
+ assert len(set(self.notes["modelName"].tolist())) == 1, \
+ "Contains more than 1 note type"
# check absence of image and sounds in the main field
- # as well incorrect tags
+ # as well as incorrect tags
for nid, note in self.notes.iterrows():
if self.exclude_media:
_, media = replace_media(
@@ -358,38 +366,24 @@ def handle_exception(exc_type, exc_value, exc_traceback):
else:
assert not tag.lower().startswith("ankireformulator")
+ # check if required tokens are higher than our limits
+ tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset)
+ tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"],
+ media=None,
+ mode="remove_media")[0])
+ for _, note in self.notes.iterrows())
+ assert tkn_sum <= tkn_warn_limit, (f"Found {tkn_sum} tokens to process, which is "
+ f"higher than the limit of {tkn_warn_limit}")
- # check if too many tokens
- tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset])
- tkn_sum += sum(
- [
- tkn_len(
- replace_media(
- content=note["fields"][self.field_name]["value"],
- media=None,
- mode="remove_media",
- )[0]
- )
- for _, note in self.notes.iterrows()
- ])
- if tkn_sum > tkn_warn_limit:
- raise Exception(
- f"Found {tkn_sum} tokens to process, which is "
- f"higher than the limit of {tkn_warn_limit}"
- )
-
- if len(self.notes) > n_note_limit:
- raise Exception(
- f"Found {len(self.notes)} notes to process "
- f"which is higher than the limit of {n_note_limit}"
- )
+ assert len(self.notes) <= n_note_limit, (f"Found {len(self.notes)} notes to process "
+ f"which is higher than the limit of {n_note_limit}")
if self.mode == "reformulate":
- func = self.reformulate
+ func = self.reformulate_note
elif self.mode == "reset":
- func = self.reset
+ func = self.reset_note
else:
- raise ValueError(self.mode)
+ raise ValueError(f"Unknown mode {self.mode}")
def error_wrapped_func(*args, **kwargs):
"""Wrapper that catches exceptions and marks failed notes with appropriate tags."""
@@ -397,7 +391,7 @@ def error_wrapped_func(*args, **kwargs):
return func(*args, **kwargs)
except Exception as err:
addtags(nid=note.name, tags="AnkiReformulator::FAILED")
- red(f"Error when running self.{self.mode}: '{err}'")
+ red(f"Error when running self.{func.__name__}: '{err}'")
return str(err)
# getting all the new values in parallel and using caching
@@ -413,11 +407,9 @@ def error_wrapped_func(*args, **kwargs):
)
)
- failed_runs = [
- self.notes.iloc[i_nv]
- for i_nv in range(len(new_values))
- if isinstance(new_values[i_nv], str)
- ]
+ failed_runs = [self.notes.iloc[i_nv]
+ for i_nv in range(len(new_values))
+ if isinstance(new_values[i_nv], str)]
if failed_runs:
red(f"Found {len(failed_runs)} failed notes")
failed_run_index = pd.DataFrame(failed_runs).index
@@ -427,6 +419,7 @@ def error_wrapped_func(*args, **kwargs):
assert len(new_values) == len(self.notes)
# applying the changes
+ whi("Applying changes")
for values in tqdm(new_values, desc="Applying changes to anki"):
if self.mode == "reformulate":
self.apply_reformulate(values)
@@ -435,8 +428,10 @@ def error_wrapped_func(*args, **kwargs):
else:
raise ValueError(self.mode)
+ whi("Clearing unused tags")
anki(action="clearUnusedTags")
+ # TODO: Why add and them remove them?
# add and remove the tag TODO to make it easier to re add by the user
# as it was cleared by calling 'clearUnusedTags'
nid, note = next(self.notes.iterrows())
@@ -445,7 +440,7 @@ def error_wrapped_func(*args, **kwargs):
sync_anki()
- # display again the total cost at the end
+ # display the total cost again at the end
db = self.load_db()
assert db, "Empty database at the end of the run. Something went wrong?"
self.compute_cost(db)
@@ -456,11 +451,11 @@ def compute_cost(self, db_content: List[Dict]) -> None:
This is used to know if something went wrong.
"""
n_db = len(db_content)
- red(f"Number of entries in databases/reformulator.db: {n_db}")
+ red(f"Number of entries in databases/reformulator/reformulator.db: {n_db}")
dol_costs = []
dol_missing = 0
for dic in db_content:
- if dic["mode"] != "reformulate":
+ if self.mode != "reformulate":
continue
try:
dol = float(dic["dollar_price"])
@@ -482,16 +477,16 @@ def compute_cost(self, db_content: List[Dict]) -> None:
elif dol_costs:
self._cost_so_far = dol_total
- def reformulate(self, nid: int, note: pd.Series) -> Dict:
+ def reformulate_note(self, nid: int, note: pd.Series) -> Dict:
"""Generate a reformulated version of a note's content using an LLM.
-
+
Parameters
----------
nid : int
Note ID from Anki
note : pd.Series
Row from the notes DataFrame containing the note data
-
+
Returns
-------
Dict
@@ -512,7 +507,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
# reformulate the content
content = note["fields"][self.field_name]["value"]
log["note_field_content"] = content
- formattedcontent = self.cloze_input_parser(content) if iscloze(content) else content
+ formattedcontent = self.cloze_input_parser(content)
log["note_field_formattedcontent"] = formattedcontent
# if the card is in the dataset, just take the dataset value directly
@@ -535,19 +530,20 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
elif d["role"] == "user":
newcontent = self.dataset[i + 1]["content"]
else:
- raise ValueError(
- f"Unexpected role of message in dataset: {d}")
+ raise ValueError(f"Unexpected role of message in dataset: {d}")
skip_llm = True
break
fc, media = replace_media(
content=formattedcontent,
media=None,
- mode="remove_media",
- )
+ mode="remove_media")
log["media"] = media
- if not skip_llm:
+ if skip_llm:
+ log["llm_answer"] = {"Skipped": True}
+ log["dollar_price"] = 0
+ else:
dataset = copy.deepcopy(self.dataset)
curr_mess = [{"role": "user", "content": fc}]
dataset = semantic_prompt_filtering(
@@ -559,8 +555,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
embedding_model=self.embedding_model,
whi=whi,
yel=yel,
- red=red,
- )
+ red=red)
dataset += curr_mess
assert dataset[0]["role"] == "system", "First message is not from system!"
@@ -603,16 +598,13 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
)
else:
log["dollar_price"] = "?"
- else:
- log["llm_answer"] = {"Skipped": True}
- log["dollar_price"] = 0
log["note_field_newcontent"] = newcontent
- formattednewcontent = self.cloze_output_parser(newcontent) if iscloze(newcontent) else newcontent
+ formattednewcontent = self.cloze_output_parser(newcontent)
log["note_field_formattednewcontent"] = formattednewcontent
log["status"] = STAT_OK_REFORM
- if iscloze(content + newcontent + formattednewcontent):
+ if iscloze(content) and iscloze( newcontent + formattednewcontent):
# check that no cloze were lost
for cl in getclozes(content):
cl = cl.split("::")[0] + "::"
@@ -628,7 +620,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
def apply_reformulate(self, log: Dict) -> None:
"""Apply reformulation changes to an Anki note and update its metadata.
-
+
Parameters
----------
log : Dict
@@ -651,7 +643,7 @@ def apply_reformulate(self, log: Dict) -> None:
new_minilog = rtoml.dumps(minilog, pretty=True)
new_minilog = new_minilog.strip().replace("\n", "
")
- previous_minilog = note["fields"]["AnkiReformulator"]["value"].strip()
+ previous_minilog = note["fields"].get("AnkiReformulator", {}).get("value", "").strip()
if previous_minilog:
new_minilog += ""
new_minilog += "
Older minilog
"
@@ -681,6 +673,7 @@ def apply_reformulate(self, log: Dict) -> None:
nid,
fields={
self.field_name: log["note_field_formattednewcontent"],
+ # TODO: Might be nice to not require this
"AnkiReformulator": new_minilog,
},
)
@@ -696,16 +689,16 @@ def apply_reformulate(self, log: Dict) -> None:
# remove DOING tag
removetags(nid, "AnkiReformulator::DOING")
- def reset(self, nid: int, note: pd.Series) -> Dict:
+ def reset_note(self, nid: int, note: pd.Series) -> Dict:
"""Reset a note back to its state before reformulation.
-
+
Parameters
----------
nid : int
Note ID from Anki
note : pd.Series
Row from the notes DataFrame containing the note data
-
+
Returns
-------
Dict
@@ -736,18 +729,14 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
]
if not entries:
- red(
- f"Entry not found for note {nid}. Looking for the content of "
- "the field AnkiReformulator"
- )
+ red(f"Entry not found for note {nid}. Looking for the content of "
+ "the field AnkiReformulator")
logfield = note["fields"]["AnkiReformulator"]["value"]
logfield = logfield.split(
"")[0] # keep most recent
if not logfield.strip():
- raise Exception(
- f"Note {nid} was not found in the db and its "
- "AnkiReformulator field was empty."
- )
+ raise Exception(f"Note {nid} was not found in the db and its "
+ "AnkiReformulator field was empty.")
# replace the [[c1::cloze]] by {{c1::cloze}}
logfield = logfield.replace("]]", "}}")
@@ -757,7 +746,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
# parse old content
buffer = []
- for i, line in enumerate(logfield.split("
")):
+ for line in logfield.split("
"):
if buffer:
try:
_ = rtoml.loads("".join(buffer + [line]))
@@ -776,10 +765,12 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
# parse new content at the time
buffer = []
- for i, line in enumerate(logfield.split("
")):
+ for line in logfield.split("
"):
if buffer:
try:
- _ = rtoml.loads("".join(buffer + [line]))
+ # TODO: What are you trying to do here? Just check that adding the line keeps valid toml?
+ # If so, you should catch the specific exception that the load function raises on error
+ rtoml.loads("".join(buffer + [line]))
buffer.append(line)
continue
except Exception:
@@ -879,7 +870,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
def apply_reset(self, log: Dict) -> None:
"""Apply reset changes to an Anki note and update its metadata.
-
+
Parameters
----------
log : Dict
@@ -933,10 +924,8 @@ def apply_reset(self, log: Dict) -> None:
# remove TO_RESET tag if present
removetags(nid, "AnkiReformulator::TO_RESET")
-
# remove Done tag
removetags(nid, "AnkiReformulator::Done")
-
# remove DOING tag
removetags(nid, "AnkiReformulator::RESETTING")
@@ -946,12 +935,12 @@ def apply_reset(self, log: Dict) -> None:
def save_to_db(self, dictionnary: Dict) -> bool:
"""Save a log dictionary to the SQLite database.
-
+
Parameters
----------
dictionnary : Dict
Log dictionary to save
-
+
Returns
-------
bool
@@ -976,34 +965,35 @@ def save_to_db(self, dictionnary: Dict) -> bool:
def load_db(self) -> Dict:
"""Load all log dictionaries from the SQLite database.
-
+
Returns
-------
Dict
All log dictionaries from the database, or False if database not found
"""
if not (REFORMULATOR_DIR / "reformulator.db").exists():
- red("db not found: '$REFORMULATOR_DIR/reformulator.db'")
+ red(f"db not found: '{REFORMULATOR_DIR}/reformulator.db'")
return False
conn = sqlite3.connect(str((REFORMULATOR_DIR / "reformulator.db").absolute()))
cursor = conn.cursor()
cursor.execute("SELECT data FROM dictionaries")
rows = cursor.fetchall()
- dictionaries = []
- for row in rows:
- dictionary = json.loads(zlib.decompress(row[0]))
- dictionaries.append(dictionary)
- return dictionaries
+ # TODO: Why do you compress? This just makes it more difficult to debug
+ return [json.loads(zlib.decompress(row[0])) for row in rows]
if __name__ == "__main__":
try:
args, kwargs = fire.Fire(lambda *args, **kwargs: [args, kwargs])
if "help" in kwargs:
- print(help(AnkiReformulator))
+ print(help(AnkiReformulator), file=sys.stderr)
else:
whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'")
- AnkiReformulator(*args, **kwargs)
- except Exception:
- sync_anki()
+ r = AnkiReformulator(*args, **kwargs)
+ r.reformulate()
+ sync_anki()
+ except AssertionError as e:
+ red(e)
+ except Exception as e:
+ red(e)
raise
diff --git a/utils/cloze_utils.py b/utils/cloze_utils.py
index e9a1a6c..52782c3 100644
--- a/utils/cloze_utils.py
+++ b/utils/cloze_utils.py
@@ -18,14 +18,16 @@ def iscloze(text: str) -> bool:
def getclozes(text: str) -> List[str]:
"return the cloze found in the text. Should only be called on cloze notes"
- assert iscloze(text)
+ assert iscloze(text), f"Text '{text}' does not contain a cloze"
return re.findall(CLOZE_REGEX, text)
def cloze_input_parser(cloze: str) -> str:
- """edits the cloze from anki before sending it to the LLM. This is useful
- if you use weird formatting that mess with LLMs"""
- assert iscloze(cloze), f"Invalid cloze: {cloze}"
+ """edit the cloze from anki before sending it to the LLM. This is useful
+ if you use weird formatting that mess with LLMs.
+ If the note content is not a cloze, then return it unmodified."""
+ if not iscloze(cloze):
+ return cloze
cloze = cloze.replace("\xa0", " ")
@@ -37,7 +39,6 @@ def cloze_input_parser(cloze: str) -> str:
# make spaces consitent
cloze = cloze.replace(" ", " ")
-
# misc
cloze = cloze.replace(">", ">")
cloze = cloze.replace("≥", ">=")
@@ -57,9 +58,12 @@ def cloze_output_parser(cloze: str) -> str:
cloze = cloze.strip()
# make sure all newlines are consistent for now
+ # TODO: You mean
?
cloze = cloze.replace("", "
")
+ cloze = cloze.replace("
", "
")
cloze = cloze.replace("\r", "
")
- cloze = cloze.replace("
", "\n")
+ # TODO: Not needed
+ # cloze = cloze.replace("
", "\n")
# make sure all spaces are consistent
cloze = cloze.replace(" ", " ")
@@ -68,4 +72,3 @@ def cloze_output_parser(cloze: str) -> str:
cloze = cloze.replace("\n", "
")
return cloze
-
diff --git a/utils/datasets.py b/utils/datasets.py
index df46e84..6ba4278 100644
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -1,3 +1,4 @@
+import collections
import json
import pandas as pd
import litellm
@@ -51,9 +52,10 @@ def load_dataset(
Returns
-------
- Dict
+ List
List of message dictionaries with 'role' and 'content' keys,
- validated according to check_dataset() rules
+ validated according to check_dataset() rules.
+ First message is the system message.
Raises
------
@@ -366,25 +368,21 @@ def semantic_prompt_filtering(
if len(output_pr) != len(prompt_messages):
red(f"Tokens of the kept prompts after {cnt} iterations: {tkns} (of all prompts: {all_tkns} tokens) Number of prompts: {len(output_pr)}/{len(prompt_messages)}")
- # check no duplicate prompts
+ # TODO: This complains about duplicates. It looks like for some reason the
+ # last one is added as assistant AND user, but we only compare the content.
contents = [pm["content"] for pm in output_pr]
- dupli = [dp for dp in contents if contents.count(dp) > 1]
+ dupli = [k for k,v in collections.Counter(contents).items() if v>1]
if dupli:
raise Exception(f"{len(dupli)} duplicate prompts found in memory.py: {dupli}")
- # remove unwanted keys
- for i, d in enumerate(output_pr):
- keys = [k for k in d.keys()]
- for k in keys:
- if k not in ["content", "role"]:
- del d[k]
- output_pr[i] = d
+ # Keep only the content and the role keys for each prompt
+ new_output = [{k: v for k, v in pk.items() if k in {"content", "role"}} for pk in output_pr]
- assert curr_mess not in output_pr
- assert output_pr, "No prompt were selected!"
- check_dataset(output_pr, **check_args)
+ assert curr_mess not in new_output
+ assert new_output, "No prompt were selected!"
+ check_dataset(new_output, **check_args)
- return output_pr
+ return new_output
def format_anchor_key(key: str) -> str:
diff --git a/utils/llm.py b/utils/llm.py
index 5188fe3..b399a4c 100644
--- a/utils/llm.py
+++ b/utils/llm.py
@@ -13,48 +13,25 @@
litellm.drop_params = True
-def load_api_keys() -> Dict:
- """Load API keys from files in the API_KEYS directory.
-
- Creates API_KEYS directory if it doesn't exist.
- Each file in API_KEYS/ should contain a single API key.
- The filename (without extension) becomes part of the environment variable name.
-
- Returns
- -------
- Dict
- Dictionary mapping environment variable names to API key values
- """
- Path("API_KEYS").mkdir(exist_ok=True)
- if not list(Path("API_KEYS").iterdir()):
- shared.red("## No API_KEYS found in API_KEYS")
- api_keys = {}
- for apifile in Path("API_KEYS").iterdir():
- keyname = f"{apifile.stem.upper()}_API_KEY"
- key = apifile.read_text().strip()
- os.environ[keyname] = key
- api_keys[keyname] = key
- return api_keys
-
llm_price = {}
for k, v in litellm.model_cost.items():
llm_price[k] = v
-embedding_models = [
- "openai/text-embedding-3-large",
- "openai/text-embedding-3-small",
- "mistral/mistral-embed",
- ]
+embedding_models = ["openai/text-embedding-3-large",
+ "openai/text-embedding-3-small",
+ "mistral/mistral-embed"]
# steps : price
-sd_price = {
- "15": 0.001,
- "30": 0.002,
- "50": 0.004,
- "100": 0.007,
- "150": "0.01",
-}
+sd_price = {"15": 0.001,
+ "30": 0.002,
+ "50": 0.004,
+ "100": 0.007,
+ "150": 0.01}
+
+tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+llm_cache = Memory(".cache", verbose=0)
+
def llm_cost_compute(
input_cost: int,
@@ -79,9 +56,6 @@ def llm_cost_compute(
return input_cost * price["input_cost_per_token"] + output_cost * price["output_cost_per_token"]
-tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
-
-
def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]):
if isinstance(message, str):
return len(tokenizer.encode(dedent(message)))
@@ -90,7 +64,6 @@ def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]):
elif isinstance(message, list):
return sum([tkn_len(subel) for subel in message])
-llm_cache = Memory(".cache", verbose=0)
@llm_cache.cache
def chat(
@@ -112,6 +85,7 @@ def chat(
assert all(a["finish_reason"] == "stop" for a in answer["choices"]), f"Found bad finish_reason: '{answer}'"
return answer
+
def wrapped_model_name_matcher(model: str) -> str:
"find the best match for a modelname (wrapped to make some check)"
# find the currently set api keys to avoid matching models from
@@ -147,10 +121,10 @@ def wrapped_model_name_matcher(model: str) -> str:
return match[0]
else:
print(f"Couldn't match the modelname {model} to any known model. "
- "Continuing but this will probably crash DocToolsLLM further "
- "down the code.")
+ "Continuing but this will probably crash further down the code.")
return model
+
def model_name_matcher(model: str) -> str:
"""find the best match for a modelname (wrapper that checks if the matched
model has a known cost and print the matched name)"""
diff --git a/utils/logger.py b/utils/logger.py
index d2f3db3..c6ff952 100644
--- a/utils/logger.py
+++ b/utils/logger.py
@@ -85,6 +85,6 @@ def create_loggers(local_file: Union[str, PosixPath], colors: List[str]):
out = []
for col in colors:
log = coloured_logger(col)
- setattr(shared, "col", log)
+ setattr(shared, col, log)
out.append(log)
return out
diff --git a/utils/shared.py b/utils/shared.py
index 757dd9a..6fafacd 100644
--- a/utils/shared.py
+++ b/utils/shared.py
@@ -39,4 +39,3 @@ def __setattr__(self, name: str, value) -> None:
raise TypeError(f'SharedModule forbids the creation of unexpected attribute "{name}"')
shared = SharedModule()
-