1818import httpx
1919from functools import lru_cache
2020
21+
2122class LLMEngine :
2223 def __init__ (self , config_path : Optional [Path ] = None ):
2324 self .logger = logging .getLogger (__name__ )
@@ -26,25 +27,21 @@ def __init__(self, config_path: Optional[Path] = None):
2627 self .models = self ._initialize_models ()
2728 self .current_model = self .models [0 ] # Start with preferred model
2829 self .response_cache = {}
29-
30+
3031 def _load_config (self , config_path : Optional [Path ] = None ) -> Dict :
3132 """Load LLM configuration"""
3233 if not config_path :
3334 config_path = Path .home () / ".vulnforge" / "configs" / "llm_config.json"
34-
35+
3536 default_config = {
3637 "preferred_model" : "deepseek-coder-v2:16b-lite-base-q5_K_S" ,
37- "fallback_models" : [
38- "deepseek-coder:6.7b" ,
39- "codellama:7b" ,
40- "mistral:7b"
41- ],
38+ "fallback_models" : ["deepseek-coder:6.7b" , "codellama:7b" , "mistral:7b" ],
4239 "cache_size" : 100 ,
4340 "timeout" : 180 ,
4441 "max_retries" : 3 ,
45- "retry_delay" : 2
42+ "retry_delay" : 2 ,
4643 }
47-
44+
4845 try :
4946 if config_path .exists ():
5047 with open (config_path ) as f :
@@ -53,72 +50,84 @@ def _load_config(self, config_path: Optional[Path] = None) -> Dict:
5350 except Exception as e :
5451 self .logger .error ("Error loading LLM config: %s" , e )
5552 self .config = {}
56-
53+
5754 def _initialize_models (self ) -> List [str ]:
5855 """Initialize available models"""
5956 available_models = []
60-
57+
6158 # Check preferred model first
6259 if self ._is_model_available (self .config ["preferred_model" ]):
6360 available_models .append (self .config ["preferred_model" ])
64-
61+
6562 # Check fallback models
6663 for model in self .config ["fallback_models" ]:
6764 if self ._is_model_available (model ):
6865 available_models .append (model )
69-
66+
7067 if not available_models :
7168 self .logger .error ("No models available!" )
72-
69+
7370 return available_models
74-
71+
7572 async def _is_model_available (self , model : str ) -> bool :
7673 """Check if a model is available"""
7774 try :
7875 async with httpx .AsyncClient (timeout = 5 ) as client :
7976 response = await client .get (f"{ self .base_url } /api/tags" )
8077 if response .status_code == 200 :
81- models = response .json ().get (' models' , [])
82- return any (m [' name' ] == model for m in models )
78+ models = response .json ().get (" models" , [])
79+ return any (m [" name" ] == model for m in models )
8380 except (httpx .RequestError , httpx .TimeoutException ) as e :
8481 self .logger .error ("Error checking model availability: %s" , e )
8582 return False
8683 return False
87-
84+
8885 def _pull_model (self , model : str ) -> bool :
8986 """Pull a model if not available"""
9087 try :
9188 self .logger .info ("Pulling model: %s" , model )
9289 data = {"name" : model }
93- response = requests .post (f"{ self .base_url } /api/pull" , json = data , stream = True )
94-
90+ response = requests .post (
91+ f"{ self .base_url } /api/pull" , json = data , stream = True
92+ )
93+
9594 for line in response .iter_lines ():
9695 if line :
9796 try :
98- status = json .loads (line .decode (' utf-8' ))
99- if status .get (' status' ) == ' success' :
97+ status = json .loads (line .decode (" utf-8" ))
98+ if status .get (" status" ) == " success" :
10099 return True
101100 except :
102101 continue
103102 except Exception as e :
104103 self .logger .error ("Error pulling model %s: %s" , model , e )
105104 return False
106-
105+
107106 @lru_cache (maxsize = 100 )
108- async def generate (self , prompt : str , system_prompt : Optional [str ] = None , model : Optional [str ] = None ) -> Optional [str ]:
107+ async def generate (
108+ self ,
109+ prompt : str ,
110+ system_prompt : Optional [str ] = None ,
111+ model : Optional [str ] = None ,
112+ ) -> Optional [str ]:
109113 """Generate text (wrapper for query)"""
110114 return await self .query (prompt , system_prompt = system_prompt , model = model )
111115
112- async def query (self , prompt : str , system_prompt : Optional [str ] = None ,
113- model : Optional [str ] = None , use_cache : bool = True ) -> Optional [str ]:
116+ async def query (
117+ self ,
118+ prompt : str ,
119+ system_prompt : Optional [str ] = None ,
120+ model : Optional [str ] = None ,
121+ use_cache : bool = True ,
122+ ) -> Optional [str ]:
114123 if not model :
115124 model = self .current_model
116-
125+
117126 # Check cache if enabled
118127 cache_key = f"{ model } :{ prompt } :{ system_prompt } "
119128 if use_cache and cache_key in self .response_cache :
120129 return self .response_cache [cache_key ]
121-
130+
122131 for attempt in range (self .config ["max_retries" ]):
123132 try :
124133 data = {
@@ -131,28 +140,27 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None,
131140 "max_tokens" : 4096 ,
132141 "num_ctx" : 8192 ,
133142 "num_thread" : 8 ,
134- "repeat_penalty" : 1.1
135- }
143+ "repeat_penalty" : 1.1 ,
144+ },
136145 }
137-
146+
138147 if system_prompt :
139148 data ["system" ] = system_prompt
140-
149+
141150 async with httpx .AsyncClient (timeout = self .config ["timeout" ]) as client :
142151 response = await client .post (
143- f"{ self .base_url } /api/generate" ,
144- json = data
152+ f"{ self .base_url } /api/generate" , json = data
145153 )
146-
154+
147155 if response .status_code == 200 :
148- result = response .json ().get (' response' , '' ).strip ()
156+ result = response .json ().get (" response" , "" ).strip ()
149157 if use_cache :
150158 self .response_cache [cache_key ] = result
151159 return result
152-
160+
153161 except (httpx .RequestError , httpx .TimeoutException ) as e :
154162 self .logger .error ("Error querying model %s: %s" , model , e )
155-
163+
156164 # Try next model if available
157165 if model in self .models :
158166 current_index = self .models .index (model )
@@ -161,23 +169,23 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None,
161169 self .logger .info ("Switching to fallback model: %s" , model )
162170 else :
163171 break
164-
172+
165173 await asyncio .sleep (self .config ["retry_delay" ])
166-
174+
167175 return None
168-
176+
169177 def clear_cache (self ):
170178 """Clear the response cache"""
171179 self .response_cache .clear ()
172180 self .query .cache_clear ()
173-
181+
174182 def get_available_models (self ) -> List [str ]:
175183 """Get list of available models"""
176184 return self .models .copy ()
177-
185+
178186 def set_preferred_model (self , model : str ) -> bool :
179187 """Set preferred model if available"""
180188 if self ._is_model_available (model ):
181189 self .current_model = model
182190 return True
183- return False
191+ return False
0 commit comments