forked from VsAthul/reflection_agent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm.py
More file actions
74 lines (57 loc) · 2.68 KB
/
llm.py
File metadata and controls
74 lines (57 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
llm.py
──────
LLM configuration for the FAQ Reflection Agent.
Exports two ready-to-use model instances:
- groq_llm : Groq-hosted LLaMA 3.3-70b (answer generation)
- ollama_llm : Local Ollama LLaMA 3.1-8b (answer validation)
Both are wrapped with `.with_structured_output()` where needed inside
the node files — raw models are exported here so node authors can apply
their own output schemas.
"""
import os
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
# Load variables from .env at import time
load_dotenv()
# ──────────────────────────────────────────────
# Groq — Answer Generation LLM
# ──────────────────────────────────────────────
def get_groq_llm() -> ChatGroq:
"""
Returns a ChatGroq instance configured for FAQ answer generation.
Model : llama-3.3-70b-versatile
Purpose: Generate detailed, accurate answers to FAQ questions.
"""
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise EnvironmentError(
"GROQ_API_KEY is not set. Please add it to your .env file."
)
return ChatGroq(
model="llama-3.3-70b-versatile",
api_key=api_key,
temperature=0.3, # Low temp → more factual, deterministic answers
max_tokens=1024,
)
# ──────────────────────────────────────────────
# Ollama — Answer Validation LLM
# ──────────────────────────────────────────────
def get_ollama_llm() -> ChatOllama:
"""
Returns a ChatOllama instance configured for answer validation.
Model : llama3.1:8b (must be pulled locally via `ollama pull llama3.1:8b`)
Purpose: Validate whether the Groq-generated answer is accurate and relevant.
"""
base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
return ChatOllama(
model="llama3.1:8b",
base_url=base_url,
temperature=0.1, # Very low temp → consistent, strict validation
)
# ──────────────────────────────────────────────
# Convenience singletons (imported by nodes)
# ──────────────────────────────────────────────
groq_llm = get_groq_llm()
ollama_llm = get_ollama_llm()