File tree 1 file changed +16
-1
lines changed
1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -357,7 +357,22 @@ def validate(
357
357
def _inference_local (self , model_input : Any ) -> Any :
358
358
"""Local inference method for the restrict-to-topic validator."""
359
359
text = model_input ["text" ]
360
- candidate_topics = model_input ["valid_topics" ] + model_input ["invalid_topics" ]
360
+
361
+ valid_topics = model_input ["valid_topics" ]
362
+ invalid_topics = model_input ["invalid_topics" ]
363
+
364
+ # There's a chance that valid topics will be passed as a plain string or a set.
365
+ # If that happens the '+ action might fail and the call to _classifier will
366
+ # not behave as expected.
367
+ if isinstance (valid_topics , str ):
368
+ valid_topics = [valid_topics , ]
369
+ elif isinstance (valid_topics , set ):
370
+ valid_topics = list (valid_topics )
371
+ if isinstance (invalid_topics , str ):
372
+ invalid_topics = [invalid_topics , ]
373
+ elif isinstance (invalid_topics , set ):
374
+ invalid_topics = list (invalid_topics )
375
+ candidate_topics = valid_topics + invalid_topics
361
376
362
377
result = self ._classifier (text , candidate_topics )
363
378
topics = result ["labels" ]
You can’t perform that action at this time.
0 commit comments