Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@

from .ollama_wrapper import OllamaWrapper

__all__ = ['OllamaWrapper']
__all__ = ["OllamaWrapper"]
96 changes: 52 additions & 44 deletions ai_wrapper/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import httpx
from functools import lru_cache


class LLMEngine:
def __init__(self, config_path: Optional[Path] = None):
self.logger = logging.getLogger(__name__)
Expand All @@ -26,25 +27,21 @@ def __init__(self, config_path: Optional[Path] = None):
self.models = self._initialize_models()
self.current_model = self.models[0] # Start with preferred model
self.response_cache = {}

def _load_config(self, config_path: Optional[Path] = None) -> Dict:
"""Load LLM configuration"""
if not config_path:
config_path = Path.home() / ".neurorift" / "configs" / "llm_config.json"

default_config = {
"preferred_model": "deepseek-coder-v2:16b-lite-base-q5_K_S",
"fallback_models": [
"deepseek-coder:6.7b",
"codellama:7b",
"mistral:7b"
],
"fallback_models": ["deepseek-coder:6.7b", "codellama:7b", "mistral:7b"],
"cache_size": 100,
"timeout": 180,
"max_retries": 3,
"retry_delay": 2
"retry_delay": 2,
}

try:
if config_path.exists():
with open(config_path) as f:
Expand All @@ -53,72 +50,84 @@ def _load_config(self, config_path: Optional[Path] = None) -> Dict:
except Exception as e:
self.logger.error("Error loading LLM config: %s", e)
self.config = {}

def _initialize_models(self) -> List[str]:
"""Initialize available models"""
available_models = []

# Check preferred model first
if self._is_model_available(self.config["preferred_model"]):
available_models.append(self.config["preferred_model"])

# Check fallback models
for model in self.config["fallback_models"]:
if self._is_model_available(model):
available_models.append(model)

if not available_models:
self.logger.error("No models available!")

return available_models

async def _is_model_available(self, model: str) -> bool:
"""Check if a model is available"""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{self.base_url}/api/tags")
if response.status_code == 200:
models = response.json().get('models', [])
return any(m['name'] == model for m in models)
models = response.json().get("models", [])
return any(m["name"] == model for m in models)
except (httpx.RequestError, httpx.TimeoutException) as e:
self.logger.error("Error checking model availability: %s", e)
return False
return False

def _pull_model(self, model: str) -> bool:
"""Pull a model if not available"""
try:
self.logger.info("Pulling model: %s", model)
data = {"name": model}
response = requests.post(f"{self.base_url}/api/pull", json=data, stream=True)

response = requests.post(
f"{self.base_url}/api/pull", json=data, stream=True
)

for line in response.iter_lines():
if line:
try:
status = json.loads(line.decode('utf-8'))
if status.get('status') == 'success':
status = json.loads(line.decode("utf-8"))
if status.get("status") == "success":
return True
except:
continue
except Exception as e:
self.logger.error("Error pulling model %s: %s", model, e)
return False

@lru_cache(maxsize=100)
async def generate(self, prompt: str, system_prompt: Optional[str] = None, model: Optional[str] = None) -> Optional[str]:
async def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
) -> Optional[str]:
"""Generate text (wrapper for query)"""
return await self.query(prompt, system_prompt=system_prompt, model=model)

async def query(self, prompt: str, system_prompt: Optional[str] = None,
model: Optional[str] = None, use_cache: bool = True) -> Optional[str]:
async def query(
self,
prompt: str,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
use_cache: bool = True,
) -> Optional[str]:
if not model:
model = self.current_model

# Check cache if enabled
cache_key = f"{model}:{prompt}:{system_prompt}"
if use_cache and cache_key in self.response_cache:
return self.response_cache[cache_key]

for attempt in range(self.config["max_retries"]):
try:
data = {
Expand All @@ -131,28 +140,27 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None,
"max_tokens": 4096,
"num_ctx": 8192,
"num_thread": 8,
"repeat_penalty": 1.1
}
"repeat_penalty": 1.1,
},
}

if system_prompt:
data["system"] = system_prompt

async with httpx.AsyncClient(timeout=self.config["timeout"]) as client:
response = await client.post(
f"{self.base_url}/api/generate",
json=data
f"{self.base_url}/api/generate", json=data
)

if response.status_code == 200:
result = response.json().get('response', '').strip()
result = response.json().get("response", "").strip()
if use_cache:
self.response_cache[cache_key] = result
return result

except (httpx.RequestError, httpx.TimeoutException) as e:
self.logger.error("Error querying model %s: %s", model, e)

# Try next model if available
if model in self.models:
current_index = self.models.index(model)
Expand All @@ -161,23 +169,23 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None,
self.logger.info("Switching to fallback model: %s", model)
else:
break

await asyncio.sleep(self.config["retry_delay"])

return None

def clear_cache(self):
"""Clear the response cache"""
self.response_cache.clear()
self.query.cache_clear()

def get_available_models(self) -> List[str]:
"""Get list of available models"""
return self.models.copy()

def set_preferred_model(self, model: str) -> bool:
"""Set preferred model if available"""
if self._is_model_available(model):
self.current_model = model
return True
return False
return False
43 changes: 18 additions & 25 deletions examples/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,24 @@

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


async def main():
# Initialize config manager
config_path = Path.home() / ".neurorift" / "config.json"
config = ConfigManager(config_path)

# Initialize notifier
notifier = Notifier(config)
await notifier.start()

try:
# Example 1: Basic notification
await notifier.notify(
"Scan started for example.com",
"info"
)

await notifier.notify("Scan started for example.com", "info")

# Example 2: Vulnerability found
await notifier.notify(
"SQL Injection vulnerability detected",
Expand All @@ -40,10 +37,10 @@ async def main():
"vulnerability": "SQL Injection",
"affected_url": "https://example.com/login",
"payload": "' OR '1'='1",
"confidence": "high"
}
"confidence": "high",
},
)

# Example 3: Critical finding
await notifier.notify(
"Remote Code Execution vulnerability found!",
Expand All @@ -52,34 +49,30 @@ async def main():
"vulnerability": "RCE",
"affected_component": "File Upload Handler",
"cve": "CVE-2023-1234",
"exploit_available": True
"exploit_available": True,
},
channels=["email", "discord"] # Send to specific channels
channels=["email", "discord"], # Send to specific channels
)

# Example 4: Scan completion
await notifier.notify(
"Scan completed successfully",
"info",
data={
"target": "example.com",
"duration": "2h 15m",
"findings": {
"critical": 1,
"high": 3,
"medium": 5,
"low": 8
}
}
"findings": {"critical": 1, "high": 3, "medium": 5, "low": 8},
},
)

# Wait for notifications to be processed
await asyncio.sleep(1)

except Exception as e:
logger.error(f"Error in notification example: {e}")
finally:
await notifier.stop()


if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
Loading
Loading