File tree 1 file changed +16
-1
lines changed
1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -362,7 +362,22 @@ def validate(
362
362
def _inference_local (self , model_input : Any ) -> Any :
363
363
"""Local inference method for the restrict-to-topic validator."""
364
364
text = model_input ["text" ]
365
- candidate_topics = list (model_input ["valid_topics" ]) + list (model_input ["invalid_topics" ])
365
+
366
+ valid_topics = model_input ["valid_topics" ]
367
+ invalid_topics = model_input ["invalid_topics" ]
368
+
369
+ # There's a chance that valid topics will be passed as a plain string or a set.
370
+ # If that happens the '+ action might fail and the call to _classifier will
371
+ # not behave as expected.
372
+ if isinstance (valid_topics , str ):
373
+ valid_topics = [valid_topics , ]
374
+ elif isinstance (valid_topics , set ):
375
+ valid_topics = list (valid_topics )
376
+ if isinstance (invalid_topics , str ):
377
+ invalid_topics = [invalid_topics , ]
378
+ elif isinstance (invalid_topics , set ):
379
+ invalid_topics = list (invalid_topics )
380
+ candidate_topics = valid_topics + invalid_topics
366
381
367
382
result = self ._classifier (text , candidate_topics )
368
383
topics = result ["labels" ]
You can’t perform that action at this time.
0 commit comments