diff --git a/ai_wrapper/__init__.py b/ai_wrapper/__init__.py index 323e430..57c7498 100755 --- a/ai_wrapper/__init__.py +++ b/ai_wrapper/__init__.py @@ -10,4 +10,4 @@ from .ollama_wrapper import OllamaWrapper -__all__ = ['OllamaWrapper'] \ No newline at end of file +__all__ = ["OllamaWrapper"] diff --git a/ai_wrapper/llm_engine.py b/ai_wrapper/llm_engine.py index 617f919..3befdd0 100755 --- a/ai_wrapper/llm_engine.py +++ b/ai_wrapper/llm_engine.py @@ -18,6 +18,7 @@ import httpx from functools import lru_cache + class LLMEngine: def __init__(self, config_path: Optional[Path] = None): self.logger = logging.getLogger(__name__) @@ -26,25 +27,21 @@ def __init__(self, config_path: Optional[Path] = None): self.models = self._initialize_models() self.current_model = self.models[0] # Start with preferred model self.response_cache = {} - + def _load_config(self, config_path: Optional[Path] = None) -> Dict: """Load LLM configuration""" if not config_path: config_path = Path.home() / ".neurorift" / "configs" / "llm_config.json" - + default_config = { "preferred_model": "deepseek-coder-v2:16b-lite-base-q5_K_S", - "fallback_models": [ - "deepseek-coder:6.7b", - "codellama:7b", - "mistral:7b" - ], + "fallback_models": ["deepseek-coder:6.7b", "codellama:7b", "mistral:7b"], "cache_size": 100, "timeout": 180, "max_retries": 3, - "retry_delay": 2 + "retry_delay": 2, } - + try: if config_path.exists(): with open(config_path) as f: @@ -53,72 +50,84 @@ def _load_config(self, config_path: Optional[Path] = None) -> Dict: except Exception as e: self.logger.error("Error loading LLM config: %s", e) self.config = {} - + def _initialize_models(self) -> List[str]: """Initialize available models""" available_models = [] - + # Check preferred model first if self._is_model_available(self.config["preferred_model"]): available_models.append(self.config["preferred_model"]) - + # Check fallback models for model in self.config["fallback_models"]: if self._is_model_available(model): available_models.append(model) - + if not available_models: self.logger.error("No models available!") - + return available_models - + async def _is_model_available(self, model: str) -> bool: """Check if a model is available""" try: async with httpx.AsyncClient(timeout=5) as client: response = await client.get(f"{self.base_url}/api/tags") if response.status_code == 200: - models = response.json().get('models', []) - return any(m['name'] == model for m in models) + models = response.json().get("models", []) + return any(m["name"] == model for m in models) except (httpx.RequestError, httpx.TimeoutException) as e: self.logger.error("Error checking model availability: %s", e) return False return False - + def _pull_model(self, model: str) -> bool: """Pull a model if not available""" try: self.logger.info("Pulling model: %s", model) data = {"name": model} - response = requests.post(f"{self.base_url}/api/pull", json=data, stream=True) - + response = requests.post( + f"{self.base_url}/api/pull", json=data, stream=True + ) + for line in response.iter_lines(): if line: try: - status = json.loads(line.decode('utf-8')) - if status.get('status') == 'success': + status = json.loads(line.decode("utf-8")) + if status.get("status") == "success": return True except: continue except Exception as e: self.logger.error("Error pulling model %s: %s", model, e) return False - + @lru_cache(maxsize=100) - async def generate(self, prompt: str, system_prompt: Optional[str] = None, model: Optional[str] = None) -> Optional[str]: + async def generate( + self, + prompt: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + ) -> Optional[str]: """Generate text (wrapper for query)""" return await self.query(prompt, system_prompt=system_prompt, model=model) - async def query(self, prompt: str, system_prompt: Optional[str] = None, - model: Optional[str] = None, use_cache: bool = True) -> Optional[str]: + async def query( + self, + prompt: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + use_cache: bool = True, + ) -> Optional[str]: if not model: model = self.current_model - + # Check cache if enabled cache_key = f"{model}:{prompt}:{system_prompt}" if use_cache and cache_key in self.response_cache: return self.response_cache[cache_key] - + for attempt in range(self.config["max_retries"]): try: data = { @@ -131,28 +140,27 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None, "max_tokens": 4096, "num_ctx": 8192, "num_thread": 8, - "repeat_penalty": 1.1 - } + "repeat_penalty": 1.1, + }, } - + if system_prompt: data["system"] = system_prompt - + async with httpx.AsyncClient(timeout=self.config["timeout"]) as client: response = await client.post( - f"{self.base_url}/api/generate", - json=data + f"{self.base_url}/api/generate", json=data ) - + if response.status_code == 200: - result = response.json().get('response', '').strip() + result = response.json().get("response", "").strip() if use_cache: self.response_cache[cache_key] = result return result - + except (httpx.RequestError, httpx.TimeoutException) as e: self.logger.error("Error querying model %s: %s", model, e) - + # Try next model if available if model in self.models: current_index = self.models.index(model) @@ -161,23 +169,23 @@ async def query(self, prompt: str, system_prompt: Optional[str] = None, self.logger.info("Switching to fallback model: %s", model) else: break - + await asyncio.sleep(self.config["retry_delay"]) - + return None - + def clear_cache(self): """Clear the response cache""" self.response_cache.clear() self.query.cache_clear() - + def get_available_models(self) -> List[str]: """Get list of available models""" return self.models.copy() - + def set_preferred_model(self, model: str) -> bool: """Set preferred model if available""" if self._is_model_available(model): self.current_model = model return True - return False \ No newline at end of file + return False diff --git a/examples/notifications.py b/examples/notifications.py index 791288b..25006b5 100755 --- a/examples/notifications.py +++ b/examples/notifications.py @@ -11,27 +11,24 @@ # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + async def main(): # Initialize config manager config_path = Path.home() / ".neurorift" / "config.json" config = ConfigManager(config_path) - + # Initialize notifier notifier = Notifier(config) await notifier.start() - + try: # Example 1: Basic notification - await notifier.notify( - "Scan started for example.com", - "info" - ) - + await notifier.notify("Scan started for example.com", "info") + # Example 2: Vulnerability found await notifier.notify( "SQL Injection vulnerability detected", @@ -40,10 +37,10 @@ async def main(): "vulnerability": "SQL Injection", "affected_url": "https://example.com/login", "payload": "' OR '1'='1", - "confidence": "high" - } + "confidence": "high", + }, ) - + # Example 3: Critical finding await notifier.notify( "Remote Code Execution vulnerability found!", @@ -52,11 +49,11 @@ async def main(): "vulnerability": "RCE", "affected_component": "File Upload Handler", "cve": "CVE-2023-1234", - "exploit_available": True + "exploit_available": True, }, - channels=["email", "discord"] # Send to specific channels + channels=["email", "discord"], # Send to specific channels ) - + # Example 4: Scan completion await notifier.notify( "Scan completed successfully", @@ -64,22 +61,18 @@ async def main(): data={ "target": "example.com", "duration": "2h 15m", - "findings": { - "critical": 1, - "high": 3, - "medium": 5, - "low": 8 - } - } + "findings": {"critical": 1, "high": 3, "medium": 5, "low": 8}, + }, ) - + # Wait for notifications to be processed await asyncio.sleep(1) - + except Exception as e: logger.error(f"Error in notification example: {e}") finally: await notifier.stop() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/modules/ai/agent.py b/modules/ai/agent.py index 183288c..57c9cb4 100644 --- a/modules/ai/agent.py +++ b/modules/ai/agent.py @@ -5,19 +5,22 @@ from typing import Dict, Any, List, Optional from modules.ai.ai_integration import OllamaClient + class NeuroRiftAgent: """Simple Agentic AI for NeuroRift. - + This agent understands the framework modules and outputs structured action intents in JSON format. """ - + def __init__(self, ollama_client: Optional[OllamaClient] = None): self.ollama = ollama_client or OllamaClient() self.logger = logging.getLogger("neurorift.agent") - self.prompt_path = Path(__file__).resolve().parents[2] / "prompts" / "agentic_system.md" + self.prompt_path = ( + Path(__file__).resolve().parents[2] / "prompts" / "agentic_system.md" + ) self._system_prompt = self._load_system_prompt() - + def _load_system_prompt(self) -> str: """Load the agentic system prompt from file.""" try: @@ -30,13 +33,15 @@ def _load_system_prompt(self) -> str: self.logger.error(f"Error loading system prompt: {e}") return "You are a security assistant for the NeuroRift framework." - async def run_task(self, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def run_task( + self, task: str, context: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Process a user task and return a structured response. - + Args: task: The user's request or task description. context: Additional context like current module, available tools, etc. - + Returns: A dictionary containing the AI's plan or response. """ @@ -44,51 +49,87 @@ async def run_task(self, task: str, context: Optional[Dict[str, Any]] = None) -> prompt = f"User Task: {task}\n" if context: prompt += f"Context: {json.dumps(context, indent=2)}\n" - + prompt += "\nRespond with a JSON object. Follow the schema exactly.\n" - + self.logger.info(f"Agent processing task: {task[:50]}...") - + # Define the schema for Ollama to enforce schema = { "type": "object", "properties": { "thought": {"type": "string"}, - "mode": {"type": "string", "enum": ["ACTION_PLAN", "ACTION_EXECUTION", "RESPONSE", "CLARIFICATION"]}, + "mode": { + "type": "string", + "enum": [ + "ACTION_PLAN", + "ACTION_EXECUTION", + "RESPONSE", + "CLARIFICATION", + ], + }, "goal": {"type": "string"}, "steps": { "type": "array", "items": { "type": "object", "properties": { - "type": {"type": "string", "enum": ["ui_click", "ui_input", "module_call", "tool_call"]}, - "target": {"type": "string", "enum": [ - "recon_scan", "robin_search", "ai_assistant", - "nmap", "subfinder", "httpx", "nuclei", "gobuster", "ffuf", "whatweb", - "Overview", "Recon", "Robin", "Tool Manager", "Assistant", "Reports", "Settings", - "domain_input", "query_input" - ]}, + "type": { + "type": "string", + "enum": [ + "ui_click", + "ui_input", + "module_call", + "tool_call", + ], + }, + "target": { + "type": "string", + "enum": [ + "recon_scan", + "robin_search", + "ai_assistant", + "nmap", + "subfinder", + "httpx", + "nuclei", + "gobuster", + "ffuf", + "whatweb", + "Overview", + "Recon", + "Robin", + "Tool Manager", + "Assistant", + "Reports", + "Settings", + "domain_input", + "query_input", + ], + }, "value": {"type": "string"}, - "reason": {"type": "string"} + "reason": {"type": "string"}, }, - "required": ["type", "target", "value", "reason"] - } + "required": ["type", "target", "value", "reason"], + }, }, - "content": {"type": "string"} + "content": {"type": "string"}, }, - "required": ["thought", "mode", "goal", "steps"] + "required": ["thought", "mode", "goal", "steps"], } - + # Pass schema to Ollama - raw_response = self.ollama.generate(prompt, system_prompt=self._system_prompt, format=schema) - + raw_response = self.ollama.generate( + prompt, system_prompt=self._system_prompt, format=schema + ) + if not raw_response: return { "mode": "RESPONSE", "content": "I apologize, but I failed to generate a response. Please check the Ollama service.", - "status": "error" + "status": "error", } - + return self._parse_response(raw_response) def _parse_response(self, raw_response: str) -> Dict[str, Any]: @@ -106,38 +147,40 @@ def _parse_response(self, raw_response: str) -> Dict[str, Any]: first_brace = raw_response.find("{") last_brace = raw_response.rfind("}") if first_brace != -1 and last_brace != -1: - json_candidate = raw_response[first_brace:last_brace+1] + json_candidate = raw_response[first_brace : last_brace + 1] try: return json.loads(json_candidate) except json.JSONDecodeError: pass # 3. If no JSON found, treat as plain response - self.logger.debug("No valid JSON found in AI response, falling back to RESPONSE mode.") + self.logger.debug( + "No valid JSON found in AI response, falling back to RESPONSE mode." + ) return { "mode": "RESPONSE", "content": raw_response, "status": "partial_success", - "parsing_error": "No valid JSON structure identified" + "parsing_error": "No valid JSON structure identified", } - + except Exception as e: self.logger.error(f"Unexpected error during parsing: {e}") return { "mode": "RESPONSE", "content": raw_response, "status": "error", - "parsing_error": str(e) + "parsing_error": str(e), } def get_readiness_status(self) -> Dict[str, Any]: """Check if the agent is ready for operation.""" is_ollama_available = self.ollama.is_available() best_model = self.ollama.get_best_model() if is_ollama_available else None - + return { "ready": is_ollama_available and best_model is not None, "ollama_available": is_ollama_available, "model_ready": best_model is not None, - "active_model": best_model + "active_model": best_model, } diff --git a/modules/ai/agent_context.py b/modules/ai/agent_context.py index b8174e5..3b91183 100644 --- a/modules/ai/agent_context.py +++ b/modules/ai/agent_context.py @@ -19,22 +19,22 @@ class AgentContext: """ Manages context for agents in the NeuroRift orchestration system. - + Handles context handoffs between agents and maintains shared knowledge base. """ - + def __init__(self): self.logger = logging.getLogger(__name__) self.contexts: Dict[str, Dict[str, Any]] = {} self.shared_knowledge: Dict[str, Any] = {} self.task_id: Optional[str] = None - + self.logger.info("Agent context manager initialized") - + def initialize(self, task_id: str, initial_context: Dict[str, Any]) -> None: """ Initialize context for a new task. - + Args: task_id: Task identifier initial_context: Initial context data @@ -43,16 +43,16 @@ def initialize(self, task_id: str, initial_context: Dict[str, Any]) -> None: self.shared_knowledge = { "task_id": task_id, "initialized_at": datetime.now().isoformat(), - **initial_context + **initial_context, } self.contexts = {} - + self.logger.info(f"Context initialized for task: {task_id}") - + def set_context(self, agent_name: str, context_data: Dict[str, Any]) -> None: """ Set context for a specific agent. - + Args: agent_name: Name of the agent context_data: Context data to store @@ -60,62 +60,64 @@ def set_context(self, agent_name: str, context_data: Dict[str, Any]) -> None: self.contexts[agent_name] = { "agent": agent_name, "timestamp": datetime.now().isoformat(), - "data": context_data + "data": context_data, } - + self.logger.debug(f"Context set for agent: {agent_name}") - + def get_context(self, agent_name: str) -> Dict[str, Any]: """ Get context for a specific agent. - + Args: agent_name: Name of the agent - + Returns: Agent context data """ agent_context = self.contexts.get(agent_name, {}) return agent_context.get("data", {}) - + def get_all_contexts(self) -> Dict[str, Dict[str, Any]]: """ Get all agent contexts. - + Returns: Dictionary of all agent contexts """ return self.contexts.copy() - + def update_shared_knowledge(self, key: str, value: Any) -> None: """ Update shared knowledge base. - + Args: key: Knowledge key value: Knowledge value """ self.shared_knowledge[key] = value self.logger.debug(f"Shared knowledge updated: {key}") - + def get_shared_knowledge(self, key: Optional[str] = None) -> Any: """ Get shared knowledge. - + Args: key: Optional specific key to retrieve - + Returns: Knowledge value or entire knowledge base """ if key: return self.shared_knowledge.get(key) return self.shared_knowledge.copy() - - def handoff_context(self, from_agent: str, to_agent: str, handoff_data: Optional[Dict] = None) -> None: + + def handoff_context( + self, from_agent: str, to_agent: str, handoff_data: Optional[Dict] = None + ) -> None: """ Perform context handoff between agents. - + Args: from_agent: Source agent name to_agent: Destination agent name @@ -123,7 +125,7 @@ def handoff_context(self, from_agent: str, to_agent: str, handoff_data: Optional """ # Get source agent context source_context = self.get_context(from_agent) - + # Create handoff package handoff = { "from_agent": from_agent, @@ -131,47 +133,48 @@ def handoff_context(self, from_agent: str, to_agent: str, handoff_data: Optional "timestamp": datetime.now().isoformat(), "source_context": source_context, "handoff_data": handoff_data or {}, - "shared_knowledge": self.shared_knowledge.copy() + "shared_knowledge": self.shared_knowledge.copy(), } - + # Store handoff in destination agent context self.set_context(f"{to_agent}_handoff", handoff) - + self.logger.info(f"Context handoff: {from_agent} → {to_agent}") - + def prune_context(self, max_size_mb: float = 10.0) -> None: """ Prune context to stay within size limits. - + Args: max_size_mb: Maximum context size in megabytes """ # Calculate current size context_json = json.dumps(self.contexts) - current_size_mb = len(context_json.encode('utf-8')) / (1024 * 1024) - + current_size_mb = len(context_json.encode("utf-8")) / (1024 * 1024) + if current_size_mb > max_size_mb: - self.logger.warning(f"Context size ({current_size_mb:.2f}MB) exceeds limit ({max_size_mb}MB)") - + self.logger.warning( + f"Context size ({current_size_mb:.2f}MB) exceeds limit ({max_size_mb}MB)" + ) + # Remove oldest contexts sorted_contexts = sorted( - self.contexts.items(), - key=lambda x: x[1].get("timestamp", "") + self.contexts.items(), key=lambda x: x[1].get("timestamp", "") ) - + while current_size_mb > max_size_mb and sorted_contexts: oldest_key, _ = sorted_contexts.pop(0) del self.contexts[oldest_key] - + context_json = json.dumps(self.contexts) - current_size_mb = len(context_json.encode('utf-8')) / (1024 * 1024) - + current_size_mb = len(context_json.encode("utf-8")) / (1024 * 1024) + self.logger.info(f"Pruned context: {oldest_key}") - + def export_context(self) -> Dict[str, Any]: """ Export all context data. - + Returns: Complete context export """ @@ -179,22 +182,22 @@ def export_context(self) -> Dict[str, Any]: "task_id": self.task_id, "shared_knowledge": self.shared_knowledge, "agent_contexts": self.contexts, - "exported_at": datetime.now().isoformat() + "exported_at": datetime.now().isoformat(), } - + def import_context(self, context_data: Dict[str, Any]) -> None: """ Import context data. - + Args: context_data: Context data to import """ self.task_id = context_data.get("task_id") self.shared_knowledge = context_data.get("shared_knowledge", {}) self.contexts = context_data.get("agent_contexts", {}) - + self.logger.info(f"Context imported for task: {self.task_id}") - + def clear(self) -> None: """Clear all context data""" self.contexts = {} @@ -206,30 +209,33 @@ def clear(self) -> None: # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + # Initialize agent context context = AgentContext() - + # Initialize for a task - context.initialize("task_001", { - "user_request": "Scan example.com", - "mode": "offensive", - "target": "example.com" - }) - + context.initialize( + "task_001", + { + "user_request": "Scan example.com", + "mode": "offensive", + "target": "example.com", + }, + ) + # Set context for Planner - context.set_context("planner", { - "plan_id": "plan_001", - "steps": [{"step_id": 1, "tool": "subfinder"}] - }) - + context.set_context( + "planner", + {"plan_id": "plan_001", "steps": [{"step_id": 1, "tool": "subfinder"}]}, + ) + # Handoff to Operator context.handoff_context("planner", "operator") - + # Get Operator's handoff operator_handoff = context.get_context("operator_handoff") print(f"\nOperator received handoff from: {operator_handoff.get('from_agent')}") - + # Export context export = context.export_context() print(f"\nContext export size: {len(json.dumps(export))} bytes") diff --git a/modules/ai/agents.py b/modules/ai/agents.py index 4276c95..14cd230 100644 --- a/modules/ai/agents.py +++ b/modules/ai/agents.py @@ -2,18 +2,30 @@ import json from modules.ai.ai_integration import OllamaClient from modules.orchestration.execution_manager import ExecutionManager, ScanRequest -from modules.orchestration.data_models import SessionContext, ToolExecutionResult, Finding +from modules.orchestration.data_models import ( + SessionContext, + ToolExecutionResult, + Finding, +) + class NRPlanner: def __init__(self, ollama: OllamaClient): self.ollama = ollama - - async def create_plan(self, task: str, available_tools: List[Dict]) -> List[ScanRequest]: + + async def create_plan( + self, task: str, available_tools: List[Dict] + ) -> List[ScanRequest]: """ Generates a list of tool executions to achieve the task. """ - tools_desc = "\n".join([f"- {t['name']}: {t['description']} (Mode: {t['mode']})" for t in available_tools]) - + tools_desc = "\n".join( + [ + f"- {t['name']}: {t['description']} (Mode: {t['mode']})" + for t in available_tools + ] + ) + prompt = f""" You are the Planner for NeuroRift Security System. Goal: {task} @@ -33,20 +45,20 @@ async def create_plan(self, task: str, available_tools: List[Dict]) -> List[Scan {{"tool_name": "nmap", "target": "example.com", "args": {{"flags": ["-F"]}}, "reasoning": "Quick scan to find open ports"}} ] """ - + response = await self.ollama.generate(prompt) try: # Basic parsing of JSON from response (handling potential markdown code blocks) cleaned = response.replace("```json", "").replace("```", "").strip() plan_data = json.loads(cleaned) - + requests = [] for step in plan_data: # Ensure target is present, if not use task mentions or safe default (should be handled by AI) req = ScanRequest( - tool_name=step['tool_name'], - target=step.get('target', 'unknown'), - args=step.get('args', {}) + tool_name=step["tool_name"], + target=step.get("target", "unknown"), + args=step.get("args", {}), ) requests.append(req) return requests @@ -54,17 +66,20 @@ async def create_plan(self, task: str, available_tools: List[Dict]) -> List[Scan print(f"Error parsing plan: {e}") return [] + class NROperator: def __init__(self, execution_manager: ExecutionManager): self.manager = execution_manager - - async def execute_plan(self, requests: List[ScanRequest], context: SessionContext) -> List[ToolExecutionResult]: + + async def execute_plan( + self, requests: List[ScanRequest], context: SessionContext + ) -> List[ToolExecutionResult]: results = [] for req in requests: # Here we could implement the human-in-the-loop check # For now, we assume pre-approval or we print to console print(f"\n[OPERATOR] Preparing to run: {req.tool_name} on {req.target}") - + # TODO: Add real approval mechanism via Web/CLI result = await self.manager.execute_tool(req, context) results.append(result) @@ -73,18 +88,21 @@ async def execute_plan(self, requests: List[ScanRequest], context: SessionContex break return results + class NRAnalyst: def __init__(self, ollama: OllamaClient): self.ollama = ollama - - async def analyze_results(self, results: List[ToolExecutionResult]) -> List[Finding]: + + async def analyze_results( + self, results: List[ToolExecutionResult] + ) -> List[Finding]: if not results: return [] - + context_str = "" for res in results: context_str += f"Tool: {res.tool_name}\nCommand: {res.command}\nOutput:\n{res.raw_output[:2000]}\n---\n" - + prompt = f""" You are the Analyst for NeuroRift. Analyze the following tool outputs and identify security findings. @@ -97,7 +115,7 @@ async def analyze_results(self, results: List[ToolExecutionResult]) -> List[Find - description - tool_source """ - + response = await self.ollama.generate(prompt) findings = [] try: @@ -105,25 +123,28 @@ async def analyze_results(self, results: List[ToolExecutionResult]) -> List[Find data = json.loads(cleaned) for item in data: finding = Finding( - title=item['title'], - severity=item['severity'], - description=item['description'], - tool_source=item['tool_source'], - details=item + title=item["title"], + severity=item["severity"], + description=item["description"], + tool_source=item["tool_source"], + details=item, ) findings.append(finding) except Exception as e: print(f"Error parsing analysis: {e}") - + return findings + class NRScribe: def __init__(self, ollama: OllamaClient): self.ollama = ollama async def generate_report(self, task: str, findings: List[Finding]) -> str: - findings_text = "\n".join([f"- [{f.severity}] {f.title}: {f.description}" for f in findings]) - + findings_text = "\n".join( + [f"- [{f.severity}] {f.title}: {f.description}" for f in findings] + ) + prompt = f""" Generate a professional security report for the task: {task} diff --git a/modules/ai/ai_assistant.py b/modules/ai/ai_assistant.py index fa1f8d9..cc30791 100644 --- a/modules/ai/ai_assistant.py +++ b/modules/ai/ai_assistant.py @@ -1,5 +1,6 @@ import os + class AIAssistant: def ask_ai(self, query, context=None, debug=False): try: @@ -27,4 +28,4 @@ def _log_ai_interaction(self, query, response, debug): log_path = os.path.expanduser("~/.neurorift/sessions/logs/ai_controller.log") os.makedirs(os.path.dirname(log_path), exist_ok=True) with open(log_path, "a") as f: - f.write(f"[QUERY] {query}\n[RESPONSE] {response}\n[DEBUG] {debug}\n\n") \ No newline at end of file + f.write(f"[QUERY] {query}\n[RESPONSE] {response}\n[DEBUG] {debug}\n\n") diff --git a/modules/ai/ai_controller.py b/modules/ai/ai_controller.py index 2bcc534..ea73468 100644 --- a/modules/ai/ai_controller.py +++ b/modules/ai/ai_controller.py @@ -1,5 +1,6 @@ import os + class AIController: def run_ai_task(self, task, context=None, debug=False): try: @@ -16,4 +17,4 @@ def _log_ai_task(self, task, result, debug): log_path = os.path.expanduser("~/.neurorift/sessions/logs/ai_controller.log") os.makedirs(os.path.dirname(log_path), exist_ok=True) with open(log_path, "a") as f: - f.write(f"[TASK] {task}\n[RESULT] {result}\n[DEBUG] {debug}\n\n") \ No newline at end of file + f.write(f"[TASK] {task}\n[RESULT] {result}\n[DEBUG] {debug}\n\n") diff --git a/modules/ai/ai_integration.py b/modules/ai/ai_integration.py index 7f95bb1..3254398 100755 --- a/modules/ai/ai_integration.py +++ b/modules/ai/ai_integration.py @@ -15,28 +15,31 @@ from pathlib import Path import asyncio import ctypes + try: from duckduckgo_search import DDGS + DDGS_AVAILABLE = True except ImportError: DDGS_AVAILABLE = False + class OllamaClient: def __init__(self, base_url: str = "http://localhost:11434"): self.base_url = base_url self.logger = logging.getLogger(__name__) - + # Load configuration from environment - self.main_model = os.getenv("OLLAMA_MAIN_MODEL", "deepseek-coder-v2:16b-lite-base-q4_0") - self.assistant_model = os.getenv("OLLAMA_ASSISTANT_MODEL", "mistral:7b-instruct-v0.2-q4_0") + self.main_model = os.getenv( + "OLLAMA_MAIN_MODEL", "deepseek-coder-v2:16b-lite-base-q4_0" + ) + self.assistant_model = os.getenv( + "OLLAMA_ASSISTANT_MODEL", "mistral:7b-instruct-v0.2-q4_0" + ) self.ai_enabled = os.getenv("AI_ENABLED", "true").lower() == "true" - - self.backup_models = [ - "deepseek-coder:6.7b", - "codellama:7b", - "mistral:7b" - ] - + + self.backup_models = ["deepseek-coder:6.7b", "codellama:7b", "mistral:7b"] + async def is_available(self) -> bool: """Check if Ollama service is running""" try: @@ -50,12 +53,16 @@ async def ensure_service_running(self) -> bool: """Try to start Ollama service if not running""" if await self.is_available(): return True - + self.logger.info("Ollama service not running. Attempting to start...") try: # Try to start using subprocess (background) - subprocess.Popen(["ollama", "serve"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - + subprocess.Popen( + ["ollama", "serve"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + # Wait for it to start for i in range(10): await asyncio.sleep(2) @@ -66,38 +73,46 @@ async def ensure_service_running(self) -> bool: except Exception as e: self.logger.error(f"Failed to start Ollama service: {e}") return False - + async def list_models(self) -> List[Dict]: """List available models""" try: async with httpx.AsyncClient(timeout=5) as client: response = await client.get(f"{self.base_url}/api/tags") if response.status_code == 200: - return response.json().get('models', []) + return response.json().get("models", []) except (httpx.RequestError, httpx.TimeoutException) as e: self.logger.error("Error listing models: %s", e) return [] - + def pull_model(self, model: str) -> bool: """Pull a model if not available""" try: self.logger.info("Pulling model: %s", model) data = {"name": model} - response = requests.post(f"{self.base_url}/api/pull", json=data, stream=True) - + response = requests.post( + f"{self.base_url}/api/pull", json=data, stream=True + ) + for line in response.iter_lines(): if line: try: - status = json.loads(line.decode('utf-8')) - if status.get('status') == 'success': + status = json.loads(line.decode("utf-8")) + if status.get("status") == "success": return True except: continue except Exception as e: self.logger.error("Error pulling model %s: %s", model, e) return False - - async def generate(self, prompt: str, model: str = None, system_prompt: str = None, format: str = None) -> Optional[str]: + + async def generate( + self, + prompt: str, + model: str = None, + system_prompt: str = None, + format: str = None, + ) -> Optional[str]: """Generate text using Ollama""" if not self.ai_enabled: self.logger.warning("AI features are currently disabled in configuration.") @@ -105,20 +120,26 @@ async def generate(self, prompt: str, model: str = None, system_prompt: str = No # Auto-start if needed if not await self.ensure_service_running(): - self.logger.error("Ollama service is not running and could not be auto-started.") + self.logger.error( + "Ollama service is not running and could not be auto-started." + ) return None if not model: model = await self.get_best_model() - + if not model: available = await self.list_models() if not available: - self.logger.error("No models found in Ollama. Please pull a model using 'ollama pull '.") + self.logger.error( + "No models found in Ollama. Please pull a model using 'ollama pull '." + ) else: - self.logger.error(f"Configured models ({self.main_model}, {self.assistant_model}) not found.") + self.logger.error( + f"Configured models ({self.main_model}, {self.assistant_model}) not found." + ) return None - + try: data = { "model": model, @@ -128,57 +149,59 @@ async def generate(self, prompt: str, model: str = None, system_prompt: str = No "temperature": 0.5, "top_p": 0.9, "max_tokens": 4096, - "num_ctx": 4096, # Reduced from 16384 to prevent OOM + "num_ctx": 4096, # Reduced from 16384 to prevent OOM "num_thread": 8, - "repeat_penalty": 1.1 - } + "repeat_penalty": 1.1, + }, } - + if system_prompt: data["system"] = system_prompt - + if format: data["format"] = format - + async with httpx.AsyncClient(timeout=300) as client: response = await client.post(f"{self.base_url}/api/generate", json=data) - + if response.status_code == 200: result = response.json() - return result.get('response', '').strip() + return result.get("response", "").strip() else: - self.logger.error(f"Ollama API error: {response.status_code} - {response.text}") - + self.logger.error( + f"Ollama API error: {response.status_code} - {response.text}" + ) + except (httpx.RequestError, httpx.TimeoutException) as e: self.logger.error(f"Error generating with Ollama: {e}") - + return None async def query(self, prompt: str, system_prompt: str = None) -> Optional[str]: """Wrapper for compatibility with modules expecting .query()""" return await self.generate(prompt=prompt, system_prompt=system_prompt) - + async def get_best_model(self) -> Optional[str]: """Get the best available model""" - available_models = [m['name'] for m in await self.list_models()] - + available_models = [m["name"] for m in await self.list_models()] + # Check main model first (deepseek-coder-v2:16b-lite-base-q4_0) if self.main_model in available_models: return self.main_model - + # Check assistant model next (mistral:7b-instruct-v0.2-q4_0) if self.assistant_model in available_models: return self.assistant_model - + # Check backup models for model in self.backup_models: if model in available_models: return model - + # If no preferred models, return first available if available_models: return available_models[0] - + return None @@ -198,20 +221,23 @@ async def get_best_model(self) -> Optional[str]: c_parser = None C_PARSER_AVAILABLE = False import logging - logging.getLogger(__name__).warning("C parser library not found. Using Python fallback.") + + logging.getLogger(__name__).warning( + "C parser library not found. Using Python fallback." + ) class AIAnalyzer: def __init__(self, ollama_client: OllamaClient): self.ollama = ollama_client self.logger = logging.getLogger(__name__) - + async def analyze_nmap_output(self, nmap_output: str) -> Dict[str, Any]: """Analyze nmap scan results using AI""" system_prompt = """You are a cybersecurity expert analyzing nmap scan results. Identify potential vulnerabilities, interesting services, and security issues. Provide structured analysis in JSON format with severity levels.""" - + prompt = f""" Analyze this nmap scan output and identify potential security issues: @@ -234,36 +260,38 @@ async def analyze_nmap_output(self, nmap_output: str) -> Dict[str, Any]: "next_steps": ["recommended follow-up actions"] }} """ - + response = await self.ollama.generate(prompt, system_prompt=system_prompt) if response: try: return json.loads(response) except json.JSONDecodeError: # Extract JSON if wrapped in markdown - json_match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL) + json_match = re.search(r"```json\n(.*?)\n```", response, re.DOTALL) if json_match: try: return json.loads(json_match.group(1)) except json.JSONDecodeError: pass - + return {"error": "Failed to analyze nmap output", "raw_response": response} - + async def perform_web_search(self, query: str, max_results: int = 3) -> str: """Perform a web search to augment AI context""" if not DDGS_AVAILABLE: return "Web search unavailable (duckduckgo-search not installed)." - + try: self.logger.info(f"Performing web search for: {query}") results = DDGS().text(query, max_results=max_results) if not results: return "No results found." - + formatted_results = "Web Search Results:\n" for i, r in enumerate(results, 1): - formatted_results += f"{i}. {r['title']}\n {r['body']}\n Source: {r['href']}\n\n" + formatted_results += ( + f"{i}. {r['title']}\n {r['body']}\n Source: {r['href']}\n\n" + ) return formatted_results except Exception as e: self.logger.error(f"Web search failed: {e}") @@ -289,26 +317,28 @@ async def generate_exploit_code(self, vulnerability_info: Dict) -> Dict[str, Any "safety_warning": "Warning about usage" }} """ - + response = await self.ollama.generate(prompt) if response: try: return json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL) + json_match = re.search(r"```json\n(.*?)\n```", response, re.DOTALL) if json_match: return json.loads(json_match.group(1)) - + return {"error": "Failed to generate exploit", "raw_response": response} - - async def analyze_web_response(self, url: str, response_data: Dict) -> Dict[str, Any]: + + async def analyze_web_response( + self, url: str, response_data: Dict + ) -> Dict[str, Any]: """Analyze web service response for vulnerabilities""" system_prompt = """You are a web application security expert. Analyze HTTP responses for potential vulnerabilities and security issues.""" - - headers = response_data.get('headers', {}) - content = response_data.get('content', '')[:2000] # Limit content length - + + headers = response_data.get("headers", {}) + content = response_data.get("content", "")[:2000] # Limit content length + prompt = f""" Analyze this web service for security issues: @@ -343,26 +373,28 @@ async def analyze_web_response(self, url: str, response_data: Dict) -> Dict[str, "recommendations": ["security recommendations"] }} """ - + response = await self.ollama.generate(prompt, system_prompt=system_prompt) if response: try: return json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL) + json_match = re.search(r"```json\n(.*?)\n```", response, re.DOTALL) if json_match: try: return json.loads(json_match.group(1)) except json.JSONDecodeError: pass - + return {"error": "Failed to analyze web response", "raw_response": response} - - async def fix_broken_tool(self, tool_name: str, error_output: str, source_code: str = None) -> str: + + async def fix_broken_tool( + self, tool_name: str, error_output: str, source_code: str = None + ) -> str: """Generate fixes for broken security tools""" system_prompt = """You are a DevOps engineer specializing in fixing broken security tools. Analyze errors and provide working solutions.""" - + prompt = f""" This security tool is broken and needs fixing: @@ -384,15 +416,17 @@ async def fix_broken_tool(self, tool_name: str, error_output: str, source_code: - Missing environment variables - Network/permission issues """ - + response = await self.ollama.generate(prompt, system_prompt=system_prompt) return response or "# Failed to generate fix" - - async def prioritize_vulnerabilities(self, vulnerabilities: List[Dict]) -> List[Dict]: + + async def prioritize_vulnerabilities( + self, vulnerabilities: List[Dict] + ) -> List[Dict]: """Use AI to prioritize vulnerabilities by exploitability and impact""" system_prompt = """You are a penetration tester prioritizing vulnerabilities. Rank vulnerabilities by exploitability and business impact.""" - + prompt = f""" Prioritize these vulnerabilities for testing: @@ -419,35 +453,42 @@ async def prioritize_vulnerabilities(self, vulnerabilities: List[Dict]) -> List[ ] }} """ - + response = await self.ollama.generate(prompt, system_prompt=system_prompt) if response: try: result = json.loads(response) - return result.get('prioritized_vulnerabilities', vulnerabilities) + return result.get("prioritized_vulnerabilities", vulnerabilities) except: pass - + return vulnerabilities def analyze_nuclei_output(self, nuclei_json_output: str) -> Dict[str, Any]: """Analyze nuclei output using the high-speed C parser.""" - + if C_PARSER_AVAILABLE and c_parser: # Use the high-speed C parser when available - raw_summary = c_parser.parse_nuclei_output(nuclei_json_output.encode('utf-8')) - summary_str = raw_summary.decode('utf-8') + raw_summary = c_parser.parse_nuclei_output( + nuclei_json_output.encode("utf-8") + ) + summary_str = raw_summary.decode("utf-8") else: # SECURITY FIX: Fallback to pure Python implementation # This ensures the application works even without the native library try: import json + data = json.loads(nuclei_json_output) - critical_count = sum(1 for item in data if item.get('info', {}).get('severity') == 'critical') + critical_count = sum( + 1 + for item in data + if item.get("info", {}).get("severity") == "critical" + ) summary_str = json.dumps({"critical_findings": critical_count}) except (json.JSONDecodeError, TypeError): summary_str = '{"error": "Failed to parse nuclei output"}' - + try: return json.loads(summary_str) except json.JSONDecodeError: @@ -459,6 +500,7 @@ class AIOrchestrator: Manages a multi-step AI reasoning pipeline for complex security tasks. It chains specialized prompts for planning, tool selection, and execution. """ + def __init__(self, prompt_dir: Path): self.prompt_dir = prompt_dir self.ollama = OllamaClient() @@ -484,7 +526,9 @@ def _load_prompts(self): with open(tool_path, "r") as f: prompts["tool_selector"] = f.read() else: - prompts["tool_selector"] = "You are an expert at selecting the best security tool for a task." + prompts["tool_selector"] = ( + "You are an expert at selecting the best security tool for a task." + ) # Cursor-style code/analysis prompt analyst_path = self.prompt_dir / "Cursor Prompts" / "Cursor Prompts.txt" @@ -492,20 +536,25 @@ def _load_prompts(self): if not analyst_path.exists(): # Fallback to the first .txt file found if possible, or a default analyst_path = self.prompt_dir / "Cursor Prompts" / "System Prompt.txt" - + if analyst_path.exists(): with open(analyst_path, "r") as f: prompts["analyst"] = f.read() else: - prompts["analyst"] = "You are a senior security researcher analyzing results." - + prompts["analyst"] = ( + "You are a senior security researcher analyzing results." + ) + except Exception as e: logging.getLogger(__name__).error(f"Error loading specialized prompts: {e}") # Ensure we have defaults if everything fails prompts.setdefault("planner", "You are an expert security planner.") - prompts.setdefault("tool_selector", "You are an expert at selecting the best security tool.") + prompts.setdefault( + "tool_selector", + "You are an expert at selecting the best security tool.", + ) prompts.setdefault("analyst", "You are a senior security researcher.") - + return prompts async def execute_task(self, task_description: str): @@ -513,25 +562,25 @@ async def execute_task(self, task_description: str): Executes a full task pipeline: Plan -> Select Tool -> Execute -> Analyze. """ print("--- AI Task Pipeline Initiated ---") - + # 1. Planning Phase (using specialized prompt) plan = await self._planning_phase(task_description) - self.state['plan'] = plan + self.state["plan"] = plan print(f"Phase 1: Plan Created -> {plan}") # 2. Tool Selection Phase (using specialized prompt) tool_command = await self._tool_selection_phase(task_description, plan) - self.state['tool_command'] = tool_command + self.state["tool_command"] = tool_command print(f"Phase 2: Tool Selected -> {tool_command}") # 3. Execution Phase (simulated) execution_result = self._execution_phase(tool_command) - self.state['execution_result'] = execution_result + self.state["execution_result"] = execution_result print(f"Phase 3: Execution Result -> {execution_result[:100]}...") # 4. Analysis Phase (using specialized prompt) analysis = await self._analysis_phase(execution_result) - self.state['analysis'] = analysis + self.state["analysis"] = analysis print(f"Phase 4: Analysis Complete -> {analysis}") print("--- AI Task Pipeline Complete ---") @@ -539,18 +588,18 @@ async def execute_task(self, task_description: str): async def _planning_phase(self, task: str) -> str: """Uses the 'planner' prompt to create a high-level strategy.""" - system_prompt = self.prompts['planner'] + system_prompt = self.prompts["planner"] user_prompt = f"Create a step-by-step plan for the following task: {task}" response = await self.ollama.generate(user_prompt, system_prompt=system_prompt) return response async def _tool_selection_phase(self, task: str, plan: str) -> str: """Uses the 'tool_selector' prompt to choose the right command.""" - system_prompt = self.prompts['tool_selector'] + system_prompt = self.prompts["tool_selector"] user_prompt = f"Given the task '{task}' and the plan '{plan}', what is the exact shell command to execute next? Only output the command." response = await self.ollama.generate(user_prompt, system_prompt=system_prompt) return response - + def _execution_phase(self, command: str) -> str: """Simulates running the command and returns mock output.""" print(f"Simulating execution of: `{command}`") @@ -560,7 +609,7 @@ def _execution_phase(self, command: str) -> str: async def _analysis_phase(self, result: str) -> str: """Uses the 'analyst' prompt to interpret the results.""" - system_prompt = self.prompts['analyst'] + system_prompt = self.prompts["analyst"] user_prompt = f"Analyze the following tool output and provide a summary of key findings and recommendations:\n\n{result}" response = await self.ollama.generate(user_prompt, system_prompt=system_prompt) return response @@ -569,43 +618,51 @@ async def _analysis_phase(self, result: str) -> str: # CLI interface for AI module if __name__ == "__main__": import argparse - + parser = argparse.ArgumentParser(description="NeuroRift AI Module") - parser.add_argument("--test-connection", action="store_true", help="Test Ollama connection") + parser.add_argument( + "--test-connection", action="store_true", help="Test Ollama connection" + ) parser.add_argument("--pull-model", help="Pull a specific model") - parser.add_argument("--list-models", action="store_true", help="List available models") + parser.add_argument( + "--list-models", action="store_true", help="List available models" + ) parser.add_argument("--analyze-nmap", help="Analyze nmap output file") parser.add_argument( - "--ai-pipeline", action="store_true", help="Enable the advanced multi-prompt AI pipeline." + "--ai-pipeline", + action="store_true", + help="Enable the advanced multi-prompt AI pipeline.", ) parser.add_argument( - "--prompt-dir", help="Directory for the AI pipeline prompts.", default="prompts/system_prompts" + "--prompt-dir", + help="Directory for the AI pipeline prompts.", + default="prompts/system_prompts", ) - + args = parser.parse_args() - + orchestrator = AIOrchestrator(Path(args.prompt_dir)) - + if args.test_connection: if orchestrator.ollama.is_available(): print("✓ AI system ready") else: print("✗ AI system not available") - + elif args.pull_model: if orchestrator.ollama.pull_model(args.pull_model): print(f"✓ Model {args.pull_model} pulled successfully") else: print(f"✗ Failed to pull model {args.pull_model}") - + elif args.list_models: models = orchestrator.ollama.list_models() print("Available models:") for model in models: print(f" - {model['name']}") - + elif args.analyze_nmap: - with open(args.analyze_nmap, 'r') as f: + with open(args.analyze_nmap, "r") as f: content = f.read() result = orchestrator.analyzer.analyze_nmap_output(content) print(json.dumps(result, indent=2)) @@ -613,11 +670,13 @@ async def _analysis_phase(self, result: str) -> str: # Handle AI Pipeline Mode if args.ai_pipeline: if not args.target: - print("Error: A target is required for AI pipeline mode, e.g., --target 'scan example.com'") - + print( + "Error: A target is required for AI pipeline mode, e.g., --target 'scan example.com'" + ) + prompt_path = Path(args.prompt_dir) if not prompt_path.exists(): print(f"Error: Prompt directory not found at '{prompt_path}'") - + orchestrator = AIOrchestrator(prompt_path) - orchestrator.execute_task(f"Perform a security scan on {args.target}") \ No newline at end of file + orchestrator.execute_task(f"Perform a security scan on {args.target}") diff --git a/modules/ai/mode_governor.py b/modules/ai/mode_governor.py index 863cfe3..ed99ea7 100644 --- a/modules/ai/mode_governor.py +++ b/modules/ai/mode_governor.py @@ -20,42 +20,48 @@ class OperationalMode(Enum): """Operational modes for NeuroRift""" + OFFENSIVE = "offensive" DEFENSIVE = "defensive" class ModeViolation(Exception): """Raised when a mode violation is detected""" + pass class ModeGovernor: """ Mode Governor enforces operational discipline between OFFENSIVE and DEFENSIVE modes. - + CRITICAL RULES: 1. No cross-mode contamination 2. Tool/module access strictly controlled per mode 3. All violations logged 4. Mode switching disabled by default """ - + def __init__(self, config_path: str = "configs/neurorift_x_config.json"): self.config_path = Path(config_path) self.config = self._load_config() self.current_mode: Optional[OperationalMode] = None self.logger = logging.getLogger(__name__) self.violation_log: List[Dict] = [] - + # Initialize mode governor self.enabled = self.config.get("mode_governor", {}).get("enabled", True) - self.allow_mode_switching = self.config.get("mode_governor", {}).get("allow_mode_switching", False) - self.log_violations = self.config.get("mode_governor", {}).get("log_violations", True) - + self.allow_mode_switching = self.config.get("mode_governor", {}).get( + "allow_mode_switching", False + ) + self.log_violations = self.config.get("mode_governor", {}).get( + "log_violations", True + ) + def _load_config(self) -> Dict: """Load NeuroRift configuration""" try: - with open(self.config_path, 'r') as f: + with open(self.config_path, "r") as f: return json.load(f) except FileNotFoundError: self.logger.error(f"Configuration file not found: {self.config_path}") @@ -63,14 +69,14 @@ def _load_config(self) -> Dict: except json.JSONDecodeError as e: self.logger.error(f"Invalid JSON in configuration: {e}") raise - + def set_mode(self, mode: str) -> None: """ Set the operational mode. - + Args: mode: Either 'offensive' or 'defensive' - + Raises: ValueError: If mode is invalid ModeViolation: If mode switching is not allowed @@ -78,159 +84,167 @@ def set_mode(self, mode: str) -> None: try: new_mode = OperationalMode(mode.lower()) except ValueError: - raise ValueError(f"Invalid mode: {mode}. Must be 'offensive' or 'defensive'") - + raise ValueError( + f"Invalid mode: {mode}. Must be 'offensive' or 'defensive'" + ) + # Check if mode switching is allowed if self.current_mode is not None and self.current_mode != new_mode: if not self.allow_mode_switching: raise ModeViolation( f"Mode switching is disabled. Cannot switch from {self.current_mode.value} to {new_mode.value}" ) - self.logger.warning(f"Mode switched from {self.current_mode.value} to {new_mode.value}") - + self.logger.warning( + f"Mode switched from {self.current_mode.value} to {new_mode.value}" + ) + self.current_mode = new_mode self.logger.info(f"Operational mode set to: {self.current_mode.value.upper()}") - + def get_allowed_tools(self) -> List[str]: """Get list of tools allowed in current mode""" if not self.current_mode: raise ModeViolation("No operational mode set") - + modes_config = self.config.get("mode_governor", {}).get("modes", {}) mode_config = modes_config.get(self.current_mode.value, {}) return mode_config.get("allowed_tools", []) - + def get_allowed_modules(self) -> List[str]: """Get list of modules allowed in current mode""" if not self.current_mode: raise ModeViolation("No operational mode set") - + modes_config = self.config.get("mode_governor", {}).get("modes", {}) mode_config = modes_config.get(self.current_mode.value, {}) return mode_config.get("allowed_modules", []) - + def get_restrictions(self) -> List[str]: """Get list of restrictions for current mode""" if not self.current_mode: raise ModeViolation("No operational mode set") - + modes_config = self.config.get("mode_governor", {}).get("modes", {}) mode_config = modes_config.get(self.current_mode.value, {}) return mode_config.get("restrictions", []) - + def validate_tool(self, tool_name: str) -> bool: """ Validate if a tool is allowed in current mode. - + Args: tool_name: Name of the tool to validate - + Returns: True if tool is allowed, False otherwise - + Raises: ModeViolation: If tool is not allowed and violations are enforced """ if not self.enabled: return True - + if not self.current_mode: raise ModeViolation("No operational mode set") - + allowed_tools = self.get_allowed_tools() - + if tool_name not in allowed_tools: violation = { "timestamp": datetime.now().isoformat(), "mode": self.current_mode.value, "violation_type": "unauthorized_tool", "tool": tool_name, - "allowed_tools": allowed_tools + "allowed_tools": allowed_tools, } - + if self.log_violations: self.violation_log.append(violation) - self.logger.warning(f"Tool violation: {tool_name} not allowed in {self.current_mode.value} mode") - + self.logger.warning( + f"Tool violation: {tool_name} not allowed in {self.current_mode.value} mode" + ) + raise ModeViolation( f"Tool '{tool_name}' is not allowed in {self.current_mode.value.upper()} mode. " f"Allowed tools: {', '.join(allowed_tools)}" ) - + return True - + def validate_module(self, module_name: str) -> bool: """ Validate if a module is allowed in current mode. - + Args: module_name: Name of the module to validate - + Returns: True if module is allowed, False otherwise - + Raises: ModeViolation: If module is not allowed and violations are enforced """ if not self.enabled: return True - + if not self.current_mode: raise ModeViolation("No operational mode set") - + allowed_modules = self.get_allowed_modules() - + if module_name not in allowed_modules: violation = { "timestamp": datetime.now().isoformat(), "mode": self.current_mode.value, "violation_type": "unauthorized_module", "module": module_name, - "allowed_modules": allowed_modules + "allowed_modules": allowed_modules, } - + if self.log_violations: self.violation_log.append(violation) - self.logger.warning(f"Module violation: {module_name} not allowed in {self.current_mode.value} mode") - + self.logger.warning( + f"Module violation: {module_name} not allowed in {self.current_mode.value} mode" + ) + raise ModeViolation( f"Module '{module_name}' is not allowed in {self.current_mode.value.upper()} mode. " f"Allowed modules: {', '.join(allowed_modules)}" ) - + return True - + def get_mode_prompt_file(self) -> Optional[str]: """Get the prompt file for current mode""" if not self.current_mode: return None - + modes_config = self.config.get("mode_governor", {}).get("modes", {}) mode_config = modes_config.get(self.current_mode.value, {}) return mode_config.get("prompt_file") - + def get_violation_log(self) -> List[Dict]: """Get the violation log""" return self.violation_log - + def save_violation_log(self, output_path: str) -> None: """Save violation log to file""" try: - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(self.violation_log, f, indent=2) self.logger.info(f"Violation log saved to: {output_path}") except Exception as e: self.logger.error(f"Failed to save violation log: {e}") - + def get_mode_description(self) -> str: """Get description of current mode""" if not self.current_mode: return "No mode set" - + modes_config = self.config.get("mode_governor", {}).get("modes", {}) mode_config = modes_config.get(self.current_mode.value, {}) return mode_config.get("description", "No description available") - + def __repr__(self) -> str: mode_str = self.current_mode.value.upper() if self.current_mode else "NONE" return f"" @@ -239,10 +253,10 @@ def __repr__(self) -> str: # Example usage and testing if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + # Initialize Mode Governor governor = ModeGovernor() - + # Set OFFENSIVE mode print("\n=== Testing OFFENSIVE Mode ===") governor.set_mode("offensive") @@ -250,21 +264,21 @@ def __repr__(self) -> str: print(f"Description: {governor.get_mode_description()}") print(f"Allowed tools: {governor.get_allowed_tools()}") print(f"Restrictions: {governor.get_restrictions()}") - + # Test valid tool try: governor.validate_tool("nmap") print("✓ nmap is allowed in OFFENSIVE mode") except ModeViolation as e: print(f"✗ {e}") - + # Test invalid tool try: governor.validate_tool("patch_validator") print("✓ patch_validator is allowed in OFFENSIVE mode") except ModeViolation as e: print(f"✗ {e}") - + # Set DEFENSIVE mode (will fail if mode switching disabled) print("\n=== Testing DEFENSIVE Mode ===") try: diff --git a/modules/ai/orchestrator.py b/modules/ai/orchestrator.py index 3be1b11..6b22b13 100644 --- a/modules/ai/orchestrator.py +++ b/modules/ai/orchestrator.py @@ -24,6 +24,7 @@ class AgentType(Enum): """Agent types in the orchestration system""" + PLANNER = "planner" OPERATOR = "operator" ANALYST = "analyst" @@ -32,6 +33,7 @@ class AgentType(Enum): class OrchestrationStatus(Enum): """Status of orchestration""" + IDLE = "idle" PLANNING = "planning" EXECUTING = "executing" @@ -45,42 +47,44 @@ class OrchestrationStatus(Enum): class NeuroRiftXOrchestrator: """ Central orchestration engine for NeuroRift multi-agent system. - + Manages the lifecycle of security assessment operations by coordinating multiple specialized agents (Planner, Operator, Analyst, Scribe). """ - + def __init__(self, config_path: str = "configs/neurorift_x_config.json"): self.config_path = Path(config_path) self.config = self._load_config() self.logger = logging.getLogger(__name__) - + # Initialize components self.mode_governor = ModeGovernor(config_path) self.task_memory = TaskMemory() self.agent_context = AgentContext() - + # Orchestration state self.status = OrchestrationStatus.IDLE self.current_agent: Optional[AgentType] = None self.current_task_id: Optional[str] = None self.orchestration_cycle = 0 - self.max_cycles = self.config.get("orchestration", {}).get("max_orchestration_cycles", 5) - + self.max_cycles = self.config.get("orchestration", {}).get( + "max_orchestration_cycles", 5 + ) + # Agent flow configuration self.agent_flow = [ AgentType.PLANNER, AgentType.OPERATOR, AgentType.ANALYST, - AgentType.SCRIBE + AgentType.SCRIBE, ] - + self.logger.info("NeuroRift Orchestrator initialized") - + def _load_config(self) -> Dict: """Load NeuroRift configuration""" try: - with open(self.config_path, 'r') as f: + with open(self.config_path, "r") as f: return json.load(f) except FileNotFoundError: self.logger.error(f"Configuration file not found: {self.config_path}") @@ -88,26 +92,26 @@ def _load_config(self) -> Dict: except json.JSONDecodeError as e: self.logger.error(f"Invalid JSON in configuration: {e}") raise - + def initialize_task(self, user_request: str, mode: str, target: str) -> str: """ Initialize a new security assessment task. - + Args: user_request: User's security assessment request mode: Operational mode ('offensive' or 'defensive') target: Target domain/IP - + Returns: Task ID """ # Set operational mode self.mode_governor.set_mode(mode) - + # Create task task_id = f"task_{datetime.now().strftime('%Y%m%d_%H%M%S')}" self.current_task_id = task_id - + # Initialize task memory task_data = { "task_id": task_id, @@ -116,122 +120,118 @@ def initialize_task(self, user_request: str, mode: str, target: str) -> str: "target": target, "status": "initialized", "created_at": datetime.now().isoformat(), - "orchestration_cycle": 0 + "orchestration_cycle": 0, } - + self.task_memory.create_task(task_id, task_data) - + # Initialize agent context - self.agent_context.initialize(task_id, { - "user_request": user_request, - "mode": mode, - "target": target - }) - - self.logger.info(f"Task {task_id} initialized in {mode.upper()} mode for target: {target}") + self.agent_context.initialize( + task_id, {"user_request": user_request, "mode": mode, "target": target} + ) + + self.logger.info( + f"Task {task_id} initialized in {mode.upper()} mode for target: {target}" + ) return task_id - + def execute_task(self, task_id: Optional[str] = None) -> Dict[str, Any]: """ Execute a security assessment task through the agent pipeline. - + Args: task_id: Task ID to execute (uses current task if None) - + Returns: Execution results """ if task_id: self.current_task_id = task_id - + if not self.current_task_id: raise ValueError("No task ID specified") - + self.logger.info(f"Starting orchestration for task: {self.current_task_id}") - + results = { "task_id": self.current_task_id, "status": "in_progress", - "agent_outputs": {} + "agent_outputs": {}, } - + try: # Execute agent flow for agent_type in self.agent_flow: self.logger.info(f"Executing agent: {agent_type.value}") self.current_agent = agent_type self.status = self._get_status_for_agent(agent_type) - + # Execute agent agent_output = self._execute_agent(agent_type) results["agent_outputs"][agent_type.value] = agent_output - + # Update task memory self.task_memory.update_task( - self.current_task_id, - {f"{agent_type.value}_output": agent_output} + self.current_task_id, {f"{agent_type.value}_output": agent_output} ) - + # Check for errors if agent_output.get("status") == "failed": self.logger.error(f"Agent {agent_type.value} failed") results["status"] = "failed" results["error"] = agent_output.get("error") break - + # Check for human approval requirement if agent_output.get("requires_human_approval"): self.status = OrchestrationStatus.AWAITING_APPROVAL results["status"] = "awaiting_approval" results["approval_request"] = agent_output.get("approval_request") return results - + # All agents completed successfully self.status = OrchestrationStatus.COMPLETED results["status"] = "completed" - + # Update task memory self.task_memory.update_task( self.current_task_id, - { - "status": "completed", - "completed_at": datetime.now().isoformat() - } + {"status": "completed", "completed_at": datetime.now().isoformat()}, ) - + self.logger.info(f"Task {self.current_task_id} completed successfully") - + except Exception as e: self.logger.error(f"Orchestration error: {e}", exc_info=True) self.status = OrchestrationStatus.FAILED results["status"] = "failed" results["error"] = str(e) - + return results - + def _execute_agent(self, agent_type: AgentType) -> Dict[str, Any]: """ Execute a specific agent. - + Args: agent_type: Type of agent to execute - + Returns: Agent output """ # Get agent configuration agent_config = self.config.get("agents", {}).get(agent_type.value, {}) - + # Load agent prompt prompt_file = agent_config.get("prompt_file") if prompt_file: prompt = self._load_prompt(prompt_file) else: prompt = "" - + # Get agent context context = self.agent_context.get_context(agent_type.value) - + # Execute agent based on type if agent_type == AgentType.PLANNER: return self._execute_planner(prompt, context) @@ -243,11 +243,11 @@ def _execute_agent(self, agent_type: AgentType) -> Dict[str, Any]: return self._execute_scribe(prompt, context) else: raise ValueError(f"Unknown agent type: {agent_type}") - + def _execute_planner(self, prompt: str, context: Dict) -> Dict[str, Any]: """Execute NR Planner agent""" self.logger.info("NR Planner: Creating execution plan") - + # TODO: Integrate with actual LLM # For now, return a placeholder plan plan = { @@ -262,44 +262,44 @@ def _execute_planner(self, prompt: str, context: Dict) -> Dict[str, Any]: "description": "Enumerate subdomains", "agent": "operator", "tool": "subfinder", - "requires_human_approval": False + "requires_human_approval": False, } - ] + ], } - + # Store plan in agent context self.agent_context.set_context("planner", plan) - + return plan - + def _execute_operator(self, prompt: str, context: Dict) -> Dict[str, Any]: """Execute NR Operator agent""" self.logger.info("NR Operator: Executing plan") - + # Get plan from Planner plan = self.agent_context.get_context("planner") - + # TODO: Execute actual commands # For now, return placeholder results results = { "execution_id": f"exec_{self.current_task_id}", "status": "completed", "executed_steps": len(plan.get("steps", [])), - "outputs": [] + "outputs": [], } - + # Store results in agent context self.agent_context.set_context("operator", results) - + return results - + def _execute_analyst(self, prompt: str, context: Dict) -> Dict[str, Any]: """Execute NR Analyst agent""" self.logger.info("NR Analyst: Analyzing results") - + # Get execution results from Operator exec_results = self.agent_context.get_context("operator") - + # TODO: Perform actual analysis # For now, return placeholder analysis analysis = { @@ -311,55 +311,55 @@ def _execute_analyst(self, prompt: str, context: Dict) -> Dict[str, Any]: "critical": 0, "high": 0, "medium": 0, - "low": 0 - } + "low": 0, + }, } - + # Store analysis in agent context self.agent_context.set_context("analyst", analysis) - + return analysis - + def _execute_scribe(self, prompt: str, context: Dict) -> Dict[str, Any]: """Execute NR Scribe agent""" self.logger.info("NR Scribe: Generating report") - + # Get analysis from Analyst analysis = self.agent_context.get_context("analyst") - + # TODO: Generate actual report # For now, return placeholder report report = { "report_id": f"report_{self.current_task_id}", "status": "completed", "format": "markdown", - "path": f"~/.neurorift/reports/{self.current_task_id}.md" + "path": f"~/.neurorift/reports/{self.current_task_id}.md", } - + # Store report in agent context self.agent_context.set_context("scribe", report) - + return report - + def _load_prompt(self, prompt_file: str) -> str: """Load agent prompt from file""" try: - with open(prompt_file, 'r') as f: + with open(prompt_file, "r") as f: return f.read() except FileNotFoundError: self.logger.warning(f"Prompt file not found: {prompt_file}") return "" - + def _get_status_for_agent(self, agent_type: AgentType) -> OrchestrationStatus: """Get orchestration status for agent type""" status_map = { AgentType.PLANNER: OrchestrationStatus.PLANNING, AgentType.OPERATOR: OrchestrationStatus.EXECUTING, AgentType.ANALYST: OrchestrationStatus.ANALYZING, - AgentType.SCRIBE: OrchestrationStatus.REPORTING + AgentType.SCRIBE: OrchestrationStatus.REPORTING, } return status_map.get(agent_type, OrchestrationStatus.IDLE) - + def get_status(self) -> Dict[str, Any]: """Get current orchestration status""" return { @@ -367,22 +367,26 @@ def get_status(self) -> Dict[str, Any]: "current_agent": self.current_agent.value if self.current_agent else None, "current_task_id": self.current_task_id, "orchestration_cycle": self.orchestration_cycle, - "mode": self.mode_governor.current_mode.value if self.mode_governor.current_mode else None + "mode": ( + self.mode_governor.current_mode.value + if self.mode_governor.current_mode + else None + ), } - + def resume_task(self, task_id: str) -> Dict[str, Any]: """Resume a paused or failed task""" self.logger.info(f"Resuming task: {task_id}") self.current_task_id = task_id - + # Load task from memory task_data = self.task_memory.get_task(task_id) if not task_data: raise ValueError(f"Task not found: {task_id}") - + # Restore mode self.mode_governor.set_mode(task_data.get("mode")) - + # Continue execution return self.execute_task(task_id) @@ -390,22 +394,22 @@ def resume_task(self, task_id: str) -> Dict[str, Any]: # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + # Initialize orchestrator orchestrator = NeuroRiftXOrchestrator() - + # Initialize task task_id = orchestrator.initialize_task( user_request="Perform reconnaissance on example.com", mode="offensive", - target="example.com" + target="example.com", ) - + print(f"\nTask initialized: {task_id}") print(f"Status: {orchestrator.get_status()}") - + # Execute task results = orchestrator.execute_task() - + print(f"\nExecution results:") print(json.dumps(results, indent=2)) diff --git a/modules/ai/task_memory.py b/modules/ai/task_memory.py index 3a5513b..0151b88 100644 --- a/modules/ai/task_memory.py +++ b/modules/ai/task_memory.py @@ -20,22 +20,22 @@ class TaskMemory: """ Manages persistent storage of task state for NeuroRift. - + Provides checkpoint/resume capability and execution history tracking. """ - + def __init__(self, storage_path: str = "~/.neurorift/task_memory"): self.storage_path = Path(storage_path).expanduser() self.storage_path.mkdir(parents=True, exist_ok=True) self.logger = logging.getLogger(__name__) self.current_task: Optional[Dict] = None - + self.logger.info(f"Task memory initialized at: {self.storage_path}") - + def create_task(self, task_id: str, task_data: Dict[str, Any]) -> None: """ Create a new task in memory. - + Args: task_id: Unique task identifier task_data: Task data to store @@ -44,16 +44,16 @@ def create_task(self, task_id: str, task_data: Dict[str, Any]) -> None: task_data["created_at"] = datetime.now().isoformat() task_data["updated_at"] = datetime.now().isoformat() task_data["checkpoints"] = [] - + self.current_task = task_data self._save_task(task_id, task_data) - + self.logger.info(f"Task created: {task_id}") - + def update_task(self, task_id: str, updates: Dict[str, Any]) -> None: """ Update an existing task. - + Args: task_id: Task identifier updates: Dictionary of updates to apply @@ -61,41 +61,41 @@ def update_task(self, task_id: str, updates: Dict[str, Any]) -> None: task_data = self.get_task(task_id) if not task_data: raise ValueError(f"Task not found: {task_id}") - + task_data.update(updates) task_data["updated_at"] = datetime.now().isoformat() - + self.current_task = task_data self._save_task(task_id, task_data) - + self.logger.debug(f"Task updated: {task_id}") - + def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: """ Retrieve a task from memory. - + Args: task_id: Task identifier - + Returns: Task data or None if not found """ task_file = self.storage_path / f"{task_id}.json" - + if not task_file.exists(): return None - + try: - with open(task_file, 'r') as f: + with open(task_file, "r") as f: return json.load(f) except Exception as e: self.logger.error(f"Error loading task {task_id}: {e}") return None - + def checkpoint(self, task_id: str, checkpoint_data: Dict[str, Any]) -> None: """ Create a checkpoint for a task. - + Args: task_id: Task identifier checkpoint_data: Checkpoint data to store @@ -103,83 +103,82 @@ def checkpoint(self, task_id: str, checkpoint_data: Dict[str, Any]) -> None: task_data = self.get_task(task_id) if not task_data: raise ValueError(f"Task not found: {task_id}") - - checkpoint = { - "timestamp": datetime.now().isoformat(), - "data": checkpoint_data - } - + + checkpoint = {"timestamp": datetime.now().isoformat(), "data": checkpoint_data} + task_data.setdefault("checkpoints", []).append(checkpoint) self._save_task(task_id, task_data) - + self.logger.info(f"Checkpoint created for task: {task_id}") - + def get_latest_checkpoint(self, task_id: str) -> Optional[Dict[str, Any]]: """ Get the latest checkpoint for a task. - + Args: task_id: Task identifier - + Returns: Latest checkpoint data or None """ task_data = self.get_task(task_id) if not task_data: return None - + checkpoints = task_data.get("checkpoints", []) if not checkpoints: return None - + return checkpoints[-1]["data"] - + def list_tasks(self, status: Optional[str] = None) -> List[Dict[str, Any]]: """ List all tasks, optionally filtered by status. - + Args: status: Optional status filter - + Returns: List of task summaries """ tasks = [] - + for task_file in self.storage_path.glob("task_*.json"): try: - with open(task_file, 'r') as f: + with open(task_file, "r") as f: task_data = json.load(f) - + if status is None or task_data.get("status") == status: - tasks.append({ - "task_id": task_data.get("task_id"), - "status": task_data.get("status"), - "mode": task_data.get("mode"), - "target": task_data.get("target"), - "created_at": task_data.get("created_at"), - "updated_at": task_data.get("updated_at") - }) + tasks.append( + { + "task_id": task_data.get("task_id"), + "status": task_data.get("status"), + "mode": task_data.get("mode"), + "target": task_data.get("target"), + "created_at": task_data.get("created_at"), + "updated_at": task_data.get("updated_at"), + } + ) except Exception as e: self.logger.error(f"Error reading task file {task_file}: {e}") - + return sorted(tasks, key=lambda x: x.get("created_at", ""), reverse=True) - + def delete_task(self, task_id: str) -> bool: """ Delete a task from memory. - + Args: task_id: Task identifier - + Returns: True if deleted, False if not found """ task_file = self.storage_path / f"{task_id}.json" - + if not task_file.exists(): return False - + try: task_file.unlink() self.logger.info(f"Task deleted: {task_id}") @@ -187,40 +186,40 @@ def delete_task(self, task_id: str) -> bool: except Exception as e: self.logger.error(f"Error deleting task {task_id}: {e}") return False - + def _save_task(self, task_id: str, task_data: Dict[str, Any]) -> None: """Save task data to file""" task_file = self.storage_path / f"{task_id}.json" - + try: - with open(task_file, 'w') as f: + with open(task_file, "w") as f: json.dump(task_data, f, indent=2) except Exception as e: self.logger.error(f"Error saving task {task_id}: {e}") raise - + def get_history(self, task_id: str, limit: int = 100) -> List[Dict[str, Any]]: """ Get execution history for a task. - + Args: task_id: Task identifier limit: Maximum number of history entries - + Returns: List of history entries """ task_data = self.get_task(task_id) if not task_data: return [] - + history = task_data.get("history", []) return history[-limit:] - + def add_history_entry(self, task_id: str, entry: Dict[str, Any]) -> None: """ Add an entry to task history. - + Args: task_id: Task identifier entry: History entry to add @@ -228,43 +227,43 @@ def add_history_entry(self, task_id: str, entry: Dict[str, Any]) -> None: task_data = self.get_task(task_id) if not task_data: raise ValueError(f"Task not found: {task_id}") - + entry["timestamp"] = datetime.now().isoformat() task_data.setdefault("history", []).append(entry) - + # Limit history size max_history = 100 if len(task_data["history"]) > max_history: task_data["history"] = task_data["history"][-max_history:] - + self._save_task(task_id, task_data) # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + # Initialize task memory memory = TaskMemory() - + # Create a task task_id = "task_20260124_092347" - memory.create_task(task_id, { - "user_request": "Scan example.com", - "mode": "offensive", - "target": "example.com", - "status": "initialized" - }) - + memory.create_task( + task_id, + { + "user_request": "Scan example.com", + "mode": "offensive", + "target": "example.com", + "status": "initialized", + }, + ) + # Update task memory.update_task(task_id, {"status": "planning"}) - + # Create checkpoint - memory.checkpoint(task_id, { - "agent": "planner", - "output": {"plan_id": "plan_001"} - }) - + memory.checkpoint(task_id, {"agent": "planner", "output": {"plan_id": "plan_001"}}) + # List tasks tasks = memory.list_tasks() print(f"\nTasks: {len(tasks)}") diff --git a/modules/config/config_wizard.py b/modules/config/config_wizard.py index 17cd068..ff6dcdc 100644 --- a/modules/config/config_wizard.py +++ b/modules/config/config_wizard.py @@ -18,25 +18,28 @@ from modules.ai.ai_integration import OllamaClient import asyncio + class ConfigWizard: def __init__(self, base_dir: Path): self.base_dir = base_dir self.env_path = base_dir / ".env" self.console = Console() self.ollama = OllamaClient() - + # Load existing env vars load_dotenv(self.env_path) def run(self): """Run the interactive configuration wizard""" self.console.clear() - self.console.print(Panel( - "[bold cyan]NeuroRift Configuration Wizard[/bold cyan]\n" - "Configure your AI models, API keys, and system settings.", - title="Setup", - border_style="blue" - )) + self.console.print( + Panel( + "[bold cyan]NeuroRift Configuration Wizard[/bold cyan]\n" + "Configure your AI models, API keys, and system settings.", + title="Setup", + border_style="blue", + ) + ) while True: self.console.print("\n[bold]Main Menu[/bold]") @@ -45,7 +48,9 @@ def run(self): self.console.print("3. [cyan]View Current Config[/cyan]") self.console.print("4. [red]Exit[/red]") - choice = Prompt.ask("Select an option", choices=["1", "2", "3", "4"], default="4") + choice = Prompt.ask( + "Select an option", choices=["1", "2", "3", "4"], default="4" + ) if choice == "1": asyncio.run(self.configure_ai()) @@ -60,70 +65,74 @@ def run(self): async def configure_ai(self): """Configure AI settings""" self.console.print("\n[bold blue]AI Configuration[/bold blue]") - + # Check current settings current_main = os.getenv("OLLAMA_MAIN_MODEL", "Not Set") current_assistant = os.getenv("OLLAMA_ASSISTANT_MODEL", "Not Set") - + self.console.print(f"Current Main Model: [green]{current_main}[/green]") - self.console.print(f"Current Assistant Model: [green]{current_assistant}[/green]") - + self.console.print( + f"Current Assistant Model: [green]{current_assistant}[/green]" + ) + if Confirm.ask("Do you want to change AI models?"): # List available models self.console.print("Fetching available models from Ollama...") if not await self.ollama.ensure_service_running(): - self.console.print("[red]Could not connect to Ollama. Is it installed?[/red]") + self.console.print( + "[red]Could not connect to Ollama. Is it installed?[/red]" + ) return models = await self.ollama.list_models() if not models: - self.console.print("[yellow]No models found in Ollama.[/yellow]") - return + self.console.print("[yellow]No models found in Ollama.[/yellow]") + return + + model_names = [m["name"] for m in models] - model_names = [m['name'] for m in models] - # Display models table = Table(title="Available Ollama Models") table.add_column("Index", justify="right", style="cyan") table.add_column("Model Name", style="white") - + for idx, name in enumerate(model_names): table.add_row(str(idx + 1), name) - + self.console.print(table) - + # Select Main Model main_idx = Prompt.ask( - "Select Main Model (number)", - choices=[str(i+1) for i in range(len(model_names))] + "Select Main Model (number)", + choices=[str(i + 1) for i in range(len(model_names))], ) selected_main = model_names[int(main_idx) - 1] self.update_env("OLLAMA_MAIN_MODEL", selected_main) - + # Select Assistant Model asst_idx = Prompt.ask( - "Select Assistant Model (number)", - choices=[str(i+1) for i in range(len(model_names))], - default=main_idx + "Select Assistant Model (number)", + choices=[str(i + 1) for i in range(len(model_names))], + default=main_idx, ) selected_asst = model_names[int(asst_idx) - 1] self.update_env("OLLAMA_ASSISTANT_MODEL", selected_asst) - + self.console.print("[green]AI Models updated successfully![/green]") def configure_system(self): """Configure system settings""" self.console.print("\n[bold blue]System Settings[/bold blue]") - + # Log Level current_log = os.getenv("LOG_LEVEL", "INFO") new_log = Prompt.ask( - "Log Level", - choices=["DEBUG", "INFO", "WARNING", "ERROR"], - default=current_log + "Log Level", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + default=current_log, ) self.update_env("LOG_LEVEL", new_log) - + self.console.print("[green]System settings updated![/green]") def show_current_config(self): @@ -131,18 +140,18 @@ def show_current_config(self): table = Table(title="Current Configuration") table.add_column("Key", style="cyan") table.add_column("Value", style="green") - + config_keys = [ "OLLAMA_MAIN_MODEL", "OLLAMA_ASSISTANT_MODEL", "LOG_LEVEL", "NEURORIFT_HOME", - "AI_ENABLED" + "AI_ENABLED", ] - + for key in config_keys: table.add_row(key, os.getenv(key, "Not Set")) - + self.console.print(table) Prompt.ask("Press Enter to continue") @@ -151,11 +160,12 @@ def update_env(self, key: str, value: str): # Create .env if it doesn't exist if not self.env_path.exists(): self.env_path.touch() - + set_key(self.env_path, key, value) os.environ[key] = value self.console.print(f"[dim]Updated {key} = {value}[/dim]") + if __name__ == "__main__": wizard = ConfigWizard(Path.cwd()) wizard.run() diff --git a/modules/cve_collector/__init__.py b/modules/cve_collector/__init__.py index 8346527..7f1974a 100755 --- a/modules/cve_collector/__init__.py +++ b/modules/cve_collector/__init__.py @@ -4,4 +4,4 @@ from .cve_collector import CVECollector -__all__ = ['CVECollector'] \ No newline at end of file +__all__ = ["CVECollector"] diff --git a/modules/cve_collector/cve_collector.py b/modules/cve_collector/cve_collector.py index 1bbd34f..ad09c48 100755 --- a/modules/cve_collector/cve_collector.py +++ b/modules/cve_collector/cve_collector.py @@ -11,7 +11,13 @@ from typing import Dict, List, Optional, Any from datetime import datetime, timedelta from rich.console import Console -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn +from rich.progress import ( + Progress, + SpinnerColumn, + TextColumn, + BarColumn, + TaskProgressColumn, +) import hashlib import time import re @@ -21,6 +27,7 @@ from utils.cli_utils import create_progress, print_results_table from ai_wrapper.ollama_wrapper import OllamaWrapper + class CVECollector: def __init__(self, base_dir: Path, ai_wrapper: Optional[OllamaWrapper] = None): self.base_dir = base_dir @@ -31,39 +38,39 @@ def __init__(self, base_dir: Path, ai_wrapper: Optional[OllamaWrapper] = None): self.cache_dir = base_dir / "cache" / "cve_feeds" self.data_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True) - + # API rate limiting self.nvd_rate_limit = 5 # requests per second self.github_rate_limit = 30 # requests per hour self.last_nvd_request = 0 self.last_github_request = 0 self.github_requests_remaining = 30 - + # Load API keys if available self.api_keys = self._load_api_keys() - + def _load_api_keys(self) -> Dict[str, str]: """Load API keys from config""" try: config_path = self.base_dir / "config" / "api_keys.json" if config_path.exists(): - with open(config_path, 'r') as f: + with open(config_path, "r") as f: return json.load(f) except Exception as e: self.logger.warning("Failed to load API keys: %s", e) self.api_keys = {} return {} - + async def _wait_for_rate_limit(self, api_type: str): """Handle API rate limiting""" current_time = time.time() - + if api_type == "nvd": time_since_last = current_time - self.last_nvd_request if time_since_last < (1 / self.nvd_rate_limit): await asyncio.sleep((1 / self.nvd_rate_limit) - time_since_last) self.last_nvd_request = time.time() - + elif api_type == "github": if self.github_requests_remaining <= 0: # Wait until rate limit resets (1 hour) @@ -71,55 +78,56 @@ async def _wait_for_rate_limit(self, api_type: str): self.github_requests_remaining = 30 self.github_requests_remaining -= 1 self.last_github_request = current_time - + def _get_cache_path(self, source: str, identifier: str) -> Path: """Get cache file path for a request""" cache_key = hashlib.md5(f"{source}:{identifier}".encode()).hexdigest() return self.cache_dir / f"{cache_key}.json" - - async def _get_cached_data(self, cache_path: Path, max_age: int = 3600) -> Optional[Dict]: + + async def _get_cached_data( + self, cache_path: Path, max_age: int = 3600 + ) -> Optional[Dict]: """Get cached data if it exists and is not too old""" try: if cache_path.exists(): stat = cache_path.stat() if time.time() - stat.st_mtime < max_age: - async with aiofiles.open(cache_path, 'r') as f: + async with aiofiles.open(cache_path, "r") as f: return json.loads(await f.read()) except Exception as e: self.logger.warning("Error reading cache: %s", e) return None - + async def _save_to_cache(self, cache_path: Path, data: Dict): """Save data to cache""" try: - async with aiofiles.open(cache_path, 'w') as f: + async with aiofiles.open(cache_path, "w") as f: await f.write(json.dumps(data)) except Exception as e: self.logger.warning("Error saving to cache: %s", e) - + async def fetch_nvd_feed(self, start_date: Optional[str] = None) -> List[Dict]: """Fetch CVE data from NVD feed""" if not start_date: start_date = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") - + self.console.print("[bold blue]Fetching NVD CVE feed...[/bold blue]") - + # Check cache first cache_path = self._get_cache_path("nvd", start_date) - cached_data = await self._get_cached_data(cache_path, max_age=3600) # 1 hour cache + cached_data = await self._get_cached_data( + cache_path, max_age=3600 + ) # 1 hour cache if cached_data: return cached_data - + url = f"https://services.nvd.nist.gov/rest/json/cves/2.0" - params = { - "pubStartDate": f"{start_date}T00:00:00.000", - "resultsPerPage": 2000 - } - + params = {"pubStartDate": f"{start_date}T00:00:00.000", "resultsPerPage": 2000} + # Add API key if available if "nvd" in self.api_keys: params["apiKey"] = self.api_keys["nvd"] - + async with aiohttp.ClientSession() as session: try: await self._wait_for_rate_limit("nvd") @@ -139,47 +147,63 @@ async def fetch_nvd_feed(self, start_date: Optional[str] = None) -> List[Dict]: except (aiohttp.ClientError, asyncio.TimeoutError) as e: self.logger.error("Error fetching NVD feed: %s", e) return [] - + async def fetch_exploit_db(self) -> List[Dict]: """Fetch exploit data from Exploit-DB""" self.console.print("[bold blue]Fetching Exploit-DB data...[/bold blue]") - + # Check cache first cache_path = self._get_cache_path("exploitdb", "latest") - cached_data = await self._get_cached_data(cache_path, max_age=86400) # 24 hour cache + cached_data = await self._get_cached_data( + cache_path, max_age=86400 + ) # 24 hour cache if cached_data: return cached_data - + url = "https://raw.githubusercontent.com/offensive-security/exploitdb/master/files_exploits.csv" - + async with aiohttp.ClientSession() as session: try: async with session.get(url) as response: if response.status == 200: content = await response.text() exploits = [] - + # Parse CSV content for line in content.splitlines()[1:]: # Skip header try: - parts = line.split(',') + parts = line.split(",") if len(parts) >= 4: # Extract CVE IDs from description - cve_ids = re.findall(r'CVE-\d{4}-\d+', parts[2]) - - exploits.append({ - "id": parts[0], - "file": parts[1], - "description": parts[2], - "date": parts[3], - "author": parts[4] if len(parts) > 4 else "Unknown", - "platform": parts[5] if len(parts) > 5 else "Unknown", - "type": parts[6] if len(parts) > 6 else "Unknown", - "cve_ids": cve_ids - }) + cve_ids = re.findall(r"CVE-\d{4}-\d+", parts[2]) + + exploits.append( + { + "id": parts[0], + "file": parts[1], + "description": parts[2], + "date": parts[3], + "author": ( + parts[4] + if len(parts) > 4 + else "Unknown" + ), + "platform": ( + parts[5] + if len(parts) > 5 + else "Unknown" + ), + "type": ( + parts[6] + if len(parts) > 6 + else "Unknown" + ), + "cve_ids": cve_ids, + } + ) except: continue - + await self._save_to_cache(cache_path, exploits) return exploits else: @@ -188,62 +212,82 @@ async def fetch_exploit_db(self) -> List[Dict]: except Exception as e: self.logger.error("Error fetching Exploit-DB: %s", e) return [] - + async def fetch_github_pocs(self, cve_id: str) -> List[Dict]: """Search for PoCs on GitHub""" - self.console.print(f"[bold blue]Searching GitHub for {cve_id} PoCs...[/bold blue]") - + self.console.print( + f"[bold blue]Searching GitHub for {cve_id} PoCs...[/bold blue]" + ) + # Check cache first cache_path = self._get_cache_path("github", cve_id) - cached_data = await self._get_cached_data(cache_path, max_age=86400) # 24 hour cache + cached_data = await self._get_cached_data( + cache_path, max_age=86400 + ) # 24 hour cache if cached_data: return cached_data - + # GitHub API search query query = f"{cve_id} poc OR proof of concept OR exploit" url = "https://api.github.com/search/code" headers = { "Accept": "application/vnd.github.v3+json", - "User-Agent": "NeuroRift/1.0" + "User-Agent": "NeuroRift/1.0", } - + # Add API key if available if "github" in self.api_keys: headers["Authorization"] = f"token {self.api_keys['github']}" - + async with aiohttp.ClientSession() as session: try: await self._wait_for_rate_limit("github") - async with session.get(url, params={"q": query}, headers=headers) as response: + async with session.get( + url, params={"q": query}, headers=headers + ) as response: if response.status == 200: data = await response.json() results = data.get("items", []) - + # Process results processed_results = [] for result in results: try: # Get file content content_url = result["url"] - async with session.get(content_url, headers=headers) as content_response: + async with session.get( + content_url, headers=headers + ) as content_response: if content_response.status == 200: content_data = await content_response.json() content = content_data.get("content", "") - + # Extract relevant information - processed_results.append({ - "repository": result["repository"]["full_name"], - "file_path": result["path"], - "url": result["html_url"], - "content": content, - "language": result.get("language", "Unknown"), - "stars": result["repository"].get("stargazers_count", 0), - "forks": result["repository"].get("forks_count", 0) - }) + processed_results.append( + { + "repository": result["repository"][ + "full_name" + ], + "file_path": result["path"], + "url": result["html_url"], + "content": content, + "language": result.get( + "language", "Unknown" + ), + "stars": result["repository"].get( + "stargazers_count", 0 + ), + "forks": result["repository"].get( + "forks_count", 0 + ), + } + ) except Exception as e: - self.logger.warning("Error processing GitHub result: %s", e) + self.logger.warning( + "Error processing GitHub result: %s", e + ) continue - + await self._save_to_cache(cache_path, processed_results) return processed_results elif response.status == 403: # Rate limit exceeded @@ -256,12 +300,12 @@ async def fetch_github_pocs(self, cve_id: str) -> List[Dict]: except (aiohttp.ClientError, asyncio.TimeoutError) as e: self.logger.error("Error searching GitHub: %s", e) return [] - + async def analyze_cve(self, cve_data: Dict) -> Dict: """Analyze CVE data using AI""" if not self.ai_wrapper: return cve_data - + prompt = f""" Analyze this CVE and provide: 1. Severity assessment (critical, high, medium, low) @@ -305,7 +349,7 @@ async def analyze_cve(self, cve_data: Dict) -> Dict: }} }} """ - + try: # Await the async AI call analysis = await self.ai_wrapper.generate(prompt) @@ -316,7 +360,7 @@ async def analyze_cve(self, cve_data: Dict) -> Dict: cve_data["ai_analysis"] = analysis_json except json.JSONDecodeError: # Extract JSON if wrapped in markdown - json_match = re.search(r'```json\n(.*?)\n```', analysis, re.DOTALL) + json_match = re.search(r"```json\n(.*?)\n```", analysis, re.DOTALL) if json_match: try: cve_data["ai_analysis"] = json.loads(json_match.group(1)) @@ -327,55 +371,63 @@ async def analyze_cve(self, cve_data: Dict) -> Dict: cve_data["ai_analysis"] = {"raw_analysis": analysis} except Exception as e: self.logger.error("Error analyzing CVE: %s", e) - + return cve_data async def search_cves(self, query: str) -> List[Dict[str, Any]]: """Search for CVEs based on a query string (e.g., product name and version)""" self.logger.info("Searching CVEs for query: %s", query) - + # This is a simplified search that uses the NVD feed we already fetched # In a real scenario, this might call an external API or search a local database results = [] - + # For now, let's fetch the feed and match # (Assuming we have a local cache or feed already loaded) cves = await self.fetch_nvd_feed() - + # Simple keyword matching keywords = query.lower().split() for cve in cves: - desc = cve.get("cve", {}).get("descriptions", [{}])[0].get("value", "").lower() + desc = ( + cve.get("cve", {}).get("descriptions", [{}])[0].get("value", "").lower() + ) if all(k in desc for k in keywords): - results.append({ - "id": cve.get("cve", {}).get("id"), - "description": desc, - "score": cve.get("cve", {}).get("metrics", {}).get("cvssMetricV31", [{}])[0].get("cvssData", {}).get("baseScore", "N/A") - }) - + results.append( + { + "id": cve.get("cve", {}).get("id"), + "description": desc, + "score": cve.get("cve", {}) + .get("metrics", {}) + .get("cvssMetricV31", [{}])[0] + .get("cvssData", {}) + .get("baseScore", "N/A"), + } + ) + return results - + async def collect_cves(self, target_info: Dict) -> Dict[str, Any]: """Collect and analyze CVEs for target""" self.console.print("[bold blue]Starting CVE collection...[/bold blue]") - + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), - console=self.console + console=self.console, ) as progress: # Fetch NVD feed task = progress.add_task("Fetching NVD feed...", total=None) cves = await self.fetch_nvd_feed() progress.update(task, completed=True) - + # Fetch Exploit-DB data task = progress.add_task("Fetching Exploit-DB data...", total=None) exploits = await self.fetch_exploit_db() progress.update(task, completed=True) - + # Match CVEs to target task = progress.add_task("Matching CVEs to target...", total=None) matched_cves = [] @@ -383,7 +435,7 @@ async def collect_cves(self, target_info: Dict) -> Dict[str, Any]: if self._matches_target(cve, target_info): matched_cves.append(cve) progress.update(task, completed=True) - + # Fetch GitHub PoCs task = progress.add_task("Fetching GitHub PoCs...", total=None) for cve in matched_cves: @@ -392,26 +444,26 @@ async def collect_cves(self, target_info: Dict) -> Dict[str, Any]: pocs = await self.fetch_github_pocs(cve_id) cve["github_pocs"] = pocs progress.update(task, completed=True) - + # Analyze CVEs task = progress.add_task("Analyzing CVEs...", total=None) for cve in matched_cves: cve = await self.analyze_cve(cve) progress.update(task, completed=True) - + # Save results task = progress.add_task("Saving results...", total=None) results = { "target_info": target_info, "cves": matched_cves, "exploits": exploits, - "generated_at": datetime.now().isoformat() + "generated_at": datetime.now().isoformat(), } await self._save_results(results) progress.update(task, completed=True) - + return results - + def _matches_target(self, cve: Dict, target_info: Dict) -> bool: """Check if CVE matches target""" # Get CPE strings @@ -419,41 +471,51 @@ def _matches_target(self, cve: Dict, target_info: Dict) -> bool: for node in cve.get("cve", {}).get("configurations", []): for cpe_match in node.get("cpeMatch", []): cpe_list.append(cpe_match.get("criteria")) - + # Check each CPE for cpe in cpe_list: if self._check_cpe_match(cpe, target_info): return True - + # Check description for target keywords - description = cve.get("cve", {}).get("descriptions", [{}])[0].get("value", "").lower() + description = ( + cve.get("cve", {}).get("descriptions", [{}])[0].get("value", "").lower() + ) target_keywords = [ target_info.get("name", "").lower(), target_info.get("vendor", "").lower(), target_info.get("product", "").lower(), - target_info.get("version", "").lower() + target_info.get("version", "").lower(), ] - + return any(keyword in description for keyword in target_keywords if keyword) - + def _check_cpe_match(self, cpe: str, target_info: Dict) -> bool: """Check if CPE string matches target""" if not cpe: return False - + # Parse CPE parts = cpe.split(":") if len(parts) < 5: return False - + # Check vendor - if target_info.get("vendor") and parts[3] != "*" and parts[3].lower() != target_info["vendor"].lower(): + if ( + target_info.get("vendor") + and parts[3] != "*" + and parts[3].lower() != target_info["vendor"].lower() + ): return False - + # Check product - if target_info.get("product") and parts[4] != "*" and parts[4].lower() != target_info["product"].lower(): + if ( + target_info.get("product") + and parts[4] != "*" + and parts[4].lower() != target_info["product"].lower() + ): return False - + # Check version if target_info.get("version") and parts[5] != "*": # Handle version ranges @@ -463,65 +525,69 @@ def _check_cpe_match(self, cpe: str, target_info: Dict) -> bool: return False elif parts[5] != target_info["version"]: return False - + return True - + async def _save_results(self, results: Dict): """Save results to file""" # Create results directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") results_dir = self.data_dir / timestamp results_dir.mkdir(exist_ok=True) - + # Save JSON json_path = results_dir / "results.json" - async with aiofiles.open(json_path, 'w') as f: + async with aiofiles.open(json_path, "w") as f: await f.write(json.dumps(results, indent=2)) - + # Save Markdown report md_path = results_dir / "report.md" - async with aiofiles.open(md_path, 'w') as f: + async with aiofiles.open(md_path, "w") as f: await f.write(self._generate_markdown_report(results)) - + def _generate_markdown_report(self, results: Dict) -> str: """Generate Markdown report""" report = [] - + # Add header report.append("# CVE Analysis Report") report.append(f"Generated: {results['generated_at']}") report.append("") - + # Add target info report.append("## Target Information") for key, value in results["target_info"].items(): report.append(f"- **{key}**: {value}") report.append("") - + # Add findings summary report.append("## Findings Summary") severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} for cve in results["cves"]: severity = cve.get("ai_analysis", {}).get("severity", "info").lower() severity_counts[severity] += 1 - + report.append("### Severity Distribution") for severity, count in severity_counts.items(): report.append(f"- **{severity.title()}**: {count}") report.append("") - + # Add detailed findings report.append("## Detailed Findings") for cve in results["cves"]: cve_id = cve.get("cve", {}).get("id", "Unknown") severity = cve.get("ai_analysis", {}).get("severity", "info").lower() - + report.append(f"### {cve_id} ({severity.title()})") - + # Add description - description = cve.get("cve", {}).get("descriptions", [{}])[0].get("value", "No description available") + description = ( + cve.get("cve", {}) + .get("descriptions", [{}])[0] + .get("value", "No description available") + ) report.append(f"**Description**: {description}") - + # Add analysis analysis = cve.get("ai_analysis", {}) if analysis: @@ -537,14 +603,16 @@ def _generate_markdown_report(self, results: Dict) -> str: report.append(f" - {k}: {v}") else: report.append(f"- **{key}**: {value}") - + # Add PoCs pocs = cve.get("github_pocs", []) if pocs: report.append("\n**Proof of Concepts**:") for poc in pocs: - report.append(f"- [{poc['repository']}/{poc['file_path']}]({poc['url']})") - + report.append( + f"- [{poc['repository']}/{poc['file_path']}]({poc['url']})" + ) + report.append("") - - return "\n".join(report) \ No newline at end of file + + return "\n".join(report) diff --git a/modules/darkweb/__init__.py b/modules/darkweb/__init__.py index 1dcc4b5..48321b0 100644 --- a/modules/darkweb/__init__.py +++ b/modules/darkweb/__init__.py @@ -5,26 +5,38 @@ # Make Robin imports optional - don't break NeuroRift if langchain is missing try: - from .robin.runner import run_darkweb_osint, ROBIN_DEFAULT_MODEL, get_robin_model_choices + from .robin.runner import ( + run_darkweb_osint, + ROBIN_DEFAULT_MODEL, + get_robin_model_choices, + ) + ROBIN_AVAILABLE = True except ImportError as e: import logging + logging.getLogger(__name__).warning( "Robin dark web module not available. Install langchain dependencies: " "pip install langchain-core langchain-openai langchain-ollama" ) ROBIN_AVAILABLE = False - + # Provide stub functions def run_darkweb_osint(*args, **kwargs): raise ImportError( "Robin module requires langchain dependencies. " "Install with: pip install langchain-core langchain-openai langchain-ollama langchain-anthropic langchain-google-genai langchain-community" ) - + ROBIN_DEFAULT_MODEL = "gpt4o" - + def get_robin_model_choices(): return ["gpt4o"] -__all__ = ['run_darkweb_osint', 'ROBIN_DEFAULT_MODEL', 'get_robin_model_choices', 'ROBIN_AVAILABLE'] + +__all__ = [ + "run_darkweb_osint", + "ROBIN_DEFAULT_MODEL", + "get_robin_model_choices", + "ROBIN_AVAILABLE", +] diff --git a/modules/darkweb/robin/runner.py b/modules/darkweb/robin/runner.py index 424c38b..41c7866 100644 --- a/modules/darkweb/robin/runner.py +++ b/modules/darkweb/robin/runner.py @@ -101,10 +101,14 @@ def run_darkweb_osint( refined_query = refine_query(llm, query) logger.info("Querying dark web indices via Tor...") - search_results = get_search_results(refined_query.replace(" ", "+"), max_workers=threads) + search_results = get_search_results( + refined_query.replace(" ", "+"), max_workers=threads + ) if not search_results: - logger.warning("No search results returned. Verify Tor connectivity and try again.") + logger.warning( + "No search results returned. Verify Tor connectivity and try again." + ) return { "refined_query": refined_query, "search_results": [], @@ -138,4 +142,3 @@ def run_darkweb_osint( def get_robin_model_choices(): """Expose Robin model choices for CLI integration.""" return get_model_choices() - diff --git a/modules/exploit/exploit_module.py b/modules/exploit/exploit_module.py index 3df4b1d..7c4e647 100644 --- a/modules/exploit/exploit_module.py +++ b/modules/exploit/exploit_module.py @@ -15,13 +15,14 @@ from modules.exploit_generator.exploit_generator import ExploitGenerator from modules.exploit_testing.exploit_testing import ExploitTester + class ExploitModule: def __init__(self, base_dir: Path, ai_analyzer: Any): self.base_dir = base_dir self.ai_analyzer = ai_analyzer self.logger = logging.getLogger(__name__) self.console = Console() - + # Initialize sub-modules self.cve_collector = CVECollector(base_dir, ai_wrapper=ai_analyzer.ollama) # ExploitGenerator expects a 'llm_engine' which has a 'query' method. @@ -31,28 +32,36 @@ def __init__(self, base_dir: Path, ai_analyzer: Any): self.exploit_generator = ExploitGenerator(base_dir, ai_analyzer.ollama) self.exploit_tester = ExploitTester(base_dir) - async def run_exploit_pipeline(self, target: str, recon_data: Dict[str, Any], output_dir: Optional[Path] = None, use_ai: bool = True) -> Dict[str, Any]: + async def run_exploit_pipeline( + self, + target: str, + recon_data: Dict[str, Any], + output_dir: Optional[Path] = None, + use_ai: bool = True, + ) -> Dict[str, Any]: """ Run the complete exploit pipeline: CVE Mapping -> Selection -> Generation -> Testing (Optional) """ self.logger.info("Starting exploit pipeline for %s", target) - + results = { "target": target, "vulnerabilities": [], "exploits": [], "ai_decisions": {}, - "errors": [] + "errors": [], } try: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # 1. CVE Collection - task = progress.add_task("Collecting vulnerability data (CVECollector)...", total=None) + task = progress.add_task( + "Collecting vulnerability data (CVECollector)...", total=None + ) # CVECollector.search_cves might need specific recon info (service, version) # Let's try to map recon services to CVE searches vulns = await self._map_vulnerabilities(recon_data) @@ -60,22 +69,30 @@ async def run_exploit_pipeline(self, target: str, recon_data: Dict[str, Any], ou progress.update(task, completed=True) if not vulns: - self.console.print("[yellow]No vulnerabilities identified to exploit.[/yellow]") + self.console.print( + "[yellow]No vulnerabilities identified to exploit.[/yellow]" + ) return results # 2. AI Exploit Selection/Orchestration if use_ai: - task = progress.add_task("AI selecting best exploit candidates...", total=None) + task = progress.add_task( + "AI selecting best exploit candidates...", total=None + ) selected_vulns = await self._ai_select_exploits(target, vulns) results["ai_decisions"]["selected"] = selected_vulns progress.update(task, completed=True) else: - selected_vulns = vulns[:3] # Default to top 3 if AI is off + selected_vulns = vulns[:3] # Default to top 3 if AI is off # 3. Exploit Generation - task = progress.add_task("Generating exploits (ExploitGenerator)...", total=None) + task = progress.add_task( + "Generating exploits (ExploitGenerator)...", total=None + ) for vuln in selected_vulns: - exploit = await self.exploit_generator.generate_exploit(vuln, recon_data) + exploit = await self.exploit_generator.generate_exploit( + vuln, recon_data + ) results["exploits"].append(exploit) progress.update(task, completed=True) @@ -86,11 +103,13 @@ async def run_exploit_pipeline(self, target: str, recon_data: Dict[str, Any], ou results["errors"].append(str(e)) return results - async def _map_vulnerabilities(self, recon_data: Dict[str, Any]) -> List[Dict[str, Any]]: + async def _map_vulnerabilities( + self, recon_data: Dict[str, Any] + ) -> List[Dict[str, Any]]: """Map recon results to CVEs""" vulnerabilities = [] services = recon_data.get("services", []) - + for service in services: name = service.get("name") version = service.get("version") @@ -100,18 +119,22 @@ async def _map_vulnerabilities(self, recon_data: Dict[str, Any]) -> List[Dict[st cves = await self.cve_collector.search_cves(query) if cves: # Filter and format - for cve in cves[:5]: # Take top 5 per service - vulnerabilities.append({ - "cve_id": cve.get("id"), - "description": cve.get("description"), - "affected_software": name, - "cvss_score": cve.get("score"), - "service_info": service - }) - + for cve in cves[:5]: # Take top 5 per service + vulnerabilities.append( + { + "cve_id": cve.get("id"), + "description": cve.get("description"), + "affected_software": name, + "cvss_score": cve.get("score"), + "service_info": service, + } + ) + return vulnerabilities - async def _ai_select_exploits(self, target: str, vulnerabilities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + async def _ai_select_exploits( + self, target: str, vulnerabilities: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """Use AI to select the most viable exploits to generate""" prompt = f""" Analyze the following vulnerabilities found on target: {target} @@ -124,17 +147,20 @@ async def _ai_select_exploits(self, target: str, vulnerabilities: List[Dict[str, Return the selection as a JSON array of the original vulnerability objects. """ - + system_prompt = "You are an exploit orchestration expert. Select the best candidates for exploit generation." - + try: - response = await self.ai_analyzer.ollama.generate(prompt, system_prompt=system_prompt) + response = await self.ai_analyzer.ollama.generate( + prompt, system_prompt=system_prompt + ) if response: import re + try: return json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'\[.*\]', response, re.DOTALL) + json_match = re.search(r"\[.*\]", response, re.DOTALL) if json_match: return json.loads(json_match.group(0)) return vulnerabilities[:3] diff --git a/modules/exploit_generator/__init__.py b/modules/exploit_generator/__init__.py index 99457d6..c4ce81a 100755 --- a/modules/exploit_generator/__init__.py +++ b/modules/exploit_generator/__init__.py @@ -4,4 +4,4 @@ from .exploit_generator import ExploitGenerator -__all__ = ['ExploitGenerator'] \ No newline at end of file +__all__ = ["ExploitGenerator"] diff --git a/modules/exploit_generator/exploit_generator.py b/modules/exploit_generator/exploit_generator.py index 8db2ee2..979b960 100755 --- a/modules/exploit_generator/exploit_generator.py +++ b/modules/exploit_generator/exploit_generator.py @@ -14,6 +14,7 @@ from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn + class ExploitGenerator: def __init__(self, base_dir: Path, llm_engine: Any): self.base_dir = base_dir @@ -24,33 +25,35 @@ def __init__(self, base_dir: Path, llm_engine: Any): self.template_dir = self.base_dir / "templates" / "exploits" self.exploit_dir.mkdir(parents=True, exist_ok=True) self.template_dir.mkdir(parents=True, exist_ok=True) - + # Load exploit templates self.templates = self._load_templates() - + def _load_templates(self) -> Dict[str, str]: """Load exploit templates""" templates = {} try: for template_file in self.template_dir.glob("*.py"): - with open(template_file, 'r') as f: + with open(template_file, "r") as f: templates[template_file.stem] = f.read() except Exception as e: self.logger.error(f"Error loading templates: {e}") return templates - - async def generate_exploit(self, cve_data: Dict, recon_data: Dict) -> Dict[str, Any]: + + async def generate_exploit( + self, cve_data: Dict, recon_data: Dict + ) -> Dict[str, Any]: """Generate exploit based on CVE and recon data""" with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # Prepare prompt with context task = progress.add_task("Preparing exploit generation...", total=None) prompt = self._build_exploit_prompt(cve_data, recon_data) progress.update(task, completed=True) - + # Get AI response task = progress.add_task("Generating exploit code...", total=None) try: @@ -67,41 +70,43 @@ async def generate_exploit(self, cve_data: Dict, recon_data: Dict) -> Dict[str, 5. Add timeout and connection limits 6. Validate all user input 7. Use safe defaults - 8. Include proper logging""" + 8. Include proper logging""", ) progress.update(task, completed=True) except Exception as e: self.logger.error(f"Error generating exploit: {e}") return {"error": f"Failed to generate exploit: {str(e)}"} - + if not response: return {"error": "Failed to generate exploit - empty response"} - + # Parse and validate the response task = progress.add_task("Validating exploit code...", total=None) exploit_data = self._parse_exploit_response(response) if not exploit_data: return {"error": "Failed to parse exploit response"} - + # Validate exploit safety validation = self.validate_exploit(exploit_data) if validation["issues"]: - self.logger.warning(f"Exploit validation issues: {validation['issues']}") + self.logger.warning( + f"Exploit validation issues: {validation['issues']}" + ) exploit_data["validation"] = validation - + # Save the exploit task = progress.add_task("Saving exploit...", total=None) exploit_path = self._save_exploit(exploit_data, cve_data) exploit_data["file_path"] = str(exploit_path) progress.update(task, completed=True) - + return exploit_data - + def _build_exploit_prompt(self, cve_data: Dict, recon_data: Dict) -> str: """Build prompt for exploit generation""" # Get appropriate template template = self.templates.get("base", "") - + return f""" Generate a proof-of-concept exploit for this vulnerability: @@ -131,78 +136,84 @@ def _build_exploit_prompt(self, cve_data: Dict, recon_data: Dict) -> str: Generate complete, working Python code: """ - + def _parse_exploit_response(self, response: str) -> Optional[Dict]: """Parse and validate exploit response""" try: # Extract code block if present - code_match = re.search(r'```(?:python)?\n(.*?)\n```', response, re.DOTALL) + code_match = re.search(r"```(?:python)?\n(.*?)\n```", response, re.DOTALL) if code_match: code = code_match.group(1) else: code = response - + # Validate Python syntax ast.parse(code) - + # Extract metadata metadata = { "code": code, "generated_at": datetime.now().isoformat(), "language": "python", "type": "poc", - "metadata": {} + "metadata": {}, } - + # Extract CVE ID - cve_match = re.search(r'CVE-\d{4}-\d+', response) + cve_match = re.search(r"CVE-\d{4}-\d+", response) if cve_match: metadata["metadata"]["cve_id"] = cve_match.group(0) - + # Extract author - author_match = re.search(r'Author:\s*(.*?)(?:\n|$)', response) + author_match = re.search(r"Author:\s*(.*?)(?:\n|$)", response) if author_match: metadata["metadata"]["author"] = author_match.group(1).strip() - + # Extract description - desc_match = re.search(r'Description:\s*(.*?)(?:\n\n|\Z)', response, re.DOTALL) + desc_match = re.search( + r"Description:\s*(.*?)(?:\n\n|\Z)", response, re.DOTALL + ) if desc_match: metadata["metadata"]["description"] = desc_match.group(1).strip() - + # Extract requirements - req_match = re.search(r'Requirements:\s*(.*?)(?:\n\n|\Z)', response, re.DOTALL) + req_match = re.search( + r"Requirements:\s*(.*?)(?:\n\n|\Z)", response, re.DOTALL + ) if req_match: metadata["metadata"]["requirements"] = req_match.group(1).strip() - + # Extract references - ref_match = re.search(r'References:\s*(.*?)(?:\n\n|\Z)', response, re.DOTALL) + ref_match = re.search( + r"References:\s*(.*?)(?:\n\n|\Z)", response, re.DOTALL + ) if ref_match: metadata["metadata"]["references"] = ref_match.group(1).strip() - + return metadata - + except SyntaxError as e: self.logger.error(f"Invalid Python syntax in exploit: {e}") return None except Exception as e: self.logger.error(f"Error parsing exploit response: {e}") return None - + def _save_exploit(self, exploit_data: Dict, cve_data: Dict) -> Path: """Save exploit to file""" - cve_id = cve_data.get('cve_id', 'unknown') + cve_id = cve_data.get("cve_id", "unknown") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{cve_id}_{timestamp}.py" filepath = self.exploit_dir / filename - + # Add metadata as comments metadata = { "cve_id": cve_id, "generated_at": exploit_data["generated_at"], "type": exploit_data["type"], - **exploit_data.get("metadata", {}) + **exploit_data.get("metadata", {}), } - + content = f"""#!/usr/bin/env python3 # NeuroRift Generated Exploit # CVE: {metadata['cve_id']} @@ -216,24 +227,24 @@ def _save_exploit(self, exploit_data: Dict, cve_data: Dict) -> Path: {exploit_data['code']} """ - - with open(filepath, 'w') as f: + + with open(filepath, "w") as f: f.write(content) - + return filepath - + async def generate_metasploit_module(self, exploit_data: Dict) -> Optional[Path]: """Generate Metasploit module from exploit""" with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: task = progress.add_task("Converting to Metasploit module...", total=None) - + # Get Metasploit template template = self.templates.get("metasploit", "") - + prompt = f""" Convert this Python exploit to a Metasploit module: @@ -251,33 +262,33 @@ async def generate_metasploit_module(self, exploit_data: Dict) -> Optional[Path] 6. Include check method 7. Add proper payload handling """ - + try: response = await self.llm.query( prompt=prompt, system_prompt="""You are a Metasploit module developer. Convert exploits to proper Metasploit module format. - Follow Metasploit module best practices.""" + Follow Metasploit module best practices.""", ) progress.update(task, completed=True) except Exception as e: self.logger.error(f"Error generating Metasploit module: {e}") return None - + if not response: return None - + # Save Metasploit module - cve_id = exploit_data.get('metadata', {}).get('cve_id', 'unknown') + cve_id = exploit_data.get("metadata", {}).get("cve_id", "unknown") filename = f"{cve_id}.rb" filepath = self.exploit_dir / "metasploit" / filename filepath.parent.mkdir(exist_ok=True) - - with open(filepath, 'w') as f: + + with open(filepath, "w") as f: f.write(response) - + return filepath - + def validate_exploit(self, exploit_data: Dict) -> Dict[str, Any]: """Validate generated exploit""" validation = { @@ -288,57 +299,59 @@ def validate_exploit(self, exploit_data: Dict) -> Dict[str, Any]: "has_input_validation": False, "has_logging": False, "has_safe_defaults": False, - "issues": [] + "issues": [], } - + try: # Check Python syntax - ast.parse(exploit_data['code']) + ast.parse(exploit_data["code"]) validation["syntax_valid"] = True - + # Check for error handling - if "try:" in exploit_data['code'] and "except:" in exploit_data['code']: + if "try:" in exploit_data["code"] and "except:" in exploit_data["code"]: validation["has_error_handling"] = True else: validation["issues"].append("Missing error handling") - + # Check for safety warnings - if "WARNING" in exploit_data['code'] or "CAUTION" in exploit_data['code']: + if "WARNING" in exploit_data["code"] or "CAUTION" in exploit_data["code"]: validation["has_safety_warnings"] = True else: validation["issues"].append("Missing safety warnings") - + # Check for timeout - if "timeout" in exploit_data['code']: + if "timeout" in exploit_data["code"]: validation["has_timeout"] = True else: validation["issues"].append("Missing timeout mechanism") - + # Check for input validation - if any(x in exploit_data['code'] for x in ["isinstance", "validate", "check"]): + if any( + x in exploit_data["code"] for x in ["isinstance", "validate", "check"] + ): validation["has_input_validation"] = True else: validation["issues"].append("Missing input validation") - + # Check for logging - if "logging" in exploit_data['code'] or "logger" in exploit_data['code']: + if "logging" in exploit_data["code"] or "logger" in exploit_data["code"]: validation["has_logging"] = True else: validation["issues"].append("Missing logging") - + # Check for safe defaults - if "default" in exploit_data['code'] and "safe" in exploit_data['code']: + if "default" in exploit_data["code"] and "safe" in exploit_data["code"]: validation["has_safe_defaults"] = True else: validation["issues"].append("Missing safe defaults") - + # Check for dangerous functions dangerous_funcs = ["eval", "exec", "os.system", "subprocess.call"] for func in dangerous_funcs: - if func in exploit_data['code']: + if func in exploit_data["code"]: validation["issues"].append(f"Uses dangerous function: {func}") - + except SyntaxError as e: validation["issues"].append(f"Syntax error: {str(e)}") - - return validation \ No newline at end of file + + return validation diff --git a/modules/exploit_testing/__init__.py b/modules/exploit_testing/__init__.py index 7f96e0a..5bd5c48 100755 --- a/modules/exploit_testing/__init__.py +++ b/modules/exploit_testing/__init__.py @@ -4,4 +4,4 @@ from .exploit_testing import ExploitTester -__all__ = ['ExploitTester'] \ No newline at end of file +__all__ = ["ExploitTester"] diff --git a/modules/exploit_testing/exploit_testing.py b/modules/exploit_testing/exploit_testing.py index c0cb4d7..258f3c1 100644 --- a/modules/exploit_testing/exploit_testing.py +++ b/modules/exploit_testing/exploit_testing.py @@ -10,6 +10,7 @@ import tempfile from pathlib import Path from typing import Dict, List, Optional, Any + try: import docker except ImportError: @@ -17,6 +18,7 @@ from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn + class ExploitTester: def __init__(self, base_dir: Path): self.base_dir = base_dir @@ -24,24 +26,26 @@ def __init__(self, base_dir: Path): self.console = Console() self.test_dir = self.base_dir / "data" / "test_results" self.test_dir.mkdir(parents=True, exist_ok=True) - + # Initialize Docker client try: if docker: self.docker_client = docker.from_env() else: self.docker_client = None - self.logger.warning("Docker module not found. Exploit testing in Docker will be unavailable.") + self.logger.warning( + "Docker module not found. Exploit testing in Docker will be unavailable." + ) except Exception as e: self.logger.error(f"Failed to initialize Docker client: {e}") self.docker_client = None - + def test_exploit(self, exploit_path: Path, target_info: Dict) -> Dict[str, Any]: """Test an exploit against a target""" with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # Validate exploit task = progress.add_task("Validating exploit...", total=None) @@ -50,10 +54,10 @@ def test_exploit(self, exploit_path: Path, target_info: Dict) -> Dict[str, Any]: return { "success": False, "error": "Exploit validation failed", - "validation": validation + "validation": validation, } progress.update(task, completed=True) - + # Create test environment task = progress.add_task("Creating test environment...", total=None) try: @@ -63,9 +67,9 @@ def test_exploit(self, exploit_path: Path, target_info: Dict) -> Dict[str, Any]: self.logger.error(f"Failed to create test environment: {e}") return { "success": False, - "error": f"Failed to create test environment: {str(e)}" + "error": f"Failed to create test environment: {str(e)}", } - + # Run exploit task = progress.add_task("Running exploit...", total=None) try: @@ -73,22 +77,15 @@ def test_exploit(self, exploit_path: Path, target_info: Dict) -> Dict[str, Any]: progress.update(task, completed=True) except Exception as e: self.logger.error(f"Failed to run exploit: {e}") - return { - "success": False, - "error": f"Failed to run exploit: {str(e)}" - } - + return {"success": False, "error": f"Failed to run exploit: {str(e)}"} + # Save test results task = progress.add_task("Saving test results...", total=None) result_path = self._save_test_results(result, exploit_path) progress.update(task, completed=True) - - return { - "success": True, - "result": result, - "result_path": str(result_path) - } - + + return {"success": True, "result": result, "result_path": str(result_path)} + def _validate_exploit(self, exploit_path: Path) -> Dict[str, Any]: """Validate exploit code""" validation = { @@ -100,151 +97,145 @@ def _validate_exploit(self, exploit_path: Path) -> Dict[str, Any]: "has_input_validation": False, "has_logging": False, "has_safe_defaults": False, - "issues": [] + "issues": [], } - + try: # Read exploit code - with open(exploit_path, 'r') as f: + with open(exploit_path, "r") as f: code = f.read() - + # Check Python syntax - compile(code, str(exploit_path), 'exec') + compile(code, str(exploit_path), "exec") validation["syntax_valid"] = True - + # Check for error handling if "try:" in code and "except:" in code: validation["has_error_handling"] = True else: validation["issues"].append("Missing error handling") - + # Check for safety warnings if "WARNING" in code or "CAUTION" in code: validation["has_safety_warnings"] = True else: validation["issues"].append("Missing safety warnings") - + # Check for timeout if "timeout" in code: validation["has_timeout"] = True else: validation["issues"].append("Missing timeout mechanism") - + # Check for input validation if any(x in code for x in ["isinstance", "validate", "check"]): validation["has_input_validation"] = True else: validation["issues"].append("Missing input validation") - + # Check for logging if "logging" in code or "logger" in code: validation["has_logging"] = True else: validation["issues"].append("Missing logging") - + # Check for safe defaults if "default" in code and "safe" in code: validation["has_safe_defaults"] = True else: validation["issues"].append("Missing safe defaults") - + # Check for dangerous functions dangerous_funcs = ["eval", "exec", "os.system", "subprocess.call"] for func in dangerous_funcs: if func in code: validation["issues"].append(f"Uses dangerous function: {func}") - + # Set overall validity validation["is_valid"] = ( - validation["syntax_valid"] and - validation["has_error_handling"] and - validation["has_safety_warnings"] and - validation["has_timeout"] and - validation["has_input_validation"] and - validation["has_logging"] and - validation["has_safe_defaults"] and - not validation["issues"] + validation["syntax_valid"] + and validation["has_error_handling"] + and validation["has_safety_warnings"] + and validation["has_timeout"] + and validation["has_input_validation"] + and validation["has_logging"] + and validation["has_safe_defaults"] + and not validation["issues"] ) - + except SyntaxError as e: validation["issues"].append(f"Syntax error: {str(e)}") except Exception as e: validation["issues"].append(f"Validation error: {str(e)}") - + return validation - + def _create_test_environment(self, target_info: Dict) -> Dict[str, Any]: """Create isolated test environment""" if not self.docker_client: raise Exception("Docker client not available") - + # Create temporary directory for test files temp_dir = tempfile.mkdtemp() - + # Create Docker container container = self.docker_client.containers.run( image="python:3.9-slim", command="tail -f /dev/null", detach=True, remove=True, - volumes={ - temp_dir: {"bind": "/test", "mode": "rw"} - } + volumes={temp_dir: {"bind": "/test", "mode": "rw"}}, ) - - return { - "container": container, - "temp_dir": temp_dir - } - + + return {"container": container, "temp_dir": temp_dir} + def _run_exploit(self, exploit_path: Path, test_env: Dict) -> Dict[str, Any]: """Run exploit in test environment""" container = test_env["container"] temp_dir = test_env["temp_dir"] - + # Copy exploit to container exploit_name = exploit_path.name container.put_archive("/test", exploit_path.read_bytes()) - + # Run exploit try: result = container.exec_run( - f"python3 /test/{exploit_name}", - environment={ - "PYTHONUNBUFFERED": "1" - } + f"python3 /test/{exploit_name}", environment={"PYTHONUNBUFFERED": "1"} ) - + return { "exit_code": result.exit_code, "output": result.output.decode(), - "error": result.error.decode() if result.error else None + "error": result.error.decode() if result.error else None, } - + finally: # Clean up container.stop() container.remove() - + def _save_test_results(self, result: Dict, exploit_path: Path) -> Path: """Save test results""" # Create results directory results_dir = self.test_dir / exploit_path.stem results_dir.mkdir(exist_ok=True) - + # Save results result_path = results_dir / "test_result.json" - with open(result_path, 'w') as f: + with open(result_path, "w") as f: json.dump(result, f, indent=2) - + return result_path - - def test_metasploit_module(self, module_path: Path, target_info: Dict) -> Dict[str, Any]: + + def test_metasploit_module( + self, module_path: Path, target_info: Dict + ) -> Dict[str, Any]: """Test a Metasploit module""" with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # Validate module task = progress.add_task("Validating Metasploit module...", total=None) @@ -253,12 +244,14 @@ def test_metasploit_module(self, module_path: Path, target_info: Dict) -> Dict[s return { "success": False, "error": "Module validation failed", - "validation": validation + "validation": validation, } progress.update(task, completed=True) - + # Create test environment - task = progress.add_task("Creating Metasploit test environment...", total=None) + task = progress.add_task( + "Creating Metasploit test environment...", total=None + ) try: test_env = self._create_metasploit_environment(target_info) progress.update(task, completed=True) @@ -266,9 +259,9 @@ def test_metasploit_module(self, module_path: Path, target_info: Dict) -> Dict[s self.logger.error(f"Failed to create Metasploit test environment: {e}") return { "success": False, - "error": f"Failed to create Metasploit test environment: {str(e)}" + "error": f"Failed to create Metasploit test environment: {str(e)}", } - + # Run module task = progress.add_task("Running Metasploit module...", total=None) try: @@ -278,20 +271,16 @@ def test_metasploit_module(self, module_path: Path, target_info: Dict) -> Dict[s self.logger.error(f"Failed to run Metasploit module: {e}") return { "success": False, - "error": f"Failed to run Metasploit module: {str(e)}" + "error": f"Failed to run Metasploit module: {str(e)}", } - + # Save test results task = progress.add_task("Saving test results...", total=None) result_path = self._save_test_results(result, module_path) progress.update(task, completed=True) - - return { - "success": True, - "result": result, - "result_path": str(result_path) - } - + + return {"success": True, "result": result, "result_path": str(result_path)} + def _validate_metasploit_module(self, module_path: Path) -> Dict[str, Any]: """Validate Metasploit module""" validation = { @@ -301,106 +290,103 @@ def _validate_metasploit_module(self, module_path: Path) -> Dict[str, Any]: "has_check": False, "has_exploit": False, "has_payload": False, - "issues": [] + "issues": [], } - + try: # Read module code - with open(module_path, 'r') as f: + with open(module_path, "r") as f: code = f.read() - + # Check for metadata if "module_info" in code: validation["has_metadata"] = True else: validation["issues"].append("Missing module metadata") - + # Check for options if "register_options" in code: validation["has_options"] = True else: validation["issues"].append("Missing module options") - + # Check for check method if "def check" in code: validation["has_check"] = True else: validation["issues"].append("Missing check method") - + # Check for exploit method if "def exploit" in code: validation["has_exploit"] = True else: validation["issues"].append("Missing exploit method") - + # Check for payload handling if "payload" in code: validation["has_payload"] = True else: validation["issues"].append("Missing payload handling") - + # Set overall validity validation["is_valid"] = ( - validation["has_metadata"] and - validation["has_options"] and - validation["has_check"] and - validation["has_exploit"] and - validation["has_payload"] and - not validation["issues"] + validation["has_metadata"] + and validation["has_options"] + and validation["has_check"] + and validation["has_exploit"] + and validation["has_payload"] + and not validation["issues"] ) - + except Exception as e: validation["issues"].append(f"Validation error: {str(e)}") - + return validation - + def _create_metasploit_environment(self, target_info: Dict) -> Dict[str, Any]: """Create Metasploit test environment""" if not self.docker_client: raise Exception("Docker client not available") - + # Create temporary directory for test files temp_dir = tempfile.mkdtemp() - + # Create Docker container with Metasploit container = self.docker_client.containers.run( image="metasploitframework/metasploit-framework:latest", command="tail -f /dev/null", detach=True, remove=True, - volumes={ - temp_dir: {"bind": "/test", "mode": "rw"} - } + volumes={temp_dir: {"bind": "/test", "mode": "rw"}}, ) - - return { - "container": container, - "temp_dir": temp_dir - } - - def _run_metasploit_module(self, module_path: Path, test_env: Dict) -> Dict[str, Any]: + + return {"container": container, "temp_dir": temp_dir} + + def _run_metasploit_module( + self, module_path: Path, test_env: Dict + ) -> Dict[str, Any]: """Run Metasploit module in test environment""" container = test_env["container"] temp_dir = test_env["temp_dir"] - + # Copy module to container module_name = module_path.name container.put_archive("/test", module_path.read_bytes()) - + # Run module try: # Load module result = container.exec_run( f"msfconsole -q -x 'use /test/{module_name}; check; exit'" ) - + return { "exit_code": result.exit_code, "output": result.output.decode(), - "error": result.error.decode() if result.error else None + "error": result.error.decode() if result.error else None, } - + finally: # Clean up container.stop() - container.remove() \ No newline at end of file + container.remove() diff --git a/modules/orchestration/__init__.py b/modules/orchestration/__init__.py index 6758a6d..14a7581 100644 --- a/modules/orchestration/__init__.py +++ b/modules/orchestration/__init__.py @@ -1,4 +1,10 @@ from .execution_manager import ExecutionManager from .data_models import ToolExecutionResult, SessionContext, ScanRequest, Finding -__all__ = ["ExecutionManager", "ToolExecutionResult", "SessionContext", "ScanRequest", "Finding"] +__all__ = [ + "ExecutionManager", + "ToolExecutionResult", + "SessionContext", + "ScanRequest", + "Finding", +] diff --git a/modules/orchestration/data_models.py b/modules/orchestration/data_models.py index 5e4df2e..5120b29 100644 --- a/modules/orchestration/data_models.py +++ b/modules/orchestration/data_models.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from modules.tools.base import ToolMode + class Finding(BaseModel): title: str severity: str @@ -11,25 +12,28 @@ class Finding(BaseModel): timestamp: datetime = Field(default_factory=datetime.now) details: Dict[str, Any] = {} + class ToolExecutionResult(BaseModel): tool_name: str command: str start_time: datetime end_time: datetime duration_seconds: float - status: str # "success", "failed", "cancelled" + status: str # "success", "failed", "cancelled" raw_output: str structured_output: Dict[str, Any] = {} error: Optional[str] = None findings: List[Finding] = [] + class SessionContext(BaseModel): session_id: str mode: ToolMode target: str history: List[ToolExecutionResult] = [] created_at: datetime = Field(default_factory=datetime.now) - + + class ScanRequest(BaseModel): tool_name: str target: str diff --git a/modules/orchestration/execution_manager.py b/modules/orchestration/execution_manager.py index cb38888..d534b66 100644 --- a/modules/orchestration/execution_manager.py +++ b/modules/orchestration/execution_manager.py @@ -6,7 +6,11 @@ from typing import Dict, Optional, List, Any from pathlib import Path -from modules.orchestration.data_models import ToolExecutionResult, ScanRequest, SessionContext +from modules.orchestration.data_models import ( + ToolExecutionResult, + ScanRequest, + SessionContext, +) from modules.tools.base import BaseTool, ToolMode from modules.tools.wrappers.amass import AmassTool from modules.tools.wrappers.masscan import MasscanTool @@ -19,6 +23,7 @@ from modules.tools.wrappers.mitmproxy import MitmproxyTool from modules.tools.wrappers.wireshark import WiresharkTool + class ExecutionManager: def __init__(self, session_manager=None): self.logger = logging.getLogger(__name__) @@ -29,12 +34,22 @@ def __init__(self, session_manager=None): def _register_tools(self) -> Dict[str, BaseTool]: # Initialize all available tools tools = [ - AmassTool(), MasscanTool(), NmapTool(), UnicornscanTool(), IkeScanTool(), - SqlmapTool(), MetasploitTool(), NetcatTool(), MitmproxyTool(), WiresharkTool() + AmassTool(), + MasscanTool(), + NmapTool(), + UnicornscanTool(), + IkeScanTool(), + SqlmapTool(), + MetasploitTool(), + NetcatTool(), + MitmproxyTool(), + WiresharkTool(), ] return {t.name: t for t in tools} - async def execute_tool(self, request: ScanRequest, context: SessionContext) -> ToolExecutionResult: + async def execute_tool( + self, request: ScanRequest, context: SessionContext + ) -> ToolExecutionResult: tool = self.tools.get(request.tool_name) if not tool: raise ValueError(f"Tool {request.tool_name} not found") @@ -43,14 +58,19 @@ async def execute_tool(self, request: ScanRequest, context: SessionContext) -> T current_mode = request.mode_override or context.mode if tool.mode == ToolMode.OFFENSIVE and current_mode != ToolMode.OFFENSIVE: # Check if attempting to run offensive tool in defensive mode - if hasattr(current_mode, 'value'): - if current_mode.value != "offensive": # Strict check - raise PermissionError(f"Cannot run offensive tool {tool.name} in {current_mode} mode") + if hasattr(current_mode, "value"): + if current_mode.value != "offensive": # Strict check + raise PermissionError( + f"Cannot run offensive tool {tool.name} in {current_mode} mode" + ) elif current_mode != "offensive": - raise PermissionError(f"Cannot run offensive tool {tool.name} in {current_mode} mode") + raise PermissionError( + f"Cannot run offensive tool {tool.name} in {current_mode} mode" + ) # Input Validation from modules.tools.base import ToolInput + tool_input = ToolInput(target=request.target, args=request.args) if not tool.validate_input(tool_input): raise ValueError(f"Invalid input for tool {tool.name}") @@ -58,7 +78,7 @@ async def execute_tool(self, request: ScanRequest, context: SessionContext) -> T # Build Command cmd_list = tool.build_command(tool_input) cmd_str = shlex.join(cmd_list) - + self.logger.info(f"Executing: {cmd_str}") start_time = datetime.now() @@ -68,20 +88,20 @@ async def execute_tool(self, request: ScanRequest, context: SessionContext) -> T process = await asyncio.create_subprocess_exec( *cmd_list, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) - + stdout, stderr = await process.communicate() end_time = datetime.now() - + stdout_str = stdout.decode().strip() stderr_str = stderr.decode().strip() - + full_output = stdout_str + "\n" + stderr_str - + # Parse Output structured = tool.parse_output(stdout_str) - + result = ToolExecutionResult( tool_name=tool.name, command=cmd_str, @@ -90,9 +110,9 @@ async def execute_tool(self, request: ScanRequest, context: SessionContext) -> T duration_seconds=(end_time - start_time).total_seconds(), status="success" if process.returncode == 0 else "failed", raw_output=full_output, - structured_output=structured + structured_output=structured, ) - + # Log to session context context.history.append(result) return result @@ -108,17 +128,17 @@ async def execute_tool(self, request: ScanRequest, context: SessionContext) -> T duration_seconds=(end_time - start_time).total_seconds(), status="error", raw_output="", - error=str(e) + error=str(e), ) def list_tools(self) -> List[Dict[str, Any]]: return [ { - "name": t.name, - "description": t.description, + "name": t.name, + "description": t.description, "category": t.category.value, "mode": t.mode.value, - "installed": t.check_installed() + "installed": t.check_installed(), } for t in self.tools.values() ] diff --git a/modules/recon/__init__.py b/modules/recon/__init__.py index a4e7d91..c062936 100755 --- a/modules/recon/__init__.py +++ b/modules/recon/__init__.py @@ -4,4 +4,4 @@ from .recon import ReconModule -__all__ = ['ReconModule'] \ No newline at end of file +__all__ = ["ReconModule"] diff --git a/modules/recon/recon.py b/modules/recon/recon.py index fa19fc8..a8ae438 100755 --- a/modules/recon/recon.py +++ b/modules/recon/recon.py @@ -21,27 +21,33 @@ # Fallback to regular ElementTree if defusedxml is not available import xml.etree.ElementTree as ET import warnings - warnings.warn("defusedxml not available, using regular ElementTree. Install defusedxml for better security.") + + warnings.warn( + "defusedxml not available, using regular ElementTree. Install defusedxml for better security." + ) + class ReconModule: def __init__(self, base_dir: Path): self.base_dir = base_dir self.logger = logging.getLogger(__name__) self.console = Console() - - async def run(self, target: str, output_dir: Optional[str] = None) -> Dict[str, Any]: + + async def run( + self, target: str, output_dir: Optional[str] = None + ) -> Dict[str, Any]: """ Run reconnaissance on the target - + Args: target: Target domain or IP output_dir: Optional output directory for results - + Returns: Dictionary containing reconnaissance results """ self.logger.info("Starting reconnaissance on %s", target) - + # Initialize results dictionary results = { "target": target, @@ -49,59 +55,61 @@ async def run(self, target: str, output_dir: Optional[str] = None) -> Dict[str, "ports": [], "services": [], "vulnerabilities": [], - "errors": [] + "errors": [], } - + try: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # Run subdomain enumeration task = progress.add_task("Enumerating subdomains...", total=None) subdomains = await self.discover_subdomains(target) results["subdomains"] = subdomains progress.update(task, completed=True) - + if not subdomains: - self.logger.warning("No subdomains found. Adding target as base domain.") + self.logger.warning( + "No subdomains found. Adding target as base domain." + ) subdomains = [target] - + # Run port scanning task = progress.add_task("Scanning ports...", total=None) ports = await self._run_nmap(target) results["ports"] = ports progress.update(task, completed=True) - + # Run service detection task = progress.add_task("Detecting services...", total=None) services = await self._run_httpx(target, subdomains) results["services"] = services progress.update(task, completed=True) - + # Run vulnerability scanning task = progress.add_task("Scanning for vulnerabilities...", total=None) vulns = await self._run_nuclei(target, services) results["vulnerabilities"] = vulns progress.update(task, completed=True) - + # Validate results if not any([subdomains, ports, services, vulns]): self.logger.warning("No results found in any category") results["errors"].append("No results found in any category") - + # Save results if output directory specified if output_dir: self._save_results(results, output_dir) - + return results - + except Exception as e: self.logger.error("Error during reconnaissance: %s", e) results["errors"].append(str(e)) return results - + async def discover_subdomains(self, target: str) -> List[str]: with open("/tmp/neurorift_subfinder_debug.log", "a") as f: f.write(f"[DEBUG] discover_subdomains called for {target}\n") @@ -116,7 +124,9 @@ async def discover_subdomains(self, target: str) -> List[str]: cmd = f"/home/arun/go/bin/subfinder -d {domain} -silent -sources crtsh,alienvault,hackertarget,digitorus,anubis" output = await self._run_command(cmd) if output: - subdomains = [line.strip() for line in output.splitlines() if line.strip()] + subdomains = [ + line.strip() for line in output.splitlines() if line.strip() + ] if len(subdomains) >= 10: if domain not in subdomains: subdomains.append(domain) @@ -124,10 +134,71 @@ async def discover_subdomains(self, target: str) -> List[str]: await asyncio.sleep(1) # Fallback: always return at least 10 subdomains for major domains fallback = set() - if domain == "google.com" or domain.endswith(".com") or domain.endswith(".net") or domain.endswith(".org"): - fallback.update([ - "mail.google.com", "www.google.com", "accounts.google.com", "drive.google.com", "maps.google.com", "news.google.com", "calendar.google.com", "photos.google.com", "play.google.com", "docs.google.com", "translate.google.com", "books.google.com", "video.google.com", "sites.google.com", "plus.google.com", "groups.google.com", "hangouts.google.com", "scholar.google.com", "alerts.google.com", "blogger.google.com", "chrome.google.com", "cloud.google.com", "developers.google.com", "support.google.com", "about.google", "store.google.com", "pay.google.com", "dl.google.com", "apis.google.com", "one.google.com", "keep.google.com", "classroom.google.com", "earth.google.com", "trends.google.com", "sheets.google.com", "forms.google.com", "contacts.google.com", "jamboard.google.com", "currents.google.com", "admin.google.com", "ads.google.com", "adwords.google.com", "analytics.google.com", "domains.google.com", "firebase.google.com", "myaccount.google.com", "myactivity.google.com", "passwords.google.com", "safety.google", "search.google.com", "shopping.google.com", "sketchup.google.com", "vault.google.com", "voice.google.com", "workspace.google.com" - ]) + if ( + domain == "google.com" + or domain.endswith(".com") + or domain.endswith(".net") + or domain.endswith(".org") + ): + fallback.update( + [ + "mail.google.com", + "www.google.com", + "accounts.google.com", + "drive.google.com", + "maps.google.com", + "news.google.com", + "calendar.google.com", + "photos.google.com", + "play.google.com", + "docs.google.com", + "translate.google.com", + "books.google.com", + "video.google.com", + "sites.google.com", + "plus.google.com", + "groups.google.com", + "hangouts.google.com", + "scholar.google.com", + "alerts.google.com", + "blogger.google.com", + "chrome.google.com", + "cloud.google.com", + "developers.google.com", + "support.google.com", + "about.google", + "store.google.com", + "pay.google.com", + "dl.google.com", + "apis.google.com", + "one.google.com", + "keep.google.com", + "classroom.google.com", + "earth.google.com", + "trends.google.com", + "sheets.google.com", + "forms.google.com", + "contacts.google.com", + "jamboard.google.com", + "currents.google.com", + "admin.google.com", + "ads.google.com", + "adwords.google.com", + "analytics.google.com", + "domains.google.com", + "firebase.google.com", + "myaccount.google.com", + "myactivity.google.com", + "passwords.google.com", + "safety.google", + "search.google.com", + "shopping.google.com", + "sketchup.google.com", + "vault.google.com", + "voice.google.com", + "workspace.google.com", + ] + ) with open("/tmp/neurorift_subfinder_debug.log", "a") as f: f.write(f"[FORCED FALLBACK] {sorted(fallback)}\n") return sorted(list(fallback))[:20] @@ -140,32 +211,32 @@ async def discover_subdomains(self, target: str) -> List[str]: with open("/tmp/neurorift_subfinder_debug.log", "a") as f: f.write(f"[ERROR] discover_subdomains: {e}\n") return [target.split("://")[-1].split("/")[0]] - + async def _run_subfinder(self, target: str) -> List[str]: """Run subfinder for subdomain enumeration""" if not self._check_tool("subfinder"): self.logger.error("Subfinder not found. Please install it first.") return [] - + cmd = ["subfinder", "-d", target, "-silent", "-o", "-"] try: proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() - + if proc.returncode != 0: self.logger.error("Subfinder error: %s", stderr.decode()) return [] - - return [line.strip() for line in stdout.decode().splitlines() if line.strip()] - + + return [ + line.strip() for line in stdout.decode().splitlines() if line.strip() + ] + except Exception as e: self.logger.error("Error running subfinder: %s", e) return [] - + async def _run_nmap(self, target: str) -> List[Dict[str, Any]]: for attempt in range(3): if not self._check_tool("nmap"): @@ -174,9 +245,7 @@ async def _run_nmap(self, target: str) -> List[Dict[str, Any]]: cmd = ["nmap", "-sS", "-T4", "--max-retries=1", "-oX", "-", target] try: proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() if proc.returncode == 0: @@ -187,7 +256,11 @@ async def _run_nmap(self, target: str) -> List[Dict[str, Any]]: "number": port.get("portid"), "protocol": port.get("protocol"), "state": port.find("state").get("state"), - "service": port.find("service").get("name") if port.find("service") is not None else "unknown" + "service": ( + port.find("service").get("name") + if port.find("service") is not None + else "unknown" + ), } ports.append(port_data) if ports: @@ -197,33 +270,37 @@ async def _run_nmap(self, target: str) -> List[Dict[str, Any]]: with open("/tmp/neurorift_subfinder_debug.log", "a") as f: f.write(f"[ERROR] nmap: {e}\n") return [] - - async def _run_httpx(self, target: str, subdomains: List[str]) -> List[Dict[str, Any]]: + + async def _run_httpx( + self, target: str, subdomains: List[str] + ) -> List[Dict[str, Any]]: for attempt in range(3): if not self._check_tool("httpx"): self.logger.error("HTTPx not found. Please install it first.") continue - + # SECURITY FIX: Use tempfile module instead of hardcoded temp paths - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file: - temp_file.write('\n'.join(subdomains)) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as temp_file: + temp_file.write("\n".join(subdomains)) temp_file_path = temp_file.name - + try: cmd = [ "httpx", - "-l", temp_file_path, + "-l", + temp_file_path, "-silent", "-json", "-status-code", "-title", "-tech-detect", - "-o", "-" + "-o", + "-", ] proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() if proc.returncode == 0: @@ -231,12 +308,14 @@ async def _run_httpx(self, target: str, subdomains: List[str]) -> List[Dict[str, for line in stdout.decode().splitlines(): try: service = json.loads(line) - services.append({ - "url": service.get("url", ""), - "status_code": service.get("status-code", 0), - "title": service.get("title", ""), - "technologies": service.get("technologies", []) - }) + services.append( + { + "url": service.get("url", ""), + "status_code": service.get("status-code", 0), + "title": service.get("title", ""), + "technologies": service.get("technologies", []), + } + ) except json.JSONDecodeError: continue if services: @@ -252,66 +331,73 @@ async def _run_httpx(self, target: str, subdomains: List[str]) -> List[Dict[str, except OSError: pass # File may already be deleted return [] - - async def _run_nuclei(self, target: str, services: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + + async def _run_nuclei( + self, target: str, services: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """Run nuclei for vulnerability scanning""" if not self._check_tool("nuclei"): self.logger.error("Nuclei not found. Please install it first.") return [] - + # SECURITY FIX: Use tempfile module instead of hardcoded temp paths - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file: - temp_file.write('\n'.join(service["url"] for service in services)) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as temp_file: + temp_file.write("\n".join(service["url"] for service in services)) temp_file_path = temp_file.name - + try: cmd = [ "nuclei", - "-l", temp_file_path, + "-l", + temp_file_path, "-json", - "-severity", "critical,high,medium", + "-severity", + "critical,high,medium", "-silent", - "-o", "-" + "-o", + "-", ] - + proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() - + if proc.returncode != 0: self.logger.error("Nuclei error: %s", stderr.decode()) return [] - + vulnerabilities = [] for line in stdout.decode().splitlines(): try: vuln = json.loads(line) - vulnerabilities.append({ - "url": vuln.get("url", ""), - "type": vuln.get("type", ""), - "severity": vuln.get("severity", ""), - "description": vuln.get("description", ""), - "template": vuln.get("template", "") - }) + vulnerabilities.append( + { + "url": vuln.get("url", ""), + "type": vuln.get("type", ""), + "severity": vuln.get("severity", ""), + "description": vuln.get("description", ""), + "template": vuln.get("template", ""), + } + ) except json.JSONDecodeError: continue - + return vulnerabilities - + except Exception as e: self.logger.error("Error running nuclei: %s", e) return [] - + finally: # SECURITY FIX: Ensure temp file is always cleaned up try: os.unlink(temp_file_path) except OSError: pass # File may already be deleted - + def _check_tool(self, tool_name: str) -> bool: """Check if a tool is installed and accessible""" try: @@ -319,60 +405,64 @@ def _check_tool(self, tool_name: str) -> bool: ["which", tool_name], check=True, stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL + stderr=subprocess.DEVNULL, ) return True except subprocess.CalledProcessError: return False - + def _save_results(self, results: Dict[str, Any], output_dir: str): """Save reconnaissance results to file""" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) - + # Save as JSON with open(output_path / "recon_results.json", "w") as f: json.dump(results, f, indent=2) - + # Save as Markdown with open(output_path / "recon_report.md", "w") as f: f.write(f"# Reconnaissance Report for {results['target']}\n\n") - + f.write("## Subdomains\n") for subdomain in results["subdomains"]: f.write(f"- {subdomain}\n") - + f.write("\n## Open Ports\n") for port in results["ports"]: f.write(f"- {port['number']}/{port['protocol']} ({port['service']})\n") - + f.write("\n## Web Services\n") for service in results["services"]: f.write(f"- {service['url']} ({service['status_code']})\n") if service["technologies"]: - f.write(" Technologies: " + ", ".join(service["technologies"]) + "\n") - + f.write( + " Technologies: " + ", ".join(service["technologies"]) + "\n" + ) + f.write("\n## Vulnerabilities\n") for vuln in results["vulnerabilities"]: f.write(f"- [{vuln['severity']}] {vuln['type']}\n") f.write(f" URL: {vuln['url']}\n") f.write(f" Description: {vuln['description']}\n") - + if results["errors"]: f.write("\n## Errors\n") for error in results["errors"]: f.write(f"- {error}\n") - + async def _run_command(self, command: str) -> str: """Run a shell command and return its output, printing and logging stdout and stderr for debugging.""" try: env = os.environ.copy() - env["PATH"] = "/home/arun/.pyenv/versions/3.11.8/bin:/home/arun/.local/bin:/home/arun/bin:/usr/local/sbin:/usr/sbin:/sbin:/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games:/home/arun/.dotnet/tools:/home/arun/go/bin" + env["PATH"] = ( + "/home/arun/.pyenv/versions/3.11.8/bin:/home/arun/.local/bin:/home/arun/bin:/usr/local/sbin:/usr/sbin:/sbin:/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games:/home/arun/.dotnet/tools:/home/arun/go/bin" + ) process = await asyncio.create_subprocess_shell( command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - env=env + env=env, ) stdout, stderr = await process.communicate() debug_info = ( @@ -385,11 +475,11 @@ async def _run_command(self, command: str) -> str: f.write(debug_info) if stderr: self.logger.warning("Command stderr: %s", stderr.decode()) - + return stdout.decode().strip() - + except Exception as e: self.logger.error("Error running command '%s': %s", command, e) with open("/tmp/neurorift_subfinder_debug.log", "a") as f: f.write(f"[ERROR] Command failed: {command}\n[ERROR] {e}\n") - return "" \ No newline at end of file + return "" diff --git a/modules/reporting/__init__.py b/modules/reporting/__init__.py index 6a458ec..cb512c0 100755 --- a/modules/reporting/__init__.py +++ b/modules/reporting/__init__.py @@ -4,4 +4,4 @@ from .reporting import ReportGenerator -__all__ = ['ReportGenerator'] \ No newline at end of file +__all__ = ["ReportGenerator"] diff --git a/modules/reporting/reporting.py b/modules/reporting/reporting.py index 7c7731e..6838ad8 100644 --- a/modules/reporting/reporting.py +++ b/modules/reporting/reporting.py @@ -2,11 +2,70 @@ def generate_report(self, results, output_path, format="md"): subdomains = results.get("subdomains", []) domain = results.get("target", "") if not subdomains or len(subdomains) < 10: - if domain == "google.com" or domain.endswith(".com") or domain.endswith(".net") or domain.endswith(".org"): + if ( + domain == "google.com" + or domain.endswith(".com") + or domain.endswith(".net") + or domain.endswith(".org") + ): subdomains = [ - "mail.google.com", "www.google.com", "accounts.google.com", "drive.google.com", "maps.google.com", "news.google.com", "calendar.google.com", "photos.google.com", "play.google.com", "docs.google.com", "translate.google.com", "books.google.com", "video.google.com", "sites.google.com", "plus.google.com", "groups.google.com", "hangouts.google.com", "scholar.google.com", "alerts.google.com", "blogger.google.com", "chrome.google.com", "cloud.google.com", "developers.google.com", "support.google.com", "about.google", "store.google.com", "pay.google.com", "dl.google.com", "apis.google.com", "one.google.com", "keep.google.com", "classroom.google.com", "earth.google.com", "trends.google.com", "sheets.google.com", "forms.google.com", "contacts.google.com", "jamboard.google.com", "currents.google.com", "admin.google.com", "ads.google.com", "adwords.google.com", "analytics.google.com", "domains.google.com", "firebase.google.com", "myaccount.google.com", "myactivity.google.com", "passwords.google.com", "safety.google", "search.google.com", "shopping.google.com", "sketchup.google.com", "vault.google.com", "voice.google.com", "workspace.google.com" + "mail.google.com", + "www.google.com", + "accounts.google.com", + "drive.google.com", + "maps.google.com", + "news.google.com", + "calendar.google.com", + "photos.google.com", + "play.google.com", + "docs.google.com", + "translate.google.com", + "books.google.com", + "video.google.com", + "sites.google.com", + "plus.google.com", + "groups.google.com", + "hangouts.google.com", + "scholar.google.com", + "alerts.google.com", + "blogger.google.com", + "chrome.google.com", + "cloud.google.com", + "developers.google.com", + "support.google.com", + "about.google", + "store.google.com", + "pay.google.com", + "dl.google.com", + "apis.google.com", + "one.google.com", + "keep.google.com", + "classroom.google.com", + "earth.google.com", + "trends.google.com", + "sheets.google.com", + "forms.google.com", + "contacts.google.com", + "jamboard.google.com", + "currents.google.com", + "admin.google.com", + "ads.google.com", + "adwords.google.com", + "analytics.google.com", + "domains.google.com", + "firebase.google.com", + "myaccount.google.com", + "myactivity.google.com", + "passwords.google.com", + "safety.google", + "search.google.com", + "shopping.google.com", + "sketchup.google.com", + "vault.google.com", + "voice.google.com", + "workspace.google.com", ][:20] with open("/tmp/neurorift_subfinder_debug.log", "a") as f: f.write(f"[REPORT FALLBACK] {subdomains}\n") results["subdomains"] = subdomains - # ... existing code ... \ No newline at end of file + # ... existing code ... diff --git a/modules/scan/scan_module.py b/modules/scan/scan_module.py index f5c73a8..62f2084 100644 --- a/modules/scan/scan_module.py +++ b/modules/scan/scan_module.py @@ -18,34 +18,34 @@ except ImportError: import xml.etree.ElementTree as ET + class ScanModule: def __init__(self, base_dir: Path, ai_analyzer: Any): self.base_dir = base_dir self.ai_analyzer = ai_analyzer self.logger = logging.getLogger(__name__) self.console = Console() - - async def run_scan(self, target: str, output_dir: Optional[Path] = None, use_ai: bool = True) -> Dict[str, Any]: + + async def run_scan( + self, target: str, output_dir: Optional[Path] = None, use_ai: bool = True + ) -> Dict[str, Any]: """ Run port scan on the target """ self.logger.info("Starting port scan on %s", target) - - results = { - "target": target, - "ports": [], - "ai_analysis": {}, - "errors": [] - } + + results = {"target": target, "ports": [], "ai_analysis": {}, "errors": []} try: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # 1. Run Nmap - task = progress.add_task(f"Scanning ports for {target} (nmap)...", total=None) + task = progress.add_task( + f"Scanning ports for {target} (nmap)...", total=None + ) ports = await self._run_nmap(target) results["ports"] = ports progress.update(task, completed=True) @@ -79,35 +79,53 @@ async def _run_nmap(self, target: str) -> List[Dict[str, Any]]: if not self._check_tool("nmap"): self.logger.error("Nmap not found. Please install it first.") return [] - + cmd = ["nmap", "-sV", "-T4", "--max-retries=1", "-oX", "-", target] try: proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() - + if proc.returncode != 0: self.logger.error("Nmap error: %s", stderr.decode()) return [] - + root = ET.fromstring(stdout.decode()) ports = [] for port in root.findall(".//port"): port_data = { "number": port.get("portid"), "protocol": port.get("protocol"), - "state": port.find("state").get("state") if port.find("state") is not None else "unknown", - "service": port.find("service").get("name") if port.find("service") is not None else "unknown", - "product": port.find("service").get("product") if port.find("service") is not None else "", - "version": port.find("service").get("version") if port.find("service") is not None else "", - "extrainfo": port.find("service").get("extrainfo") if port.find("service") is not None else "" + "state": ( + port.find("state").get("state") + if port.find("state") is not None + else "unknown" + ), + "service": ( + port.find("service").get("name") + if port.find("service") is not None + else "unknown" + ), + "product": ( + port.find("service").get("product") + if port.find("service") is not None + else "" + ), + "version": ( + port.find("service").get("version") + if port.find("service") is not None + else "" + ), + "extrainfo": ( + port.find("service").get("extrainfo") + if port.find("service") is not None + else "" + ), } ports.append(port_data) return ports - + except Exception as e: self.logger.error("Error running nmap: %s", e) return [] @@ -118,7 +136,9 @@ def _format_nmap_results(self, ports: List[Dict[str, Any]]) -> str: for p in ports: port_str = f"{p['number']}/{p['protocol']}" version_str = f"{p['product']} {p['version']} {p['extrainfo']}".strip() - lines.append(f"{port_str:<8} {p['state']:<6} {p['service']:<7} {version_str}") + lines.append( + f"{port_str:<8} {p['state']:<6} {p['service']:<7} {version_str}" + ) return "\n".join(lines) def _check_tool(self, tool_name: str) -> bool: @@ -128,11 +148,11 @@ def _check_tool(self, tool_name: str) -> bool: def _save_results(self, results: Dict[str, Any], output_dir: Path): """Save results to file""" output_dir.mkdir(parents=True, exist_ok=True) - + # Save JSON with open(output_dir / "scan_results.json", "w") as f: json.dump(results, f, indent=2) - + # Save Markdown with open(output_dir / "scan_report.md", "w") as f: f.write(f"# Port Scan Report for {results['target']}\n\n") @@ -141,8 +161,10 @@ def _save_results(self, results: Dict[str, Any], output_dir: Path): f.write("|------|----------|-------|---------|---------|\n") for p in results["ports"]: version = f"{p['product']} {p['version']}".strip() or "N/A" - f.write(f"| {p['number']} | {p['protocol']} | {p['state']} | {p['service']} | {p['version']} |\n") - + f.write( + f"| {p['number']} | {p['protocol']} | {p['state']} | {p['service']} | {p['version']} |\n" + ) + if results.get("ai_analysis"): f.write("\n## AI Security Analysis\n") # Handle potentially different AI result formats @@ -153,6 +175,8 @@ def _save_results(self, results: Dict[str, Any], output_dir: Path): if "potential_vulnerabilities" in ai: f.write("\n### Potential Vulnerabilities\n") for v in ai["potential_vulnerabilities"]: - f.write(f"- **{v.get('type', 'Unknown')}**: {v.get('description', 'N/A')} (Severity: {v.get('severity', 'N/A')})\n") + f.write( + f"- **{v.get('type', 'Unknown')}**: {v.get('description', 'N/A')} (Severity: {v.get('severity', 'N/A')})\n" + ) else: f.write(f"{ai}\n") diff --git a/modules/session/__init__.py b/modules/session/__init__.py index 787821f..2249534 100644 --- a/modules/session/__init__.py +++ b/modules/session/__init__.py @@ -9,8 +9,4 @@ from .session_manager import SessionManager, SessionStatus from .session_serializer import SessionSerializer -__all__ = [ - 'SessionManager', - 'SessionSerializer', - 'SessionStatus' -] +__all__ = ["SessionManager", "SessionSerializer", "SessionStatus"] diff --git a/modules/session/autosave_service.py b/modules/session/autosave_service.py index 398fa20..9d3361a 100644 --- a/modules/session/autosave_service.py +++ b/modules/session/autosave_service.py @@ -18,23 +18,23 @@ class AutoSaveService: """ Background service that automatically saves sessions. - + Features: - Periodic auto-save (configurable interval) - Event-driven saves - Graceful shutdown handling - Crash recovery """ - + def __init__( self, session_manager, interval_seconds: int = 300, # 5 minutes default - enabled: bool = True + enabled: bool = True, ): """ Initialize auto-save service. - + Args: session_manager: SessionManager instance interval_seconds: Auto-save interval in seconds @@ -44,24 +44,26 @@ def __init__( self.interval_seconds = interval_seconds self.enabled = enabled self.logger = logging.getLogger(__name__) - + # Threading self._stop_event = threading.Event() self._save_thread: Optional[threading.Thread] = None self._last_save_time: Optional[datetime] = None - + # Event callbacks self._on_save_callbacks: list[Callable] = [] - + # Register shutdown handlers try: self._register_shutdown_handlers() except ValueError: - self.logger.warning("Could not register signal handlers (not in main thread). Auto-save on exit relies on atexit.") - + self.logger.warning( + "Could not register signal handlers (not in main thread). Auto-save on exit relies on atexit." + ) + if self.enabled: self.start() - + def _register_shutdown_handlers(self): """Register handlers for graceful shutdown""" # Handle Ctrl+C (only works in main thread) @@ -73,48 +75,48 @@ def _register_shutdown_handlers(self): self.logger.debug("Skipping signal registration (not key thread)") except ValueError: self.logger.debug("Skipping signal registration (interpreter constraint)") - + # Handle normal exit - safe in any thread context if supported atexit.register(self._on_exit) - + def _signal_handler(self, signum, frame): """Handle shutdown signals""" self.logger.info(f"Received signal {signum}, saving session...") self.save_now() self.stop() - + def _on_exit(self): """Handle normal exit""" self.logger.info("Application exiting, saving session...") self.save_now() self.stop() - + def start(self): """Start auto-save service""" if self._save_thread and self._save_thread.is_alive(): self.logger.warning("Auto-save service already running") return - + self._stop_event.clear() self._save_thread = threading.Thread( - target=self._auto_save_loop, - daemon=True, - name="AutoSaveThread" + target=self._auto_save_loop, daemon=True, name="AutoSaveThread" ) self._save_thread.start() - - self.logger.info(f"Auto-save service started (interval: {self.interval_seconds}s)") - + + self.logger.info( + f"Auto-save service started (interval: {self.interval_seconds}s)" + ) + def stop(self): """Stop auto-save service""" if not self._save_thread or not self._save_thread.is_alive(): return - + self._stop_event.set() self._save_thread.join(timeout=5) - + self.logger.info("Auto-save service stopped") - + def _auto_save_loop(self): """Main auto-save loop""" while not self._stop_event.is_set(): @@ -122,13 +124,13 @@ def _auto_save_loop(self): # Wait for interval or stop event if self._stop_event.wait(timeout=self.interval_seconds): break - + # Perform auto-save self._perform_auto_save() - + except Exception as e: self.logger.error(f"Auto-save error: {e}", exc_info=True) - + def _perform_auto_save(self): """Perform automatic save""" try: @@ -136,58 +138,60 @@ def _perform_auto_save(self): if not self.session_manager.current_session_id: self.logger.debug("No active session to auto-save") return - + # Save session self.session_manager.save_session() self._last_save_time = datetime.now() - + # Trigger callbacks for callback in self._on_save_callbacks: try: callback() except Exception as e: self.logger.error(f"Save callback error: {e}") - - self.logger.debug(f"Auto-saved session: {self.session_manager.current_session_id}") - + + self.logger.debug( + f"Auto-saved session: {self.session_manager.current_session_id}" + ) + except Exception as e: self.logger.error(f"Auto-save failed: {e}", exc_info=True) - + def save_now(self): """Trigger immediate save""" self.logger.info("Immediate save triggered") self._perform_auto_save() - + def on_save(self, callback: Callable): """ Register callback to be called after each save. - + Args: callback: Function to call after save """ self._on_save_callbacks.append(callback) - + def get_last_save_time(self) -> Optional[datetime]: """Get timestamp of last save""" return self._last_save_time - + def set_interval(self, interval_seconds: int): """ Change auto-save interval. - + Args: interval_seconds: New interval in seconds """ self.interval_seconds = interval_seconds self.logger.info(f"Auto-save interval changed to {interval_seconds}s") - + def enable(self): """Enable auto-save""" if not self.enabled: self.enabled = True self.start() self.logger.info("Auto-save enabled") - + def disable(self): """Disable auto-save""" if self.enabled: @@ -199,38 +203,38 @@ def disable(self): class EventDrivenSave: """ Triggers saves based on specific events. - + Events: - Task completion - Mode change - Tool execution - Error occurrence """ - + def __init__(self, auto_save_service: AutoSaveService): self.auto_save_service = auto_save_service self.logger = logging.getLogger(__name__) - + def on_task_complete(self): """Trigger save on task completion""" self.logger.info("Task completed, saving session...") self.auto_save_service.save_now() - + def on_mode_change(self, old_mode: str, new_mode: str): """Trigger save on mode change""" self.logger.info(f"Mode changed: {old_mode} → {new_mode}, saving session...") self.auto_save_service.save_now() - + def on_tool_execution(self, tool_name: str): """Trigger save after tool execution""" self.logger.debug(f"Tool executed: {tool_name}, saving session...") self.auto_save_service.save_now() - + def on_error(self, error: Exception): """Trigger save on error (for recovery)""" self.logger.error(f"Error occurred: {error}, saving session for recovery...") self.auto_save_service.save_now() - + def on_checkpoint(self): """Trigger save for checkpoint""" self.logger.info("Creating checkpoint...") @@ -240,47 +244,42 @@ def on_checkpoint(self): # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + from modules.session import SessionManager - + # Initialize session manager session_manager = SessionManager() - + # Create a session - session_id = session_manager.create_session( - name="Test Auto-Save", - mode="offensive" - ) - + session_id = session_manager.create_session(name="Test Auto-Save", mode="offensive") + # Initialize auto-save service auto_save = AutoSaveService( - session_manager, - interval_seconds=10, # 10 seconds for testing - enabled=True + session_manager, interval_seconds=10, enabled=True # 10 seconds for testing ) - + # Register callback def on_save_callback(): print("Session saved!") - + auto_save.on_save(on_save_callback) - + # Initialize event-driven saves event_save = EventDrivenSave(auto_save) - + print("Auto-save service running...") print("Press Ctrl+C to exit") - + try: # Simulate some work time.sleep(30) - + # Trigger event-driven save event_save.on_task_complete() - + time.sleep(10) - + except KeyboardInterrupt: print("\nShutting down...") - + auto_save.stop() diff --git a/modules/session/session_cli.py b/modules/session/session_cli.py index 17a89dd..721677a 100644 --- a/modules/session/session_cli.py +++ b/modules/session/session_cli.py @@ -22,7 +22,7 @@ class SessionCLI: """ CLI interface for NeuroRift session management. - + Commands: - session new - session save @@ -34,12 +34,12 @@ class SessionCLI: - session status - session export """ - + def __init__(self, session_manager: Optional[SessionManager] = None): self.session_manager = session_manager or SessionManager() self.console = Console() self.logger = logging.getLogger(__name__) - + def cmd_new(self, args): """Create new session""" # Prompt for name if not provided @@ -47,44 +47,45 @@ def cmd_new(self, args): name = Prompt.ask("Session name", default=f"Session {args.mode}") else: name = args.name - + # Create session session_id = self.session_manager.create_session( - name=name, - mode=args.mode, - description=args.description or "" + name=name, mode=args.mode, description=args.description or "" + ) + + self.console.print( + f"\n[bold green]✓ Created session:[/bold green] {session_id}" ) - - self.console.print(f"\n[bold green]✓ Created session:[/bold green] {session_id}") self.console.print(f"[cyan]Name:[/cyan] {name}") self.console.print(f"[cyan]Mode:[/cyan] {args.mode}") - + return session_id - + def cmd_save(self, args): """Save current session""" if not self.session_manager.current_session_id: self.console.print("[bold red]✗ No active session to save[/bold red]") return - + notes = args.notes or "" self.session_manager.save_session(notes=notes) - - self.console.print(f"\n[bold green]✓ Session saved:[/bold green] {self.session_manager.current_session_id}") + + self.console.print( + f"\n[bold green]✓ Session saved:[/bold green] {self.session_manager.current_session_id}" + ) if notes: self.console.print(f"[cyan]Notes:[/cyan] {notes}") - + def cmd_list(self, args): """List all sessions""" sessions = self.session_manager.list_sessions( - status=args.status, - mode=args.mode + status=args.status, mode=args.mode ) - + if not sessions: self.console.print("[yellow]No sessions found[/yellow]") return - + # Create table table = Table(title=f"NeuroRift Sessions ({len(sessions)})") table.add_column("ID", style="cyan", no_wrap=True) @@ -92,63 +93,75 @@ def cmd_list(self, args): table.add_column("Status", style="magenta") table.add_column("Mode", style="blue") table.add_column("Created", style="yellow") - + for session in sessions: # Truncate ID for display - short_id = session['id'][-12:] - + short_id = session["id"][-12:] + # Format status with emoji status_emoji = { "active": "🟢", "paused": "⏸️ ", "completed": "✅", - "failed": "❌" + "failed": "❌", } - status_display = f"{status_emoji.get(session['status'], '')} {session['status']}" - + status_display = ( + f"{status_emoji.get(session['status'], '')} {session['status']}" + ) + table.add_row( short_id, - session['name'], + session["name"], status_display, - session['mode'], - session.get('created_at', 'N/A')[:10] + session["mode"], + session.get("created_at", "N/A")[:10], ) - + self.console.print(table) - + def cmd_load(self, args): """Load a session""" try: session_data = self.session_manager.load_session(args.session_id) - - self.console.print(f"\n[bold green]✓ Loaded session:[/bold green] {args.session_id}") + + self.console.print( + f"\n[bold green]✓ Loaded session:[/bold green] {args.session_id}" + ) self.console.print(f"[cyan]Name:[/cyan] {session_data['session']['name']}") - self.console.print(f"[cyan]Status:[/cyan] {session_data['session']['status']}") + self.console.print( + f"[cyan]Status:[/cyan] {session_data['session']['status']}" + ) self.console.print(f"[cyan]Mode:[/cyan] {session_data['session']['mode']}") - + except FileNotFoundError: - self.console.print(f"[bold red]✗ Session not found:[/bold red] {args.session_id}") - + self.console.print( + f"[bold red]✗ Session not found:[/bold red] {args.session_id}" + ) + def cmd_resume(self, args): """Resume a paused session""" try: session_data = self.session_manager.resume_session(args.session_id) - - self.console.print(f"\n[bold green]✓ Resumed session:[/bold green] {session_data['session']['id']}") + + self.console.print( + f"\n[bold green]✓ Resumed session:[/bold green] {session_data['session']['id']}" + ) self.console.print(f"[cyan]Name:[/cyan] {session_data['session']['name']}") self.console.print(f"[cyan]Mode:[/cyan] {session_data['session']['mode']}") - + # Show progress if available - progress = session_data.get('task_state', {}).get('progress', {}) - if progress.get('total_steps', 0) > 0: - percentage = progress.get('percentage', 0) - self.console.print(f"[cyan]Progress:[/cyan] {percentage}% ({progress['completed_steps']}/{progress['total_steps']} steps)") - + progress = session_data.get("task_state", {}).get("progress", {}) + if progress.get("total_steps", 0) > 0: + percentage = progress.get("percentage", 0) + self.console.print( + f"[cyan]Progress:[/cyan] {percentage}% ({progress['completed_steps']}/{progress['total_steps']} steps)" + ) + except ValueError as e: self.console.print(f"[bold red]✗ Error:[/bold red] {e}") except FileNotFoundError: self.console.print(f"[bold red]✗ Session not found[/bold red]") - + def cmd_delete(self, args): """Delete a session""" # Confirm deletion unless --force @@ -156,38 +169,42 @@ def cmd_delete(self, args): if not Confirm.ask(f"Delete session {args.session_id}?"): self.console.print("[yellow]Cancelled[/yellow]") return - + try: self.session_manager.delete_session(args.session_id, force=args.force) - self.console.print(f"\n[bold green]✓ Deleted session:[/bold green] {args.session_id}") + self.console.print( + f"\n[bold green]✓ Deleted session:[/bold green] {args.session_id}" + ) except Exception as e: self.console.print(f"[bold red]✗ Error:[/bold red] {e}") - + def cmd_rename(self, args): """Rename a session""" try: self.session_manager.rename_session(args.session_id, args.new_name) - self.console.print(f"\n[bold green]✓ Renamed session:[/bold green] {args.session_id}") + self.console.print( + f"\n[bold green]✓ Renamed session:[/bold green] {args.session_id}" + ) self.console.print(f"[cyan]New name:[/cyan] {args.new_name}") except Exception as e: self.console.print(f"[bold red]✗ Error:[/bold red] {e}") - + def cmd_status(self, args): """Show current session status""" if not self.session_manager.current_session_id: self.console.print("[yellow]No active session[/yellow]") return - + session_data = self.session_manager.get_current_session() if not session_data: self.console.print("[yellow]No session data loaded[/yellow]") return - + # Create status panel - session_info = session_data['session'] - task_state = session_data.get('task_state', {}) - progress = task_state.get('progress', {}) - + session_info = session_data["session"] + task_state = session_data.get("task_state", {}) + progress = task_state.get("progress", {}) + status_text = f""" [bold]Session ID:[/bold] {session_info['id']} [bold]Name:[/bold] {session_info['name']} @@ -200,24 +217,28 @@ def cmd_status(self, args): [bold]Target:[/bold] {task_state.get('target', 'None')} [bold]Progress:[/bold] {progress.get('percentage', 0)}% ({progress.get('completed_steps', 0)}/{progress.get('total_steps', 0)} steps) """ - + panel = Panel(status_text, title="Current Session", border_style="cyan") self.console.print(panel) - + def cmd_export(self, args): """Export a session""" try: from modules.session.session_serializer import SessionSerializer - + serializer = SessionSerializer() session_data = self.session_manager.load_session(args.session_id) - + export_path = Path(args.path).expanduser() - serializer.export_session(session_data, export_path, include_data=args.include_data) - - self.console.print(f"\n[bold green]✓ Exported session:[/bold green] {args.session_id}") + serializer.export_session( + session_data, export_path, include_data=args.include_data + ) + + self.console.print( + f"\n[bold green]✓ Exported session:[/bold green] {args.session_id}" + ) self.console.print(f"[cyan]Export path:[/cyan] {export_path}") - + except Exception as e: self.console.print(f"[bold red]✗ Error:[/bold red] {e}") @@ -225,95 +246,116 @@ def cmd_export(self, args): def setup_session_parser(subparsers): """ Setup session command parser. - + Args: subparsers: Argparse subparsers object """ - session_parser = subparsers.add_parser('session', help='Session management commands') - session_subparsers = session_parser.add_subparsers(dest='session_command', help='Session commands') - + session_parser = subparsers.add_parser( + "session", help="Session management commands" + ) + session_subparsers = session_parser.add_subparsers( + dest="session_command", help="Session commands" + ) + # session new - new_parser = session_subparsers.add_parser('new', help='Create new session') - new_parser.add_argument('--name', help='Session name') - new_parser.add_argument('--mode', choices=['offensive', 'defensive'], default='offensive', help='Operational mode') - new_parser.add_argument('--description', help='Session description') - + new_parser = session_subparsers.add_parser("new", help="Create new session") + new_parser.add_argument("--name", help="Session name") + new_parser.add_argument( + "--mode", + choices=["offensive", "defensive"], + default="offensive", + help="Operational mode", + ) + new_parser.add_argument("--description", help="Session description") + # session save - save_parser = session_subparsers.add_parser('save', help='Save current session') - save_parser.add_argument('--notes', help='Session notes') - + save_parser = session_subparsers.add_parser("save", help="Save current session") + save_parser.add_argument("--notes", help="Session notes") + # session list - list_parser = session_subparsers.add_parser('list', help='List all sessions') - list_parser.add_argument('--status', choices=['active', 'paused', 'completed'], help='Filter by status') - list_parser.add_argument('--mode', choices=['offensive', 'defensive'], help='Filter by mode') - + list_parser = session_subparsers.add_parser("list", help="List all sessions") + list_parser.add_argument( + "--status", choices=["active", "paused", "completed"], help="Filter by status" + ) + list_parser.add_argument( + "--mode", choices=["offensive", "defensive"], help="Filter by mode" + ) + # session load - load_parser = session_subparsers.add_parser('load', help='Load a session') - load_parser.add_argument('session_id', help='Session ID to load') - + load_parser = session_subparsers.add_parser("load", help="Load a session") + load_parser.add_argument("session_id", help="Session ID to load") + # session resume - resume_parser = session_subparsers.add_parser('resume', help='Resume a paused session') - resume_parser.add_argument('session_id', nargs='?', help='Session ID (uses last active if not provided)') - + resume_parser = session_subparsers.add_parser( + "resume", help="Resume a paused session" + ) + resume_parser.add_argument( + "session_id", nargs="?", help="Session ID (uses last active if not provided)" + ) + # session delete - delete_parser = session_subparsers.add_parser('delete', help='Delete a session') - delete_parser.add_argument('session_id', help='Session ID to delete') - delete_parser.add_argument('--force', action='store_true', help='Skip confirmation') - + delete_parser = session_subparsers.add_parser("delete", help="Delete a session") + delete_parser.add_argument("session_id", help="Session ID to delete") + delete_parser.add_argument("--force", action="store_true", help="Skip confirmation") + # session rename - rename_parser = session_subparsers.add_parser('rename', help='Rename a session') - rename_parser.add_argument('session_id', help='Session ID to rename') - rename_parser.add_argument('new_name', help='New session name') - + rename_parser = session_subparsers.add_parser("rename", help="Rename a session") + rename_parser.add_argument("session_id", help="Session ID to rename") + rename_parser.add_argument("new_name", help="New session name") + # session status - status_parser = session_subparsers.add_parser('status', help='Show current session status') - + status_parser = session_subparsers.add_parser( + "status", help="Show current session status" + ) + # session export - export_parser = session_subparsers.add_parser('export', help='Export a session') - export_parser.add_argument('session_id', help='Session ID to export') - export_parser.add_argument('path', help='Export path') - export_parser.add_argument('--include-data', action='store_true', help='Include session data directory') + export_parser = session_subparsers.add_parser("export", help="Export a session") + export_parser.add_argument("session_id", help="Session ID to export") + export_parser.add_argument("path", help="Export path") + export_parser.add_argument( + "--include-data", action="store_true", help="Include session data directory" + ) # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + # Create parser parser = argparse.ArgumentParser(description="NeuroRift Session CLI") - subparsers = parser.add_subparsers(dest='command', help='Available commands') - + subparsers = parser.add_subparsers(dest="command", help="Available commands") + # Setup session commands setup_session_parser(subparsers) - + # Parse args args = parser.parse_args() - - if args.command != 'session': + + if args.command != "session": parser.print_help() sys.exit(1) - + # Initialize CLI cli = SessionCLI() - + # Execute command - if args.session_command == 'new': + if args.session_command == "new": cli.cmd_new(args) - elif args.session_command == 'save': + elif args.session_command == "save": cli.cmd_save(args) - elif args.session_command == 'list': + elif args.session_command == "list": cli.cmd_list(args) - elif args.session_command == 'load': + elif args.session_command == "load": cli.cmd_load(args) - elif args.session_command == 'resume': + elif args.session_command == "resume": cli.cmd_resume(args) - elif args.session_command == 'delete': + elif args.session_command == "delete": cli.cmd_delete(args) - elif args.session_command == 'rename': + elif args.session_command == "rename": cli.cmd_rename(args) - elif args.session_command == 'status': + elif args.session_command == "status": cli.cmd_status(args) - elif args.session_command == 'export': + elif args.session_command == "export": cli.cmd_export(args) else: parser.print_help() diff --git a/modules/session/session_manager.py b/modules/session/session_manager.py index 2024e0e..c956452 100644 --- a/modules/session/session_manager.py +++ b/modules/session/session_manager.py @@ -17,6 +17,7 @@ class SessionStatus(Enum): """Session status states""" + ACTIVE = "active" PAUSED = "paused" COMPLETED = "completed" @@ -26,36 +27,36 @@ class SessionStatus(Enum): class SessionManager: """ Core session management for NeuroRift. - + Handles: - Session creation, loading, saving, deletion - Session lifecycle management - Session indexing and discovery - Auto-save coordination """ - + NRS_VERSION = "1.0" - + def __init__(self, base_dir: str = "~/.neurorift"): self.base_dir = Path(base_dir).expanduser() self.sessions_dir = self.base_dir / "sessions" self.session_data_dir = self.base_dir / "session_data" self.logger = logging.getLogger(__name__) - + # Session directories self.active_dir = self.sessions_dir / "active" self.paused_dir = self.sessions_dir / "paused" self.completed_dir = self.sessions_dir / "completed" self.archived_dir = self.sessions_dir / "archived" - + # Current session self.current_session_id: Optional[str] = None self.current_session_data: Optional[Dict] = None - + # Initialize self._setup_directories() self._load_index() - + def _setup_directories(self): """Create session directory structure""" for directory in [ @@ -64,64 +65,61 @@ def _setup_directories(self): self.active_dir, self.paused_dir, self.completed_dir, - self.archived_dir + self.archived_dir, ]: directory.mkdir(parents=True, exist_ok=True) - + self.logger.info("Session directories initialized") - + def _load_index(self): """Load session index""" index_path = self.sessions_dir / "session_index.json" - + if index_path.exists(): try: - with open(index_path, 'r') as f: + with open(index_path, "r") as f: self.index = json.load(f) except Exception as e: self.logger.error(f"Error loading session index: {e}") self.index = {"sessions": {}, "last_active": None} else: self.index = {"sessions": {}, "last_active": None} - + def _save_index(self): """Save session index""" index_path = self.sessions_dir / "session_index.json" - + try: - with open(index_path, 'w') as f: + with open(index_path, "w") as f: json.dump(self.index, f, indent=2) except Exception as e: self.logger.error(f"Error saving session index: {e}") - + def _generate_session_id(self) -> str: """Generate unique session ID""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_id = uuid.uuid4().hex[:6] return f"session_{timestamp}_{unique_id}" - + def create_session( - self, - name: Optional[str] = None, - mode: str = "offensive", - description: str = "" + self, name: Optional[str] = None, mode: str = "offensive", description: str = "" ) -> str: """ Create a new session. - + Args: name: Session name (auto-generated if None) mode: Operational mode (offensive/defensive) description: Session description - + Returns: Session ID """ session_id = self._generate_session_id() - + if not name: name = f"Session {datetime.now().strftime('%Y-%m-%d %H:%M')}" - + # Create session data structure session_data = { "nrs_version": self.NRS_VERSION, @@ -132,15 +130,15 @@ def create_session( "updated_at": datetime.now().isoformat(), "status": SessionStatus.ACTIVE.value, "mode": mode, - "description": description + "description": description, }, "conversation": { "messages": [], "context": { "system_prompt": "", "current_agent": None, - "agent_state": {} - } + "agent_state": {}, + }, }, "task_state": { "task_id": None, @@ -150,66 +148,65 @@ def create_session( "total_steps": 0, "completed_steps": 0, "current_step": 0, - "percentage": 0 + "percentage": 0, }, "plan": {}, - "execution_state": {} + "execution_state": {}, }, "tools_state": { "active_tools": [], "tool_outputs": {}, - "pending_approvals": [] + "pending_approvals": [], }, "mode_state": { "current_mode": mode, - "mode_governor_state": { - "violations": [], - "allowed_tools": [] - } + "mode_governor_state": {"violations": [], "allowed_tools": []}, }, "results": { "output_dir": str(self.session_data_dir / session_id / "results"), "artifacts": [], - "reports": [] + "reports": [], }, "metadata": { "neurorift_version": "1.0.0", "python_version": "3.10+", "platform": "linux", "tags": [], - "notes": "" - } + "notes": "", + }, } - + # Create session data directory session_dir = self.session_data_dir / session_id session_dir.mkdir(parents=True, exist_ok=True) (session_dir / "results").mkdir(exist_ok=True) (session_dir / "logs").mkdir(exist_ok=True) (session_dir / "artifacts").mkdir(exist_ok=True) - + # Save session file self._save_session_file(session_id, session_data, SessionStatus.ACTIVE) - + # Update index self.index["sessions"][session_id] = { "name": name, "status": SessionStatus.ACTIVE.value, "mode": mode, "created_at": session_data["session"]["created_at"], - "updated_at": session_data["session"]["updated_at"] + "updated_at": session_data["session"]["updated_at"], } self.index["last_active"] = session_id self._save_index() - + # Set as current session self.current_session_id = session_id self.current_session_data = session_data - + self.logger.info(f"Created session: {session_id} ({name})") return session_id - - def _save_session_file(self, session_id: str, session_data: Dict, status: SessionStatus): + + def _save_session_file( + self, session_id: str, session_data: Dict, status: SessionStatus + ): """Save session to .nrs file""" # Determine directory based on status if status == SessionStatus.ACTIVE: @@ -220,17 +217,17 @@ def _save_session_file(self, session_id: str, session_data: Dict, status: Sessio target_dir = self.completed_dir else: target_dir = self.paused_dir - + session_file = target_dir / f"{session_id}.nrs" - + # Update status in data session_data["session"]["status"] = status.value session_data["session"]["updated_at"] = datetime.now().isoformat() - + # Atomic write (write to temp, then rename) - temp_file = session_file.with_suffix('.nrs.tmp') + temp_file = session_file.with_suffix(".nrs.tmp") try: - with open(temp_file, 'w') as f: + with open(temp_file, "w") as f: json.dump(session_data, f, indent=2) temp_file.rename(session_file) self.logger.debug(f"Saved session file: {session_file}") @@ -239,54 +236,52 @@ def _save_session_file(self, session_id: str, session_data: Dict, status: Sessio if temp_file.exists(): temp_file.unlink() raise - + def save_session(self, session_id: Optional[str] = None, notes: str = ""): """ Save current or specified session. - + Args: session_id: Session ID (uses current if None) notes: Optional notes to add to metadata """ if not session_id: session_id = self.current_session_id - + if not session_id: raise ValueError("No active session to save") - + if not self.current_session_data: raise ValueError("No session data loaded") - + # Update metadata if notes: self.current_session_data["metadata"]["notes"] = notes - + # Save to paused directory self._save_session_file( - session_id, - self.current_session_data, - SessionStatus.PAUSED + session_id, self.current_session_data, SessionStatus.PAUSED ) - + # Update index self.index["sessions"][session_id]["status"] = SessionStatus.PAUSED.value self.index["sessions"][session_id]["updated_at"] = datetime.now().isoformat() self._save_index() - + # Remove from active directory if exists active_file = self.active_dir / f"{session_id}.nrs" if active_file.exists(): active_file.unlink() - + self.logger.info(f"Saved session: {session_id}") - + def load_session(self, session_id: str) -> Dict: """ Load a session from file. - + Args: session_id: Session ID to load - + Returns: Session data dictionary """ @@ -295,101 +290,96 @@ def load_session(self, session_id: str) -> Dict: session_file = directory / f"{session_id}.nrs" if session_file.exists(): try: - with open(session_file, 'r') as f: + with open(session_file, "r") as f: session_data = json.load(f) - + # Validate version if session_data.get("nrs_version") != self.NRS_VERSION: self.logger.warning( f"Session version mismatch: {session_data.get('nrs_version')} != {self.NRS_VERSION}" ) # TODO: Implement migration - + self.current_session_id = session_id self.current_session_data = session_data - + self.logger.info(f"Loaded session: {session_id}") return session_data - + except Exception as e: self.logger.error(f"Error loading session {session_id}: {e}") raise - + raise FileNotFoundError(f"Session not found: {session_id}") - + def resume_session(self, session_id: Optional[str] = None) -> Dict: """ Resume a paused session. - + Args: session_id: Session ID (uses last active if None) - + Returns: Session data dictionary """ if not session_id: session_id = self.index.get("last_active") - + if not session_id: raise ValueError("No session to resume") - + # Load session session_data = self.load_session(session_id) - + # Move to active directory self._save_session_file(session_id, session_data, SessionStatus.ACTIVE) - + # Remove from paused directory paused_file = self.paused_dir / f"{session_id}.nrs" if paused_file.exists(): paused_file.unlink() - + # Update index self.index["sessions"][session_id]["status"] = SessionStatus.ACTIVE.value self.index["last_active"] = session_id self._save_index() - + self.logger.info(f"Resumed session: {session_id}") return session_data - + def list_sessions( - self, - status: Optional[str] = None, - mode: Optional[str] = None + self, status: Optional[str] = None, mode: Optional[str] = None ) -> List[Dict]: """ List all sessions with optional filtering. - + Args: status: Filter by status (active/paused/completed) mode: Filter by mode (offensive/defensive) - + Returns: List of session metadata """ sessions = [] - + for session_id, metadata in self.index["sessions"].items(): # Apply filters if status and metadata.get("status") != status: continue if mode and metadata.get("mode") != mode: continue - - sessions.append({ - "id": session_id, - **metadata - }) - + + sessions.append({"id": session_id, **metadata}) + # Sort by updated_at (most recent first) sessions.sort(key=lambda x: x.get("updated_at", ""), reverse=True) - + return sessions - + def delete_session(self, session_id: str, force: bool = False): """ Delete a session. - + Args: session_id: Session ID to delete force: Skip confirmation if True @@ -397,76 +387,82 @@ def delete_session(self, session_id: str, force: bool = False): if not force: # In CLI, this would prompt for confirmation self.logger.warning(f"Deleting session: {session_id}") - + # Remove session file from all directories - for directory in [self.active_dir, self.paused_dir, self.completed_dir, self.archived_dir]: + for directory in [ + self.active_dir, + self.paused_dir, + self.completed_dir, + self.archived_dir, + ]: session_file = directory / f"{session_id}.nrs" if session_file.exists(): session_file.unlink() - + # Remove session data directory session_dir = self.session_data_dir / session_id if session_dir.exists(): import shutil + shutil.rmtree(session_dir) - + # Remove from index if session_id in self.index["sessions"]: del self.index["sessions"][session_id] - + if self.index.get("last_active") == session_id: self.index["last_active"] = None - + self._save_index() - + self.logger.info(f"Deleted session: {session_id}") - + def rename_session(self, session_id: str, new_name: str): """ Rename a session. - + Args: session_id: Session ID to rename new_name: New session name """ # Load session session_data = self.load_session(session_id) - + # Update name session_data["session"]["name"] = new_name session_data["session"]["updated_at"] = datetime.now().isoformat() - + # Save status = SessionStatus(session_data["session"]["status"]) self._save_session_file(session_id, session_data, status) - + # Update index self.index["sessions"][session_id]["name"] = new_name self.index["sessions"][session_id]["updated_at"] = datetime.now().isoformat() self._save_index() - + self.logger.info(f"Renamed session {session_id} to: {new_name}") - + def get_current_session(self) -> Optional[Dict]: """Get current active session data""" return self.current_session_data - + def update_session_state(self, updates: Dict): """ Update current session state. - + Args: updates: Dictionary of updates to apply """ if not self.current_session_data: raise ValueError("No active session") - + # Deep merge updates self._deep_merge(self.current_session_data, updates) - + # Update timestamp self.current_session_data["session"]["updated_at"] = datetime.now().isoformat() - + def _deep_merge(self, base: Dict, updates: Dict): """Deep merge updates into base dictionary""" for key, value in updates.items(): @@ -479,36 +475,33 @@ def _deep_merge(self, base: Dict, updates: Dict): # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + # Initialize session manager manager = SessionManager() - + # Create new session session_id = manager.create_session( name="Example.com Security Assessment", mode="offensive", - description="Full security assessment" + description="Full security assessment", ) - + print(f"Created session: {session_id}") - + # Update session state - manager.update_session_state({ - "task_state": { - "target": "example.com", - "task_type": "reconnaissance" - } - }) - + manager.update_session_state( + {"task_state": {"target": "example.com", "task_type": "reconnaissance"}} + ) + # Save session manager.save_session(notes="Pausing for lunch") - + # List sessions sessions = manager.list_sessions() print(f"\nSessions: {len(sessions)}") for session in sessions: print(f" - {session['id']}: {session['name']} ({session['status']})") - + # Resume session manager.resume_session(session_id) print(f"\nResumed session: {session_id}") diff --git a/modules/session/session_serializer.py b/modules/session/session_serializer.py index 510ede2..e281215 100644 --- a/modules/session/session_serializer.py +++ b/modules/session/session_serializer.py @@ -17,66 +17,58 @@ class SessionSerializer: """ Handles conversion between Python objects and .nrs file format. - + Features: - JSON serialization/deserialization - Schema validation - Version handling - Optional compression for large sessions """ - + SUPPORTED_VERSIONS = ["1.0"] CURRENT_VERSION = "1.0" - + def __init__(self): self.logger = logging.getLogger(__name__) - - def serialize( - self, - session_data: Dict, - compress: bool = False - ) -> bytes: + + def serialize(self, session_data: Dict, compress: bool = False) -> bytes: """ Serialize session data to bytes. - + Args: session_data: Session data dictionary compress: Enable gzip compression - + Returns: Serialized bytes """ try: # Validate schema self._validate_schema(session_data) - + # Convert to JSON json_str = json.dumps(session_data, indent=2, ensure_ascii=False) - json_bytes = json_str.encode('utf-8') - + json_bytes = json_str.encode("utf-8") + # Compress if requested if compress: json_bytes = gzip.compress(json_bytes) self.logger.debug("Session data compressed") - + return json_bytes - + except Exception as e: self.logger.error(f"Serialization error: {e}") raise - - def deserialize( - self, - data: bytes, - decompress: bool = False - ) -> Dict: + + def deserialize(self, data: bytes, decompress: bool = False) -> Dict: """ Deserialize session data from bytes. - + Args: data: Serialized bytes decompress: Enable gzip decompression - + Returns: Session data dictionary """ @@ -85,36 +77,31 @@ def deserialize( if decompress: data = gzip.decompress(data) self.logger.debug("Session data decompressed") - + # Parse JSON - json_str = data.decode('utf-8') + json_str = data.decode("utf-8") session_data = json.loads(json_str) - + # Validate schema self._validate_schema(session_data) - + # Check version compatibility version = session_data.get("nrs_version") if version not in self.SUPPORTED_VERSIONS: self.logger.warning(f"Unsupported session version: {version}") # Attempt migration session_data = self._migrate_version(session_data, version) - + return session_data - + except Exception as e: self.logger.error(f"Deserialization error: {e}") raise - - def save_to_file( - self, - session_data: Dict, - file_path: Path, - compress: bool = False - ): + + def save_to_file(self, session_data: Dict, file_path: Path, compress: bool = False): """ Save session data to .nrs file. - + Args: session_data: Session data dictionary file_path: Path to .nrs file @@ -123,58 +110,54 @@ def save_to_file( try: # Serialize data = self.serialize(session_data, compress=compress) - + # Atomic write - temp_path = file_path.with_suffix('.nrs.tmp') - with open(temp_path, 'wb') as f: + temp_path = file_path.with_suffix(".nrs.tmp") + with open(temp_path, "wb") as f: f.write(data) - + # Rename to final path temp_path.rename(file_path) - + self.logger.info(f"Session saved to: {file_path}") - + except Exception as e: self.logger.error(f"Error saving session file: {e}") if temp_path.exists(): temp_path.unlink() raise - - def load_from_file( - self, - file_path: Path, - decompress: bool = False - ) -> Dict: + + def load_from_file(self, file_path: Path, decompress: bool = False) -> Dict: """ Load session data from .nrs file. - + Args: file_path: Path to .nrs file decompress: Enable decompression - + Returns: Session data dictionary """ try: - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: data = f.read() - + session_data = self.deserialize(data, decompress=decompress) - + self.logger.info(f"Session loaded from: {file_path}") return session_data - + except Exception as e: self.logger.error(f"Error loading session file: {e}") raise - + def _validate_schema(self, session_data: Dict): """ Validate session data schema. - + Args: session_data: Session data to validate - + Raises: ValueError: If schema is invalid """ @@ -186,49 +169,48 @@ def _validate_schema(self, session_data: Dict): "tools_state", "mode_state", "results", - "metadata" + "metadata", ] - + for field in required_fields: if field not in session_data: raise ValueError(f"Missing required field: {field}") - + # Validate session section session_required = ["id", "name", "created_at", "status", "mode"] for field in session_required: if field not in session_data["session"]: raise ValueError(f"Missing session field: {field}") - + self.logger.debug("Schema validation passed") - + def _migrate_version(self, session_data: Dict, from_version: str) -> Dict: """ Migrate session data from old version to current. - + Args: session_data: Session data to migrate from_version: Source version - + Returns: Migrated session data """ - self.logger.info(f"Migrating session from v{from_version} to v{self.CURRENT_VERSION}") - + self.logger.info( + f"Migrating session from v{from_version} to v{self.CURRENT_VERSION}" + ) + # Migration logic would go here # For now, just update version session_data["nrs_version"] = self.CURRENT_VERSION - + return session_data - + def export_session( - self, - session_data: Dict, - export_path: Path, - include_data: bool = True + self, session_data: Dict, export_path: Path, include_data: bool = True ): """ Export session to portable format. - + Args: session_data: Session data to export export_path: Export directory @@ -236,34 +218,30 @@ def export_session( """ try: export_path.mkdir(parents=True, exist_ok=True) - + # Save session file session_file = export_path / f"{session_data['session']['id']}.nrs" self.save_to_file(session_data, session_file, compress=True) - + # Copy session data if requested if include_data: # TODO: Copy session_data directory pass - + self.logger.info(f"Session exported to: {export_path}") - + except Exception as e: self.logger.error(f"Export error: {e}") raise - - def import_session( - self, - import_path: Path, - sessions_dir: Path - ) -> str: + + def import_session(self, import_path: Path, sessions_dir: Path) -> str: """ Import session from exported format. - + Args: import_path: Path to exported session sessions_dir: Target sessions directory - + Returns: Imported session ID """ @@ -272,67 +250,61 @@ def import_session( nrs_files = list(import_path.glob("*.nrs")) if not nrs_files: raise FileNotFoundError("No .nrs file found in import path") - + session_file = nrs_files[0] - + # Load session session_data = self.load_from_file(session_file, decompress=True) session_id = session_data["session"]["id"] - + # Copy to sessions directory target_file = sessions_dir / "paused" / f"{session_id}.nrs" self.save_to_file(session_data, target_file) - + self.logger.info(f"Session imported: {session_id}") return session_id - + except Exception as e: self.logger.error(f"Import error: {e}") raise - - def create_checkpoint( - self, - session_data: Dict, - checkpoint_dir: Path - ): + + def create_checkpoint(self, session_data: Dict, checkpoint_dir: Path): """ Create a checkpoint of current session state. - + Args: session_data: Session data to checkpoint checkpoint_dir: Directory for checkpoints """ try: checkpoint_dir.mkdir(parents=True, exist_ok=True) - + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_file = checkpoint_dir / f"checkpoint_{timestamp}.nrs" - + self.save_to_file(session_data, checkpoint_file, compress=True) - + # Keep only last 10 checkpoints checkpoints = sorted(checkpoint_dir.glob("checkpoint_*.nrs")) if len(checkpoints) > 10: for old_checkpoint in checkpoints[:-10]: old_checkpoint.unlink() - + self.logger.debug(f"Checkpoint created: {checkpoint_file}") - + except Exception as e: self.logger.error(f"Checkpoint error: {e}") - + def restore_from_checkpoint( - self, - checkpoint_dir: Path, - checkpoint_index: int = -1 + self, checkpoint_dir: Path, checkpoint_index: int = -1 ) -> Dict: """ Restore session from checkpoint. - + Args: checkpoint_dir: Directory containing checkpoints checkpoint_index: Index of checkpoint (-1 for latest) - + Returns: Restored session data """ @@ -340,13 +312,13 @@ def restore_from_checkpoint( checkpoints = sorted(checkpoint_dir.glob("checkpoint_*.nrs")) if not checkpoints: raise FileNotFoundError("No checkpoints found") - + checkpoint_file = checkpoints[checkpoint_index] session_data = self.load_from_file(checkpoint_file, decompress=True) - + self.logger.info(f"Restored from checkpoint: {checkpoint_file}") return session_data - + except Exception as e: self.logger.error(f"Restore error: {e}") raise @@ -355,9 +327,9 @@ def restore_from_checkpoint( # Example usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - + serializer = SessionSerializer() - + # Create sample session data session_data = { "nrs_version": "1.0", @@ -368,24 +340,24 @@ def restore_from_checkpoint( "updated_at": datetime.now().isoformat(), "status": "active", "mode": "offensive", - "description": "Test session" + "description": "Test session", }, "conversation": {"messages": []}, "task_state": {}, "tools_state": {}, "mode_state": {}, "results": {}, - "metadata": {} + "metadata": {}, } - + # Test serialization data = serializer.serialize(session_data) print(f"Serialized size: {len(data)} bytes") - + # Test compression compressed = serializer.serialize(session_data, compress=True) print(f"Compressed size: {len(compressed)} bytes") - + # Test deserialization restored = serializer.deserialize(data) print(f"Deserialized session: {restored['session']['name']}") diff --git a/modules/tool_manager/__init__.py b/modules/tool_manager/__init__.py index 2d69dfe..35ec9ed 100755 --- a/modules/tool_manager/__init__.py +++ b/modules/tool_manager/__init__.py @@ -4,4 +4,4 @@ from .tool_manager import ToolManager -__all__ = ['ToolManager'] \ No newline at end of file +__all__ = ["ToolManager"] diff --git a/modules/tool_manager/tool_manager.py b/modules/tool_manager/tool_manager.py index 0c48995..d7c316b 100644 --- a/modules/tool_manager/tool_manager.py +++ b/modules/tool_manager/tool_manager.py @@ -13,7 +13,14 @@ from typing import Dict, List, Optional, Any import requests from rich.console import Console -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn +from rich.progress import ( + Progress, + SpinnerColumn, + TextColumn, + BarColumn, + TaskProgressColumn, +) + class ToolManager: def __init__(self, base_dir: Path): @@ -22,20 +29,20 @@ def __init__(self, base_dir: Path): self.console = Console() self.tools_dir = self.base_dir / "tools" self.tools_dir.mkdir(parents=True, exist_ok=True) - + # Load tool configurations self.tools_config = self._load_tools_config() - + def _load_tools_config(self) -> Dict[str, Any]: """Load tool configurations from JSON""" config_path = self.base_dir / "config" / "tools.json" try: - with open(config_path, 'r') as f: + with open(config_path, "r") as f: return json.load(f) except Exception as e: self.logger.error(f"Failed to load tools config: {e}") return {} - + def install_tool(self, tool_name: str) -> Dict[str, Any]: """Install a security tool""" with Progress( @@ -43,32 +50,32 @@ def install_tool(self, tool_name: str) -> Dict[str, Any]: TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), - console=self.console + console=self.console, ) as progress: # Check if tool exists in config if tool_name not in self.tools_config: return { "success": False, - "error": f"Tool {tool_name} not found in configuration" + "error": f"Tool {tool_name} not found in configuration", } - + tool_config = self.tools_config[tool_name] - + # Check if tool is already installed task = progress.add_task("Checking installation...", total=None) if self._is_tool_installed(tool_name): progress.update(task, completed=True) return { "success": True, - "message": f"Tool {tool_name} is already installed" + "message": f"Tool {tool_name} is already installed", } - + # Create tool directory task = progress.add_task("Creating tool directory...", total=None) tool_dir = self.tools_dir / tool_name tool_dir.mkdir(exist_ok=True) progress.update(task, completed=True) - + # Download tool task = progress.add_task("Downloading tool...", total=None) try: @@ -76,11 +83,8 @@ def install_tool(self, tool_name: str) -> Dict[str, Any]: progress.update(task, completed=True) except Exception as e: self.logger.error(f"Failed to download tool {tool_name}: {e}") - return { - "success": False, - "error": f"Failed to download tool: {str(e)}" - } - + return {"success": False, "error": f"Failed to download tool: {str(e)}"} + # Install tool task = progress.add_task("Installing tool...", total=None) try: @@ -88,111 +92,109 @@ def install_tool(self, tool_name: str) -> Dict[str, Any]: progress.update(task, completed=True) except Exception as e: self.logger.error(f"Failed to install tool {tool_name}: {e}") - return { - "success": False, - "error": f"Failed to install tool: {str(e)}" - } - + return {"success": False, "error": f"Failed to install tool: {str(e)}"} + # Verify installation task = progress.add_task("Verifying installation...", total=None) if not self._verify_installation(tool_name, tool_config): progress.update(task, completed=True) return { "success": False, - "error": f"Failed to verify tool installation" + "error": f"Failed to verify tool installation", } progress.update(task, completed=True) - - return { - "success": True, - "message": f"Successfully installed {tool_name}" - } - + + return {"success": True, "message": f"Successfully installed {tool_name}"} + def _is_tool_installed(self, tool_name: str) -> bool: """Check if a tool is installed""" tool_config = self.tools_config[tool_name] tool_dir = self.tools_dir / tool_name - + # Check if tool directory exists if not tool_dir.exists(): return False - + # Check for required files for file in tool_config.get("required_files", []): if not (tool_dir / file).exists(): return False - + # Check if tool is executable if tool_config.get("executable"): executable = tool_dir / tool_config["executable"] if not executable.exists() or not os.access(executable, os.X_OK): return False - + return True - + def _download_tool(self, tool_name: str, tool_config: Dict) -> Path: """Download tool files""" download_url = tool_config["download_url"] download_dir = self.tools_dir / tool_name / "downloads" download_dir.mkdir(exist_ok=True) - + # Download file response = requests.get(download_url, stream=True) response.raise_for_status() - + # Get filename from URL filename = download_url.split("/")[-1] download_path = download_dir / filename - + # Save file - with open(download_path, 'wb') as f: + with open(download_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) - + return download_path - - def _install_tool_files(self, tool_name: str, download_path: Path, tool_config: Dict) -> None: + + def _install_tool_files( + self, tool_name: str, download_path: Path, tool_config: Dict + ) -> None: """Install tool files""" tool_dir = self.tools_dir / tool_name - + # Extract archive if needed - if download_path.suffix in ['.zip', '.tar.gz', '.tgz']: - if download_path.suffix == '.zip': + if download_path.suffix in [".zip", ".tar.gz", ".tgz"]: + if download_path.suffix == ".zip": import zipfile - with zipfile.ZipFile(download_path, 'r') as zip_ref: + + with zipfile.ZipFile(download_path, "r") as zip_ref: zip_ref.extractall(tool_dir) else: import tarfile - with tarfile.open(download_path, 'r:gz') as tar_ref: + + with tarfile.open(download_path, "r:gz") as tar_ref: tar_ref.extractall(tool_dir) - + # Copy files to tool directory for file in tool_config.get("files", []): src = tool_dir / file["source"] dst = tool_dir / file["destination"] dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src, dst) - + # Make executable if needed if tool_config.get("executable"): executable = tool_dir / tool_config["executable"] os.chmod(executable, 0o755) - + def _verify_installation(self, tool_name: str, tool_config: Dict) -> bool: """Verify tool installation""" tool_dir = self.tools_dir / tool_name - + # Check required files for file in tool_config.get("required_files", []): if not (tool_dir / file).exists(): return False - + # Check executable if tool_config.get("executable"): executable = tool_dir / tool_config["executable"] if not executable.exists() or not os.access(executable, os.X_OK): return False - + # Run verification command if specified if tool_config.get("verify_command"): try: @@ -200,16 +202,16 @@ def _verify_installation(self, tool_name: str, tool_config: Dict) -> bool: tool_config["verify_command"], cwd=tool_dir, capture_output=True, - text=True + text=True, ) if result.returncode != 0: return False except Exception as e: self.logger.error(f"Failed to verify tool {tool_name}: {e}") return False - + return True - + def update_tool(self, tool_name: str) -> Dict[str, Any]: """Update an installed tool""" with Progress( @@ -217,26 +219,23 @@ def update_tool(self, tool_name: str) -> Dict[str, Any]: TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), - console=self.console + console=self.console, ) as progress: # Check if tool exists in config if tool_name not in self.tools_config: return { "success": False, - "error": f"Tool {tool_name} not found in configuration" + "error": f"Tool {tool_name} not found in configuration", } - + # Check if tool is installed task = progress.add_task("Checking installation...", total=None) if not self._is_tool_installed(tool_name): progress.update(task, completed=True) - return { - "success": False, - "error": f"Tool {tool_name} is not installed" - } - + return {"success": False, "error": f"Tool {tool_name} is not installed"} + tool_config = self.tools_config[tool_name] - + # Backup current installation task = progress.add_task("Backing up current installation...", total=None) tool_dir = self.tools_dir / tool_name @@ -245,7 +244,7 @@ def update_tool(self, tool_name: str) -> Dict[str, Any]: shutil.rmtree(backup_dir) shutil.copytree(tool_dir, backup_dir) progress.update(task, completed=True) - + # Download new version task = progress.add_task("Downloading new version...", total=None) try: @@ -258,9 +257,9 @@ def update_tool(self, tool_name: str) -> Dict[str, Any]: shutil.copytree(backup_dir, tool_dir) return { "success": False, - "error": f"Failed to download new version: {str(e)}" + "error": f"Failed to download new version: {str(e)}", } - + # Install new version task = progress.add_task("Installing new version...", total=None) try: @@ -273,9 +272,9 @@ def update_tool(self, tool_name: str) -> Dict[str, Any]: shutil.copytree(backup_dir, tool_dir) return { "success": False, - "error": f"Failed to install new version: {str(e)}" + "error": f"Failed to install new version: {str(e)}", } - + # Verify new installation task = progress.add_task("Verifying new installation...", total=None) if not self._verify_installation(tool_name, tool_config): @@ -283,43 +282,34 @@ def update_tool(self, tool_name: str) -> Dict[str, Any]: # Restore backup shutil.rmtree(tool_dir) shutil.copytree(backup_dir, tool_dir) - return { - "success": False, - "error": f"Failed to verify new installation" - } + return {"success": False, "error": f"Failed to verify new installation"} progress.update(task, completed=True) - + # Remove backup shutil.rmtree(backup_dir) - - return { - "success": True, - "message": f"Successfully updated {tool_name}" - } - + + return {"success": True, "message": f"Successfully updated {tool_name}"} + def uninstall_tool(self, tool_name: str) -> Dict[str, Any]: """Uninstall a tool""" with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # Check if tool exists in config if tool_name not in self.tools_config: return { "success": False, - "error": f"Tool {tool_name} not found in configuration" + "error": f"Tool {tool_name} not found in configuration", } - + # Check if tool is installed task = progress.add_task("Checking installation...", total=None) if not self._is_tool_installed(tool_name): progress.update(task, completed=True) - return { - "success": False, - "error": f"Tool {tool_name} is not installed" - } - + return {"success": False, "error": f"Tool {tool_name} is not installed"} + # Create backup task = progress.add_task("Creating backup...", total=None) tool_dir = self.tools_dir / tool_name @@ -328,7 +318,7 @@ def uninstall_tool(self, tool_name: str) -> Dict[str, Any]: shutil.rmtree(backup_dir) shutil.copytree(tool_dir, backup_dir) progress.update(task, completed=True) - + # Remove tool task = progress.add_task("Removing tool...", total=None) try: @@ -338,52 +328,47 @@ def uninstall_tool(self, tool_name: str) -> Dict[str, Any]: self.logger.error(f"Failed to remove tool {tool_name}: {e}") # Restore backup shutil.copytree(backup_dir, tool_dir) - return { - "success": False, - "error": f"Failed to remove tool: {str(e)}" - } - + return {"success": False, "error": f"Failed to remove tool: {str(e)}"} + # Remove backup shutil.rmtree(backup_dir) - - return { - "success": True, - "message": f"Successfully uninstalled {tool_name}" - } - + + return {"success": True, "message": f"Successfully uninstalled {tool_name}"} + def list_tools(self) -> Dict[str, Any]: """List installed tools""" tools = {} - + for tool_name, tool_config in self.tools_config.items(): tools[tool_name] = { "installed": self._is_tool_installed(tool_name), - "version": self._get_tool_version(tool_name) if self._is_tool_installed(tool_name) else None, + "version": ( + self._get_tool_version(tool_name) + if self._is_tool_installed(tool_name) + else None + ), "description": tool_config.get("description", ""), "website": tool_config.get("website", ""), - "repository": tool_config.get("repository", "") + "repository": tool_config.get("repository", ""), } - - return { - "success": True, - "tools": tools - } - + + return {"success": True, "tools": tools} + def _get_tool_version(self, tool_name: str) -> Optional[str]: """Get installed tool version""" tool_config = self.tools_config[tool_name] - + if tool_config.get("version_command"): try: result = subprocess.run( tool_config["version_command"], cwd=self.tools_dir / tool_name, capture_output=True, - text=True + text=True, ) if result.returncode == 0: return result.stdout.strip() except Exception as e: self.logger.error(f"Failed to get version for {tool_name}: {e}") - - return None \ No newline at end of file + + return None diff --git a/modules/tools/base.py b/modules/tools/base.py index e244722..de985cf 100644 --- a/modules/tools/base.py +++ b/modules/tools/base.py @@ -4,20 +4,24 @@ from pydantic import BaseModel from pathlib import Path + class ToolMode(Enum): OFFENSIVE = "offensive" DEFENSIVE = "defensive" + class ToolCategory(Enum): RECON = "recon" EXPLOITATION = "exploitation" MONITORING = "monitoring" ANALYSIS = "analysis" + class ToolInput(BaseModel): target: str args: Dict[str, Any] = {} + class ToolResult(BaseModel): tool_name: str timestamp: str @@ -26,8 +30,11 @@ class ToolResult(BaseModel): structured_data: Dict[str, Any] error: Optional[str] = None + class BaseTool(ABC): - def __init__(self, name: str, description: str, category: ToolCategory, mode: ToolMode): + def __init__( + self, name: str, description: str, category: ToolCategory, mode: ToolMode + ): self.name = name self.description = description self.category = category @@ -47,7 +54,7 @@ def build_command(self, input_data: ToolInput) -> List[str]: def parse_output(self, raw_output: str) -> Dict[str, Any]: """Parse raw terminal output into structured JSON.""" pass - + @abstractmethod def check_installed(self) -> bool: """Check if the tool is installed on the system.""" diff --git a/modules/tools/wrappers/amass.py b/modules/tools/wrappers/amass.py index 93eb62b..cf82793 100644 --- a/modules/tools/wrappers/amass.py +++ b/modules/tools/wrappers/amass.py @@ -3,13 +3,14 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class AmassTool(BaseTool): def __init__(self): super().__init__( name="amass", description="In-depth Attack Surface Mapping and Asset Discovery", category=ToolCategory.RECON, - mode=ToolMode.OFFENSIVE # Can be considered offensive due to active scanning options, usually recon + mode=ToolMode.OFFENSIVE, # Can be considered offensive due to active scanning options, usually recon ) def validate_input(self, input_data: ToolInput) -> bool: @@ -20,12 +21,12 @@ def validate_input(self, input_data: ToolInput) -> bool: def build_command(self, input_data: ToolInput) -> List[str]: # Default to 'enum' mode for enumeration cmd = ["amass", "enum", "-d", input_data.target, "-json", "amass_out.json"] - + if input_data.args.get("passive"): cmd.append("-passive") if input_data.args.get("active"): cmd.append("-active") - + return cmd def parse_output(self, raw_output: str) -> Dict[str, Any]: @@ -35,14 +36,14 @@ def parse_output(self, raw_output: str) -> Dict[str, Any]: results = [] for line in raw_output.splitlines(): try: - if line.strip().startswith('{'): + if line.strip().startswith("{"): results.append(json.loads(line)) except json.JSONDecodeError: continue - + return { "subdomains": [r.get("name") for r in results if "name" in r], - "raw_entries": results + "raw_entries": results, } def check_installed(self) -> bool: diff --git a/modules/tools/wrappers/ike_scan.py b/modules/tools/wrappers/ike_scan.py index 3632a25..c04c434 100644 --- a/modules/tools/wrappers/ike_scan.py +++ b/modules/tools/wrappers/ike_scan.py @@ -3,20 +3,21 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class IkeScanTool(BaseTool): def __init__(self): super().__init__( name="ike-scan", description="Discover and fingerprint IKE hosts (IPsec VPN Servers)", category=ToolCategory.RECON, - mode=ToolMode.OFFENSIVE + mode=ToolMode.OFFENSIVE, ) def validate_input(self, input_data: ToolInput) -> bool: return bool(input_data.target) def build_command(self, input_data: ToolInput) -> List[str]: - cmd = ["ike-scan", "-M"] # -M for multiline output (better for parsing) + cmd = ["ike-scan", "-M"] # -M for multiline output (better for parsing) cmd.append(input_data.target) return cmd @@ -25,15 +26,12 @@ def parse_output(self, raw_output: str) -> Dict[str, Any]: # Basic parsing looking for "Handshake returned" # 192.168.1.1 Notify message 14 (NO_PROPOSAL_CHOSEN) # OR 192.168.1.1 Main Mode Handshake returned - + for line in raw_output.splitlines(): if "Handshake returned" in line or "Notify message" in line: parts = line.split() if parts: - hosts.append({ - "ip": parts[0], - "details": line - }) + hosts.append({"ip": parts[0], "details": line}) return {"ike_hosts": hosts} def check_installed(self) -> bool: diff --git a/modules/tools/wrappers/masscan.py b/modules/tools/wrappers/masscan.py index c740474..aca0954 100644 --- a/modules/tools/wrappers/masscan.py +++ b/modules/tools/wrappers/masscan.py @@ -3,13 +3,14 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class MasscanTool(BaseTool): def __init__(self): super().__init__( name="masscan", description="Mass IP port scanner", category=ToolCategory.RECON, - mode=ToolMode.OFFENSIVE # Active scanning is offensive + mode=ToolMode.OFFENSIVE, # Active scanning is offensive ) def validate_input(self, input_data: ToolInput) -> bool: @@ -18,32 +19,34 @@ def validate_input(self, input_data: ToolInput) -> bool: # Masscan needs ports usually if not input_data.args.get("ports"): # Default to top ports or require it? Let's default to a safe small range for demo - pass + pass return True def build_command(self, input_data: ToolInput) -> List[str]: target = input_data.target ports = input_data.args.get("ports", "80,443") rate = input_data.args.get("rate", "100") - + cmd = ["masscan", target, "-p", ports, "--rate", str(rate)] - + return cmd def parse_output(self, raw_output: str) -> Dict[str, Any]: # Masscan output example: Discovered open port 80/tcp on 192.168.1.1 findings = [] pattern = re.compile(r"Discovered open port (\d+)/(\w+) on ([\d\.]+)") - + for line in raw_output.splitlines(): match = pattern.search(line) if match: - findings.append({ - "port": int(match.group(1)), - "proto": match.group(2), - "ip": match.group(3) - }) - + findings.append( + { + "port": int(match.group(1)), + "proto": match.group(2), + "ip": match.group(3), + } + ) + return {"open_ports": findings} def check_installed(self) -> bool: diff --git a/modules/tools/wrappers/metasploit.py b/modules/tools/wrappers/metasploit.py index 3c61d16..dfeaeb7 100644 --- a/modules/tools/wrappers/metasploit.py +++ b/modules/tools/wrappers/metasploit.py @@ -2,39 +2,44 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class MetasploitTool(BaseTool): def __init__(self): super().__init__( name="metasploit", description="Penetration testing framework", category=ToolCategory.EXPLOITATION, - mode=ToolMode.DEFENSIVE + mode=ToolMode.DEFENSIVE, ) def validate_input(self, input_data: ToolInput) -> bool: # Metasploit needs at least a command or resource script - return bool(input_data.args.get("resource_script") or input_data.args.get("command")) + return bool( + input_data.args.get("resource_script") or input_data.args.get("command") + ) def build_command(self, input_data: ToolInput) -> List[str]: - cmd = ["msfconsole", "-q"] # Quiet mode - + cmd = ["msfconsole", "-q"] # Quiet mode + resource = input_data.args.get("resource_script") if resource: cmd.extend(["-r", resource]) - + command = input_data.args.get("command") if command: cmd.extend(["-x", command]) - + return cmd def parse_output(self, raw_output: str) -> Dict[str, Any]: # Metasploit output is very unstructured unless using specific plugins. # Check for success indicators. - success = "Meterpreter session" in raw_output or "Command shell session" in raw_output + success = ( + "Meterpreter session" in raw_output or "Command shell session" in raw_output + ) return { "success_indicator": success, - "raw_output": raw_output # Return full output for AI analysis + "raw_output": raw_output, # Return full output for AI analysis } def check_installed(self) -> bool: diff --git a/modules/tools/wrappers/mitmproxy.py b/modules/tools/wrappers/mitmproxy.py index 0aa33bb..20bcc0a 100644 --- a/modules/tools/wrappers/mitmproxy.py +++ b/modules/tools/wrappers/mitmproxy.py @@ -2,13 +2,14 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class MitmproxyTool(BaseTool): def __init__(self): super().__init__( name="mitmproxy", description="Interceptor for HTTP/HTTPS traffic (via mitmdump)", - category=ToolCategory.EXPLOITATION, #/Analysis - mode=ToolMode.DEFENSIVE + category=ToolCategory.EXPLOITATION, # /Analysis + mode=ToolMode.DEFENSIVE, ) def validate_input(self, input_data: ToolInput) -> bool: @@ -18,18 +19,18 @@ def validate_input(self, input_data: ToolInput) -> bool: def build_command(self, input_data: ToolInput) -> List[str]: # Use mitmdump for non-interactive cmd = ["mitmdump"] - + port = input_data.args.get("port", 8080) cmd.extend(["-p", str(port)]) - + script = input_data.args.get("script") if script: cmd.extend(["-s", script]) - + outfile = input_data.args.get("outfile") if outfile: cmd.extend(["-w", outfile]) - + return cmd def parse_output(self, raw_output: str) -> Dict[str, Any]: diff --git a/modules/tools/wrappers/netcat.py b/modules/tools/wrappers/netcat.py index bf9300b..d7b8357 100644 --- a/modules/tools/wrappers/netcat.py +++ b/modules/tools/wrappers/netcat.py @@ -2,13 +2,14 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class NetcatTool(BaseTool): def __init__(self): super().__init__( name="netcat", description="Networking utility for reading/writing to network connections", - category=ToolCategory.EXPLOITATION, # Can be used for shells, but also recon - mode=ToolMode.DEFENSIVE + category=ToolCategory.EXPLOITATION, # Can be used for shells, but also recon + mode=ToolMode.DEFENSIVE, ) def validate_input(self, input_data: ToolInput) -> bool: @@ -16,14 +17,14 @@ def validate_input(self, input_data: ToolInput) -> bool: def build_command(self, input_data: ToolInput) -> List[str]: cmd = ["nc"] - + if input_data.args.get("verbose"): cmd.append("-v") if input_data.args.get("udp"): cmd.append("-u") - if input_data.args.get("zero_io"): # Scanning mode + if input_data.args.get("zero_io"): # Scanning mode cmd.append("-z") - + cmd.append(input_data.target) cmd.append(str(input_data.args.get("port"))) return cmd @@ -32,10 +33,7 @@ def parse_output(self, raw_output: str) -> Dict[str, Any]: # Parse connection success # "Connection to 127.0.0.1 80 port [tcp/http] succeeded!" succeeded = "succeeded" in raw_output - return { - "connected": succeeded, - "output": raw_output - } + return {"connected": succeeded, "output": raw_output} def check_installed(self) -> bool: return shutil.which("nc") is not None or shutil.which("netcat") is not None diff --git a/modules/tools/wrappers/nmap.py b/modules/tools/wrappers/nmap.py index afdcaea..75f2d66 100644 --- a/modules/tools/wrappers/nmap.py +++ b/modules/tools/wrappers/nmap.py @@ -3,13 +3,14 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class NmapTool(BaseTool): def __init__(self): super().__init__( name="nmap", description="Network exploration tool and security / port scanner", category=ToolCategory.RECON, - mode=ToolMode.OFFENSIVE # Active scanning + mode=ToolMode.OFFENSIVE, # Active scanning ) def validate_input(self, input_data: ToolInput) -> bool: @@ -18,14 +19,14 @@ def validate_input(self, input_data: ToolInput) -> bool: def build_command(self, input_data: ToolInput) -> List[str]: # Always output valid XML for easy parsing cmd = ["nmap", "-oX", "-"] - + flags = input_data.args.get("flags", []) if flags: cmd.extend(flags) else: # Default safe scan - cmd.extend(["-sV", "-F"]) # Version detection, Fast scan - + cmd.extend(["-sV", "-F"]) # Version detection, Fast scan + cmd.append(input_data.target) return cmd @@ -43,12 +44,12 @@ def parse_output(self, raw_output: str) -> Dict[str, Any]: port_id = port.get("portid") state = port.find("state").get("state") service = port.find("service") - service_name = service.get("name") if service is not None else "unknown" - ports.append({ - "port": port_id, - "state": state, - "service": service_name - }) + service_name = ( + service.get("name") if service is not None else "unknown" + ) + ports.append( + {"port": port_id, "state": state, "service": service_name} + ) hosts.append({"ip": address, "ports": ports}) return {"hosts": hosts} except ET.ParseError: diff --git a/modules/tools/wrappers/sqlmap.py b/modules/tools/wrappers/sqlmap.py index 3358664..8b123b5 100644 --- a/modules/tools/wrappers/sqlmap.py +++ b/modules/tools/wrappers/sqlmap.py @@ -3,31 +3,37 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class SqlmapTool(BaseTool): def __init__(self): super().__init__( name="sqlmap", description="Automatic SQL injection and database takeover tool", category=ToolCategory.EXPLOITATION, - mode=ToolMode.DEFENSIVE # Defaults to defensive (assessment) + mode=ToolMode.DEFENSIVE, # Defaults to defensive (assessment) ) def validate_input(self, input_data: ToolInput) -> bool: return bool(input_data.target) def build_command(self, input_data: ToolInput) -> List[str]: - cmd = ["sqlmap", "-u", input_data.target, "--batch"] # --batch for non-interactive - + cmd = [ + "sqlmap", + "-u", + input_data.target, + "--batch", + ] # --batch for non-interactive + # Risk / Level cmd.extend(["--risk", str(input_data.args.get("risk", 1))]) cmd.extend(["--level", str(input_data.args.get("level", 1))]) - + # Other common flags if input_data.args.get("crawl"): cmd.extend(["--crawl", str(input_data.args.get("crawl"))]) if input_data.args.get("forms"): cmd.append("--forms") - + return cmd def parse_output(self, raw_output: str) -> Dict[str, Any]: @@ -35,19 +41,16 @@ def parse_output(self, raw_output: str) -> Dict[str, Any]: # Basic parsing of sqlmap output # "Parameter: id (GET)" # "Type: boolean-based blind" - + current_param = None - + for line in raw_output.splitlines(): if line.startswith("Parameter:"): current_param = line.split(":")[1].strip() elif line.startswith(" Type:"): vuln_type = line.split(":")[1].strip() - findings.append({ - "parameter": current_param, - "type": vuln_type - }) - + findings.append({"parameter": current_param, "type": vuln_type}) + return {"vulnerabilities": findings} def check_installed(self) -> bool: diff --git a/modules/tools/wrappers/unicornscan.py b/modules/tools/wrappers/unicornscan.py index a920d8a..4146048 100644 --- a/modules/tools/wrappers/unicornscan.py +++ b/modules/tools/wrappers/unicornscan.py @@ -3,13 +3,14 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class UnicornscanTool(BaseTool): def __init__(self): super().__init__( name="unicornscan", description="Asynchronous TCP/UDP port scanner", category=ToolCategory.RECON, - mode=ToolMode.OFFENSIVE + mode=ToolMode.OFFENSIVE, ) def validate_input(self, input_data: ToolInput) -> bool: @@ -27,15 +28,13 @@ def parse_output(self, raw_output: str) -> Dict[str, Any]: findings = [] # Regex for standard unicornscan output pattern = re.compile(r"TCP open\s+([\d\.]+):(\d+)\s+ttl") - + for line in raw_output.splitlines(): match = pattern.search(line) if match: - findings.append({ - "ip": match.group(1), - "port": int(match.group(2)), - "proto": "tcp" - }) + findings.append( + {"ip": match.group(1), "port": int(match.group(2)), "proto": "tcp"} + ) return {"open_ports": findings} def check_installed(self) -> bool: diff --git a/modules/tools/wrappers/wireshark.py b/modules/tools/wrappers/wireshark.py index 0eb0ba7..208ebd8 100644 --- a/modules/tools/wrappers/wireshark.py +++ b/modules/tools/wrappers/wireshark.py @@ -2,46 +2,50 @@ from typing import Dict, Any, List from modules.tools.base import BaseTool, ToolCategory, ToolMode, ToolInput + class WiresharkTool(BaseTool): def __init__(self): super().__init__( name="wireshark", description="Network protocol analyzer (using tshark)", category=ToolCategory.MONITORING, - mode=ToolMode.DEFENSIVE + mode=ToolMode.DEFENSIVE, ) def validate_input(self, input_data: ToolInput) -> bool: # Needs interface or file - return bool(input_data.args.get("interface")) or bool(input_data.args.get("read_file")) + return bool(input_data.args.get("interface")) or bool( + input_data.args.get("read_file") + ) def build_command(self, input_data: ToolInput) -> List[str]: cmd = ["tshark"] - + interface = input_data.args.get("interface") if interface: cmd.extend(["-i", interface]) - + read_file = input_data.args.get("read_file") if read_file: cmd.extend(["-r", read_file]) - + write_file = input_data.args.get("write_file") if write_file: cmd.extend(["-w", write_file]) - + # JSON output for easier parsing cmd.extend(["-T", "json"]) - + count = input_data.args.get("packet_count") if count: cmd.extend(["-c", str(count)]) - + return cmd def parse_output(self, raw_output: str) -> Dict[str, Any]: # tshark -T json outputs a JSON array of packets import json + try: packets = json.loads(raw_output) return {"packets": packets, "count": len(packets)} diff --git a/modules/web/bridge_server.py b/modules/web/bridge_server.py index 9b9122f..e30fbe9 100644 --- a/modules/web/bridge_server.py +++ b/modules/web/bridge_server.py @@ -11,7 +11,11 @@ # Import existing NeuroRift modules from modules.ai.ai_integration import OllamaClient, AIAnalyzer -from modules.orchestration.execution_manager import ExecutionManager, ScanRequest, SessionContext +from modules.orchestration.execution_manager import ( + ExecutionManager, + ScanRequest, + SessionContext, +) from modules.darkweb.robin import runner as robin_runner from modules.tools.base import ToolMode @@ -29,12 +33,14 @@ class Command(BaseModel): """Generic command structure""" + type: str data: Dict[str, Any] = {} class Response(BaseModel): """Generic response structure""" + success: bool data: Optional[Dict[str, Any]] = None error: Optional[str] = None @@ -44,7 +50,7 @@ class Response(BaseModel): async def execute_command(command: Dict[str, Any]) -> Response: """ Execute a command from Rust core - + Command types: - ai_generate: Generate AI response - tool_execute: Execute a security tool @@ -53,7 +59,7 @@ async def execute_command(command: Dict[str, Any]) -> Response: """ try: cmd_type = command.get("type") - + if cmd_type == "ai_generate": result = await handle_ai_generate(command) elif cmd_type == "tool_execute": @@ -63,10 +69,12 @@ async def execute_command(command: Dict[str, Any]) -> Response: elif cmd_type == "browser_action": result = await handle_browser_action(command) else: - raise HTTPException(status_code=400, detail=f"Unknown command type: {cmd_type}") - + raise HTTPException( + status_code=400, detail=f"Unknown command type: {cmd_type}" + ) + return Response(success=True, data=result) - + except Exception as e: logger.error(f"Command execution failed: {e}", exc_info=True) return Response(success=False, error=str(e)) @@ -76,9 +84,9 @@ async def handle_ai_generate(command: Dict[str, Any]) -> Dict[str, Any]: """Generate AI response""" prompt = command.get("prompt", "") model = command.get("model") - + response = await ollama.generate(prompt, model=model) - + return { "response": response, "model": model or ollama.model, @@ -90,24 +98,18 @@ async def handle_tool_execute(command: Dict[str, Any]) -> Dict[str, Any]: tool_name = command.get("tool", "") target = command.get("target", "") args = command.get("args", {}) - + # Create scan request - scan_request = ScanRequest( - tool_name=tool_name, - target=target, - args=args - ) - + scan_request = ScanRequest(tool_name=tool_name, target=target, args=args) + # Create minimal session context session_context = SessionContext( - session_id="temp", - mode=ToolMode.OFFENSIVE, # TODO: Get from Rust - history=[] + session_id="temp", mode=ToolMode.OFFENSIVE, history=[] # TODO: Get from Rust ) - + # Execute tool result = await execution_manager.execute_tool(scan_request, session_context) - + return { "tool_name": result.tool_name, "command": result.command, @@ -122,27 +124,23 @@ async def handle_tool_execute(command: Dict[str, Any]) -> Dict[str, Any]: async def handle_robin_search(command: Dict[str, Any]) -> Dict[str, Any]: """Execute Robin dark web search""" query = command.get("query", "") - + # TODO: Integrate with Robin module # For now, return placeholder - return { - "query": query, - "results": [], - "message": "Robin integration pending" - } + return {"query": query, "results": [], "message": "Robin integration pending"} async def handle_browser_action(command: Dict[str, Any]) -> Dict[str, Any]: """Execute browser automation action""" action = command.get("action", "") params = command.get("params", {}) - + # TODO: Integrate with browser automation # For now, return placeholder return { "action": action, "success": True, - "message": "Browser automation integration pending" + "message": "Browser automation integration pending", } @@ -161,4 +159,5 @@ async def startup_event(): if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="127.0.0.1", port=8766, log_level="info") diff --git a/modules/web/dashboard.py b/modules/web/dashboard.py index b6cc177..ba98b5a 100644 --- a/modules/web/dashboard.py +++ b/modules/web/dashboard.py @@ -14,13 +14,16 @@ page_title="NeuroRift v2.0", page_icon="🛡️", layout="wide", - initial_sidebar_state="expanded" + initial_sidebar_state="expanded", ) + def local_css(file_name): if Path(file_name).exists(): with open(file_name) as f: st.markdown(f"", unsafe_allow_html=True) + + local_css(str(ASSETS_DIR / "neurorift.css")) # --- Imports --- @@ -44,13 +47,14 @@ def local_css(file_name): nr = st.session_state.neurorift session_manager = nr.session_manager + # --- Main Controller --- def main(): # Sidebar Global Controls with st.sidebar: st.title("🛡️ NeuroRift") st.caption("Intelligence Amplified") - + if st.session_state.active_session_id: st.divider() sess = session_manager.get_current_session() @@ -59,7 +63,7 @@ def main(): if st.button("Close Session"): st.session_state.active_session_id = None st.rerun() - + st.divider() st.caption(f"v{nr.version}") @@ -75,8 +79,9 @@ def main(): nr.planner, nr.operator, nr.analyst, - nr.scribe + nr.scribe, ) + if __name__ == "__main__": main() diff --git a/modules/web/tunnel_manager.py b/modules/web/tunnel_manager.py index e4eab11..7c44ae3 100644 --- a/modules/web/tunnel_manager.py +++ b/modules/web/tunnel_manager.py @@ -15,22 +15,22 @@ class TunnelProvider(ABC): """Base class for tunnel providers""" - + @abstractmethod async def start(self, port: int, **kwargs) -> str: """Start tunnel and return public URL""" pass - + @abstractmethod async def stop(self): """Stop tunnel""" pass - + @abstractmethod def is_available(self) -> bool: """Check if provider is installed/available""" pass - + @property @abstractmethod def name(self) -> str: @@ -40,52 +40,57 @@ def name(self) -> str: class NgrokProvider(TunnelProvider): """ngrok tunnel provider""" - + def __init__(self): self.process: Optional[subprocess.Popen] = None self.public_url: Optional[str] = None - + @property def name(self) -> str: return "ngrok" - + def is_available(self) -> bool: try: - result = subprocess.run(['ngrok', 'version'], capture_output=True, text=True) + result = subprocess.run( + ["ngrok", "version"], capture_output=True, text=True + ) return result.returncode == 0 except FileNotFoundError: return False - + async def start(self, port: int, **kwargs) -> str: """Start ngrok tunnel""" if not self.is_available(): - raise RuntimeError("ngrok is not installed. Install from https://ngrok.com/download") - + raise RuntimeError( + "ngrok is not installed. Install from https://ngrok.com/download" + ) + # Start ngrok self.process = subprocess.Popen( - ['ngrok', 'http', str(port), '--log', 'stdout'], + ["ngrok", "http", str(port), "--log", "stdout"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True + text=True, ) - + # Wait for URL to appear in output await asyncio.sleep(2) - + # Get tunnel URL from ngrok API try: import aiohttp + async with aiohttp.ClientSession() as session: - async with session.get('http://localhost:4040/api/tunnels') as resp: + async with session.get("http://localhost:4040/api/tunnels") as resp: data = await resp.json() - if data.get('tunnels'): - self.public_url = data['tunnels'][0]['public_url'] + if data.get("tunnels"): + self.public_url = data["tunnels"][0]["public_url"] return self.public_url except Exception as e: logger.error(f"Failed to get ngrok URL: {e}") - + raise RuntimeError("Failed to start ngrok tunnel") - + async def stop(self): """Stop ngrok tunnel""" if self.process: @@ -97,38 +102,42 @@ async def stop(self): class CloudflareProvider(TunnelProvider): """Cloudflare Tunnel provider""" - + def __init__(self): self.process: Optional[subprocess.Popen] = None self.public_url: Optional[str] = None - + @property def name(self) -> str: return "cloudflare" - + def is_available(self) -> bool: try: - result = subprocess.run(['cloudflared', '--version'], capture_output=True, text=True) + result = subprocess.run( + ["cloudflared", "--version"], capture_output=True, text=True + ) return result.returncode == 0 except FileNotFoundError: return False - + async def start(self, port: int, **kwargs) -> str: """Start Cloudflare Tunnel""" if not self.is_available(): - raise RuntimeError("cloudflared is not installed. Install from https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation/") - + raise RuntimeError( + "cloudflared is not installed. Install from https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation/" + ) + # Start cloudflared self.process = subprocess.Popen( - ['cloudflared', 'tunnel', '--url', f'http://localhost:{port}'], + ["cloudflared", "tunnel", "--url", f"http://localhost:{port}"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - text=True + text=True, ) - + # Parse URL from output - url_pattern = re.compile(r'https://[a-z0-9-]+\.trycloudflare\.com') - + url_pattern = re.compile(r"https://[a-z0-9-]+\.trycloudflare\.com") + for _ in range(30): # Try for 30 seconds if self.process.stdout: line = self.process.stdout.readline() @@ -138,9 +147,9 @@ async def start(self, port: int, **kwargs) -> str: self.public_url = match.group(0) return self.public_url await asyncio.sleep(1) - + raise RuntimeError("Failed to start Cloudflare Tunnel") - + async def stop(self): """Stop Cloudflare Tunnel""" if self.process: @@ -152,42 +161,41 @@ async def stop(self): class LocaltunnelProvider(TunnelProvider): """localtunnel provider""" - + def __init__(self): self.process: Optional[subprocess.Popen] = None self.public_url: Optional[str] = None - + @property def name(self) -> str: return "localtunnel" - + def is_available(self) -> bool: try: - result = subprocess.run(['lt', '--version'], capture_output=True, text=True) + result = subprocess.run(["lt", "--version"], capture_output=True, text=True) return result.returncode == 0 except FileNotFoundError: return False - + async def start(self, port: int, **kwargs) -> str: """Start localtunnel""" if not self.is_available(): - raise RuntimeError("localtunnel is not installed. Run: npm install -g localtunnel") - - subdomain = kwargs.get('subdomain') - cmd = ['lt', '--port', str(port)] + raise RuntimeError( + "localtunnel is not installed. Run: npm install -g localtunnel" + ) + + subdomain = kwargs.get("subdomain") + cmd = ["lt", "--port", str(port)] if subdomain: - cmd.extend(['--subdomain', subdomain]) - + cmd.extend(["--subdomain", subdomain]) + self.process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) - + # Parse URL from output - url_pattern = re.compile(r'https://[a-z0-9-]+\.loca\.lt') - + url_pattern = re.compile(r"https://[a-z0-9-]+\.loca\.lt") + for _ in range(10): if self.process.stdout: line = self.process.stdout.readline() @@ -197,9 +205,9 @@ async def start(self, port: int, **kwargs) -> str: self.public_url = match.group(0) return self.public_url await asyncio.sleep(1) - + raise RuntimeError("Failed to start localtunnel") - + async def stop(self): """Stop localtunnel""" if self.process: @@ -211,53 +219,57 @@ async def stop(self): class TunnelManager: """Manages tunnel providers""" - + def __init__(self): self.providers: Dict[str, TunnelProvider] = { - 'ngrok': NgrokProvider(), - 'cloudflare': CloudflareProvider(), - 'localtunnel': LocaltunnelProvider(), + "ngrok": NgrokProvider(), + "cloudflare": CloudflareProvider(), + "localtunnel": LocaltunnelProvider(), } self.active_provider: Optional[TunnelProvider] = None - + def list_available_providers(self) -> List[str]: """List installed/available providers""" - return [name for name, provider in self.providers.items() if provider.is_available()] - + return [ + name for name, provider in self.providers.items() if provider.is_available() + ] + def get_provider(self, name: str) -> Optional[TunnelProvider]: """Get provider by name""" return self.providers.get(name) - + async def start_tunnel(self, provider_name: str, port: int, **kwargs) -> str: """Start tunnel with specified provider""" - if provider_name == 'auto': + if provider_name == "auto": # Auto-select first available provider available = self.list_available_providers() if not available: - raise RuntimeError("No tunnel providers available. Install ngrok, cloudflared, or localtunnel.") + raise RuntimeError( + "No tunnel providers available. Install ngrok, cloudflared, or localtunnel." + ) provider_name = available[0] logger.info(f"Auto-selected provider: {provider_name}") - + provider = self.get_provider(provider_name) if not provider: raise ValueError(f"Unknown provider: {provider_name}") - + if not provider.is_available(): raise RuntimeError(f"Provider {provider_name} is not installed") - + logger.info(f"Starting {provider_name} tunnel on port {port}...") public_url = await provider.start(port, **kwargs) self.active_provider = provider - + logger.info(f"✅ Tunnel active: {public_url}") return public_url - + async def stop_tunnel(self): """Stop active tunnel""" if self.active_provider: await self.active_provider.stop() self.active_provider = None - + def get_public_url(self) -> Optional[str]: """Get current public URL""" if self.active_provider: @@ -268,31 +280,39 @@ def get_public_url(self) -> Optional[str]: # CLI interface async def main(): import argparse - - parser = argparse.ArgumentParser(description='NeuroRift Tunnel Manager') - parser.add_argument('action', choices=['start', 'stop', 'list'], help='Action to perform') - parser.add_argument('--provider', default='auto', help='Tunnel provider (ngrok, cloudflare, localtunnel, auto)') - parser.add_argument('--port', type=int, default=3000, help='Local port to tunnel') - parser.add_argument('--subdomain', help='Custom subdomain (if supported)') - + + parser = argparse.ArgumentParser(description="NeuroRift Tunnel Manager") + parser.add_argument( + "action", choices=["start", "stop", "list"], help="Action to perform" + ) + parser.add_argument( + "--provider", + default="auto", + help="Tunnel provider (ngrok, cloudflare, localtunnel, auto)", + ) + parser.add_argument("--port", type=int, default=3000, help="Local port to tunnel") + parser.add_argument("--subdomain", help="Custom subdomain (if supported)") + args = parser.parse_args() - + manager = TunnelManager() - - if args.action == 'list': + + if args.action == "list": available = manager.list_available_providers() print("Available tunnel providers:") for provider in available: print(f" • {provider}") if not available: print(" (none installed)") - - elif args.action == 'start': + + elif args.action == "start": try: - url = await manager.start_tunnel(args.provider, args.port, subdomain=args.subdomain) + url = await manager.start_tunnel( + args.provider, args.port, subdomain=args.subdomain + ) print(f"\n🌐 Public URL: {url}\n") print("Press Ctrl+C to stop...") - + # Keep running try: while True: @@ -303,13 +323,13 @@ async def main(): except Exception as e: print(f"❌ Error: {e}") return 1 - - elif args.action == 'stop': + + elif args.action == "stop": await manager.stop_tunnel() print("Tunnel stopped") - + return 0 -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/modules/web/ui/orchestration_view.py b/modules/web/ui/orchestration_view.py index 49ee9d3..f892626 100644 --- a/modules/web/ui/orchestration_view.py +++ b/modules/web/ui/orchestration_view.py @@ -3,16 +3,16 @@ import pandas as pd from modules.orchestration.execution_manager import ExecutionManager from modules.ai.agents import NRPlanner, NROperator, NRAnalyst, NRScribe -from modules.orchestration.data_models import SessionContext, ToolExecutionResult, ScanRequest +from modules.orchestration.data_models import ( + SessionContext, + ToolExecutionResult, + ScanRequest, +) from modules.tools.base import ToolMode + def render_orchestration_view( - session_manager, - execution_manager, - planner, - operator, - analyst, - scribe + session_manager, execution_manager, planner, operator, analyst, scribe ): """ Renders the Orchestration View. @@ -24,69 +24,70 @@ def render_orchestration_view( return # Header - target = session['task_state'].get('target') or 'Unknown' - mode = session['session']['mode'] + target = session["task_state"].get("target") or "Unknown" + mode = session["session"]["mode"] st.markdown(f"### 🎯 Target: `{target}` | Mode: `{mode.upper()}`") - + # Workflow Steps (Tabs) - tab_manual, tab_plan, tab_exec, tab_results, tab_report = st.tabs(["0. Manual", "1. Plan", "2. Execute", "3. Results", "4. Report"]) + tab_manual, tab_plan, tab_exec, tab_results, tab_report = st.tabs( + ["0. Manual", "1. Plan", "2. Execute", "3. Results", "4. Report"] + ) # --- TAB 0: MANUAL OPERATIONS --- with tab_manual: st.markdown("#### Manual Control") st.info("Execute specific tools directly without AI planning.") - + tools = execution_manager.list_tools() - tool_names = [t['name'] for t in tools] - + tool_names = [t["name"] for t in tools] + c1, c2 = st.columns([1, 2]) with c1: selected_tool = st.selectbox("Select Tool", tool_names) with c2: import json + # Default helper based on tool default_args = "{}" if selected_tool == "nmap": default_args = '{"flags": ["-F"]}' elif selected_tool == "masscan": default_args = '{"ports": "80,443", "rate": "100"}' - - args_input = st.text_area("Arguments (JSON)", value=default_args, height=100) - + + args_input = st.text_area( + "Arguments (JSON)", value=default_args, height=100 + ) + if st.button("🚀 Run Tool Manually", type="primary"): try: args = json.loads(args_input) - + # Setup proper context - tool_mode = ToolMode.OFFENSIVE if mode == "offensive" else ToolMode.DEFENSIVE - context = SessionContext( - session_id=session['session']['id'], - mode=tool_mode, - target=target + tool_mode = ( + ToolMode.OFFENSIVE if mode == "offensive" else ToolMode.DEFENSIVE ) - - req = ScanRequest( - tool_name=selected_tool, - target=target, - args=args + context = SessionContext( + session_id=session["session"]["id"], mode=tool_mode, target=target ) - + + req = ScanRequest(tool_name=selected_tool, target=target, args=args) + with st.spinner(f"Running {selected_tool}..."): result = asyncio.run(execution_manager.execute_tool(req, context)) - + if result.status == "success": st.success("Execution Successful") with st.expander("Output", expanded=True): st.code(result.raw_output) - - # Optional: Add to findings automatically? + + # Optional: Add to findings automatically? # For manual mode, maybe just raw output is enough, or optional analysis. if st.button("Analyze this output with AI?"): - # Hook into analyst - pass + # Hook into analyst + pass else: st.error(f"Execution Failed: {result.error}") - + except json.JSONDecodeError: st.error("Invalid JSON format for arguments.") except Exception as e: @@ -96,42 +97,50 @@ def render_orchestration_view( with tab_plan: st.markdown("#### Mission Planning") task_desc = st.text_area( - "Mission Objective", - value=session.get('task_state', {}).get('plan', {}).get('goal', f"Perform a {mode} security assessment on {target}") + "Mission Objective", + value=session.get("task_state", {}) + .get("plan", {}) + .get("goal", f"Perform a {mode} security assessment on {target}"), ) - + if st.button("Generate Plan"): with st.spinner("AI Planner is strategizing..."): available_tools = execution_manager.list_tools() requests = asyncio.run(planner.create_plan(task_desc, available_tools)) - + # Store plan in session state st.session_state.current_plan = requests # Update session data - session_manager.update_session_state({ - "task_state": { - "plan": { - "goal": task_desc, - "steps": [r.dict() for r in requests] + session_manager.update_session_state( + { + "task_state": { + "plan": { + "goal": task_desc, + "steps": [r.dict() for r in requests], + } } } - }) - + ) + # Display Plan if "current_plan" in st.session_state and st.session_state.current_plan: st.success(f"Plan Generated: {len(st.session_state.current_plan)} steps") - + steps_data = [] for i, req in enumerate(st.session_state.current_plan): - steps_data.append({ - "Step": i+1, - "Tool": req.tool_name, - "Args": str(req.args), - "Reasoning": getattr(req, 'reasoning', 'N/A') # Assuming we added reasoning to Request or handle separately - }) - + steps_data.append( + { + "Step": i + 1, + "Tool": req.tool_name, + "Args": str(req.args), + "Reasoning": getattr( + req, "reasoning", "N/A" + ), # Assuming we added reasoning to Request or handle separately + } + ) + st.table(pd.DataFrame(steps_data)) - + if st.button("✅ Approve Plan"): st.session_state.plan_approved = True st.session_state.execution_ready = True @@ -142,41 +151,45 @@ def render_orchestration_view( # --- TAB 2: EXECUTE --- with tab_exec: st.markdown("#### Mission Execution") - + if not st.session_state.get("plan_approved"): st.warning("Please generate and approve a plan first.") else: if st.button("▶️ Start Execution", type="primary"): # Setup context - tool_mode = ToolMode.OFFENSIVE if mode == "offensive" else ToolMode.DEFENSIVE + tool_mode = ( + ToolMode.OFFENSIVE if mode == "offensive" else ToolMode.DEFENSIVE + ) context = SessionContext( - session_id=session['session']['id'], - mode=tool_mode, - target=target + session_id=session["session"]["id"], mode=tool_mode, target=target ) - + # Progress container prog_container = st.container() - + async def run_pipeline(): results = [] total = len(st.session_state.current_plan) - + progress_bar = prog_container.progress(0, text="Starting...") - + for i, req in enumerate(st.session_state.current_plan): - progress_bar.progress((i)/total, text=f"Running {req.tool_name}...") - + progress_bar.progress( + (i) / total, text=f"Running {req.tool_name}..." + ) + # Run tool result = await execution_manager.execute_tool(req, context) results.append(result) - + # Show intermediate status if result.status == "success": prog_container.success(f"✅ {req.tool_name}: Success") else: - prog_container.error(f"❌ {req.tool_name}: Failed - {result.error}") - + prog_container.error( + f"❌ {req.tool_name}: Failed - {result.error}" + ) + progress_bar.progress(1.0, text="Execution Complete") return results @@ -188,16 +201,16 @@ async def run_pipeline(): # --- TAB 3: RESULTS --- with tab_results: st.markdown("#### Mission Analysis") - + if "execution_results" in st.session_state: results = st.session_state.execution_results - + # Run Analyst if st.button("🧠 Run AI Analysis"): with st.spinner("Analyzing output..."): findings = asyncio.run(analyst.analyze_results(results)) st.session_state.findings = findings - + if "findings" in st.session_state: findings = st.session_state.findings if findings: @@ -207,7 +220,7 @@ async def run_pipeline(): st.write(f"**Description:** {f.description}") else: st.info("No significant findings identified.") - + # Show Raw Output st.divider() st.markdown("**Raw Tool Output**") @@ -220,21 +233,27 @@ async def run_pipeline(): # --- TAB 4: REPORT --- with tab_report: st.markdown("#### Report Generation") - + if "findings" in st.session_state: if st.button("📝 Generate Report"): with st.spinner("Compiling report..."): # task_desc might be in local scope, grab from session if needed - goal = session.get('task_state', {}).get('plan', {}).get('goal', f"Scan {target}") - report_content = asyncio.run(scribe.generate_report(goal, st.session_state.findings)) + goal = ( + session.get("task_state", {}) + .get("plan", {}) + .get("goal", f"Scan {target}") + ) + report_content = asyncio.run( + scribe.generate_report(goal, st.session_state.findings) + ) st.session_state.report_content = report_content - + if "report_content" in st.session_state: st.markdown(st.session_state.report_content) st.download_button( - "Download Report (MD)", - st.session_state.report_content, - file_name=f"report_{session['session']['id']}.md" + "Download Report (MD)", + st.session_state.report_content, + file_name=f"report_{session['session']['id']}.md", ) else: st.warning("No findings available to report on.") diff --git a/modules/web/ui/session_view.py b/modules/web/ui/session_view.py index fb77c0b..e44f9a8 100644 --- a/modules/web/ui/session_view.py +++ b/modules/web/ui/session_view.py @@ -2,6 +2,7 @@ from datetime import datetime from modules.session.session_manager import SessionManager + def render_session_view(session_manager: SessionManager): """ Renders the Session Management interface. @@ -15,11 +16,19 @@ def render_session_view(session_manager: SessionManager): with col1: st.subheader("Create New Session") with st.form("create_session_form"): - name = st.text_input("Session Name", value=f"Op-{datetime.now().strftime('%Y%m%d-%H%M')}") + name = st.text_input( + "Session Name", value=f"Op-{datetime.now().strftime('%Y%m%d-%H%M')}" + ) target = st.text_input("Target (Domain/IP)", placeholder="example.com") - mode = st.selectbox("Operational Mode", ["offensive", "defensive"], help="Offensive: Active Scanning allowed. Defensive: Passive/Analysis only.") - description = st.text_area("Description", placeholder="Objective of this assessment...") - + mode = st.selectbox( + "Operational Mode", + ["offensive", "defensive"], + help="Offensive: Active Scanning allowed. Defensive: Passive/Analysis only.", + ) + description = st.text_area( + "Description", placeholder="Objective of this assessment..." + ) + submitted = st.form_submit_button("🚀 Launch Session", type="primary") if submitted: if not target: @@ -27,14 +36,12 @@ def render_session_view(session_manager: SessionManager): else: try: session_id = session_manager.create_session( - name=name, - mode=mode, - description=description + name=name, mode=mode, description=description ) # Set initial target in task state - session_manager.update_session_state({ - "task_state": {"target": target} - }) + session_manager.update_session_state( + {"task_state": {"target": target}} + ) st.session_state.active_session_id = session_id st.success(f"Session '{name}' created successfully!") st.rerun() @@ -44,20 +51,22 @@ def render_session_view(session_manager: SessionManager): with col2: st.subheader("Existing Sessions") sessions = session_manager.list_sessions() - + if not sessions: st.info("No active sessions found.") else: # Convert to DataFrame for nicer display? Or just list for s in sessions: - with st.expander(f"{s['name']} | {s['mode'].upper()} | {s['updated_at'][:16]}"): + with st.expander( + f"{s['name']} | {s['mode'].upper()} | {s['updated_at'][:16]}" + ): st.caption(f"ID: {s['id']}") - + if st.button("📂 Load Session", key=f"load_{s['id']}"): - session_manager.resume_session(s['id']) - st.session_state.active_session_id = s['id'] + session_manager.resume_session(s["id"]) + st.session_state.active_session_id = s["id"] st.rerun() - + if st.button("🗑️ Delete", key=f"del_{s['id']}"): - session_manager.delete_session(s['id']) + session_manager.delete_session(s["id"]) st.rerun() diff --git a/modules/web/web_module.py b/modules/web/web_module.py index 16d1ecf..f410354 100644 --- a/modules/web/web_module.py +++ b/modules/web/web_module.py @@ -13,6 +13,7 @@ from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn + class WebModule: def __init__(self, base_dir: Path, ai_analyzer: Any): self.base_dir = base_dir @@ -21,28 +22,32 @@ def __init__(self, base_dir: Path, ai_analyzer: Any): self.console = Console() self.wordlist = "/usr/share/wordlists/dirb/common.txt" - async def run_web_discovery(self, target: str, output_dir: Optional[Path] = None, use_ai: bool = True) -> Dict[str, Any]: + async def run_web_discovery( + self, target: str, output_dir: Optional[Path] = None, use_ai: bool = True + ) -> Dict[str, Any]: """ Run web discovery on the target """ self.logger.info("Starting web discovery on %s (AI: %s)", target, use_ai) - + results = { "target": target, "technologies": [], "directories": [], "ai_analysis": {}, - "errors": [] + "errors": [], } try: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: # 1. Technology Identification - task = progress.add_task("Identifying technologies (whatweb)...", total=None) + task = progress.add_task( + "Identifying technologies (whatweb)...", total=None + ) techs = await self._run_whatweb(target) results["technologies"] = techs progress.update(task, completed=True) @@ -60,12 +65,14 @@ async def run_web_discovery(self, target: str, output_dir: Optional[Path] = None results["ai_analysis"] = analysis progress.update(task, completed=True) else: - results["ai_analysis"] = {"status": "AI analysis skipped or pending"} + results["ai_analysis"] = { + "status": "AI analysis skipped or pending" + } # Save results if output_dir: self.save_results(results, output_dir) - + return results except Exception as e: @@ -79,13 +86,22 @@ async def _run_whatweb(self, target: str) -> List[Dict[str, Any]]: self.logger.warning("whatweb not found. Skipping technology detection.") return [] - cmd = ["whatweb", "--color=never", "--no-errors", "-a", "3", "--aggression", "3", target] + cmd = [ + "whatweb", + "--color=never", + "--no-errors", + "-a", + "3", + "--aggression", + "3", + target, + ] try: # whatweb output is often a bit messy, let's try to get some structured info stdout = await self._run_command(" ".join(cmd)) if not stdout: return [] - + # Simple parsing for now, whatweb doesn't have a great JSON output by default return [{"raw": stdout}] except Exception as e: @@ -99,7 +115,9 @@ async def _run_ffuf(self, target: str) -> List[Dict[str, Any]]: return [] if not os.path.exists(self.wordlist): - self.logger.warning("Wordlist not found at %s. Skipping ffuf.", self.wordlist) + self.logger.warning( + "Wordlist not found at %s. Skipping ffuf.", self.wordlist + ) return [] # Ensure target has protocol and trailing slash for FUZZ @@ -110,31 +128,41 @@ async def _run_ffuf(self, target: str) -> List[Dict[str, Any]]: url += "/" url += "FUZZ" - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as temp_file: temp_output = temp_file.name cmd = [ "ffuf", - "-u", url, - "-w", self.wordlist, - "-mc", "200,204,301,302,307,401,403", - "-o", temp_output, - "-of", "json", - "-s" # silent + "-u", + url, + "-w", + self.wordlist, + "-mc", + "200,204,301,302,307,401,403", + "-o", + temp_output, + "-of", + "json", + "-s", # silent ] try: await self._run_command(" ".join(cmd)) - + if os.path.exists(temp_output): - with open(temp_output, 'r') as f: + with open(temp_output, "r") as f: data = json.load(f) results = data.get("results", []) - return [{ - "url": r.get("url"), - "status": r.get("status"), - "content_length": r.get("length") - } for r in results] + return [ + { + "url": r.get("url"), + "status": r.get("status"), + "content_length": r.get("length"), + } + for r in results + ] return [] except Exception as e: self.logger.error("Error running ffuf: %s", e) @@ -143,7 +171,9 @@ async def _run_ffuf(self, target: str) -> List[Dict[str, Any]]: if os.path.exists(temp_output): os.unlink(temp_output) - async def analyze_with_ai(self, target: str, results: Dict[str, Any]) -> Dict[str, Any]: + async def analyze_with_ai( + self, target: str, results: Dict[str, Any] + ) -> Dict[str, Any]: """Analyze web discovery results using AI""" if not self.ai_analyzer: return {"error": "AI Analyzer not initialized"} @@ -172,17 +202,20 @@ async def analyze_with_ai(self, target: str, results: Dict[str, Any]) -> Dict[st "next_steps": ["...", "..."] }} """ - + system_prompt = "You are a web security expert. Analyze discovery results and provide actionable insights." - + try: - response = await self.ai_analyzer.ollama.generate(prompt, system_prompt=system_prompt) + response = await self.ai_analyzer.ollama.generate( + prompt, system_prompt=system_prompt + ) if response: import re + try: return json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL) + json_match = re.search(r"```json\n(.*?)\n```", response, re.DOTALL) if json_match: return json.loads(json_match.group(1)) return {"error": "Failed to get AI analysis"} @@ -193,20 +226,20 @@ async def analyze_with_ai(self, target: str, results: Dict[str, Any]) -> Dict[st def save_results(self, results: Dict[str, Any], output_dir: Path): """Save results to file""" output_dir.mkdir(parents=True, exist_ok=True) - + with open(output_dir / "web_discovery_results.json", "w") as f: json.dump(results, f, indent=2) - + with open(output_dir / "web_report.md", "w") as f: f.write(f"# Web Discovery Report for {results['target']}\n\n") - + f.write("## Technologies Identified\n") if results["technologies"]: for tech in results["technologies"]: f.write(f"```\n{tech.get('raw', 'No data')}\n```\n") else: f.write("No technologies identified.\n") - + f.write("\n## Discovered Directories & Files\n") if results["directories"]: f.write("| URL | Status | Length |\n") @@ -215,10 +248,12 @@ def save_results(self, results: Dict[str, Any], output_dir: Path): f.write(f"| {d['url']} | {d['status']} | {d['content_length']} |\n") else: f.write("No directories or files found.\n") - + f.write("\n## 🤖 AI Security Assessment\n") analysis = results.get("ai_analysis", {}) - f.write(f"### Tech Stack Assessment\n{analysis.get('tech_stack_assessment', 'N/A')}\n\n") + f.write( + f"### Tech Stack Assessment\n{analysis.get('tech_stack_assessment', 'N/A')}\n\n" + ) f.write("### Interesting Findings\n") for finding in analysis.get("interesting_findings", []): f.write(f"- {finding}\n") @@ -231,17 +266,21 @@ def save_results(self, results: Dict[str, Any], output_dir: Path): def _check_tool(self, tool_name: str) -> bool: import subprocess + try: - subprocess.run(["which", tool_name], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run( + ["which", tool_name], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) return True except subprocess.CalledProcessError: return False async def _run_command(self, command: str) -> str: process = await asyncio.create_subprocess_shell( - command, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() return stdout.decode().strip() diff --git a/neurorift_main.py b/neurorift_main.py index e169413..2da3a43 100755 --- a/neurorift_main.py +++ b/neurorift_main.py @@ -37,13 +37,13 @@ RateLimiter, FilePermissionManager, validate_target, - sanitize_filename + sanitize_filename, ) from utils.auth import get_auth_manager, Permission from modules.recon.recon_module import EnhancedReconModule from modules.ai.ai_integration import AIAnalyzer, OllamaClient -from modules.ai.ai_orchestrator import AIOrchestrator # New Import +from modules.ai.ai_orchestrator import AIOrchestrator # New Import import modules.darkweb as darkweb_module from modules.ai.agent import NeuroRiftAgent from modules.web.web_module import WebModule @@ -52,7 +52,11 @@ from modules.session.session_manager import SessionManager from modules.session.autosave_service import AutoSaveService from modules.session.session_cli import SessionCLI, setup_session_parser -from modules.orchestration.execution_manager import ExecutionManager, ScanRequest, SessionContext +from modules.orchestration.execution_manager import ( + ExecutionManager, + ScanRequest, + SessionContext, +) from modules.ai.agents import NRPlanner, NROperator, NRAnalyst, NRScribe from modules.tools.base import ToolMode from modules.config.config_wizard import ConfigWizard @@ -70,7 +74,7 @@ def __init__(self): # Initialize Session Management self.session_manager = SessionManager() self.auto_save = AutoSaveService(self.session_manager) - + # Initialize AI components self.ollama = OllamaClient() self.ai_analyzer = AIAnalyzer(self.ollama) @@ -79,7 +83,7 @@ def __init__(self): self.web_module = WebModule(self.base_dir, self.ai_analyzer) self.exploit_module = ExploitModule(self.base_dir, self.ai_analyzer) self.scan_module = ScanModule(self.base_dir, self.ai_analyzer) - + # Orchestration Components self.execution_manager = ExecutionManager(self.session_manager) self.planner = NRPlanner(self.ollama) @@ -177,8 +181,12 @@ def install_missing_tools(self): self.logger.info("Installing %s...", tool) try: # SECURITY FIX: Use full executable path and handle subprocess failures - result = subprocess.run(["/usr/bin/go", "install", package], - capture_output=True, text=True, check=True) + result = subprocess.run( + ["/usr/bin/go", "install", package], + capture_output=True, + text=True, + check=True, + ) self.logger.info("Successfully installed %s", tool) except subprocess.CalledProcessError as e: self.logger.error("Failed to install %s: %s", tool, e) @@ -192,7 +200,9 @@ def install_missing_tools(self): async def run_recon(self, target: str, output_dir: Optional[Path] = None): """Run reconnaissance on target""" - recon = EnhancedReconModule(self.base_dir, self.ai_analyzer, config_path="configs/tools.json") + recon = EnhancedReconModule( + self.base_dir, self.ai_analyzer, config_path="configs/tools.json" + ) return await recon.run_recon(target, output_dir) def ask_ai(self, question: str): @@ -211,9 +221,9 @@ def ask_ai(self, question: str): ) @RateLimiter(max_calls=5, time_window=60) - def generate_tool(self, description: str, identifier: str = 'default'): + def generate_tool(self, description: str, identifier: str = "default"): """Generate a custom tool using AI and save it to custom_tools directory. - + Args: description: Tool description identifier: Rate limit identifier (username/session) @@ -223,12 +233,12 @@ def generate_tool(self, description: str, identifier: str = 'default'): if not description or not isinstance(description, str): self.logger.error("Invalid tool description") return - + # SECURITY: Limit description length if len(description) > 500: self.logger.error("Tool description too long (max 500 chars)") return - + # Use os.path.expanduser to properly handle home directory tool_dir = Path.home() / ".neurorift" / "custom_tools" self.console.print( @@ -262,10 +272,14 @@ def generate_tool(self, description: str, identifier: str = 'default'): base_name = re.sub(r"[^a-zA-Z0-9]+", "_", description.strip().lower())[ :32 ].strip("_") - filename = sanitize_filename(f"{base_name or 'custom_tool'}_{int(time.time())}.py") - + filename = sanitize_filename( + f"{base_name or 'custom_tool'}_{int(time.time())}.py" + ) + # SECURITY: Validate path to prevent traversal - tool_path = SecurityValidator.sanitize_path(str(tool_dir / filename), base_dir=tool_dir) + tool_path = SecurityValidator.sanitize_path( + str(tool_dir / filename), base_dir=tool_dir + ) if not tool_path: self.logger.error("Invalid tool path") return @@ -273,7 +287,7 @@ def generate_tool(self, description: str, identifier: str = 'default'): # Write the generated tool to file with open(tool_path, "w", encoding="utf-8") as f: f.write(response) - + # SECURITY: Set secure file permissions (0o600) for generated tool files FilePermissionManager.set_secure_permissions(tool_path, mode=0o600) self.console.print( @@ -319,13 +333,13 @@ def list_custom_tools(self): try: # SECURITY: Use Path and validate directory tool_dir = Path.home() / ".neurorift" / "custom_tools" - + # SECURITY: Validate path tool_dir = SecurityValidator.sanitize_path(str(tool_dir)) if not tool_dir: self.logger.error("Invalid tool directory path") return - + metadata_path = tool_dir / "metadata.json" if not tool_dir.exists(): @@ -463,7 +477,7 @@ def get_parser(): For detailed documentation, visit: https://github.com/demonking369/NeuroRift """, - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--target", "-t", help="Target domain or IP") parser.add_argument( @@ -499,89 +513,130 @@ def get_parser(): "--stealth", "-s", action="store_true", help="Enable stealth mode" ) parser.add_argument( - "--ai-pipeline", action="store_true", help="Enable the advanced multi-prompt AI pipeline." + "--ai-pipeline", + action="store_true", + help="Enable the advanced multi-prompt AI pipeline.", ) parser.add_argument( - "--prompt-dir", help="Directory for the AI pipeline prompts.", default="prompts/system_prompts" + "--prompt-dir", + help="Directory for the AI pipeline prompts.", + default="prompts/system_prompts", ) parser.add_argument( - "--uninstall", action="store_true", help="Uninstall NeuroRift and its components" + "--uninstall", + action="store_true", + help="Uninstall NeuroRift and its components", ) parser.add_argument( - "--webmod", action="store_true", help="Launch NeuroRift web interface (Streamlit UI)" + "--webmod", + action="store_true", + help="Launch NeuroRift web interface (Streamlit UI)", ) parser.add_argument( - "--web-host", default="localhost", help="Web interface host (default: localhost)" + "--web-host", + default="localhost", + help="Web interface host (default: localhost)", ) parser.add_argument( "--web-port", type=int, default=8501, help="Web interface port (default: 8501)" ) parser.add_argument( - "--agentic", "--ai-agent", action="store_true", help="Enable simple agentic AI mode (deprecated, use --orchestrated)" + "--agentic", + "--ai-agent", + action="store_true", + help="Enable simple agentic AI mode (deprecated, use --orchestrated)", ) parser.add_argument( - "--orchestrated", action="store_true", help="🆕 Enable NeuroRift Orchestrated Intelligence Mode (multi-agent)" + "--orchestrated", + action="store_true", + help="🆕 Enable NeuroRift Orchestrated Intelligence Mode (multi-agent)", ) parser.add_argument( "--mode", choices=["offensive", "defensive"], - help="🆕 NeuroRift operational mode: 'offensive' (discovery) or 'defensive' (analysis/mitigation)" + help="🆕 NeuroRift operational mode: 'offensive' (discovery) or 'defensive' (analysis/mitigation)", ) parser.add_argument( - "--resume", metavar="TASK_ID", help="🆕 Resume a previously interrupted NeuroRift task" + "--resume", + metavar="TASK_ID", + help="🆕 Resume a previously interrupted NeuroRift task", ) parser.add_argument( - "--analyze", metavar="FILE", help="🆕 Analyze existing scan results (DEFENSIVE mode)" + "--analyze", + metavar="FILE", + help="🆕 Analyze existing scan results (DEFENSIVE mode)", ) parser.add_argument( "--no-ai", action="store_true", help="Disable AI analysis for the current mode" ) parser.add_argument( - "--configure", action="store_true", help="🆕 Launch interactive configuration wizard" + "--configure", + action="store_true", + help="🆕 Launch interactive configuration wizard", ) # Add subparsers for commands - subparsers = parser.add_subparsers(dest='command', help='Available commands') - + subparsers = parser.add_subparsers(dest="command", help="Available commands") + # Ask AI command - ask_ai_parser = subparsers.add_parser('ask-ai', help='Ask the AI assistant a question') - ask_ai_parser.add_argument('question', help='The question to ask the AI') - ask_ai_parser.add_argument('--verbose', action='store_true', help='Show detailed model logs') - ask_ai_parser.add_argument('--dangerous', action='store_true', help='Enable dangerous mode') - ask_ai_parser.add_argument('--confirm-danger', action='store_true', help='Confirm dangerous mode') - + ask_ai_parser = subparsers.add_parser( + "ask-ai", help="Ask the AI assistant a question" + ) + ask_ai_parser.add_argument("question", help="The question to ask the AI") + ask_ai_parser.add_argument( + "--verbose", action="store_true", help="Show detailed model logs" + ) + ask_ai_parser.add_argument( + "--dangerous", action="store_true", help="Enable dangerous mode" + ) + ask_ai_parser.add_argument( + "--confirm-danger", action="store_true", help="Confirm dangerous mode" + ) + # Generate tool command - generate_tool_parser = subparsers.add_parser('generate-tool', help='Generate a custom tool') - generate_tool_parser.add_argument('description', help='Description of the tool to generate') - generate_tool_parser.add_argument('--verbose', action='store_true', help='Show detailed generation logs') - + generate_tool_parser = subparsers.add_parser( + "generate-tool", help="Generate a custom tool" + ) + generate_tool_parser.add_argument( + "description", help="Description of the tool to generate" + ) + generate_tool_parser.add_argument( + "--verbose", action="store_true", help="Show detailed generation logs" + ) + # List tools command - list_tools_parser = subparsers.add_parser('list-tools', help='List all custom tools') - list_tools_parser.add_argument('--verbose', action='store_true', help='Show detailed tool information') + list_tools_parser = subparsers.add_parser( + "list-tools", help="List all custom tools" + ) + list_tools_parser.add_argument( + "--verbose", action="store_true", help="Show detailed tool information" + ) # Dark web OSINT command (Robin integration) darkweb_parser = subparsers.add_parser( - 'darkweb', help='Run the Robin dark web OSINT workflow' + "darkweb", help="Run the Robin dark web OSINT workflow" + ) + darkweb_parser.add_argument( + "--query", "-q", required=True, help="Dark web search query" ) - darkweb_parser.add_argument('--query', '-q', required=True, help='Dark web search query') darkweb_parser.add_argument( - '--model', - '-m', + "--model", + "-m", choices=darkweb_module.get_robin_model_choices(), default=darkweb_module.ROBIN_DEFAULT_MODEL, - help='LLM model to use for refinement/filtering', + help="LLM model to use for refinement/filtering", ) darkweb_parser.add_argument( - '--threads', - '-t', + "--threads", + "-t", type=int, default=5, - help='Number of concurrent requests for search/scrape', + help="Number of concurrent requests for search/scrape", ) darkweb_parser.add_argument( - '--output', - '-o', - help='Optional output file or directory for the markdown report', + "--output", + "-o", + help="Optional output file or directory for the markdown report", ) # Session management commands @@ -590,6 +645,7 @@ def get_parser(): args = parser.parse_args() return parser, args + async def _async_main(args): # Initialize NeuroRift vf = NeuroRift() @@ -608,23 +664,23 @@ async def _async_main(args): # Handle Session commands if args.command == "session": session_cli = SessionCLI(vf.session_manager) - if args.session_command == 'new': + if args.session_command == "new": session_cli.cmd_new(args) - elif args.session_command == 'save': + elif args.session_command == "save": session_cli.cmd_save(args) - elif args.session_command == 'list': + elif args.session_command == "list": session_cli.cmd_list(args) - elif args.session_command == 'load': + elif args.session_command == "load": session_cli.cmd_load(args) - elif args.session_command == 'resume': + elif args.session_command == "resume": session_cli.cmd_resume(args) - elif args.session_command == 'delete': + elif args.session_command == "delete": session_cli.cmd_delete(args) - elif args.session_command == 'rename': + elif args.session_command == "rename": session_cli.cmd_rename(args) - elif args.session_command == 'status': + elif args.session_command == "status": session_cli.cmd_status(args) - elif args.session_command == 'export': + elif args.session_command == "export": session_cli.cmd_export(args) return @@ -636,78 +692,86 @@ async def _async_main(args): if not vf.session_manager.current_session_id: session_name = f"Scan: {args.target}" if args.target else "New Operation" vf.session_manager.create_session( - name=session_name, - mode=args.mode or "offensive" + name=session_name, mode=args.mode or "offensive" ) # Handle Orchestrated Mode if args.orchestrated: - vf.console.print(Panel("[bold green]NeuroRift Orchestrated Intelligence Mode[/bold green]", style="bold blue")) - + vf.console.print( + Panel( + "[bold green]NeuroRift Orchestrated Intelligence Mode[/bold green]", + style="bold blue", + ) + ) + target = args.target if not target: # Try to get from session session = vf.session_manager.get_current_session() if session: target = session.get("task_state", {}).get("target") - + if not target: target = input("Enter target: ").strip() - + if not target: - print("Target required.") - return + print("Target required.") + return # Setup context - tool_mode = ToolMode.OFFENSIVE if args.mode == "offensive" else ToolMode.DEFENSIVE - + tool_mode = ( + ToolMode.OFFENSIVE if args.mode == "offensive" else ToolMode.DEFENSIVE + ) + # Ensure session exists if not vf.session_manager.current_session_id: - vf.session_manager.create_session(name=f"Assessment on {target}", mode=tool_mode.value) - + vf.session_manager.create_session( + name=f"Assessment on {target}", mode=tool_mode.value + ) + context = SessionContext( session_id=vf.session_manager.current_session_id, mode=tool_mode, - target=target + target=target, ) - + task_desc = f"Perform a {tool_mode.value} security assessment on {target}" vf.console.print(f"[bold]Task:[/bold] {task_desc}") - + # 1. Plan available_tools = vf.execution_manager.list_tools() vf.console.print("[bold blue]Planning execution...[/bold blue]") requests = await vf.planner.create_plan(task_desc, available_tools) - + if not requests: vf.console.print("[red]Failed to generate plan.[/red]") return - + vf.console.print(f"[green]Plan generated with {len(requests)} steps.[/green]") for i, req in enumerate(requests): - print(f"{i+1}. {req.tool_name} {req.args}") - - if input("\nApprove plan? (Y/n): ").lower() == 'n': + print(f"{i+1}. {req.tool_name} {req.args}") + + if input("\nApprove plan? (Y/n): ").lower() == "n": print("Aborted.") return # 2. Execute vf.console.print("\n[bold blue]Executing plan...[/bold blue]") results = await vf.operator.execute_plan(requests, context) - + # 3. Analyze vf.console.print("\n[bold blue]Analyzing results...[/bold blue]") findings = await vf.analyst.analyze_results(results) - + # 4. Report vf.console.print("\n[bold blue]Generating report...[/bold blue]") report = await vf.scribe.generate_report(task_desc, findings) - + # Save report report_path = vf.results_dir / f"report_{context.session_id}.md" with open(report_path, "w") as f: f.write(report) - + vf.console.print(Panel(report[:1000] + "\\n...", title="Report Preview")) vf.console.print(f"[green]Full report saved to {report_path}[/green]") return @@ -716,11 +780,11 @@ async def _async_main(args): if args.uninstall: script_dir = Path(__file__).parent uninstall_script = script_dir / "uninstall_script.sh" - + if not uninstall_script.exists(): print(f"Error: Uninstall script not found at {uninstall_script}") return - + try: subprocess.run([str(uninstall_script)], check=True) except subprocess.CalledProcessError as e: @@ -732,14 +796,16 @@ async def _async_main(args): # Handle AI Pipeline Mode if args.ai_pipeline: if not args.target: - print("Error: A target is required for AI pipeline mode, e.g., --target 'scan example.com'") + print( + "Error: A target is required for AI pipeline mode, e.g., --target 'scan example.com'" + ) return - + prompt_path = Path(args.prompt_dir) if not prompt_path.exists(): print(f"Error: Prompt directory not found at '{prompt_path}'") return - + orchestrator = AIOrchestrator(prompt_path) orchestrator.execute_task(f"Perform a security scan on {args.target}") return @@ -756,8 +822,10 @@ async def _async_main(args): vf.console.print("\n[bold cyan]--- Agentic Action Plan ---[/bold cyan]") vf.console.print(json.dumps(result, indent=2)) else: - vf.console.print("\n[bold yellow]Agentic mode enabled. Ready for instructions...[/bold yellow]") - + vf.console.print( + "\n[bold yellow]Agentic mode enabled. Ready for instructions...[/bold yellow]" + ) + # If we are not in web mode, we might want an interactive CLI loop here # For now, we'll just continue to respect other flags. @@ -782,9 +850,11 @@ async def _async_main(args): # Check if Robin is available if not darkweb_module.ROBIN_AVAILABLE: print("❌ Error: Robin module dependencies not installed.") - print("Install with: pip install langchain-core langchain-openai langchain-ollama") + print( + "Install with: pip install langchain-core langchain-openai langchain-ollama" + ) return - + darkweb_module.run_darkweb_osint( args.query, model=args.model, @@ -815,7 +885,7 @@ async def _async_main(args): "Use --target to specify a domain or IP address you own or have authorization to test" ) return - + # SECURITY: Validate target input if not validate_target(args.target): print(f"Error: Invalid target format: {args.target}") @@ -838,11 +908,11 @@ async def _async_main(args): # SECURITY: Create session directory with secure permissions timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + # SECURITY: Sanitize target for use in path safe_target = sanitize_filename(args.target) session_dir = vf.base_dir / "sessions" / safe_target / timestamp - + # SECURITY: Create with restricted permissions if not FilePermissionManager.create_secure_directory(session_dir, mode=0o700): print("Error: Failed to create secure session directory") @@ -897,36 +967,44 @@ async def _async_main(args): elif args.operation_mode == "scan": print(f"\n📡 Starting port scan on {args.target}") - + results = await vf.scan_module.run_scan(args.target, session_dir, use_ai=False) console = Console() console.print("\n[bold green]Scan Complete![/bold green]") console.print(f"Found {len(results['ports'])} open ports") - if results['ports']: + if results["ports"]: from rich.table import Table + table = Table(title=f"Open Ports on {args.target}") table.add_column("Port", style="cyan") table.add_column("State", style="green") table.add_column("Service", style="magenta") table.add_column("Version", style="yellow") - for p in results['ports']: + for p in results["ports"]: version = f"{p['product']} {p['version']}".strip() or "N/A" - table.add_row(f"{p['number']}/{p['protocol']}", p['state'], p['service'], version) - + table.add_row( + f"{p['number']}/{p['protocol']}", p["state"], p["service"], version + ) + console.print(table) # AI Analysis Interaction if not args.no_ai: - if input("\n🤖 Do you want an AI analysis of these results? (y/N): ").lower() == 'y': + if ( + input( + "\n🤖 Do you want an AI analysis of these results? (y/N): " + ).lower() + == "y" + ): console.print("\n[bold cyan]Generating AI Analysis...[/bold cyan]") # We reuse AIAnalyzer.analyze_nmap_output via ScanModule - nmap_str = vf.scan_module._format_nmap_results(results['ports']) + nmap_str = vf.scan_module._format_nmap_results(results["ports"]) analysis = await vf.ai_analyzer.analyze_nmap_output(nmap_str) results["ai_analysis"] = analysis - + # Update saved results vf.scan_module._save_results(results, session_dir) @@ -934,11 +1012,17 @@ async def _async_main(args): console.print("\n[bold cyan]AI Security Insights:[/bold cyan]") if isinstance(analysis, dict): if "summary" in analysis: - console.print(f"\n[bold]Summary:[/bold]\n{analysis['summary']}") + console.print( + f"\n[bold]Summary:[/bold]\n{analysis['summary']}" + ) if "potential_vulnerabilities" in analysis: - console.print("\n[bold yellow]Potential Vulnerabilities:[/bold yellow]") + console.print( + "\n[bold yellow]Potential Vulnerabilities:[/bold yellow]" + ) for v in analysis["potential_vulnerabilities"]: - console.print(f"- [bold]{v.get('type')}[/bold]: {v.get('description')} (Severity: {v.get('severity')})") + console.print( + f"- [bold]{v.get('type')}[/bold]: {v.get('description')} (Severity: {v.get('severity')})" + ) else: console.print(analysis) else: @@ -950,9 +1034,11 @@ async def _async_main(args): print(f"\n🌐 Starting web discovery on {args.target}") if args.ai_only: print("Running in AI-only mode - AI will make all decisions") - + # Run discovery without AI initially to allow for interactive prompt at the end - results = await vf.web_module.run_web_discovery(args.target, session_dir, use_ai=False) + results = await vf.web_module.run_web_discovery( + args.target, session_dir, use_ai=False + ) # Display summary console = Console() @@ -962,11 +1048,16 @@ async def _async_main(args): # Interactive AI Analysis Prompt if not args.no_ai: - if input("\n🤖 Do you want an AI analysis of these results? (y/N): ").lower() == 'y': + if ( + input( + "\n🤖 Do you want an AI analysis of these results? (y/N): " + ).lower() + == "y" + ): console.print("\n[bold cyan]Generating AI Analysis...[/bold cyan]") analysis = await vf.web_module.analyze_with_ai(args.target, results) results["ai_analysis"] = analysis - + # Update saved results with AI analysis vf.web_module.save_results(results, session_dir) @@ -974,10 +1065,12 @@ async def _async_main(args): console.print("\n[bold cyan]AI Analysis Summary:[/bold cyan]") st = analysis.get("tech_stack_assessment", "N/A") console.print(f"[bold]Tech Stack:[/bold] {st}") - + interesting = analysis.get("interesting_findings", []) if interesting: - console.print("\n[bold yellow]Interesting Findings:[/bold yellow]") + console.print( + "\n[bold yellow]Interesting Findings:[/bold yellow]" + ) for finding in interesting: console.print(f"- {finding}") else: @@ -991,45 +1084,56 @@ async def _async_main(args): elif args.operation_mode == "exploit": print(f"\n💥 Starting exploitation on {args.target}") - + # To run exploit mode, we need recon data # Let's check if there's a recent recon scan for this target recon_data = {} recon_results_path = session_dir / "recon_results.json" web_results_path = session_dir / "web_discovery_results.json" - + if recon_results_path.exists(): - with open(recon_results_path, 'r') as f: + with open(recon_results_path, "r") as f: recon_data = json.load(f) elif web_results_path.exists(): - with open(web_results_path, 'r') as f: + with open(web_results_path, "r") as f: web_data = json.load(f) # Map web data to a format exploit module understands recon_data = { "target": args.target, - "services": [{"name": tech.get("raw"), "version": ""} for tech in web_data.get("technologies", [])] + "services": [ + {"name": tech.get("raw"), "version": ""} + for tech in web_data.get("technologies", []) + ], } else: - print("[yellow]No reconnaissance data found for this target in the current session.[/yellow]") + print( + "[yellow]No reconnaissance data found for this target in the current session.[/yellow]" + ) print("Exploit mode works best when preceded by 'recon' or 'web' mode.") # We can still try with basic info if provided recon_data = {"target": args.target, "services": []} - results = await vf.exploit_module.run_exploit_pipeline(args.target, recon_data, session_dir, use_ai=not args.no_ai) + results = await vf.exploit_module.run_exploit_pipeline( + args.target, recon_data, session_dir, use_ai=not args.no_ai + ) console = Console() console.print("\n[bold green]Exploit Pipeline Complete![/bold green]") - console.print(f"Mapped {len(results['vulnerabilities'])} potential vulnerabilities") + console.print( + f"Mapped {len(results['vulnerabilities'])} potential vulnerabilities" + ) console.print(f"Generated {len(results['exploits'])} exploits") - if results['exploits']: + if results["exploits"]: console.print("\n[bold cyan]Generated Exploits:[/bold cyan]") - for exploit in results['exploits']: + for exploit in results["exploits"]: if "error" not in exploit: console.print(f"- [green]{exploit.get('file_path')}[/green]") if exploit.get("validation", {}).get("issues"): - console.print(f" [yellow]Validation Issues:[/yellow] {', '.join(exploit['validation']['issues'])}") - + console.print( + f" [yellow]Validation Issues:[/yellow] {', '.join(exploit['validation']['issues'])}" + ) + # Start dev mode shell if requested if args.dev_mode: await dev_mode_shell(vf, session_dir) @@ -1043,39 +1147,44 @@ def main(): if args.webmod: # Check if Robin module is available (Optional now) if not darkweb_module.ROBIN_AVAILABLE: - print("⚠️ Warning: Robin module dependencies not installed. Dark Web OSINT features may be unavailable.") + print( + "⚠️ Warning: Robin module dependencies not installed. Dark Web OSINT features may be unavailable." + ) - print("🌐 Launching NeuroRift Web Interface...") print(f"📍 Access the UI at: http://{args.web_host}:{args.web_port}") print("⚠️ Press Ctrl+C to stop the server\n") - + try: # Import streamlit CLI from streamlit.web import cli as stcli import sys - + # Get the UI file path ui_file = Path(__file__).parent / "modules" / "web" / "dashboard.py" - + if not ui_file.exists(): print(f"❌ Error: Web UI file not found at {ui_file}") return - + # Prepare streamlit arguments sys.argv = [ "streamlit", "run", str(ui_file), - "--server.port", str(args.web_port), - "--server.address", args.web_host, - "--server.headless", "true", - "--browser.gatherUsageStats", "false" + "--server.port", + str(args.web_port), + "--server.address", + args.web_host, + "--server.headless", + "true", + "--browser.gatherUsageStats", + "false", ] - + # Launch streamlit sys.exit(stcli.main()) - + except ImportError: print("❌ Error: Streamlit is not installed.") return @@ -1086,5 +1195,6 @@ def main(): # Run the main async pipeline asyncio.run(_async_main(args)) + if __name__ == "__main__": main() diff --git a/screen_control/launcher.py b/screen_control/launcher.py index 6a61485..b4ff341 100644 --- a/screen_control/launcher.py +++ b/screen_control/launcher.py @@ -8,21 +8,31 @@ # Ensure you have build-essential, g++, rustc, cargo, and nasm installed. # sudo apt-get install build-essential g++ rustc cargo nasm xdotool + def build_modules(): """Compiles all native C++, Rust, and Assembly modules.""" print("--- Building native modules ---") - + # Build C++ print("Building C++ module...") try: - subprocess.run([ - "g++", "-shared", "-fPIC", "-o", "cpp/screen.so", - "cpp/screen.cpp", "-lX11", "-lXext" - ], check=True) + subprocess.run( + [ + "g++", + "-shared", + "-fPIC", + "-o", + "cpp/screen.so", + "cpp/screen.cpp", + "-lX11", + "-lXext", + ], + check=True, + ) print("✓ C++ module built successfully") except (subprocess.CalledProcessError, FileNotFoundError) as e: print(f"✗ C++ build failed: {e}") - + # Build Rust print("Building Rust module...") try: @@ -36,66 +46,76 @@ def build_modules(): print("Building Assembly module...") try: # SECURITY FIX: Use full executable paths and proper argument lists - subprocess.run([ - "/usr/bin/nasm", "-f", "elf64", "assembly/hook.asm", - "-o", "assembly/hook.o" - ], check=True) - subprocess.run([ - "/usr/bin/ld", "-shared", "-o", "assembly/hook.so", "assembly/hook.o" - ], check=True) + subprocess.run( + [ + "/usr/bin/nasm", + "-f", + "elf64", + "assembly/hook.asm", + "-o", + "assembly/hook.o", + ], + check=True, + ) + subprocess.run( + ["/usr/bin/ld", "-shared", "-o", "assembly/hook.so", "assembly/hook.o"], + check=True, + ) print("✓ Assembly module built successfully") except (subprocess.CalledProcessError, FileNotFoundError) as e: print(f"✗ Assembly build failed: {e}") - + print("--- Build complete ---") + def run_demo(): """Runs a demonstration of the screen control system.""" print("--- Screen Control Demo ---") - + # Initialize screen control with fallback support try: screen = ScreenControl(".") print("✓ Screen control initialized") - + # Test basic functionality print("Testing mouse movement...") screen.move_mouse(100, 100) - + print("Testing text input...") screen.type_text("Hello from NeuroRift!") - + print("Testing scroll...") screen.scroll(1) - + print("Testing offset calculation...") result = screen.calculate_offset(100, 50) print(f"Offset calculation result: {result}") - + # Test JSON command execution commands = [ {"type": "move_mouse", "x": 200, "y": 200}, {"type": "click", "button": 1}, {"type": "type", "text": "AI-controlled screen interaction"}, {"type": "wait", "seconds": 1}, - {"type": "scroll", "direction": -1} + {"type": "scroll", "direction": -1}, ] - + print("Testing command sequence...") screen.run_sequence(commands) - + print("✓ Demo completed successfully") - + except Exception as e: print(f"✗ Demo failed: {e}") print("Running in fallback mode with basic functionality") + if __name__ == "__main__": # Try to build modules, but continue if it fails try: build_modules() except Exception as e: print(f"Build failed, continuing with fallback: {e}") - + # Run demo - run_demo() \ No newline at end of file + run_demo() diff --git a/scripts/analyze_results.py b/scripts/analyze_results.py index 0d0bcd9..7bc4cc7 100755 --- a/scripts/analyze_results.py +++ b/scripts/analyze_results.py @@ -15,89 +15,94 @@ # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + class ResultAnalyzer: def __init__(self, base_dir: Path): self.base_dir = base_dir self.logger = logging.getLogger(__name__) - + # Initialize components with proper config path config_path = base_dir / "config" / "config.json" self.config = ConfigManager(config_path) self.llm = LLMEngine() self.context = ContextBuilder(base_dir) self.notifier = Notifier(self.config) - + async def analyze_scan_results(self, results_file: Path): """Analyze scan results using AI""" try: # Load scan results with open(results_file) as f: results = json.load(f) - + # Add scan results to context self.context.add_scan_result("recon", results) - + # Build analysis prompt prompt = self.context.build_prompt( "recon", { "target_info": json.dumps(results["target"], indent=2), - "context": self.context._format_context() - } + "context": self.context._format_context(), + }, ) - + # Get AI analysis analysis = self.llm.query( prompt=prompt, - system_prompt="You are a cybersecurity expert analyzing reconnaissance data." + system_prompt="You are a cybersecurity expert analyzing reconnaissance data.", ) - + if not analysis: logger.error("Failed to get AI analysis") return - + # Parse analysis try: analysis_data = json.loads(analysis) - + # Check for critical findings if analysis_data.get("critical_findings"): await self.notifier.notify( "Critical findings detected in scan", "critical", - data=analysis_data["critical_findings"] + data=analysis_data["critical_findings"], ) - + # Save analysis - output_path = self.base_dir / "data" / "analysis" / f"analysis_{Path(results_file).stem}.json" + output_path = ( + self.base_dir + / "data" + / "analysis" + / f"analysis_{Path(results_file).stem}.json" + ) output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(output_path, 'w') as f: + + with open(output_path, "w") as f: json.dump(analysis_data, f, indent=2) - + logger.info(f"Analysis saved to {output_path}") - + # Generate exploit if critical vulnerability found if analysis_data.get("critical_findings"): await self.generate_exploits(analysis_data) - + except json.JSONDecodeError as e: logger.error(f"Failed to parse AI analysis: {e}") - + except Exception as e: logger.error(f"Error analyzing results: {e}") - + async def generate_exploits(self, analysis_data: Dict): """Generate exploits for critical findings""" from modules.exploit_generator.exploit_generator import ExploitGenerator - + generator = ExploitGenerator(self.base_dir, self.llm) - + for finding in analysis_data.get("critical_findings", []): try: # Generate exploit @@ -106,24 +111,26 @@ async def generate_exploits(self, analysis_data: Dict): "cve_id": finding.get("cve_id", "Unknown"), "description": finding.get("description", ""), "affected_software": finding.get("affected_component", ""), - "cvss_score": finding.get("severity", "Unknown") + "cvss_score": finding.get("severity", "Unknown"), }, recon_data={ "ip": finding.get("location", ""), "port": finding.get("port", ""), "service": finding.get("service", ""), - "version": finding.get("version", "") - } + "version": finding.get("version", ""), + }, ) - + if exploit_data and "error" not in exploit_data: # Validate exploit validation = generator.validate_exploit(exploit_data) - + if validation["syntax_valid"] and validation["has_error_handling"]: # Generate Metasploit module - metasploit_module = generator.generate_metasploit_module(exploit_data) - + metasploit_module = generator.generate_metasploit_module( + exploit_data + ) + # Notify about exploit generation await self.notifier.notify( f"Exploit generated for {finding.get('type', 'Unknown')} vulnerability", @@ -131,29 +138,36 @@ async def generate_exploits(self, analysis_data: Dict): data={ "finding": finding, "exploit_path": exploit_data.get("file_path"), - "metasploit_module": str(metasploit_module) if metasploit_module else None, - "validation": validation - } + "metasploit_module": ( + str(metasploit_module) + if metasploit_module + else None + ), + "validation": validation, + }, ) - + except Exception as e: logger.error(f"Error generating exploit for finding: {e}") - + + async def main(): base_dir = Path(__file__).parent.parent analyzer = ResultAnalyzer(base_dir) - + # Find latest scan results results_dir = base_dir / "data" / "scan_results" if not results_dir.exists(): logger.error("No scan results found") return - - latest_result = max(results_dir.glob("recon_*.json"), key=lambda p: p.stat().st_mtime) - + + latest_result = max( + results_dir.glob("recon_*.json"), key=lambda p: p.stat().st_mtime + ) + # Start notifier await analyzer.notifier.start() - + try: # Analyze results await analyzer.analyze_scan_results(latest_result) @@ -161,6 +175,8 @@ async def main(): # Stop notifier await analyzer.notifier.stop() + if __name__ == "__main__": import asyncio - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/scripts/performance_analysis.py b/scripts/performance_analysis.py index 715490a..4f13db6 100644 --- a/scripts/performance_analysis.py +++ b/scripts/performance_analysis.py @@ -23,159 +23,164 @@ console = Console() + class PerformanceAnalyzer: def __init__(self): self.console = Console() self.results = {} - + def analyze_ai_reasoning(self) -> Dict[str, Any]: """Analyze AI reasoning capabilities - planning vs guessing""" console.print("[bold blue]🔍 Analyzing AI Reasoning Capabilities[/bold blue]") - + analysis = { "planning_quality": {}, "tool_selection_accuracy": {}, "execution_success_rate": {}, - "analysis_depth": {} + "analysis_depth": {}, } - + # Test cases for AI reasoning test_tasks = [ "Perform a port scan on example.com", "Find subdomains of test.com", "Generate a web vulnerability scanner", - "Analyze the security of a web application" + "Analyze the security of a web application", ] - + for task in test_tasks: console.print(f"Testing AI reasoning for: {task}") - + # Measure planning time and quality start_time = time.time() # TODO: Implement actual AI testing planning_time = time.time() - start_time - + analysis["planning_quality"][task] = { "time": planning_time, "complexity_score": self._assess_planning_complexity(task), - "specificity_score": self._assess_planning_specificity(task) + "specificity_score": self._assess_planning_specificity(task), } - + return analysis - + def audit_native_modules(self) -> Dict[str, Any]: """Audit C++, Rust, and Assembly code for safety and efficiency""" console.print("[bold blue]🛡️ Auditing Native Modules[/bold blue]") - + audit_results = { "cpp_safety": {}, "rust_safety": {}, "assembly_safety": {}, "performance_benchmarks": {}, - "security_vulnerabilities": [] + "security_vulnerabilities": [], } - + # C++ Module Analysis cpp_issues = self._audit_cpp_module() audit_results["cpp_safety"] = cpp_issues - + # Rust Module Analysis rust_issues = self._audit_rust_module() audit_results["rust_safety"] = rust_issues - + # Assembly Module Analysis asm_issues = self._audit_assembly_module() audit_results["assembly_safety"] = asm_issues - + # Performance Benchmarks perf_results = self._benchmark_native_modules() audit_results["performance_benchmarks"] = perf_results - + return audit_results - + def analyze_tool_generation(self) -> Dict[str, Any]: """Analyze custom tool generation quality and capabilities""" console.print("[bold blue]🔧 Analyzing Custom Tool Generation[/bold blue]") - + analysis = { "generation_success_rate": 0, "tool_types_generated": {}, "code_quality_metrics": {}, "security_analysis": {}, - "execution_success_rate": 0 + "execution_success_rate": 0, } - + # Test tool generation requests test_requests = [ "Create a port scanner", "Generate a web crawler", "Build a DNS enumeration tool", "Create a vulnerability scanner", - "Generate a password cracker" + "Generate a password cracker", ] - + successful_generations = 0 tool_types = {} - + for request in test_requests: console.print(f"Testing tool generation: {request}") - + # TODO: Implement actual tool generation testing generation_result = self._test_tool_generation(request) - + if generation_result["success"]: successful_generations += 1 tool_type = generation_result["type"] tool_types[tool_type] = tool_types.get(tool_type, 0) + 1 - - analysis["generation_success_rate"] = successful_generations / len(test_requests) + + analysis["generation_success_rate"] = successful_generations / len( + test_requests + ) analysis["tool_types_generated"] = tool_types - + return analysis - + def benchmark_recon_speed(self) -> Dict[str, Any]: """Benchmark reconnaissance speed and measure the "3x faster" claim""" console.print("[bold blue]⚡ Benchmarking Reconnaissance Speed[/bold blue]") - + benchmarks = { "subdomain_discovery": {}, "port_scanning": {}, "vulnerability_scanning": {}, "overall_performance": {}, - "native_vs_python": {} + "native_vs_python": {}, } - + test_targets = ["example.com", "test.com", "demo.com"] - + for target in test_targets: console.print(f"Benchmarking recon for: {target}") - + # Test with native modules native_times = self._benchmark_native_recon(target) - + # Test with Python fallback python_times = self._benchmark_python_recon(target) - + # Calculate speedup speedup = {} for operation in native_times: if operation in python_times: - speedup[operation] = python_times[operation] / native_times[operation] - + speedup[operation] = ( + python_times[operation] / native_times[operation] + ) + benchmarks["native_vs_python"][target] = speedup - + return benchmarks - + def _assess_planning_complexity(self, task: str) -> float: """Assess the complexity of AI-generated plans""" # TODO: Implement complexity scoring return 0.75 # Placeholder - + def _assess_planning_specificity(self, task: str) -> float: """Assess the specificity of AI-generated plans""" # TODO: Implement specificity scoring return 0.80 # Placeholder - + def _audit_cpp_module(self) -> Dict[str, Any]: """Audit C++ module for safety issues""" issues = { @@ -183,9 +188,9 @@ def _audit_cpp_module(self) -> Dict[str, Any]: "null_pointer_derefs": [], "buffer_overflows": [], "race_conditions": [], - "security_score": 0.85 + "security_score": 0.85, } - + # Analyze screen.cpp cpp_code = """ #include @@ -203,65 +208,61 @@ def _audit_cpp_module(self) -> Dict[str, Any]: } } """ - + # Check for potential issues if "XOpenDisplay(NULL)" in cpp_code: issues["null_pointer_derefs"].append("Potential NULL display handling") - + if "XCloseDisplay" in cpp_code: issues["memory_leaks"].append("Display cleanup looks good") - + return issues - + def _audit_rust_module(self) -> Dict[str, Any]: """Audit Rust module for safety issues""" issues = { "memory_safety": [], "thread_safety": [], "unsafe_blocks": [], - "security_score": 0.95 + "security_score": 0.95, } - + # Rust is generally safer issues["memory_safety"].append("Rust provides memory safety guarantees") issues["unsafe_blocks"].append("Minimal unsafe code usage") - + return issues - + def _audit_assembly_module(self) -> Dict[str, Any]: """Audit Assembly module for safety issues""" - issues = { - "register_usage": [], - "stack_management": [], - "security_score": 0.70 - } - + issues = {"register_usage": [], "stack_management": [], "security_score": 0.70} + # Assembly is inherently less safe issues["register_usage"].append("Simple register operations") issues["stack_management"].append("No stack manipulation - safe") - + return issues - + def _benchmark_native_modules(self) -> Dict[str, float]: """Benchmark native module performance""" benchmarks = {} - + # Test C++ mouse movement start_time = time.time() for _ in range(1000): # TODO: Call actual C++ function pass benchmarks["cpp_mouse_movement"] = time.time() - start_time - + # Test Rust text input start_time = time.time() for _ in range(1000): # TODO: Call actual Rust function pass benchmarks["rust_text_input"] = time.time() - start_time - + return benchmarks - + def _test_tool_generation(self, request: str) -> Dict[str, Any]: """Test tool generation for a specific request""" # TODO: Implement actual tool generation testing @@ -269,59 +270,63 @@ def _test_tool_generation(self, request: str) -> Dict[str, Any]: "success": True, "type": "python_script", "execution_success": True, - "security_score": 0.80 + "security_score": 0.80, } - + def _benchmark_native_recon(self, target: str) -> Dict[str, float]: """Benchmark reconnaissance with native modules""" times = {} - + # Simulate native module performance times["subdomain_discovery"] = 2.5 times["port_scanning"] = 1.8 times["vulnerability_scanning"] = 3.2 - + return times - + def _benchmark_python_recon(self, target: str) -> Dict[str, float]: """Benchmark reconnaissance with Python fallback""" times = {} - + # Simulate Python fallback performance times["subdomain_discovery"] = 7.5 times["port_scanning"] = 5.4 times["vulnerability_scanning"] = 9.6 - + return times - + def generate_report(self) -> str: """Generate comprehensive performance report""" console.print("[bold green]📊 Generating Performance Report[/bold green]") - + # Run all analyses ai_analysis = self.analyze_ai_reasoning() native_audit = self.audit_native_modules() tool_analysis = self.analyze_tool_generation() speed_benchmarks = self.benchmark_recon_speed() - + # Compile results report = { "ai_reasoning_analysis": ai_analysis, "native_module_audit": native_audit, "tool_generation_analysis": tool_analysis, "speed_benchmarks": speed_benchmarks, - "summary": self._generate_summary(ai_analysis, native_audit, tool_analysis, speed_benchmarks) + "summary": self._generate_summary( + ai_analysis, native_audit, tool_analysis, speed_benchmarks + ), } - + # Save report report_path = Path("performance_report.json") - with open(report_path, 'w') as f: + with open(report_path, "w") as f: json.dump(report, f, indent=2) - + console.print(f"[green]Report saved to: {report_path}[/green]") return str(report_path) - - def _generate_summary(self, ai_analysis, native_audit, tool_analysis, speed_benchmarks): + + def _generate_summary( + self, ai_analysis, native_audit, tool_analysis, speed_benchmarks + ): """Generate executive summary of all analyses""" summary = { "ai_reasoning_quality": "Good planning capabilities with room for improvement", @@ -332,22 +337,23 @@ def _generate_summary(self, ai_analysis, native_audit, tool_analysis, speed_benc "Enhance AI planning with more specific prompts", "Add more comprehensive error handling to native modules", "Improve tool generation validation", - "Implement more detailed performance monitoring" - ] + "Implement more detailed performance monitoring", + ], } - + return summary + def main(): """Main function to run comprehensive performance analysis""" analyzer = PerformanceAnalyzer() - + console.print("[bold yellow]🚀 NeuroRift Performance Analysis[/bold yellow]") console.print("=" * 60) - + # Run analysis report_path = analyzer.generate_report() - + console.print("\n[bold green]✅ Analysis Complete![/bold green]") console.print(f"📄 Full report: {report_path}") console.print("\n[bold blue]Key Findings:[/bold blue]") @@ -356,9 +362,6 @@ def main(): console.print("• Tool generation has good success rate") console.print("• Performance improvements are measurable") + if __name__ == "__main__": main() - - - - diff --git a/scripts/recon_scan.py b/scripts/recon_scan.py index 97fe6e9..a3e2c0a 100755 --- a/scripts/recon_scan.py +++ b/scripts/recon_scan.py @@ -16,124 +16,130 @@ # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + class ReconScanner: def __init__(self, config_path: Path): self.config = self._load_config(config_path) self.results = { "scan_time": datetime.now().isoformat(), "target": self.config["target"], - "findings": [] + "findings": [], } - + def _load_config(self, config_path: Path) -> Dict: """Load scan configuration""" with open(config_path) as f: return json.load(f) - + async def run_nmap_scan(self): """Run Nmap scan""" try: - cmd = ["nmap"] + self.config["tools"]["nmap"]["flags"] + [self.config["target"]["domain"]] + cmd = ( + ["nmap"] + + self.config["tools"]["nmap"]["flags"] + + [self.config["target"]["domain"]] + ) process = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() - + if process.returncode == 0: - self.results["findings"].append({ - "tool": "nmap", - "output": stdout.decode(), - "timestamp": datetime.now().isoformat() - }) + self.results["findings"].append( + { + "tool": "nmap", + "output": stdout.decode(), + "timestamp": datetime.now().isoformat(), + } + ) else: logger.error("Nmap scan failed: %s", stderr.decode()) - + except Exception as e: logger.error("Error running Nmap scan: %s", e) - + async def run_subfinder(self): """Run Subfinder for subdomain enumeration""" try: cmd = ["subfinder", "-d", self.config["target"]["domain"], "-silent"] process = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() - + if process.returncode == 0: subdomains = stdout.decode().splitlines() - self.results["findings"].append({ - "tool": "subfinder", - "subdomains": subdomains, - "timestamp": datetime.now().isoformat() - }) + self.results["findings"].append( + { + "tool": "subfinder", + "subdomains": subdomains, + "timestamp": datetime.now().isoformat(), + } + ) else: logger.error("Subfinder failed: %s", stderr.decode()) - + except Exception as e: logger.error("Error running Subfinder: %s", e) - + async def run_httpx(self, urls: List[str]): """Run httpx for HTTP probing""" try: # Write URLs to temporary file temp_file = Path("temp_urls.txt") temp_file.write_text("\n".join(urls)) - + cmd = ["httpx", "-l", str(temp_file), "-silent", "-status-code", "-title"] process = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() - + if process.returncode == 0: - self.results["findings"].append({ - "tool": "httpx", - "output": stdout.decode(), - "timestamp": datetime.now().isoformat() - }) + self.results["findings"].append( + { + "tool": "httpx", + "output": stdout.decode(), + "timestamp": datetime.now().isoformat(), + } + ) else: logger.error("httpx failed: %s", stderr.decode()) - + # Clean up temp file temp_file.unlink() - + except Exception as e: logger.error("Error running httpx: %s", e) - + async def run_nuclei(self, urls: List[str]): """Run Nuclei for vulnerability scanning""" # SECURITY FIX: Use tempfile module instead of hardcoded temp paths - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file: - temp_file.write('\n'.join(urls)) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as temp_file: + temp_file.write("\n".join(urls)) temp_file_path = temp_file.name - + try: cmd = ["nuclei", "-l", temp_file_path, "-severity", "critical,high,medium"] process = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() - + if process.returncode == 0: - self.results["findings"].append({ - "tool": "nuclei", - "output": stdout.decode(), - "timestamp": datetime.now().isoformat() - }) + self.results["findings"].append( + { + "tool": "nuclei", + "output": stdout.decode(), + "timestamp": datetime.now().isoformat(), + } + ) else: logger.error("Nuclei failed: %s", stderr.decode()) except Exception as e: @@ -144,47 +150,52 @@ async def run_nuclei(self, urls: List[str]): os.unlink(temp_file_path) except OSError: pass # File may already be deleted - + def save_results(self, output_path: Path): """Save scan results to file""" - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(self.results, f, indent=2) - + + async def main(): # Initialize scanner scanner = ReconScanner(Path("configs/scan_config.json")) - + # Run scans logger.info("Starting reconnaissance scan...") - + # Run Nmap scan logger.info("Running Nmap scan...") await scanner.run_nmap_scan() - + # Run Subfinder logger.info("Running Subfinder...") await scanner.run_subfinder() - + # Get all discovered URLs urls = [scanner.config["target"]["url"]] for finding in scanner.results["findings"]: if finding["tool"] == "subfinder": urls.extend([f"https://{subdomain}" for subdomain in finding["subdomains"]]) - + # Run httpx logger.info("Running httpx...") await scanner.run_httpx(urls) - + # Run Nuclei logger.info("Running Nuclei...") await scanner.run_nuclei(urls) - + # Save results - output_path = Path("data/scan_results") / f"recon_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + output_path = ( + Path("data/scan_results") + / f"recon_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + ) output_path.parent.mkdir(parents=True, exist_ok=True) scanner.save_results(output_path) - + logger.info("Scan completed. Results saved to %s", output_path) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/setup.py b/setup.py index bb0bf16..a75ab5e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,16 @@ setup( name="neurorift", version="1.0.0", - packages=find_packages(include=['modules', 'modules.*', 'utils', 'utils.*', 'ai_wrapper', 'ai_wrapper.*']), + packages=find_packages( + include=[ + "modules", + "modules.*", + "utils", + "utils.*", + "ai_wrapper", + "ai_wrapper.*", + ] + ), py_modules=[ # Top-level modules used by the CLI "neurorift_main", @@ -13,9 +22,9 @@ "ai_controller", ], package_data={ - 'modules': ['**/*.json', '**/*.md'], - 'prompts': ['**/*.md', '**/*.txt'], - 'configs': ['*.json'], + "modules": ["**/*.json", "**/*.md"], + "prompts": ["**/*.md", "**/*.txt"], + "configs": ["*.json"], }, include_package_data=True, install_requires=[ @@ -76,4 +85,4 @@ ], python_requires=">=3.10", keywords="security penetration-testing bug-bounty ai multi-agent reconnaissance vulnerability-scanning", -) \ No newline at end of file +) diff --git a/tests/test_ai_features.py b/tests/test_ai_features.py index d8e6925..cfa1bf5 100644 --- a/tests/test_ai_features.py +++ b/tests/test_ai_features.py @@ -8,6 +8,7 @@ from ai_controller import AIController from utils.report_generator import ReportGenerator + class TestAIFeatures: """Test suite for AI features.""" @@ -16,24 +17,25 @@ def setup(self, tmp_path): """Setup test environment.""" self.test_dir = tmp_path / "test_session" self.test_dir.mkdir(parents=True) - + # Create test config self.config_path = self.test_dir / "test_config.json" with open(self.config_path, "w") as f: - json.dump({ - "ai": { - "model": "deepseek-coder", - "temperature": 0.7, - "max_tokens": 1000 + json.dump( + { + "ai": { + "model": "deepseek-coder", + "temperature": 0.7, + "max_tokens": 1000, + }, + "notifications": {"enabled": False}, }, - "notifications": { - "enabled": False - } - }, f) - + f, + ) + # Initialize AI controller self.ai_controller = AIController(str(self.test_dir), str(self.config_path)) - + # Initialize report generator self.report_generator = ReportGenerator(str(self.test_dir)) @@ -43,20 +45,20 @@ async def test_ai_query_processing(self): # Test basic query query = "What should I do if port 8080 is open?" response = await self.ai_controller.process_query(query) - + assert "answer" in response assert "logs" in response assert "prompt" in response assert isinstance(response["answer"], str) - + # Test query with context context = { "open_ports": [8080], "services": {"8080": "http-alt"}, - "previous_findings": ["Potential web server detected"] + "previous_findings": ["Potential web server detected"], } response = await self.ai_controller.process_query(query, context) - + assert "answer" in response assert "logs" in response assert "prompt" in response @@ -70,7 +72,7 @@ def test_report_generation(self): "domain": "example.com", "ip": "93.184.216.34", "scan_time": datetime.now().isoformat(), - "duration": "00:05:23" + "duration": "00:05:23", }, "modules": [ { @@ -78,7 +80,7 @@ def test_report_generation(self): "status": "success", "start_time": "2024-03-20T10:00:00", "end_time": "2024-03-20T10:05:00", - "findings": ["Open port 80", "Open port 443"] + "findings": ["Open port 80", "Open port 443"], } ], "tools": { @@ -86,7 +88,7 @@ def test_report_generation(self): "status": "success", "purpose": "Port scanning", "config": "-sS -sV", - "installed": True + "installed": True, } }, "vulnerabilities": [ @@ -99,9 +101,9 @@ def test_report_generation(self): "references": [ { "title": "CVE-2024-1234", - "url": "https://nvd.nist.gov/vuln/detail/CVE-2024-1234" + "url": "https://nvd.nist.gov/vuln/detail/CVE-2024-1234", } - ] + ], } ], "exploits": [ @@ -109,54 +111,54 @@ def test_report_generation(self): "name": "Test Exploit", "status": "success", "command": "test_command", - "output": "Test output" + "output": "Test output", } ], "defensive_measures": [ { "name": "WAF Detection", "description": "Web Application Firewall detected", - "details": "Cloudflare WAF" + "details": "Cloudflare WAF", } ], "recommendations": [ { "title": "Test Recommendation", "description": "Test recommendation description", - "steps": ["Step 1", "Step 2"] + "steps": ["Step 1", "Step 2"], } ], "errors": [ { "timestamp": datetime.now().isoformat(), "message": "Test error", - "context": "Test error context" + "context": "Test error context", } - ] + ], } - + # Generate reports report_paths = self.report_generator.generate_reports(context) - + # Verify report files exist assert "html" in report_paths assert "markdown" in report_paths assert "json" in report_paths - + # Verify HTML report html_path = Path(report_paths["html"]) assert html_path.exists() html_content = html_path.read_text() assert "NeuroRift Report" in html_content assert "example.com" in html_content - + # Verify Markdown report md_path = Path(report_paths["markdown"]) assert md_path.exists() md_content = md_path.read_text() assert "# NeuroRift Report" in md_content assert "example.com" in md_content - + # Verify JSON report json_path = Path(report_paths["json"]) assert json_path.exists() @@ -168,22 +170,20 @@ def test_session_data_management(self): # Add test data self.ai_controller.session_data["target_domain"] = "example.com" self.ai_controller.session_data["target_ip"] = "93.184.216.34" - self.ai_controller.session_data["modules"].append({ - "name": "test_module", - "status": "success" - }) - self.ai_controller.session_data["findings"].append({ - "type": "vulnerability", - "severity": "high" - }) - + self.ai_controller.session_data["modules"].append( + {"name": "test_module", "status": "success"} + ) + self.ai_controller.session_data["findings"].append( + {"type": "vulnerability", "severity": "high"} + ) + # Generate report report_paths = self.ai_controller._generate_report() - + # Verify report contains session data json_path = Path(report_paths["json"]) json_content = json.loads(json_path.read_text()) - + assert json_content["target"]["domain"] == "example.com" assert json_content["target"]["ip"] == "93.184.216.34" assert len(json_content["modules"]) > 0 @@ -194,15 +194,16 @@ def test_error_handling(self): # Test invalid query with pytest.raises(Exception): asyncio.run(self.ai_controller.process_query(None)) - + # Test invalid context with pytest.raises(Exception): self.report_generator.generate_reports(None) - + # Test invalid session data self.ai_controller.session_data = None with pytest.raises(Exception): self.ai_controller._generate_report() + if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/test_security.py b/tests/test_security.py index 2da1b61..87d97a4 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -15,66 +15,76 @@ FilePermissionManager, TokenGenerator, validate_target, - sanitize_filename + sanitize_filename, ) from utils.auth import AuthManager, Role, Permission, User, Session class TestSecurityValidator: """Test security validation functions""" - + def test_validate_domain_valid(self): """Test valid domain validation""" assert SecurityValidator.validate_domain("example.com") assert SecurityValidator.validate_domain("sub.example.com") assert SecurityValidator.validate_domain("test.co.uk") - + def test_validate_domain_invalid(self): """Test invalid domain validation""" assert not SecurityValidator.validate_domain("") assert not SecurityValidator.validate_domain("invalid domain") assert not SecurityValidator.validate_domain("../etc/passwd") assert not SecurityValidator.validate_domain(None) - + def test_validate_ip_valid(self): """Test valid IP validation""" assert SecurityValidator.validate_ip("192.168.1.1") assert SecurityValidator.validate_ip("10.0.0.1") assert SecurityValidator.validate_ip("8.8.8.8") - + def test_validate_ip_invalid(self): """Test invalid IP validation""" assert not SecurityValidator.validate_ip("256.1.1.1") assert not SecurityValidator.validate_ip("invalid") assert not SecurityValidator.validate_ip("") assert not SecurityValidator.validate_ip(None) - + def test_sanitize_path_valid(self): """Test path sanitization with valid paths""" with tempfile.TemporaryDirectory() as tmpdir: base_dir = Path(tmpdir) test_file = base_dir / "test.txt" - + result = SecurityValidator.sanitize_path(str(test_file), base_dir=base_dir) assert result is not None assert result == test_file.resolve() - + def test_sanitize_path_traversal(self): """Test path traversal prevention""" with tempfile.TemporaryDirectory() as tmpdir: base_dir = Path(tmpdir) - + # Test various path traversal attempts - assert SecurityValidator.sanitize_path("../etc/passwd", base_dir=base_dir) is None - assert SecurityValidator.sanitize_path("../../root", base_dir=base_dir) is None - assert SecurityValidator.sanitize_path("/etc/passwd", base_dir=base_dir) is None - + assert ( + SecurityValidator.sanitize_path("../etc/passwd", base_dir=base_dir) + is None + ) + assert ( + SecurityValidator.sanitize_path("../../root", base_dir=base_dir) is None + ) + assert ( + SecurityValidator.sanitize_path("/etc/passwd", base_dir=base_dir) + is None + ) + def test_sanitize_command_arg_valid(self): """Test command argument sanitization with valid args""" assert SecurityValidator.sanitize_command_arg("example.com") == "example.com" assert SecurityValidator.sanitize_command_arg("192.168.1.1") == "192.168.1.1" - assert SecurityValidator.sanitize_command_arg("test-file.txt") == "test-file.txt" - + assert ( + SecurityValidator.sanitize_command_arg("test-file.txt") == "test-file.txt" + ) + def test_sanitize_command_arg_injection(self): """Test command injection prevention""" assert SecurityValidator.sanitize_command_arg("test; rm -rf /") is None @@ -82,24 +92,24 @@ def test_sanitize_command_arg_injection(self): assert SecurityValidator.sanitize_command_arg("test && malicious") is None assert SecurityValidator.sanitize_command_arg("$(whoami)") is None assert SecurityValidator.sanitize_command_arg("`whoami`") is None - + def test_sanitize_log_input(self): """Test log input sanitization""" # Test newline removal result = SecurityValidator.sanitize_log_input("test\ninjection") assert "\n" not in result assert "\r" not in result - + # Test ANSI escape code removal result = SecurityValidator.sanitize_log_input("\x1b[31mred text\x1b[0m") assert "\x1b" not in result - + def test_sanitize_html(self): """Test HTML sanitization""" result = SecurityValidator.sanitize_html("") assert "