Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get reformulator working #7

Open
wants to merge 10 commits into
base: public
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ This collection of scripts is the culmination of my efforts to contributes the A

## Tools

### Illustrator
### Illustrator
Creates custom mnemonic images for your cards using AI image generation. It:
- Analyzes card content to identify key concepts
- Generates creative visual memory hooks
Expand Down Expand Up @@ -387,14 +387,24 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a
Click to read more
</summary>

First, create an _API_KEYS/_ directory and place your API key in a separate file.

Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it.

#### Reformulator

Next... create a database? it expects a sqlite db in databases/reformulator/reformulator?

Next... something about adding a field called `AnkiReformulator` to notes you want to change?
* Do you have to create a special note type for this to work?

The Reformulator can be run from the command line:

```bash
python reformulator.py \
--query "(rated:2:1 OR rated:2:2) -is:suspended" \
--dataset_path "data/reformulator_dataset.txt" \
--string_formatting "data/string_formatting.py" \
--dataset_path "examples/reformulator_dataset.txt" \
--string_formatting "examples/string_formatting.py" \
--ntfy_url "ntfy.sh/YOUR_TOPIC" \
--main_field_index 0 \
--llm "openai/gpt-4" \
Expand Down
56 changes: 25 additions & 31 deletions reformulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
d = datetime.datetime.today()
today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}"

whi("Loading api keys")
load_api_keys()


Expand Down Expand Up @@ -184,7 +185,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
[print(line) for line in traceback.format_tb(exc_traceback)]
print(str(exc_value))
print(str(exc_type))
print("\n--verbose was used so opening debug console at the "
print("\n--debug was used so opening debug console at the "
"appropriate frame. Press 'c' to continue to the frame "
"of this print.")
pdb.post_mortem(exc_traceback)
Expand All @@ -203,13 +204,14 @@ def handle_exception(exc_type, exc_value, exc_traceback):
print(json.dumps(db_content, ensure_ascii=False, indent=4))
return
else:
sync_anki()
# sync_anki()
assert query is not None, "Must specify --query"
assert dataset_path is not None, "Must specify --dataset_path"
litellm.set_verbose = verbose

# arg sanity check and storing
assert "note:" in query, "You have to specify a notetype in the query"
# TODO: Is this needed? The example in the readme doesn't set it
# assert "note:" in query, f"You have to specify a notetype in the query ({query})"
assert mode in ["reformulate", "reset"], "Invalid value for 'mode'"
assert isinstance(exclude_done, bool), "exclude_done must be a boolean"
assert isinstance(exclude_version, bool), "exclude_version must be a boolean"
Expand All @@ -225,7 +227,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
main_field_index = int(main_field_index)
assert main_field_index >= 0, "invalid field_index"
self.mode = mode
if string_formatting is not None:
if string_formatting:
red(f"Loading specific string formatting from {string_formatting}")
cloze_input_parser, cloze_output_parser = load_formatting_funcs(
path=string_formatting,
Expand Down Expand Up @@ -264,14 +266,14 @@ def handle_exception(exc_type, exc_value, exc_traceback):
query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\""

# load db just in case
self.db_content = self.load_db()
if not self.db_content:
red(
"Empty database. If you have already ran anki_reformulator "
"before then something went wrong!"
)
else:
self.compute_cost(self.db_content)

# TODO: How is the user supposed to create the database in the first place?
# self.db_content = self.load_db()
# if not self.db_content:
# red("Empty database. If you have already ran anki_reformulator "
# "before then something went wrong!")
# else:
# self.compute_cost(self.db_content)

# load dataset
dataset = load_dataset(dataset_path)
Expand All @@ -286,9 +288,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
nids = anki(action="findNotes",
query="tag:AnkiReformulator::RESETTING")
if nids:
red(
f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}"
)
red(f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}")
nids = anki(action="findNotes", query="tag:AnkiReformulator::DOING")
if nids:
red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}")
Expand All @@ -298,13 +298,10 @@ def handle_exception(exc_type, exc_value, exc_traceback):
assert nids, f"No notes found for the query '{query}'"

# find the model field names
fields = anki(
action="notesInfo",
notes=[int(nids[0])]
)[0]["fields"]
assert (
"AnkiReformulator" in fields.keys()
), "The notetype to edit must have a field called 'AnkiReformulator'"
fields = anki(action="notesInfo",
notes=[int(nids[0])])[0]["fields"]
# assert "AnkiReformulator" in fields.keys(), \
# "The notetype to edit must have a field called 'AnkiReformulator'"
self.field_name = list(fields.keys())[0]

if self.exclude_media:
Expand All @@ -328,9 +325,8 @@ def handle_exception(exc_type, exc_value, exc_traceback):
self.notes = self.notes.loc[nids]
assert not self.notes.empty, "Empty notes df"

assert (
len(set(self.notes["modelName"].tolist())) == 1
), "Contains more than 1 note type"
assert len(set(self.notes["modelName"].tolist())) == 1, \
"Contains more than 1 note type"

# check absence of image and sounds in the main field
# as well incorrect tags
Expand Down Expand Up @@ -358,11 +354,9 @@ def handle_exception(exc_type, exc_value, exc_traceback):
else:
assert not tag.lower().startswith("ankireformulator")


# check if too many tokens
tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset])
tkn_sum += sum(
[
tkn_len(
replace_media(
content=note["fields"][self.field_name]["value"],
Expand All @@ -371,7 +365,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
)[0]
)
for _, note in self.notes.iterrows()
])
)
if tkn_sum > tkn_warn_limit:
raise Exception(
f"Found {tkn_sum} tokens to process, which is "
Expand Down Expand Up @@ -983,7 +977,7 @@ def load_db(self) -> Dict:
All log dictionaries from the database, or False if database not found
"""
if not (REFORMULATOR_DIR / "reformulator.db").exists():
red("db not found: '$REFORMULATOR_DIR/reformulator.db'")
red(f"db not found: '{REFORMULATOR_DIR}/reformulator.db'")
return False
conn = sqlite3.connect(str((REFORMULATOR_DIR / "reformulator.db").absolute()))
cursor = conn.cursor()
Expand All @@ -1000,10 +994,10 @@ def load_db(self) -> Dict:
try:
args, kwargs = fire.Fire(lambda *args, **kwargs: [args, kwargs])
if "help" in kwargs:
print(help(AnkiReformulator))
print(help(AnkiReformulator), file=sys.stderr)
else:
whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'")
AnkiReformulator(*args, **kwargs)
sync_anki()
except Exception:
sync_anki()
raise
1 change: 1 addition & 0 deletions utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def load_api_keys() -> Dict:
Path("API_KEYS").mkdir(exist_ok=True)
if not list(Path("API_KEYS").iterdir()):
shared.red("## No API_KEYS found in API_KEYS")
raise Exception("Need to write API KEYS to API_KEYS/")
api_keys = {}
for apifile in Path("API_KEYS").iterdir():
keyname = f"{apifile.stem.upper()}_API_KEY"
Expand Down
2 changes: 1 addition & 1 deletion utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def create_loggers(local_file: Union[str, PosixPath], colors: List[str]):
out = []
for col in colors:
log = coloured_logger(col)
setattr(shared, "col", log)
setattr(shared, col, log)
out.append(log)
return out
1 change: 0 additions & 1 deletion utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ def __setattr__(self, name: str, value) -> None:
raise TypeError(f'SharedModule forbids the creation of unexpected attribute "{name}"')

shared = SharedModule()