Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions ai_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from utils.context_builder import ContextBuilder
from utils.notifier import Notifier


class AIController:
"""Controls AI operations and decision making."""

def __init__(self, session_dir: str, config_path: str):
"""Initialize the AI controller.

Args:
session_dir: Directory to store session data
config_path: Path to configuration file
Expand All @@ -35,7 +36,7 @@ def __init__(self, session_dir: str, config_path: str):
self.notifier = Notifier(session_dir, config_path)
self.report_generator = ReportGenerator(self.session_dir)
self.output_format = "all" # Default output format

# Initialize session data
self.session_data = {
"start_time": datetime.now().isoformat(),
Expand All @@ -44,16 +45,18 @@ def __init__(self, session_dir: str, config_path: str):
"findings": [],
"execution_history": [],
"ai_decisions": [],
"errors": []
"errors": [],
}

async def process_query(self, query: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
async def process_query(
self, query: str, context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Process an AI query.

Args:
query: The query to process
context: Optional context data

Returns:
Dictionary containing the response and metadata
"""
Expand All @@ -62,13 +65,13 @@ async def process_query(self, query: str, context: Optional[Dict[str, Any]] = No
raise ValueError("Query must be a non-empty string")

self.logger.info("Processing AI query: %s", query)

try:
# Prepare context
context_str = ""
if context:
context_str = json.dumps(context, indent=2)

# Create prompt
prompt = f"""You are a security expert. Please provide guidance on the following query:

Expand All @@ -85,113 +88,120 @@ async def process_query(self, query: str, context: Optional[Dict[str, Any]] = No
5. Safety considerations

Provide a detailed response:"""

# Get AI response
response = await self._get_ai_response(prompt)

# Log response
self.logger.info("AI response received")

return {
"answer": response,
# Return an empty log list as the logger doesn't expose in-memory logs
"logs": [],
"prompt": prompt
"prompt": prompt,
}

except Exception as e:
self.logger.error("Error processing query: %s", str(e))
raise

async def _get_ai_response(self, prompt: str) -> str:
"""Get response from AI model.

Args:
prompt: The prompt to send to the AI

Returns:
The AI's response
"""
try:
# TODO: Implement actual AI model call
# For now, return a placeholder response
return "This is a placeholder response. AI integration pending."

except Exception as e:
self.logger.error("Error getting AI response: %s", str(e))
raise

def _generate_report(self) -> Dict[str, str]:
"""Generate reports for the current session.

Returns:
Dictionary mapping report types to their file paths
"""
self.logger.info("Generating reports in %s format...", self.output_format)

try:
# Prepare context data
context = {
"target": {
"domain": self.session_data.get("target_domain"),
"ip": self.session_data.get("target_ip"),
"scan_time": self.session_data.get("start_time"),
"duration": self._calculate_duration()
"duration": self._calculate_duration(),
},
"modules": self.session_data["modules"],
"tools": self.session_data["tools"],
"vulnerabilities": self.session_data.get("findings", []),
"exploits": self.session_data.get("execution_history", []),
"defensive_measures": self.session_data.get("defensive_measures", []),
"recommendations": self.session_data.get("recommendations", []),
"errors": self.session_data["errors"]
"errors": self.session_data["errors"],
}

# Generate reports
report_paths = self.report_generator.generate_reports(context, self.output_format)

report_paths = self.report_generator.generate_reports(
context, self.output_format
)

self.logger.info("Reports generated successfully")
return report_paths

except Exception as e:
self.logger.error("Error generating reports: %s", str(e))
raise

def _calculate_duration(self) -> str:
"""Calculate session duration.

Returns:
Formatted duration string
"""
start_time = datetime.fromisoformat(self.session_data["start_time"])
end_time = datetime.now()
duration = end_time - start_time

hours = duration.seconds // 3600
minutes = (duration.seconds % 3600) // 60
seconds = duration.seconds % 60

return f"{hours:02d}:{minutes:02d}:{seconds:02d}"

def setup_ai(self) -> bool:
"""Setup AI environment: check Ollama service and model availability."""
try:
if hasattr(self, 'ollama'):
if hasattr(self, "ollama"):
client = self.ollama
else:
from ai_integration import OllamaClient

client = OllamaClient()
if not client.is_available():
self.logger.error("Ollama service not available. Start with: ollama serve")
self.logger.error(
"Ollama service not available. Start with: ollama serve"
)
return False
models = client.list_models()
if not models:
self.logger.info("No models found. Pulling default model...")
if not client.pull_model(client.main_model):
self.logger.error("Failed to pull default model")
return False
self.logger.info("AI setup complete. Available models: %s", [m['name'] for m in models])
self.logger.info(
"AI setup complete. Available models: %s", [m["name"] for m in models]
)
return True
except Exception as e:
self.logger.error("Error in setup_ai: %s", e)
return False
return False
Loading