Skip to content

Commit

Permalink
Fix handling of example CSV in LLM annotator
Browse files Browse the repository at this point in the history
  • Loading branch information
stijn-uva committed Jan 8, 2025
1 parent 2122510 commit f320101
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions processors/machine_learning/annotate_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from common.lib.dmi_service_manager import DmiServiceManager, DmiServiceManagerException, DsmOutOfMemory
from common.lib.exceptions import QueryParametersException
from common.lib.user_input import UserInput
from common.lib.helpers import sniff_encoding
from common.lib.helpers import sniff_encoding, sniff_csv_dialect
from common.config_manager import config

__author__ = "Stijn Peeters"
Expand Down Expand Up @@ -141,7 +141,6 @@ def process(self):

model = self.parameters.get("model")
textfield = self.parameters.get("text-column")
labels = {l.strip(): [] for l in self.parameters.get("categories").split(",") if l.strip()}

# Make output dir
staging_area = self.dataset.get_staging_area()
Expand Down Expand Up @@ -172,6 +171,28 @@ def process(self):
return self.dataset.finish_with_error(
"Cannot connect to DMI Service Manager. Verify that this 4CAT server has access to it.")

if self.parameters["shotstyle"] == "fewshot":
# do we have examples?
example_path = self.dataset.get_results_path().with_suffix(".importing")
if not example_path.exists():
return self.dataset.finish_with_error("Cannot open example file")

labels = {}
with example_path.open() as infile:
dialect, has_header = sniff_csv_dialect(infile)
reader = csv.reader(infile, dialect=dialect)
for row in reader:
if row[0] not in labels:
labels[row[0]] = []
labels[row[0]].append(row[1])

example_path.unlink()

else:
# if we have no examples, just include an empty list
labels = {l.strip(): [] for l in self.parameters.get("categories").split(",") if l.strip()}


# store labels in a file (since we don't know how much data this is)
labels_path = staging_area.joinpath("labels.temp.json")
with labels_path.open("w") as outfile:
Expand Down Expand Up @@ -288,14 +309,10 @@ def validate_query(query, request, user):

# we want a very specific type of CSV file!
encoding = sniff_encoding(file)

wrapped_file = io.TextIOWrapper(file, encoding=encoding)
try:
sample = wrapped_file.read(1024 * 1024)
wrapped_file.seek(0)
has_header = csv.Sniffer().has_header(sample)
dialect = csv.Sniffer().sniff(sample, delimiters=(",", ";", "\t"))

dialect, has_header = sniff_csv_dialect(file)
reader = csv.reader(wrapped_file, dialect=dialect) if not has_header else csv.DictReader(wrapped_file)
row = next(reader)
if len(list(row)) != 2:
Expand Down Expand Up @@ -326,7 +343,7 @@ def after_create(query, dataset, request):
if query.get("shotstyle") != "fewshot":
return

file = request.files["option-category_file"]
file = request.files["option-category-file"]
file.seek(0)
with dataset.get_results_path().with_suffix(".importing").open("wb") as outfile:
while True:
Expand Down

0 comments on commit f320101

Please sign in to comment.