diff --git a/README.md b/README.md
index 81bec5b..22f5171 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@ This collection of scripts is the culmination of my efforts to contributes the A
 
 ## Tools
 
-### Illustrator 
+### Illustrator
 Creates custom mnemonic images for your cards using AI image generation. It:
 - Analyzes card content to identify key concepts
 - Generates creative visual memory hooks
@@ -387,14 +387,21 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a
 Click to read more
 </summary>
 
+First, ensure that you API keys are set in you env variables.
+
+Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it.
+
+
 #### Reformulator
+
+The reformulator expects the notes you modify to have a specific field present so that it can save the old versions and add logging. Modify the note type you want to reformulate by adding a `AnkiReformulator` field to it. 
 The Reformulator can be run from the command line:
 
 ```bash
 python reformulator.py \
-    --query "(rated:2:1 OR rated:2:2) -is:suspended" \
-    --dataset_path "data/reformulator_dataset.txt" \
-    --string_formatting "data/string_formatting.py" \
+    --query "note:Cloze (rated:2:1 OR rated:2:2) -is:suspended" \
+    --dataset_path "examples/reformulator_dataset.txt" \
+    --string_formatting "examples/string_formatting.py" \
     --ntfy_url "ntfy.sh/YOUR_TOPIC" \
     --main_field_index 0 \
     --llm "openai/gpt-4" \
diff --git a/reformulator.py b/reformulator.py
index 6bd554c..25ae566 100644
--- a/reformulator.py
+++ b/reformulator.py
@@ -31,7 +31,7 @@
 import litellm
 
 from utils.misc import load_formatting_funcs, replace_media
-from utils.llm import load_api_keys, llm_price, tkn_len, chat, model_name_matcher
+from utils.llm import llm_price, tkn_len, chat, model_name_matcher
 from utils.anki import anki, sync_anki, addtags, removetags, updatenote
 from utils.logger import create_loggers
 from utils.datasets import load_dataset, semantic_prompt_filtering
@@ -51,8 +51,6 @@
 d = datetime.datetime.today()
 today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}"
 
-load_api_keys()
-
 
 # status string
 STAT_CHANGED_CONT = "Content has been changed"
@@ -184,7 +182,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
                     [print(line) for line in traceback.format_tb(exc_traceback)]
                     print(str(exc_value))
                     print(str(exc_type))
-                    print("\n--verbose was used so opening debug console at the "
+                    print("\n--debug was used so opening debug console at the "
                       "appropriate frame. Press 'c' to continue to the frame "
                       "of this print.")
                     pdb.post_mortem(exc_traceback)
@@ -209,7 +207,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
         litellm.set_verbose = verbose
 
         # arg sanity check and storing
-        assert "note:" in query, "You have to specify a notetype in the query"
+        assert "note:" in query, f"You have to specify a notetype in the query ({query})"
         assert mode in ["reformulate", "reset"], "Invalid value for 'mode'"
         assert isinstance(exclude_done, bool), "exclude_done must be a boolean"
         assert isinstance(exclude_version, bool), "exclude_version must be a boolean"
@@ -224,8 +222,13 @@ def handle_exception(exc_type, exc_value, exc_traceback):
             parallel = int(parallel)
         main_field_index = int(main_field_index)
         assert main_field_index >= 0, "invalid field_index"
+        self.base_query = query
+        self.dataset_path = dataset_path
         self.mode = mode
-        if string_formatting is not None:
+        self.exclude_done = exclude_done
+        self.exclude_version = exclude_version
+
+        if string_formatting:
             red(f"Loading specific string formatting from {string_formatting}")
             cloze_input_parser, cloze_output_parser = load_formatting_funcs(
                     path=string_formatting,
@@ -256,29 +259,36 @@ def handle_exception(exc_type, exc_value, exc_traceback):
         else:
             raise Exception(f"{llm} not found in llm_price")
         self.verbose = verbose
-        if mode == "reformulate":
-            if exclude_done:
+
+    def reformulate(self):
+        query = self.base_query
+        if self.mode == "reformulate":
+            if self.exclude_done:
                 query += " -AnkiReformulator::Done::*"
 
-            if exclude_version:
+            if self.exclude_version:
                 query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\""
 
-        # load db just in case
+        # load db just in case, and create one if it doesn't already exist
         self.db_content = self.load_db()
         if not self.db_content:
-            red(
-                "Empty database. If you have already ran anki_reformulator "
-                "before then something went wrong!"
-            )
-        else:
-            self.compute_cost(self.db_content)
+            red("Empty database. If you have already ran anki_reformulator "
+                "before then something went wrong!")
+            whi("Creating a empty database")
+            self.save_to_db({})
+            self.db_content = self.load_db()
+            assert self.db_content, "Could not create database"
+
+        whi("Computing estimated costs")
+        self.compute_cost(self.db_content)
 
         # load dataset
-        dataset = load_dataset(dataset_path)
-        # check that each note is valid but exclude the system prompt
-        for id, d in enumerate(dataset):
-            if id != 0:
-                dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"]
+        whi("Loading dataset")
+        dataset = load_dataset(self.dataset_path)
+        # check that each note is valid but exclude the system prompt, which is
+        # the first entry
+        for id, d in enumerate(dataset[1:]):
+            dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"]
         assert len(dataset) % 2 == 1, "Even number of examples in dataset"
         self.dataset = dataset
 
@@ -286,26 +296,25 @@ def handle_exception(exc_type, exc_value, exc_traceback):
         nids = anki(action="findNotes",
                     query="tag:AnkiReformulator::RESETTING")
         if nids:
-            red(
-                f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}"
-            )
+            red(f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}")
         nids = anki(action="findNotes", query="tag:AnkiReformulator::DOING")
         if nids:
             red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}")
 
-        # find notes ids for the first time
+        # find notes ids for the specific note type
         nids = anki(action="findNotes", query=query)
         assert nids, f"No notes found for the query '{query}'"
 
-        # find the model field names
-        fields = anki(
-            action="notesInfo",
-            notes=[int(nids[0])]
-            )[0]["fields"]
-        assert (
-            "AnkiReformulator" in fields.keys()
-        ), "The notetype to edit must have a field called 'AnkiReformulator'"
-        self.field_name = list(fields.keys())[0]
+        # find the field names for this note type
+        fields = anki(action="notesInfo",
+                      notes=[int(nids[0])])[0]["fields"]
+        assert "AnkiReformulator" in fields.keys(), \
+                "The notetype to edit must have a field called 'AnkiReformulator'"
+        try:
+            self.field_name = list(fields.keys())[self.field_index]
+        except IndexError:
+            raise AssertionError(f"main_field_index {self.field_index} is invalid. "
+                                 f"Note only has {len(fields.keys())} fields!")
 
         if self.exclude_media:
             # now find notes ids after excluding the img in the important field
@@ -316,7 +325,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
             query += f' -{self.field_name}:"*http://*"'
             query += f' -{self.field_name}:"*https://*"'
 
-        whi(f"Query to find note: {query}")
+        whi(f"Query to find note: '{query}'")
         nids = anki(action="findNotes", query=query)
         assert nids, f"No notes found for the query '{query}'"
         whi(f"Found {len(nids)} notes")
@@ -326,14 +335,13 @@ def handle_exception(exc_type, exc_value, exc_traceback):
             anki(action="notesInfo", notes=nids)
         ).set_index("noteId")
         self.notes = self.notes.loc[nids]
-        assert not self.notes.empty, "Empty notes df"
+        assert not self.notes.empty, "Empty notes"
 
-        assert (
-            len(set(self.notes["modelName"].tolist())) == 1
-        ), "Contains more than 1 note type"
+        assert len(set(self.notes["modelName"].tolist())) == 1, \
+                "Contains more than 1 note type"
 
         # check absence of image and sounds in the main field
-        # as well incorrect tags
+        # as well as incorrect tags
         for nid, note in self.notes.iterrows():
             if self.exclude_media:
                 _, media = replace_media(
@@ -358,38 +366,24 @@ def handle_exception(exc_type, exc_value, exc_traceback):
                 else:
                     assert not tag.lower().startswith("ankireformulator")
 
+        # check if required tokens are higher than our limits
+        tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset)
+        tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"],
+                                             media=None,
+                                             mode="remove_media")[0])
+                       for _, note in self.notes.iterrows())
+        assert tkn_sum <= tkn_warn_limit, (f"Found {tkn_sum} tokens to process, which is "
+                                           f"higher than the limit of {tkn_warn_limit}")
 
-        # check if too many tokens
-        tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset])
-        tkn_sum += sum(
-            [
-                tkn_len(
-                    replace_media(
-                        content=note["fields"][self.field_name]["value"],
-                        media=None,
-                        mode="remove_media",
-                    )[0]
-                )
-                for _, note in self.notes.iterrows()
-            ])
-        if tkn_sum > tkn_warn_limit:
-            raise Exception(
-                f"Found {tkn_sum} tokens to process, which is "
-                f"higher than the limit of {tkn_warn_limit}"
-            )
-
-        if len(self.notes) > n_note_limit:
-            raise Exception(
-                f"Found {len(self.notes)} notes to process "
-                f"which is higher than the limit of {n_note_limit}"
-            )
+        assert len(self.notes) <= n_note_limit, (f"Found {len(self.notes)} notes to process "
+                                                 f"which is higher than the limit of {n_note_limit}")
 
         if self.mode == "reformulate":
-            func = self.reformulate
+            func = self.reformulate_note
         elif self.mode == "reset":
-            func = self.reset
+            func = self.reset_note
         else:
-            raise ValueError(self.mode)
+            raise ValueError(f"Unknown mode {self.mode}")
 
         def error_wrapped_func(*args, **kwargs):
             """Wrapper that catches exceptions and marks failed notes with appropriate tags."""
@@ -397,7 +391,7 @@ def error_wrapped_func(*args, **kwargs):
                 return func(*args, **kwargs)
             except Exception as err:
                 addtags(nid=note.name, tags="AnkiReformulator::FAILED")
-                red(f"Error when running self.{self.mode}: '{err}'")
+                red(f"Error when running self.{func.__name__}: '{err}'")
                 return str(err)
 
         # getting all the new values in parallel and using caching
@@ -413,11 +407,9 @@ def error_wrapped_func(*args, **kwargs):
             )
         )
 
-        failed_runs = [
-            self.notes.iloc[i_nv]
-            for i_nv in range(len(new_values))
-            if isinstance(new_values[i_nv], str)
-        ]
+        failed_runs = [self.notes.iloc[i_nv]
+                       for i_nv in range(len(new_values))
+                       if isinstance(new_values[i_nv], str)]
         if failed_runs:
             red(f"Found {len(failed_runs)} failed notes")
             failed_run_index = pd.DataFrame(failed_runs).index
@@ -427,6 +419,7 @@ def error_wrapped_func(*args, **kwargs):
             assert len(new_values) == len(self.notes)
 
         # applying the changes
+        whi("Applying changes")
         for values in tqdm(new_values, desc="Applying changes to anki"):
             if self.mode == "reformulate":
                 self.apply_reformulate(values)
@@ -435,8 +428,10 @@ def error_wrapped_func(*args, **kwargs):
             else:
                 raise ValueError(self.mode)
 
+        whi("Clearing unused tags")
         anki(action="clearUnusedTags")
 
+        # TODO: Why add and them remove them?
         # add and remove the tag TODO to make it easier to re add by the user
         # as it was cleared by calling 'clearUnusedTags'
         nid, note = next(self.notes.iterrows())
@@ -445,7 +440,7 @@ def error_wrapped_func(*args, **kwargs):
 
         sync_anki()
 
-        # display again the total cost at the end
+        # display the total cost again at the end
         db = self.load_db()
         assert db, "Empty database at the end of the run. Something went wrong?"
         self.compute_cost(db)
@@ -456,11 +451,11 @@ def compute_cost(self, db_content: List[Dict]) -> None:
         This is used to know if something went wrong.
         """
         n_db = len(db_content)
-        red(f"Number of entries in databases/reformulator.db: {n_db}")
+        red(f"Number of entries in databases/reformulator/reformulator.db: {n_db}")
         dol_costs = []
         dol_missing = 0
         for dic in db_content:
-            if dic["mode"] != "reformulate":
+            if self.mode != "reformulate":
                 continue
             try:
                 dol = float(dic["dollar_price"])
@@ -482,16 +477,16 @@ def compute_cost(self, db_content: List[Dict]) -> None:
         elif dol_costs:
             self._cost_so_far = dol_total
 
-    def reformulate(self, nid: int, note: pd.Series) -> Dict:
+    def reformulate_note(self, nid: int, note: pd.Series) -> Dict:
         """Generate a reformulated version of a note's content using an LLM.
-        
+
         Parameters
         ----------
         nid : int
             Note ID from Anki
         note : pd.Series
             Row from the notes DataFrame containing the note data
-            
+
         Returns
         -------
         Dict
@@ -512,7 +507,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
         # reformulate the content
         content = note["fields"][self.field_name]["value"]
         log["note_field_content"] = content
-        formattedcontent = self.cloze_input_parser(content) if iscloze(content) else content
+        formattedcontent = self.cloze_input_parser(content)
         log["note_field_formattedcontent"] = formattedcontent
 
         # if the card is in the dataset, just take the dataset value directly
@@ -535,19 +530,20 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
                 elif d["role"] == "user":
                     newcontent = self.dataset[i + 1]["content"]
                 else:
-                    raise ValueError(
-                        f"Unexpected role of message in dataset: {d}")
+                    raise ValueError(f"Unexpected role of message in dataset: {d}")
                 skip_llm = True
                 break
 
         fc, media = replace_media(
             content=formattedcontent,
             media=None,
-            mode="remove_media",
-        )
+            mode="remove_media")
         log["media"] = media
 
-        if not skip_llm:
+        if skip_llm:
+            log["llm_answer"] = {"Skipped": True}
+            log["dollar_price"] = 0
+        else:
             dataset = copy.deepcopy(self.dataset)
             curr_mess = [{"role": "user", "content": fc}]
             dataset = semantic_prompt_filtering(
@@ -559,8 +555,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
                 embedding_model=self.embedding_model,
                 whi=whi,
                 yel=yel,
-                red=red,
-            )
+                red=red)
             dataset += curr_mess
 
             assert dataset[0]["role"] == "system", "First message is not from system!"
@@ -603,16 +598,13 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
                 )
             else:
                 log["dollar_price"] = "?"
-        else:
-            log["llm_answer"] = {"Skipped": True}
-            log["dollar_price"] = 0
 
         log["note_field_newcontent"] = newcontent
-        formattednewcontent = self.cloze_output_parser(newcontent) if iscloze(newcontent) else newcontent
+        formattednewcontent = self.cloze_output_parser(newcontent)
         log["note_field_formattednewcontent"] = formattednewcontent
         log["status"] = STAT_OK_REFORM
 
-        if iscloze(content + newcontent + formattednewcontent):
+        if iscloze(content) and iscloze( newcontent + formattednewcontent):
             # check that no cloze were lost
             for cl in getclozes(content):
                 cl = cl.split("::")[0] + "::"
@@ -628,7 +620,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
 
     def apply_reformulate(self, log: Dict) -> None:
         """Apply reformulation changes to an Anki note and update its metadata.
-        
+
         Parameters
         ----------
         log : Dict
@@ -651,7 +643,7 @@ def apply_reformulate(self, log: Dict) -> None:
 
         new_minilog = rtoml.dumps(minilog, pretty=True)
         new_minilog = new_minilog.strip().replace("\n", "<br>")
-        previous_minilog = note["fields"]["AnkiReformulator"]["value"].strip()
+        previous_minilog = note["fields"].get("AnkiReformulator", {}).get("value", "").strip()
         if previous_minilog:
             new_minilog += "<!--SEPARATOR-->"
             new_minilog += "<br><br><details><summary>Older minilog</summary>"
@@ -681,6 +673,7 @@ def apply_reformulate(self, log: Dict) -> None:
             nid,
             fields={
                 self.field_name: log["note_field_formattednewcontent"],
+                # TODO: Might be nice to not require this
                 "AnkiReformulator": new_minilog,
             },
         )
@@ -696,16 +689,16 @@ def apply_reformulate(self, log: Dict) -> None:
         # remove DOING tag
         removetags(nid, "AnkiReformulator::DOING")
 
-    def reset(self, nid: int, note: pd.Series) -> Dict:
+    def reset_note(self, nid: int, note: pd.Series) -> Dict:
         """Reset a note back to its state before reformulation.
-        
+
         Parameters
         ----------
         nid : int
             Note ID from Anki
         note : pd.Series
             Row from the notes DataFrame containing the note data
-            
+
         Returns
         -------
         Dict
@@ -736,18 +729,14 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
         ]
 
         if not entries:
-            red(
-                f"Entry not found for note {nid}. Looking for the content of "
-                "the field AnkiReformulator"
-            )
+            red(f"Entry not found for note {nid}. Looking for the content of "
+                "the field AnkiReformulator")
             logfield = note["fields"]["AnkiReformulator"]["value"]
             logfield = logfield.split(
                 "<!--SEPARATOR-->")[0]  # keep most recent
             if not logfield.strip():
-                raise Exception(
-                    f"Note {nid} was not found in the db and its "
-                    "AnkiReformulator field was empty."
-                )
+                raise Exception(f"Note {nid} was not found in the db and its "
+                                "AnkiReformulator field was empty.")
 
             # replace the [[c1::cloze]] by {{c1::cloze}}
             logfield = logfield.replace("]]", "}}")
@@ -757,7 +746,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
 
             # parse old content
             buffer = []
-            for i, line in enumerate(logfield.split("<br>")):
+            for line in logfield.split("<br>"):
                 if buffer:
                     try:
                         _ = rtoml.loads("".join(buffer + [line]))
@@ -776,10 +765,12 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
 
             # parse new content at the time
             buffer = []
-            for i, line in enumerate(logfield.split("<br>")):
+            for line in logfield.split("<br>"):
                 if buffer:
                     try:
-                        _ = rtoml.loads("".join(buffer + [line]))
+                        # TODO: What are you trying to do here? Just check that adding the line keeps valid toml?
+                        # If so, you should catch the specific exception that the load function raises on error
+                        rtoml.loads("".join(buffer + [line]))
                         buffer.append(line)
                         continue
                     except Exception:
@@ -879,7 +870,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict:
 
     def apply_reset(self, log: Dict) -> None:
         """Apply reset changes to an Anki note and update its metadata.
-        
+
         Parameters
         ----------
         log : Dict
@@ -933,10 +924,8 @@ def apply_reset(self, log: Dict) -> None:
 
         # remove TO_RESET tag if present
         removetags(nid, "AnkiReformulator::TO_RESET")
-
         # remove Done tag
         removetags(nid, "AnkiReformulator::Done")
-
         # remove DOING tag
         removetags(nid, "AnkiReformulator::RESETTING")
 
@@ -946,12 +935,12 @@ def apply_reset(self, log: Dict) -> None:
 
     def save_to_db(self, dictionnary: Dict) -> bool:
         """Save a log dictionary to the SQLite database.
-        
+
         Parameters
         ----------
         dictionnary : Dict
             Log dictionary to save
-            
+
         Returns
         -------
         bool
@@ -976,34 +965,35 @@ def save_to_db(self, dictionnary: Dict) -> bool:
 
     def load_db(self) -> Dict:
         """Load all log dictionaries from the SQLite database.
-        
+
         Returns
         -------
         Dict
             All log dictionaries from the database, or False if database not found
         """
         if not (REFORMULATOR_DIR / "reformulator.db").exists():
-            red("db not found: '$REFORMULATOR_DIR/reformulator.db'")
+            red(f"db not found: '{REFORMULATOR_DIR}/reformulator.db'")
             return False
         conn = sqlite3.connect(str((REFORMULATOR_DIR / "reformulator.db").absolute()))
         cursor = conn.cursor()
         cursor.execute("SELECT data FROM dictionaries")
         rows = cursor.fetchall()
-        dictionaries = []
-        for row in rows:
-            dictionary = json.loads(zlib.decompress(row[0]))
-            dictionaries.append(dictionary)
-        return dictionaries
+        # TODO: Why do you compress? This just makes it more difficult to debug
+        return [json.loads(zlib.decompress(row[0])) for row in rows]
 
 
 if __name__ == "__main__":
     try:
         args, kwargs = fire.Fire(lambda *args, **kwargs: [args, kwargs])
         if "help" in kwargs:
-            print(help(AnkiReformulator))
+            print(help(AnkiReformulator), file=sys.stderr)
         else:
             whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'")
-            AnkiReformulator(*args, **kwargs)
-    except Exception:
-        sync_anki()
+            r = AnkiReformulator(*args, **kwargs)
+            r.reformulate()
+            sync_anki()
+    except AssertionError as e:
+        red(e)
+    except Exception as e:
+        red(e)
         raise
diff --git a/utils/cloze_utils.py b/utils/cloze_utils.py
index e9a1a6c..52782c3 100644
--- a/utils/cloze_utils.py
+++ b/utils/cloze_utils.py
@@ -18,14 +18,16 @@ def iscloze(text: str) -> bool:
 
 def getclozes(text: str) -> List[str]:
     "return the cloze found in the text. Should only be called on cloze notes"
-    assert iscloze(text)
+    assert iscloze(text), f"Text '{text}' does not contain a cloze"
     return re.findall(CLOZE_REGEX, text)
 
 
 def cloze_input_parser(cloze: str) -> str:
-    """edits the cloze from anki before sending it to the LLM. This is useful
-    if you use weird formatting that mess with LLMs"""
-    assert iscloze(cloze), f"Invalid cloze: {cloze}"
+    """edit the cloze from anki before sending it to the LLM. This is useful
+    if you use weird formatting that mess with LLMs.
+    If the note content is not a cloze, then return it unmodified."""
+    if not iscloze(cloze):
+        return cloze
 
     cloze = cloze.replace("\xa0", " ")
 
@@ -37,7 +39,6 @@ def cloze_input_parser(cloze: str) -> str:
     # make spaces consitent
     cloze = cloze.replace("&nbsp;", " ")
 
-
     # misc
     cloze = cloze.replace("&gt;", ">")
     cloze = cloze.replace("&ge;", ">=")
@@ -57,9 +58,12 @@ def cloze_output_parser(cloze: str) -> str:
     cloze = cloze.strip()
 
     # make sure all newlines are consistent for now
+    # TODO: You mean <br/>?
     cloze = cloze.replace("</br>", "<br>")
+    cloze = cloze.replace("<br/>", "<br>")
     cloze = cloze.replace("\r", "<br>")
-    cloze = cloze.replace("<br>", "\n")
+    # TODO: Not needed
+    # cloze = cloze.replace("<br>", "\n")
 
     # make sure all spaces are consistent
     cloze = cloze.replace("&nbsp;", " ")
@@ -68,4 +72,3 @@ def cloze_output_parser(cloze: str) -> str:
     cloze = cloze.replace("\n", "<br>")
 
     return cloze
-
diff --git a/utils/datasets.py b/utils/datasets.py
index df46e84..6ba4278 100644
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -1,3 +1,4 @@
+import collections
 import json
 import pandas as pd
 import litellm
@@ -51,9 +52,10 @@ def load_dataset(
 
     Returns
     -------
-    Dict
+    List
         List of message dictionaries with 'role' and 'content' keys,
-        validated according to check_dataset() rules
+        validated according to check_dataset() rules.
+        First message is the system message.
 
     Raises
     ------
@@ -366,25 +368,21 @@ def semantic_prompt_filtering(
     if len(output_pr) != len(prompt_messages):
         red(f"Tokens of the kept prompts after {cnt} iterations: {tkns} (of all prompts: {all_tkns} tokens) Number of prompts: {len(output_pr)}/{len(prompt_messages)}")
 
-    # check no duplicate prompts
+    # TODO: This complains about duplicates. It looks like for some reason the
+    # last one is added as assistant AND user, but we only compare the content.
     contents = [pm["content"] for pm in output_pr]
-    dupli = [dp for dp in contents if contents.count(dp) > 1]
+    dupli = [k for k,v in collections.Counter(contents).items() if v>1]
     if dupli:
         raise Exception(f"{len(dupli)} duplicate prompts found in memory.py: {dupli}")
 
-    # remove unwanted keys
-    for i, d in enumerate(output_pr):
-        keys = [k for k in d.keys()]
-        for k in keys:
-            if k not in ["content", "role"]:
-                del d[k]
-        output_pr[i] = d
+    # Keep only the content and the role keys for each prompt
+    new_output = [{k: v for k, v in pk.items() if k in {"content", "role"}} for pk in output_pr]
 
-    assert curr_mess not in output_pr
-    assert output_pr, "No prompt were selected!"
-    check_dataset(output_pr, **check_args)
+    assert curr_mess not in new_output
+    assert new_output, "No prompt were selected!"
+    check_dataset(new_output, **check_args)
 
-    return output_pr
+    return new_output
 
 
 def format_anchor_key(key: str) -> str:
diff --git a/utils/llm.py b/utils/llm.py
index 5188fe3..b399a4c 100644
--- a/utils/llm.py
+++ b/utils/llm.py
@@ -13,48 +13,25 @@
 
 litellm.drop_params = True
 
-def load_api_keys() -> Dict:
-    """Load API keys from files in the API_KEYS directory.
-
-    Creates API_KEYS directory if it doesn't exist.
-    Each file in API_KEYS/ should contain a single API key.
-    The filename (without extension) becomes part of the environment variable name.
-
-    Returns
-    -------
-    Dict
-        Dictionary mapping environment variable names to API key values
-    """
-    Path("API_KEYS").mkdir(exist_ok=True)
-    if not list(Path("API_KEYS").iterdir()):
-        shared.red("## No API_KEYS found in API_KEYS")
-    api_keys = {}
-    for apifile in Path("API_KEYS").iterdir():
-        keyname = f"{apifile.stem.upper()}_API_KEY"
-        key = apifile.read_text().strip()
-        os.environ[keyname] = key
-        api_keys[keyname] = key
-    return api_keys
-
 
 llm_price = {}
 for k, v in litellm.model_cost.items():
     llm_price[k] = v
 
-embedding_models = [
-        "openai/text-embedding-3-large",
-        "openai/text-embedding-3-small",
-        "mistral/mistral-embed",
-        ]
+embedding_models = ["openai/text-embedding-3-large",
+                    "openai/text-embedding-3-small",
+                    "mistral/mistral-embed"]
 
 # steps : price
-sd_price = {
-    "15": 0.001,
-    "30": 0.002,
-    "50": 0.004,
-    "100": 0.007,
-    "150": "0.01",
-}
+sd_price = {"15": 0.001,
+            "30": 0.002,
+            "50": 0.004,
+            "100": 0.007,
+            "150": 0.01}
+
+tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+llm_cache = Memory(".cache", verbose=0)
+
 
 def llm_cost_compute(
         input_cost: int,
@@ -79,9 +56,6 @@ def llm_cost_compute(
     return input_cost * price["input_cost_per_token"] + output_cost * price["output_cost_per_token"]
 
 
-tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
-
-
 def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]):
     if isinstance(message, str):
         return len(tokenizer.encode(dedent(message)))
@@ -90,7 +64,6 @@ def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]):
     elif isinstance(message, list):
         return sum([tkn_len(subel) for subel in message])
 
-llm_cache = Memory(".cache", verbose=0)
 
 @llm_cache.cache
 def chat(
@@ -112,6 +85,7 @@ def chat(
         assert all(a["finish_reason"] == "stop" for a in answer["choices"]), f"Found bad finish_reason: '{answer}'"
     return answer
 
+
 def wrapped_model_name_matcher(model: str) -> str:
     "find the best match for a modelname (wrapped to make some check)"
     # find the currently set api keys to avoid matching models from
@@ -147,10 +121,10 @@ def wrapped_model_name_matcher(model: str) -> str:
         return match[0]
     else:
         print(f"Couldn't match the modelname {model} to any known model. "
-            "Continuing but this will probably crash DocToolsLLM further "
-            "down the code.")
+              "Continuing but this will probably crash further down the code.")
         return model
 
+
 def model_name_matcher(model: str) -> str:
     """find the best match for a modelname (wrapper that checks if the matched
     model has a known cost and print the matched name)"""
diff --git a/utils/logger.py b/utils/logger.py
index d2f3db3..c6ff952 100644
--- a/utils/logger.py
+++ b/utils/logger.py
@@ -85,6 +85,6 @@ def create_loggers(local_file: Union[str, PosixPath], colors: List[str]):
     out = []
     for col in colors:
         log = coloured_logger(col)
-        setattr(shared, "col", log)
+        setattr(shared, col, log)
         out.append(log)
     return out
diff --git a/utils/shared.py b/utils/shared.py
index 757dd9a..6fafacd 100644
--- a/utils/shared.py
+++ b/utils/shared.py
@@ -39,4 +39,3 @@ def __setattr__(self, name: str, value) -> None:
             raise TypeError(f'SharedModule forbids the creation of unexpected attribute "{name}"')
 
 shared = SharedModule()
-