Skip to content

Commit f8e4d18

Browse files
style: format code with Black
This commit fixes the style issues introduced in 1fa8576 according to the output from Black. Details: None
1 parent 1fa8576 commit f8e4d18

11 files changed

Lines changed: 1316 additions & 888 deletions

File tree

ai_integration.py

Lines changed: 169 additions & 112 deletions
Large diffs are not rendered by default.

ai_wrapper/llm_engine.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import httpx
1919
from functools import lru_cache
2020

21+
2122
class LLMEngine:
2223
def __init__(self, config_path: Optional[Path] = None):
2324
self.logger = logging.getLogger(__name__)
@@ -26,25 +27,21 @@ def __init__(self, config_path: Optional[Path] = None):
2627
self.models = self._initialize_models()
2728
self.current_model = self.models[0] # Start with preferred model
2829
self.response_cache = {}
29-
30+
3031
def _load_config(self, config_path: Optional[Path] = None) -> Dict:
3132
"""Load LLM configuration"""
3233
if not config_path:
3334
config_path = Path.home() / ".vulnforge" / "configs" / "llm_config.json"
34-
35+
3536
default_config = {
3637
"preferred_model": "deepseek-coder-v2:16b-lite-base-q5_K_S",
37-
"fallback_models": [
38-
"deepseek-coder:6.7b",
39-
"codellama:7b",
40-
"mistral:7b"
41-
],
38+
"fallback_models": ["deepseek-coder:6.7b", "codellama:7b", "mistral:7b"],
4239
"cache_size": 100,
4340
"timeout": 180,
4441
"max_retries": 3,
45-
"retry_delay": 2
42+
"retry_delay": 2,
4643
}
47-
44+
4845
try:
4946
if config_path.exists():
5047
with open(config_path) as f:
@@ -53,72 +50,84 @@ def _load_config(self, config_path: Optional[Path] = None) -> Dict:
5350
except Exception as e:
5451
self.logger.error("Error loading LLM config: %s", e)
5552
self.config = {}
56-
53+
5754
def _initialize_models(self) -> List[str]:
5855
"""Initialize available models"""
5956
available_models = []
60-
57+
6158
# Check preferred model first
6259
if self._is_model_available(self.config["preferred_model"]):
6360
available_models.append(self.config["preferred_model"])
64-
61+
6562
# Check fallback models
6663
for model in self.config["fallback_models"]:
6764
if self._is_model_available(model):
6865
available_models.append(model)
69-
66+
7067
if not available_models:
7168
self.logger.error("No models available!")
72-
69+
7370
return available_models
74-
71+
7572
async def _is_model_available(self, model: str) -> bool:
7673
"""Check if a model is available"""
7774
try:
7875
async with httpx.AsyncClient(timeout=5) as client:
7976
response = await client.get(f"{self.base_url}/api/tags")
8077
if response.status_code == 200:
81-
models = response.json().get('models', [])
82-
return any(m['name'] == model for m in models)
78+
models = response.json().get("models", [])
79+
return any(m["name"] == model for m in models)
8380
except (httpx.RequestError, httpx.TimeoutException) as e:
8481
self.logger.error("Error checking model availability: %s", e)
8582
return False
8683
return False
87-
84+
8885
def _pull_model(self, model: str) -> bool:
8986
"""Pull a model if not available"""
9087
try:
9188
self.logger.info("Pulling model: %s", model)
9289
data = {"name": model}
93-
response = requests.post(f"{self.base_url}/api/pull", json=data, stream=True)
94-
90+
response = requests.post(
91+
f"{self.base_url}/api/pull", json=data, stream=True
92+
)
93+
9594
for line in response.iter_lines():
9695
if line:
9796
try:
98-
status = json.loads(line.decode('utf-8'))
99-
if status.get('status') == 'success':
97+
status = json.loads(line.decode("utf-8"))
98+
if status.get("status") == "success":
10099
return True
101100
except:
102101
continue
103102
except Exception as e:
104103
self.logger.error("Error pulling model %s: %s", model, e)
105104
return False
106-
105+
107106
@lru_cache(maxsize=100)
108-
async def generate(self, prompt: str, system_prompt: Optional[str] = None, model: Optional[str] = None) -> Optional[str]:
107+
async def generate(
108+
self,
109+
prompt: str,
110+
system_prompt: Optional[str] = None,
111+
model: Optional[str] = None,
112+
) -> Optional[str]:
109113
"""Generate text (wrapper for query)"""
110114
return await self.query(prompt, system_prompt=system_prompt, model=model)
111115

112-
async def query(self, prompt: str, system_prompt: Optional[str] = None,
113-
model: Optional[str] = None, use_cache: bool = True) -> Optional[str]:
116+
async def query(
117+
self,
118+
prompt: str,
119+
system_prompt: Optional[str] = None,
120+
model: Optional[str] = None,
121+
use_cache: bool = True,
122+
) -> Optional[str]:
114123
if not model:
115124
model = self.current_model
116-
125+
117126
# Check cache if enabled
118127
cache_key = f"{model}:{prompt}:{system_prompt}"
119128
if use_cache and cache_key in self.response_cache:
120129
return self.response_cache[cache_key]
121-
130+
122131
for attempt in range(self.config["max_retries"]):
123132
try:
124133
data = {
@@ -131,28 +140,27 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None,
131140
"max_tokens": 4096,
132141
"num_ctx": 8192,
133142
"num_thread": 8,
134-
"repeat_penalty": 1.1
135-
}
143+
"repeat_penalty": 1.1,
144+
},
136145
}
137-
146+
138147
if system_prompt:
139148
data["system"] = system_prompt
140-
149+
141150
async with httpx.AsyncClient(timeout=self.config["timeout"]) as client:
142151
response = await client.post(
143-
f"{self.base_url}/api/generate",
144-
json=data
152+
f"{self.base_url}/api/generate", json=data
145153
)
146-
154+
147155
if response.status_code == 200:
148-
result = response.json().get('response', '').strip()
156+
result = response.json().get("response", "").strip()
149157
if use_cache:
150158
self.response_cache[cache_key] = result
151159
return result
152-
160+
153161
except (httpx.RequestError, httpx.TimeoutException) as e:
154162
self.logger.error("Error querying model %s: %s", model, e)
155-
163+
156164
# Try next model if available
157165
if model in self.models:
158166
current_index = self.models.index(model)
@@ -161,23 +169,23 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None,
161169
self.logger.info("Switching to fallback model: %s", model)
162170
else:
163171
break
164-
172+
165173
await asyncio.sleep(self.config["retry_delay"])
166-
174+
167175
return None
168-
176+
169177
def clear_cache(self):
170178
"""Clear the response cache"""
171179
self.response_cache.clear()
172180
self.query.cache_clear()
173-
181+
174182
def get_available_models(self) -> List[str]:
175183
"""Get list of available models"""
176184
return self.models.copy()
177-
185+
178186
def set_preferred_model(self, model: str) -> bool:
179187
"""Set preferred model if available"""
180188
if self._is_model_available(model):
181189
self.current_model = model
182190
return True
183-
return False
191+
return False

0 commit comments

Comments
 (0)