Skip to content

Commit 1c8c918

Browse files
committed
🎨 reformatted actions.py
1 parent 46fab03 commit 1c8c918

File tree

1 file changed

+44
-11
lines changed
  • nemoguardrails/library/embedding_topic_detector

1 file changed

+44
-11
lines changed

nemoguardrails/library/embedding_topic_detector/actions.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,46 @@
2727

2828

2929
class EmbeddingTopicDetector:
30-
def __init__(self, embedding_model: str, embedding_engine: str, examples: Dict[str, List[str]], threshold: float, top_k: int):
30+
def __init__(
31+
self,
32+
embedding_model: str,
33+
embedding_engine: str,
34+
examples: Dict[str, List[str]],
35+
threshold: float,
36+
top_k: int,
37+
):
3138
self.threshold = threshold
3239
self.top_k = top_k
3340
self.model = init_embedding_model(embedding_model, embedding_engine)
3441
self.embeddings = {
3542
cat: [np.array(e) for e in self.model.encode(queries)]
36-
for cat, queries in examples.items() if queries
43+
for cat, queries in examples.items()
44+
if queries
3745
}
3846

3947
async def detect(self, query: str) -> Dict:
4048
query_emb = np.array((await self.model.encode_async([query]))[0])
4149

4250
sims = sorted(
43-
[(cat, float(np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) or 1)))
44-
for cat, embs in self.embeddings.items() for emb in embs],
45-
key=lambda x: x[1], reverse=True
51+
[
52+
(
53+
cat,
54+
float(
55+
np.dot(query_emb, emb)
56+
/ (np.linalg.norm(query_emb) * np.linalg.norm(emb) or 1)
57+
),
58+
)
59+
for cat, embs in self.embeddings.items()
60+
for emb in embs
61+
],
62+
key=lambda x: x[1],
63+
reverse=True,
4664
)[: self.top_k]
4765

48-
scores = {cat: float(np.mean([s for c, s in sims if c == cat]) or 0.0) for cat in self.embeddings}
66+
scores = {
67+
cat: float(np.mean([s for c, s in sims if c == cat]) or 0.0)
68+
for cat in self.embeddings
69+
}
4970
max_score = max(scores.values(), default=0.0)
5071

5172
return {
@@ -62,22 +83,34 @@ async def _check(context: Optional[dict], llm_task_manager, message_key: str) ->
6283

6384
if cache_key not in _detector_cache:
6485
_detector_cache[cache_key] = EmbeddingTopicDetector(
65-
config["embedding_model"], config["embedding_engine"],
66-
config["examples"], config.get("threshold", 0.75), config.get("top_k", 3)
86+
config["embedding_model"],
87+
config["embedding_engine"],
88+
config["examples"],
89+
config.get("threshold", 0.75),
90+
config.get("top_k", 3),
6791
)
6892

6993
query = context.get(message_key) if context else None
7094
if not query:
71-
return {"on_topic": True, "confidence": 0.0, "top_category": None, "category_scores": {}}
95+
return {
96+
"on_topic": True,
97+
"confidence": 0.0,
98+
"top_category": None,
99+
"category_scores": {},
100+
}
72101

73102
return await _detector_cache[cache_key].detect(query)
74103

75104

76105
@action(is_system_action=True)
77-
async def embedding_topic_check(context: Optional[dict] = None, llm_task_manager=None) -> dict:
106+
async def embedding_topic_check(
107+
context: Optional[dict] = None, llm_task_manager=None
108+
) -> dict:
78109
return await _check(context, llm_task_manager, "user_message")
79110

80111

81112
@action(is_system_action=True)
82-
async def embedding_topic_check_output(context: Optional[dict] = None, llm_task_manager=None) -> dict:
113+
async def embedding_topic_check_output(
114+
context: Optional[dict] = None, llm_task_manager=None
115+
) -> dict:
83116
return await _check(context, llm_task_manager, "bot_message")

0 commit comments

Comments
 (0)