Skip to content

Commit 6a815fc

Browse files
authored
Merge pull request #21 from guardrails-ai/jc/patch_inference_local_types
Safety check for when _inference_local gets called with sets or strings.
2 parents 7e818d0 + 9fa3f54 commit 6a815fc

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

Diff for: validator/main.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,22 @@ def validate(
357357
def _inference_local(self, model_input: Any) -> Any:
358358
"""Local inference method for the restrict-to-topic validator."""
359359
text = model_input["text"]
360-
candidate_topics = model_input["valid_topics"] + model_input["invalid_topics"]
360+
361+
valid_topics = model_input["valid_topics"]
362+
invalid_topics = model_input["invalid_topics"]
363+
364+
# There's a chance that valid topics will be passed as a plain string or a set.
365+
# If that happens the '+ action might fail and the call to _classifier will
366+
# not behave as expected.
367+
if isinstance(valid_topics, str):
368+
valid_topics = [valid_topics, ]
369+
elif isinstance(valid_topics, set):
370+
valid_topics = list(valid_topics)
371+
if isinstance(invalid_topics, str):
372+
invalid_topics = [invalid_topics, ]
373+
elif isinstance(invalid_topics, set):
374+
invalid_topics = list(invalid_topics)
375+
candidate_topics = valid_topics + invalid_topics
361376

362377
result = self._classifier(text, candidate_topics)
363378
topics = result["labels"]

0 commit comments

Comments
 (0)