2727
2828
2929class EmbeddingTopicDetector :
30- def __init__ (self , embedding_model : str , embedding_engine : str , examples : Dict [str , List [str ]], threshold : float , top_k : int ):
30+ def __init__ (
31+ self ,
32+ embedding_model : str ,
33+ embedding_engine : str ,
34+ examples : Dict [str , List [str ]],
35+ threshold : float ,
36+ top_k : int ,
37+ ):
3138 self .threshold = threshold
3239 self .top_k = top_k
3340 self .model = init_embedding_model (embedding_model , embedding_engine )
3441 self .embeddings = {
3542 cat : [np .array (e ) for e in self .model .encode (queries )]
36- for cat , queries in examples .items () if queries
43+ for cat , queries in examples .items ()
44+ if queries
3745 }
3846
3947 async def detect (self , query : str ) -> Dict :
4048 query_emb = np .array ((await self .model .encode_async ([query ]))[0 ])
4149
4250 sims = sorted (
43- [(cat , float (np .dot (query_emb , emb ) / (np .linalg .norm (query_emb ) * np .linalg .norm (emb ) or 1 )))
44- for cat , embs in self .embeddings .items () for emb in embs ],
45- key = lambda x : x [1 ], reverse = True
51+ [
52+ (
53+ cat ,
54+ float (
55+ np .dot (query_emb , emb )
56+ / (np .linalg .norm (query_emb ) * np .linalg .norm (emb ) or 1 )
57+ ),
58+ )
59+ for cat , embs in self .embeddings .items ()
60+ for emb in embs
61+ ],
62+ key = lambda x : x [1 ],
63+ reverse = True ,
4664 )[: self .top_k ]
4765
48- scores = {cat : float (np .mean ([s for c , s in sims if c == cat ]) or 0.0 ) for cat in self .embeddings }
66+ scores = {
67+ cat : float (np .mean ([s for c , s in sims if c == cat ]) or 0.0 )
68+ for cat in self .embeddings
69+ }
4970 max_score = max (scores .values (), default = 0.0 )
5071
5172 return {
@@ -62,22 +83,34 @@ async def _check(context: Optional[dict], llm_task_manager, message_key: str) ->
6283
6384 if cache_key not in _detector_cache :
6485 _detector_cache [cache_key ] = EmbeddingTopicDetector (
65- config ["embedding_model" ], config ["embedding_engine" ],
66- config ["examples" ], config .get ("threshold" , 0.75 ), config .get ("top_k" , 3 )
86+ config ["embedding_model" ],
87+ config ["embedding_engine" ],
88+ config ["examples" ],
89+ config .get ("threshold" , 0.75 ),
90+ config .get ("top_k" , 3 ),
6791 )
6892
6993 query = context .get (message_key ) if context else None
7094 if not query :
71- return {"on_topic" : True , "confidence" : 0.0 , "top_category" : None , "category_scores" : {}}
95+ return {
96+ "on_topic" : True ,
97+ "confidence" : 0.0 ,
98+ "top_category" : None ,
99+ "category_scores" : {},
100+ }
72101
73102 return await _detector_cache [cache_key ].detect (query )
74103
75104
76105@action (is_system_action = True )
77- async def embedding_topic_check (context : Optional [dict ] = None , llm_task_manager = None ) -> dict :
106+ async def embedding_topic_check (
107+ context : Optional [dict ] = None , llm_task_manager = None
108+ ) -> dict :
78109 return await _check (context , llm_task_manager , "user_message" )
79110
80111
81112@action (is_system_action = True )
82- async def embedding_topic_check_output (context : Optional [dict ] = None , llm_task_manager = None ) -> dict :
113+ async def embedding_topic_check_output (
114+ context : Optional [dict ] = None , llm_task_manager = None
115+ ) -> dict :
83116 return await _check (context , llm_task_manager , "bot_message" )
0 commit comments