Skip to content

Commit efc36d8

Browse files
authoredDec 10, 2024
Merge branch 'main' into jc/allow_only_invalid
2 parents 0cbf43d + 6a815fc commit efc36d8

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed
 

‎validator/main.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,22 @@ def validate(
362362
def _inference_local(self, model_input: Any) -> Any:
363363
"""Local inference method for the restrict-to-topic validator."""
364364
text = model_input["text"]
365-
candidate_topics = list(model_input["valid_topics"]) + list(model_input["invalid_topics"])
365+
366+
valid_topics = model_input["valid_topics"]
367+
invalid_topics = model_input["invalid_topics"]
368+
369+
# There's a chance that valid topics will be passed as a plain string or a set.
370+
# If that happens the '+ action might fail and the call to _classifier will
371+
# not behave as expected.
372+
if isinstance(valid_topics, str):
373+
valid_topics = [valid_topics, ]
374+
elif isinstance(valid_topics, set):
375+
valid_topics = list(valid_topics)
376+
if isinstance(invalid_topics, str):
377+
invalid_topics = [invalid_topics, ]
378+
elif isinstance(invalid_topics, set):
379+
invalid_topics = list(invalid_topics)
380+
candidate_topics = valid_topics + invalid_topics
366381

367382
result = self._classifier(text, candidate_topics)
368383
topics = result["labels"]

0 commit comments

Comments
 (0)
Please sign in to comment.