diff --git a/src/cloudwatch-appsignals-mcp-server/README.md b/src/cloudwatch-appsignals-mcp-server/README.md index b3761df1cf..8581a3b588 100644 --- a/src/cloudwatch-appsignals-mcp-server/README.md +++ b/src/cloudwatch-appsignals-mcp-server/README.md @@ -588,10 +588,35 @@ analyze_canary_failures(canary_name="webapp-erorrpagecanary") ## Recommended Workflows -### 🎯 Primary Audit Workflow (Most Common) +### 🎯 Primary Service Audit Workflow (Interactive Batch Processing) + +#### **For Large Service Lists (>10 services) - Interactive Batch Mode:** 1. **Start with `audit_services()`** - Use wildcard patterns for automatic service discovery -2. **Review findings summary** - Let user choose which issues to investigate further -3. **Deep dive with `auditors="all"`** - For selected services needing root cause analysis + ``` + audit_services(service_targets='[{"Type":"service","Data":{"Service":{"Type":"Service","Name":"*"}}}]') + ``` +2. **System automatically starts batch processing** - Processes first batch of 10 services +3. **🚨 CRITICAL: When findings are discovered in a batch:** + - **STOP processing immediately** + - **Present complete audit findings** to user in clear summary + - **ALWAYS ASK USER TO CHOOSE:** + - **Option A:** Investigate specific finding with `auditors="all"` + - **Option B:** Continue processing with `continue_audit_batch(session_id)` + - **WAIT for user decision** - Never auto-continue when findings exist +4. **✅ When batch has NO findings (all services healthy):** + - **Auto-continue** to next batch with `continue_audit_batch(session_id)` +5. **Repeat** batch processing cycle with user choice at each step when findings exist +6. **When all services processed** - Summarize audit results from all batches + +#### **For Small Service Lists (≤10 services) - Direct Processing:** +1. **Start with `audit_services()`** - Processes all services immediately +2. **Present all audit results** showing summary of findings +3. **🚨 IF FINDINGS EXIST:** Ask user which specific finding to investigate +4. **WAIT for user decision** before performing targeted root cause analysis +5. **Targeted investigation** - Use `auditors="all"` for user-selected finding only + +#### **Available Batch Management Tools:** +- **`continue_audit_batch(session_id)`** - Continue to next batch in active session ### 🔍 SLO Investigation Workflow 1. **Use `get_slo()`** - Understand SLO configuration and thresholds diff --git a/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/batch_processing_utils.py b/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/batch_processing_utils.py new file mode 100644 index 0000000000..3e33d67cb7 --- /dev/null +++ b/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/batch_processing_utils.py @@ -0,0 +1,230 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for interactive batch processing of audit operations.""" + +import json +import uuid +from datetime import datetime, timezone +from loguru import logger +from typing import Any, Dict, List, Optional + + +# Global storage for batch sessions (in production, this would be in a database) +_batch_sessions: Dict[str, Dict[str, Any]] = {} + +# Configuration for session cache limits +MAX_BATCH_SESSIONS = 1 + + +def create_batch_session( + targets: List[Dict[str, Any]], + input_obj: Dict[str, Any], + region: str, + banner: str, + batch_size: int = 10, + auto_complete: Optional[bool] = None, +) -> str: + """Create a new batch processing session. + + Args: + targets: List of all targets to process + input_obj: Base input object for API calls + region: AWS region + banner: Banner text for display + batch_size: Number of targets per batch + auto_complete: If True, process all batches automatically. If False, use interactive mode. + If None, auto-decide based on target count. + + Returns: + Session ID for tracking the batch processing + """ + session_id = str(uuid.uuid4()) + + # Auto-decide batch processing mode if not specified + if auto_complete is None: + auto_complete = len(targets) <= batch_size # Auto-complete for small lists + + # Create batches + batches = [] + for i in range(0, len(targets), batch_size): + batch = targets[i : i + batch_size] + batches.append(batch) + + now = datetime.now(timezone.utc).isoformat() + session = { + 'session_id': session_id, + 'created_at': now, + 'last_activity': now, + 'targets': targets, + 'input_obj': input_obj, + 'batches': batches, + 'current_batch_index': 0, + 'processed_batches': [], + 'failed_batches': [], + 'all_findings': [], + 'auto_complete': auto_complete, + 'status': 'created', + } + + _batch_sessions[session_id] = session + logger.info(f'Created batch session {session_id} with {len(batches)} batches') + + # Clean up excess sessions if we exceed the limit + _cleanup_excess_sessions() + + return session_id + + +def get_batch_session(session_id: str) -> Optional[Dict[str, Any]]: + """Get batch session by ID.""" + return _batch_sessions.get(session_id) + + +def update_batch_session_activity(session_id: str) -> None: + """Update last activity timestamp for session.""" + if session_id in _batch_sessions: + _batch_sessions[session_id]['last_activity'] = datetime.now(timezone.utc).isoformat() + + +def _create_batch_metadata( + session: Dict[str, Any], current_batch: List[Dict[str, Any]] +) -> Dict[str, Any]: + """Create common batch metadata.""" + current_index = session['current_batch_index'] + return { + 'batch_index': current_index + 1, + 'total_batches': len(session['batches']), + 'targets_in_batch': len(current_batch), + 'targets': current_batch, + } + + +def _build_api_input( + session: Dict[str, Any], current_batch: List[Dict[str, Any]] +) -> Dict[str, Any]: + """Build API input object for the current batch.""" + batch_input = { + 'StartTime': datetime.fromtimestamp(session['input_obj']['StartTime'], tz=timezone.utc), + 'EndTime': datetime.fromtimestamp(session['input_obj']['EndTime'], tz=timezone.utc), + 'AuditTargets': current_batch, + } + if 'Auditors' in session['input_obj']: + batch_input['Auditors'] = session['input_obj']['Auditors'] + return batch_input + + +def _update_session_after_batch(session: Dict[str, Any], batch_result: Dict[str, Any]) -> None: + """Update session state after processing a batch.""" + session['current_batch_index'] += 1 + + if batch_result['status'] == 'success': + session['processed_batches'].append(batch_result) + session['all_findings'].extend(batch_result.get('findings', [])) + else: + session['failed_batches'].append(batch_result) + + # Update overall session status + session['status'] = ( + 'completed' if session['current_batch_index'] >= len(session['batches']) else 'in_progress' + ) + + +def process_next_batch(session_id: str, appsignals_client) -> Dict[str, Any]: + """Process the next batch in the session. + + Returns: + Dictionary with batch results and session status + """ + session = get_batch_session(session_id) + if not session: + return {'error': 'Session not found or expired'} + + update_batch_session_activity(session_id) + + current_index = session['current_batch_index'] + batches = session['batches'] + + if current_index >= len(batches): + return {'error': 'No more batches to process', 'status': 'completed'} + + current_batch = batches[current_index] + batch_metadata = _create_batch_metadata(session, current_batch) + + try: + # Build and execute API call + batch_input = _build_api_input(session, current_batch) + response = appsignals_client.list_audit_findings(**batch_input) + + # Create success result + batch_findings = response.get('AuditFindings', []) + batch_result = { + **batch_metadata, + 'findings_count': len(batch_findings), + 'findings': batch_findings, + 'status': 'success', + } + + # Only update session state on success + _update_session_after_batch(session, batch_result) + + except Exception as e: + # Create error result but DON'T update session state + # This allows the same batch to be retried on next call + batch_result = {**batch_metadata, 'error': str(e), 'status': 'failed'} + logger.warning( + f'Batch {batch_metadata["batch_index"]} failed: {str(e)}. Will retry on next call.' + ) + + return batch_result + + +def _cleanup_excess_sessions() -> None: + """Remove oldest sessions if we exceed MAX_BATCH_SESSIONS limit.""" + global _batch_sessions + + if len(_batch_sessions) <= MAX_BATCH_SESSIONS: + return + + # Sort sessions by creation time (oldest first) + sessions_by_age = sorted(_batch_sessions.items(), key=lambda x: x[1].get('created_at', '')) + + # Remove oldest sessions until we're under the limit + excess_count = len(_batch_sessions) - MAX_BATCH_SESSIONS + for i in range(excess_count): + session_id, _ = sessions_by_age[i] + del _batch_sessions[session_id] + + logger.info(f'Cleaned up {excess_count} excess batch sessions') + + +def format_batch_result(batch_result: Dict[str, Any], session: Dict[str, Any]) -> str: + """Format batch processing result for user display with essential information only.""" + batch_index = batch_result['batch_index'] + total_batches = batch_result['total_batches'] + + if batch_result.get('error'): + return f'❌ Batch {batch_index}/{total_batches} failed: {batch_result["error"]}' + + findings_count = len(batch_result.get('findings', [])) + + if findings_count == 0: + status = f'✅ Batch {batch_index}/{total_batches}: {batch_result["targets_in_batch"]} services healthy' + if batch_index < total_batches: + status += f" | Continue: continue_audit_batch('{session['session_id']}')" + return status + + # Keep full JSON for MCP observation when findings exist + findings_json = json.dumps(batch_result['findings'], indent=2, default=str) + return f'⚠️ Batch {batch_index}/{total_batches}: {findings_count} findings\n```\n{findings_json}\n```' diff --git a/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/batch_tools.py b/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/batch_tools.py new file mode 100644 index 0000000000..2bd2c72f5e --- /dev/null +++ b/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/batch_tools.py @@ -0,0 +1,64 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Batch processing tools for interactive audit workflows.""" + +from .aws_clients import appsignals_client +from .batch_processing_utils import ( + format_batch_result, + get_batch_session, + process_next_batch, +) +from loguru import logger +from pydantic import Field + + +async def continue_audit_batch( + batch_session_id: str = Field( + ..., description='Session ID from previous batch processing to continue' + ), +) -> str: + """Continue processing the next batch in an active audit session. + + **INTERACTIVE BATCH PROCESSING TOOL** + Use this tool to continue processing the next batch of targets in an ongoing audit session. + + **WHEN TO USE:** + - **When there are no findings from the last batch** - Services appear healthy, continue to next batch + - **When customer wants to continue processing next batch** + + **RETURNS:** + - Results from the next batch with progress information + - Full JSON findings for MCP observation and service name extraction + - Continuation instructions if more batches remain + - Error message if session is invalid or expired + """ + try: + batch_result = process_next_batch(batch_session_id, appsignals_client) + session = get_batch_session(batch_session_id) + + if batch_result.get('error'): + return f'Error: {batch_result["error"]}' + + if not session: + return 'Error: Session not found or expired' + + # Format and return batch result + formatted_result = format_batch_result(batch_result, session) + + return formatted_result + + except Exception as e: + logger.error(f'Error in continue_audit_batch: {e}', exc_info=True) + return f'Error: {str(e)}' diff --git a/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/server.py b/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/server.py index f650c7e7d1..6cf2ce09ec 100644 --- a/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/server.py +++ b/src/cloudwatch-appsignals-mcp-server/awslabs/cloudwatch_appsignals_mcp_server/server.py @@ -33,6 +33,15 @@ s3_client, synthetics_client, ) +from .batch_processing_utils import ( + create_batch_session, + format_batch_result, + get_batch_session, + process_next_batch, +) +from .batch_tools import ( + continue_audit_batch, +) from .canary_utils import ( analyze_canary_logs_with_time_window, analyze_har_file, @@ -166,6 +175,11 @@ async def audit_services( ) -> str: """PRIMARY SERVICE AUDIT TOOL - The #1 tool for comprehensive AWS service health auditing and monitoring. + **🚨 CRITICAL: When findings are discovered, ALWAYS ASK USER TO CHOOSE:** + - **Option A: Investigate specific finding** (deep dive with auditors="all") + - **Option B: Continue processing** (use continue_audit_batch()) + **DO NOT auto-continue when findings exist - WAIT for user decision** + **IMPORTANT: For operation-specific auditing, use audit_service_operations() as the PRIMARY tool instead.** **USE THIS FIRST FOR ALL SERVICE-LEVEL AUDITING TASKS** @@ -263,11 +277,16 @@ async def audit_services( **TYPICAL SERVICE AUDIT WORKFLOWS:** 1. **Basic Service Audit** (most common): - Call `audit_services()` with service targets - automatically discovers services when using wildcard patterns + - For large service lists (>5 services), uses interactive batch processing + - Returns results of each batch + - If services in batch look normal, continue with `continue_audit_batch()` + - If audit returns findings/issues, ASK USER TO CHOOSE: investigate or continue - Uses default fast auditors (slo,operation_metric) for quick health overview - Supports wildcard patterns like `*` or `*payment*` for automatic service discovery 2. **Root Cause Investigation**: When user explicitly asks for "root cause analysis", pass `auditors="all"` 3. **Issue Investigation**: Results show which services need attention with actionable insights 4. **Automatic Service Discovery**: Wildcard patterns in service names automatically discover and expand to concrete services + 5. **Interactive Batch Processing**: For large target lists, process in batches and interactively ask user to proceed **AUDIT RESULTS INCLUDE:** - **Prioritized findings** by severity (critical, warning, info) @@ -279,19 +298,34 @@ async def audit_services( **IMPORTANT: This tool provides comprehensive service audit coverage and should be your first choice for any service auditing task.** **RECOMMENDED WORKFLOW - PRESENT FINDINGS FIRST:** - When the audit returns multiple findings or issues, follow this workflow: + **For Large Service Lists (>5 services):** + 1. **audit_services() automatically starts interactive batch processing** + 2. **System processes first batch and shows full JSON findings** for MCP observation + 3. **🚨 IF BATCH HAS FINDINGS/ISSUES - MANDATORY WORKFLOW:** + 1. **STOP processing immediately** + 2. **Present complete audit findings** in clear summary format + 3. **ASK USER TO CHOOSE:** Investigate specific finding OR continue processing + 4. **WAIT for user decision** - Do not auto-continue + 5. **If investigating:** Follow-up call with `auditors="all"` for user-selected finding only + 6. **If continuing:** Use `continue_audit_batch()` to process next batch + 4. **✅ If batch looks healthy (NO findings):** + - Use `continue_audit_batch(session_id)` to process next batch immediately + 5. **When all services processed, conclude the audit results in all batches** + + **For Small Service Lists (≤5 services):** 1. **Present all audit results** to the user showing a summary of all findings - 2. **Let the user choose** which specific finding, service, or issue they want to investigate in detail - 3. **Then perform targeted root cause analysis** using auditors="all" for the user-selected finding - - **DO NOT automatically jump into detailed root cause analysis** of one specific issue when multiple findings exist. - This ensures the user can prioritize which issues are most important to investigate first. + 2. **🚨 IF FINDINGS EXIST:** Ask user to choose which specific finding to investigate + 3. **WAIT for user decision** before performing targeted root cause analysis + 4. **Then perform targeted root cause analysis** using auditors="all" for user-selected finding **Example workflow:** - - First call: `audit_services()` with default auditors for overview - - Present findings summary to user - - User selects specific service/issue to investigate - - Follow-up call: `audit_services()` with `auditors="all"` for selected service only + 1. First call: `audit_services()` with default auditors for overview → Returns batch session with first batch results and findings + 2. **🚨 IF FINDINGS DISCOVERED:** Present findings summary to user and ask for choice + 3. **WAIT for user decision:** User selects specific service/issue to investigate OR user chooses to continue processing + 4. **If investigating:** Follow-up call: `audit_services()` with `auditors="all"` for selected service only + 5. **If continuing:** `continue_audit_batch(session_id)` → Process next batch + 6. **Repeat:** Continue batch processing cycle with user choice at each step when findings exist + 7. **Conclude:** When all services processed, summarize audit results from all batches """ start_time_perf = timer() logger.debug('Starting audit_services (PRIMARY SERVICE AUDIT TOOL)') @@ -393,12 +427,43 @@ async def audit_services( if auditors_list: input_obj['Auditors'] = auditors_list - # Execute audit API using shared utility - result = await execute_audit_api(input_obj, region, banner) + # Interactive Batch Processing Logic + if len(normalized_targets) > BATCH_SIZE_THRESHOLD: + # Create interactive batch session for large target lists + session_id = create_batch_session( + targets=normalized_targets, + input_obj=input_obj, + region=region, + banner=banner, + batch_size=BATCH_SIZE_THRESHOLD, + auto_complete=None, # Auto-decide based on target count + ) - elapsed = timer() - start_time_perf - logger.debug(f'audit_services completed in {elapsed:.3f}s (region={region})') - return result + # Process first batch + batch_result = process_next_batch(session_id, appsignals_client) + session = get_batch_session(session_id) + + if batch_result.get('error'): + return f'Error processing first batch: {batch_result["error"]}' + + # Format and return interactive batch result + if session is not None: + formatted_result = banner + format_batch_result(batch_result, session) + else: + formatted_result = banner + 'Error: Session not found for batch processing' + + elapsed = timer() - start_time_perf + logger.debug( + f'audit_services (interactive batch mode) completed first batch in {elapsed:.3f}s (region={region})' + ) + return formatted_result + else: + # Execute audit API using shared utility for small target lists + result = await execute_audit_api(input_obj, region, banner) + + elapsed = timer() - start_time_perf + logger.debug(f'audit_services completed in {elapsed:.3f}s (region={region})') + return result except Exception as e: logger.error(f'Unexpected error in audit_services: {e}', exc_info=True) @@ -1338,6 +1403,9 @@ async def analyze_canary_failures(canary_name: str, region: str = AWS_REGION) -> mcp.tool()(list_slis) mcp.tool()(analyze_canary_failures) +# Register batch processing tools +mcp.tool()(continue_audit_batch) + def main(): """Run the MCP server.""" diff --git a/src/cloudwatch-appsignals-mcp-server/tests/test_batch_processing.py b/src/cloudwatch-appsignals-mcp-server/tests/test_batch_processing.py new file mode 100644 index 0000000000..342d462428 --- /dev/null +++ b/src/cloudwatch-appsignals-mcp-server/tests/test_batch_processing.py @@ -0,0 +1,249 @@ +"""Tests for batch processing functionality with proper mocking.""" + +import json +import pytest +from awslabs.cloudwatch_appsignals_mcp_server.server import audit_services +from unittest.mock import MagicMock, patch + + +@pytest.fixture +def mock_appsignals_client(): + """Create a properly mocked appsignals client.""" + mock_client = MagicMock() + mock_client.list_audit_findings.return_value = {'AuditFindings': []} + return mock_client + + +@pytest.mark.asyncio +async def test_audit_services_batch_processing_success(mock_appsignals_client): + """Test audit_services triggers batch processing for large target lists.""" + # Create 12 targets to exceed AUDIT_SERVICE_BATCH_SIZE_THRESHOLD (5) + service_targets = json.dumps( + [ + { + 'Type': 'service', + 'Data': { + 'Service': { + 'Type': 'Service', + 'Name': f'test-service-{i}', + 'Environment': 'eks:test', + } + }, + } + for i in range(12) + ] + ) + + with ( + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.service_audit_utils.validate_and_enrich_service_targets' + ) as mock_validate, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.server.appsignals_client', + mock_appsignals_client, + ), + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.aws_clients.appsignals_client', + mock_appsignals_client, + ), + ): + mock_validate.return_value = json.loads(service_targets) + + result = await audit_services( + service_targets=service_targets, start_time=None, end_time=None, auditors='slo' + ) + + # Verify batch processing was triggered + assert '📦 Batching: Processing 12 targets in batches of 5' in result + + +@pytest.mark.asyncio +async def test_audit_services_batch_processing_error(mock_appsignals_client): + """Test audit_services batch processing error handling.""" + # Set up client mock to raise an exception + mock_appsignals_client.list_audit_findings.side_effect = Exception( + 'Failed to process batch due to API error' + ) + + # Create 11 targets to exceed threshold + service_targets = json.dumps( + [ + { + 'Type': 'service', + 'Data': { + 'Service': { + 'Type': 'Service', + 'Name': f'test-service-{i}', + 'Environment': 'eks:test', + } + }, + } + for i in range(11) + ] + ) + + with ( + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.service_audit_utils.validate_and_enrich_service_targets' + ) as mock_validate, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.server.appsignals_client', + mock_appsignals_client, + ), + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.aws_clients.appsignals_client', + mock_appsignals_client, + ), + ): + mock_validate.return_value = json.loads(service_targets) + + result = await audit_services( + service_targets=service_targets, start_time=None, end_time=None, auditors='slo' + ) + + assert 'Error processing first batch: Failed to process batch due to API error' in result + + +@pytest.mark.asyncio +async def test_audit_services_batch_processing_session_not_found(mock_appsignals_client): + """Test audit_services batch processing when session is not found.""" + # Create 11 targets to exceed threshold + service_targets = json.dumps( + [ + { + 'Type': 'service', + 'Data': { + 'Service': { + 'Type': 'Service', + 'Name': f'test-service-{i}', + 'Environment': 'eks:test', + } + }, + } + for i in range(11) + ] + ) + + with ( + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.service_audit_utils.validate_and_enrich_service_targets' + ) as mock_validate, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.get_batch_session' + ) as mock_get_session, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.server.appsignals_client', + mock_appsignals_client, + ), + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.aws_clients.appsignals_client', + mock_appsignals_client, + ), + ): + mock_validate.return_value = json.loads(service_targets) + mock_get_session.return_value = None # Session not found + + result = await audit_services( + service_targets=service_targets, start_time=None, end_time=None, auditors='slo' + ) + + assert 'Session not found' in result + + +@pytest.mark.asyncio +async def test_audit_services_session_not_found_error(mock_appsignals_client): + """Test audit_services when get_batch_session returns None (line 456).""" + service_targets = json.dumps( + [ + { + 'Type': 'service', + 'Data': { + 'Service': { + 'Type': 'Service', + 'Name': f'test-service-{i}', + 'Environment': 'eks:test', + } + }, + } + for i in range(11) # Exceed threshold + ] + ) + + with ( + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.service_audit_utils.validate_and_enrich_service_targets' + ) as mock_validate, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.get_batch_session', + return_value=None, + ), + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.server.appsignals_client', + mock_appsignals_client, + ), + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.aws_clients.appsignals_client', + mock_appsignals_client, + ), + ): + mock_validate.return_value = json.loads(service_targets) + + result = await audit_services( + service_targets=service_targets, start_time=None, end_time=None, auditors='slo' + ) + + assert 'Error processing first batch: Session not found or expired' in result + + +@pytest.mark.asyncio +async def test_audit_services_small_target_list_no_batching(mock_appsignals_client): + """Test audit_services with small target list - non-batch path (lines 465-469).""" + service_targets = json.dumps( + [ + { + 'Type': 'service', + 'Data': { + 'Service': { + 'Type': 'Service', + 'Name': 'test-service-1', + 'Environment': 'eks:test', + } + }, + }, + { + 'Type': 'service', + 'Data': { + 'Service': { + 'Type': 'Service', + 'Name': 'test-service-2', + 'Environment': 'eks:test', + } + }, + }, + ] + ) # Only 2 targets, below threshold + + mock_appsignals_client.list_audit_findings.return_value = {'AuditFindings': []} + + with ( + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.service_audit_utils.validate_and_enrich_service_targets' + ) as mock_validate, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.server.appsignals_client', + mock_appsignals_client, + ), + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.aws_clients.appsignals_client', + mock_appsignals_client, + ), + ): + mock_validate.return_value = json.loads(service_targets) + + result = await audit_services( + service_targets=service_targets, start_time=None, end_time=None, auditors='slo' + ) + + assert '[MCP-SERVICE] Application Signals Service Audit' in result + assert '📦 Batching:' not in result # No batching message + mock_appsignals_client.list_audit_findings.assert_called_once() diff --git a/src/cloudwatch-appsignals-mcp-server/tests/test_batch_processing_utils.py b/src/cloudwatch-appsignals-mcp-server/tests/test_batch_processing_utils.py new file mode 100644 index 0000000000..d5e61381dc --- /dev/null +++ b/src/cloudwatch-appsignals-mcp-server/tests/test_batch_processing_utils.py @@ -0,0 +1,1320 @@ +"""Tests for batch processing utilities.""" + +import pytest +import uuid +from awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils import ( + _batch_sessions, + _build_api_input, + _cleanup_excess_sessions, + _create_batch_metadata, + _update_session_after_batch, + create_batch_session, + format_batch_result, + get_batch_session, + process_next_batch, + update_batch_session_activity, +) +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + + +@pytest.fixture +def sample_targets(): + """Sample targets for testing.""" + return [ + {'Type': 'service', 'Data': {'Service': {'Type': 'Service', 'Name': f'service-{i}'}}} + for i in range(1, 16) # 15 services + ] + + +@pytest.fixture +def sample_input_obj(): + """Sample input object for testing.""" + return { + 'StartTime': datetime.now(timezone.utc).timestamp(), + 'EndTime': datetime.now(timezone.utc).timestamp(), + 'Auditors': 'slo,operation_metric', + } + + +@pytest.fixture +def mock_appsignals_client(): + """Mock Application Signals client.""" + client = MagicMock() + client.list_audit_findings.return_value = { + 'AuditFindings': [ + { + 'FindingId': 'finding-1', + 'Severity': 'CRITICAL', + 'Title': 'High error rate detected', + 'Description': 'Service experiencing elevated error rates', + } + ] + } + return client + + +class TestCreateBatchSession: + """Test create_batch_session function.""" + + def test_create_batch_session_basic(self, sample_targets, sample_input_obj): + """Test basic batch session creation.""" + session_id = create_batch_session( + targets=sample_targets[:5], # Small list + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=10, + ) + + assert isinstance(session_id, str) + assert len(session_id) == 36 # UUID length + + session = get_batch_session(session_id) + assert session is not None + assert session['session_id'] == session_id + assert len(session['targets']) == 5 + assert len(session['batches']) == 1 # All targets fit in one batch + assert session['current_batch_index'] == 0 + assert session['auto_complete'] is True # Small list auto-completes + assert session['status'] == 'created' + + def test_create_batch_session_large_list(self, sample_targets, sample_input_obj): + """Test batch session creation with large target list.""" + session_id = create_batch_session( + targets=sample_targets, # 15 services + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=5, + ) + + session = get_batch_session(session_id) + assert session is not None + assert len(session['batches']) == 3 # 15 targets / 5 batch_size = 3 batches + assert session['auto_complete'] is False # Large list uses interactive mode + assert len(session['batches'][0]) == 5 + assert len(session['batches'][1]) == 5 + assert len(session['batches'][2]) == 5 + + def test_create_batch_session_explicit_auto_complete(self, sample_targets, sample_input_obj): + """Test batch session creation with explicit auto_complete setting.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=10, + auto_complete=False, # Force interactive mode + ) + + session = get_batch_session(session_id) + assert session is not None + assert session['auto_complete'] is False + + def test_create_batch_session_without_auditors(self, sample_targets): + """Test batch session creation without auditors in input.""" + input_obj = { + 'StartTime': datetime.now(timezone.utc).timestamp(), + 'EndTime': datetime.now(timezone.utc).timestamp(), + } + + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + assert 'Auditors' not in session['input_obj'] + + def test_create_batch_session_uneven_batches(self, sample_input_obj): + """Test batch session creation with uneven batch sizes.""" + targets = [ + {'Type': 'service', 'Data': {'Service': {'Type': 'Service', 'Name': f'service-{i}'}}} + for i in range(1, 8) # 7 services + ] + + session_id = create_batch_session( + targets=targets, + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=3, + ) + + session = get_batch_session(session_id) + assert session is not None + assert len(session['batches']) == 3 # 7 targets / 3 batch_size = 3 batches + assert len(session['batches'][0]) == 3 + assert len(session['batches'][1]) == 3 + assert len(session['batches'][2]) == 1 # Remainder + + +class TestGetBatchSession: + """Test get_batch_session function.""" + + def test_get_batch_session_exists(self, sample_targets, sample_input_obj): + """Test getting an existing batch session.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + assert session['session_id'] == session_id + + def test_get_batch_session_not_exists(self): + """Test getting a non-existent batch session.""" + fake_session_id = str(uuid.uuid4()) + session = get_batch_session(fake_session_id) + assert session is None + + +class TestUpdateBatchSessionActivity: + """Test update_batch_session_activity function.""" + + def test_update_batch_session_activity_exists(self, sample_targets, sample_input_obj): + """Test updating activity for existing session.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + original_activity = session['last_activity'] + + # Small delay to ensure timestamp difference + import time + + time.sleep(0.01) + + update_batch_session_activity(session_id) + + session = get_batch_session(session_id) + assert session is not None + updated_activity = session['last_activity'] + assert updated_activity != original_activity + + def test_update_batch_session_activity_not_exists(self): + """Test updating activity for non-existent session.""" + fake_session_id = str(uuid.uuid4()) + # Should not raise an exception + update_batch_session_activity(fake_session_id) + + +class TestCreateBatchMetadata: + """Test _create_batch_metadata function.""" + + def test_create_batch_metadata(self, sample_targets, sample_input_obj): + """Test creating batch metadata.""" + session_id = create_batch_session( + targets=sample_targets, + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=5, + ) + + session = get_batch_session(session_id) + assert session is not None + current_batch = session['batches'][0] + + metadata = _create_batch_metadata(session, current_batch) + + assert metadata['batch_index'] == 1 # 1-based indexing + assert metadata['total_batches'] == 3 + assert metadata['targets_in_batch'] == 5 + assert metadata['targets'] == current_batch + + +class TestBuildApiInput: + """Test _build_api_input function.""" + + def test_build_api_input_with_auditors(self, sample_targets, sample_input_obj): + """Test building API input with auditors.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + current_batch = session['batches'][0] + + api_input = _build_api_input(session, current_batch) + + assert 'StartTime' in api_input + assert 'EndTime' in api_input + assert 'AuditTargets' in api_input + assert 'Auditors' in api_input + assert api_input['AuditTargets'] == current_batch + assert api_input['Auditors'] == 'slo,operation_metric' + assert isinstance(api_input['StartTime'], datetime) + assert isinstance(api_input['EndTime'], datetime) + + def test_build_api_input_without_auditors(self, sample_targets): + """Test building API input without auditors.""" + input_obj = { + 'StartTime': datetime.now(timezone.utc).timestamp(), + 'EndTime': datetime.now(timezone.utc).timestamp(), + } + + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + current_batch = session['batches'][0] + + api_input = _build_api_input(session, current_batch) + + assert 'StartTime' in api_input + assert 'EndTime' in api_input + assert 'AuditTargets' in api_input + assert 'Auditors' not in api_input + + +class TestUpdateSessionAfterBatch: + """Test _update_session_after_batch function.""" + + def test_update_session_after_batch_success(self, sample_targets, sample_input_obj): + """Test updating session after successful batch processing.""" + session_id = create_batch_session( + targets=sample_targets, + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=5, + ) + + session = get_batch_session(session_id) + assert session is not None + original_index = session['current_batch_index'] + + batch_result = { + 'batch_index': 1, + 'total_batches': 3, + 'targets_in_batch': 5, + 'findings_count': 2, + 'findings': [{'finding': 'test'}], + 'status': 'success', + } + + _update_session_after_batch(session, batch_result) + + assert session['current_batch_index'] == original_index + 1 + assert len(session['processed_batches']) == 1 + assert session['processed_batches'][0] == batch_result + assert len(session['all_findings']) == 1 + assert session['status'] == 'in_progress' + + def test_update_session_after_batch_failure(self, sample_targets, sample_input_obj): + """Test updating session after failed batch processing.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 3, + 'error': 'API error', + 'status': 'failed', + } + + _update_session_after_batch(session, batch_result) + + assert session['current_batch_index'] == 1 + assert len(session['failed_batches']) == 1 + assert session['failed_batches'][0] == batch_result + assert len(session['all_findings']) == 0 + assert session['status'] == 'completed' # Single batch completed + + def test_update_session_after_batch_completion(self, sample_targets, sample_input_obj): + """Test session status when all batches are processed.""" + session_id = create_batch_session( + targets=sample_targets[:5], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=5, + ) + + session = get_batch_session(session_id) + assert session is not None + + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 5, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + } + + _update_session_after_batch(session, batch_result) + + assert session['status'] == 'completed' + + +class TestProcessNextBatch: + """Test process_next_batch function.""" + + def test_process_next_batch_success( + self, sample_targets, sample_input_obj, mock_appsignals_client + ): + """Test successful batch processing.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + result = process_next_batch(session_id, mock_appsignals_client) + + assert result['status'] == 'success' + assert result['batch_index'] == 1 + assert result['total_batches'] == 1 + assert result['targets_in_batch'] == 3 + assert result['findings_count'] == 1 + assert len(result['findings']) == 1 + + # Verify session was updated + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 + assert session['status'] == 'completed' + + def test_process_next_batch_api_error(self, sample_targets, sample_input_obj): + """Test batch processing with API error - batch index should NOT advance for retry.""" + mock_client = MagicMock() + mock_client.list_audit_findings.side_effect = Exception('API error') + + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + result = process_next_batch(session_id, mock_client) + + assert result['status'] == 'failed' + assert result['error'] == 'API error' + assert result['batch_index'] == 1 + + # Verify session was NOT updated (for retry capability) + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 0 # Should NOT advance on failure + assert len(session['failed_batches']) == 0 # Should NOT be added to failed_batches + assert len(session['processed_batches']) == 0 + assert session['status'] == 'created' # Should remain in original status + + def test_process_next_batch_session_not_found(self, mock_appsignals_client): + """Test batch processing with non-existent session.""" + fake_session_id = str(uuid.uuid4()) + + result = process_next_batch(fake_session_id, mock_appsignals_client) + + assert 'error' in result + assert 'Session not found or expired' in result['error'] + + def test_process_next_batch_no_more_batches( + self, sample_targets, sample_input_obj, mock_appsignals_client + ): + """Test batch processing when no more batches remain.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + # Process the first (and only) batch + process_next_batch(session_id, mock_appsignals_client) + + # Try to process again + result = process_next_batch(session_id, mock_appsignals_client) + + assert 'error' in result + assert 'No more batches to process' in result['error'] + assert result['status'] == 'completed' + + def test_process_next_batch_multiple_batches( + self, sample_targets, sample_input_obj, mock_appsignals_client + ): + """Test processing multiple batches sequentially.""" + session_id = create_batch_session( + targets=sample_targets[:10], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=3, + ) + + # Process first batch + result1 = process_next_batch(session_id, mock_appsignals_client) + assert result1['batch_index'] == 1 + assert result1['total_batches'] == 4 # 10 targets / 3 batch_size = 4 batches + + # Process second batch + result2 = process_next_batch(session_id, mock_appsignals_client) + assert result2['batch_index'] == 2 + assert result2['total_batches'] == 4 + + # Verify session state + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 2 + assert session['status'] == 'in_progress' + + def test_process_next_batch_retry_after_failure(self, sample_targets, sample_input_obj): + """Test that failed batches can be retried by calling process_next_batch again.""" + # Mock client that fails first, then succeeds + mock_client = MagicMock() + mock_client.list_audit_findings.side_effect = [ + Exception('Network timeout'), # First call fails + {'AuditFindings': [{'FindingId': 'finding-1'}]}, # Second call succeeds + ] + + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Retry Test', + ) + + # First attempt - should fail + result1 = process_next_batch(session_id, mock_client) + assert result1['status'] == 'failed' + assert result1['error'] == 'Network timeout' + assert result1['batch_index'] == 1 + + # Verify session state unchanged (ready for retry) + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 0 # Still at first batch + assert session['status'] == 'created' # Status unchanged + assert len(session['failed_batches']) == 0 # No failed batches recorded + assert len(session['processed_batches']) == 0 + + # Second attempt - should succeed (retry same batch) + result2 = process_next_batch(session_id, mock_client) + assert result2['status'] == 'success' + assert result2['batch_index'] == 1 # Same batch index + assert result2['findings_count'] == 1 + + # Verify session state updated after success + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 # Now advanced + assert session['status'] == 'completed' # Single batch completed + assert len(session['processed_batches']) == 1 # Success recorded + assert len(session['failed_batches']) == 0 # No failures recorded + + def test_process_next_batch_multiple_retry_attempts(self, sample_targets, sample_input_obj): + """Test multiple retry attempts for the same batch.""" + # Mock client that fails multiple times, then succeeds + mock_client = MagicMock() + mock_client.list_audit_findings.side_effect = [ + Exception('Connection timeout'), # Attempt 1: fail + Exception('Rate limit exceeded'), # Attempt 2: fail + Exception('Service unavailable'), # Attempt 3: fail + {'AuditFindings': []}, # Attempt 4: succeed + ] + + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Multiple Retry Test', + ) + + # Attempt 1: Connection timeout + result1 = process_next_batch(session_id, mock_client) + assert result1['status'] == 'failed' + assert result1['error'] == 'Connection timeout' + assert result1['batch_index'] == 1 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 0 # No advancement + + # Attempt 2: Rate limit + result2 = process_next_batch(session_id, mock_client) + assert result2['status'] == 'failed' + assert result2['error'] == 'Rate limit exceeded' + assert result2['batch_index'] == 1 # Same batch + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 0 # Still no advancement + + # Attempt 3: Service unavailable + result3 = process_next_batch(session_id, mock_client) + assert result3['status'] == 'failed' + assert result3['error'] == 'Service unavailable' + assert result3['batch_index'] == 1 # Same batch + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 0 # Still no advancement + + # Attempt 4: Success + result4 = process_next_batch(session_id, mock_client) + assert result4['status'] == 'success' + assert result4['batch_index'] == 1 # Same batch + assert result4['findings_count'] == 0 + + # Verify final session state + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 # Finally advanced + assert session['status'] == 'completed' + assert len(session['processed_batches']) == 1 + assert len(session['failed_batches']) == 0 # No failures recorded in session + + def test_process_next_batch_retry_in_multi_batch_session( + self, sample_targets, sample_input_obj + ): + """Test retry behavior in a multi-batch session.""" + # Mock client: batch 1 succeeds, batch 2 fails then succeeds, batch 3 succeeds + mock_client = MagicMock() + mock_client.list_audit_findings.side_effect = [ + {'AuditFindings': [{'FindingId': 'batch1-finding'}]}, # Batch 1: success + Exception('Temporary failure'), # Batch 2: fail + {'AuditFindings': [{'FindingId': 'batch2-finding'}]}, # Batch 2 retry: success + {'AuditFindings': []}, # Batch 3: success + ] + + session_id = create_batch_session( + targets=sample_targets[:9], # 9 targets = 3 batches of 3 + input_obj=sample_input_obj, + region='us-east-1', + banner='Multi-batch Retry Test', + batch_size=3, + ) + + # Process batch 1 (success) + result1 = process_next_batch(session_id, mock_client) + assert result1['status'] == 'success' + assert result1['batch_index'] == 1 + assert result1['findings_count'] == 1 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 # Advanced to batch 2 + + # Process batch 2 (failure) + result2 = process_next_batch(session_id, mock_client) + assert result2['status'] == 'failed' + assert result2['error'] == 'Temporary failure' + assert result2['batch_index'] == 2 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 # Still at batch 2 (no advancement) + + # Retry batch 2 (success) + result3 = process_next_batch(session_id, mock_client) + assert result3['status'] == 'success' + assert result3['batch_index'] == 2 # Same batch index + assert result3['findings_count'] == 1 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 2 # Now advanced to batch 3 + + # Process batch 3 (success) + result4 = process_next_batch(session_id, mock_client) + assert result4['status'] == 'success' + assert result4['batch_index'] == 3 + assert result4['findings_count'] == 0 + + # Verify final session state + session = get_batch_session(session_id) + assert session is not None + assert session['status'] == 'completed' + assert len(session['processed_batches']) == 3 # All batches successful + assert len(session['failed_batches']) == 0 # No failures recorded + assert len(session['all_findings']) == 2 # Findings from batch 1 and 2 + + +class TestFormatBatchResult: + """Test format_batch_result function.""" + + def test_format_batch_result_error(self, sample_targets, sample_input_obj): + """Test formatting batch result with error.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 3, + 'error': 'API connection failed', + 'status': 'failed', + } + + formatted = format_batch_result(batch_result, session) + + assert '❌ Batch 1/1 failed: API connection failed' == formatted + + def test_format_batch_result_healthy_with_continuation(self, sample_targets, sample_input_obj): + """Test formatting healthy batch result with continuation instruction.""" + session_id = create_batch_session( + targets=sample_targets[:10], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + batch_size=5, + ) + + session = get_batch_session(session_id) + assert session is not None + batch_result = { + 'batch_index': 1, + 'total_batches': 2, + 'targets_in_batch': 5, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + } + + formatted = format_batch_result(batch_result, session) + + expected = ( + f"✅ Batch 1/2: 5 services healthy | Continue: continue_audit_batch('{session_id}')" + ) + assert formatted == expected + + def test_format_batch_result_healthy_final_batch(self, sample_targets, sample_input_obj): + """Test formatting healthy final batch result.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 3, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + } + + formatted = format_batch_result(batch_result, session) + + assert formatted == '✅ Batch 1/1: 3 services healthy' + + def test_format_batch_result_with_findings(self, sample_targets, sample_input_obj): + """Test formatting batch result with findings.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + findings = [ + { + 'FindingId': 'finding-1', + 'Severity': 'CRITICAL', + 'Title': 'High error rate', + 'Description': 'Service experiencing errors', + }, + { + 'FindingId': 'finding-2', + 'Severity': 'WARNING', + 'Title': 'Elevated latency', + 'Description': 'Response times are high', + }, + ] + + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 3, + 'findings_count': 2, + 'findings': findings, + 'status': 'success', + } + + formatted = format_batch_result(batch_result, session) + + assert formatted.startswith('⚠️ Batch 1/1: 2 findings') + assert '```' in formatted # JSON formatting + assert 'finding-1' in formatted + assert 'CRITICAL' in formatted + + def test_format_batch_result_missing_findings_key(self, sample_targets, sample_input_obj): + """Test formatting batch result when findings key is missing.""" + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Test Banner', + ) + + session = get_batch_session(session_id) + assert session is not None + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 3, + 'findings_count': 0, + 'status': 'success', + # 'findings' key is missing + } + + formatted = format_batch_result(batch_result, session) + + assert formatted == '✅ Batch 1/1: 3 services healthy' + + +class TestIntegration: + """Integration tests for batch processing workflow.""" + + def test_full_batch_processing_workflow( + self, sample_targets, sample_input_obj, mock_appsignals_client + ): + """Test complete batch processing workflow.""" + # Create session with multiple batches + session_id = create_batch_session( + targets=sample_targets[:7], # 7 targets + input_obj=sample_input_obj, + region='us-east-1', + banner='Integration Test', + batch_size=3, # Will create 3 batches: [3, 3, 1] + ) + + session = get_batch_session(session_id) + assert session is not None + assert len(session['batches']) == 3 + assert session['status'] == 'created' + + # Process first batch + result1 = process_next_batch(session_id, mock_appsignals_client) + assert result1['status'] == 'success' + assert result1['batch_index'] == 1 + + session = get_batch_session(session_id) + assert session is not None + assert session['status'] == 'in_progress' + assert len(session['processed_batches']) == 1 + + # Process second batch + result2 = process_next_batch(session_id, mock_appsignals_client) + assert result2['status'] == 'success' + assert result2['batch_index'] == 2 + + # Process final batch + result3 = process_next_batch(session_id, mock_appsignals_client) + assert result3['status'] == 'success' + assert result3['batch_index'] == 3 + assert result3['targets_in_batch'] == 1 # Final batch has 1 target + + session = get_batch_session(session_id) + assert session is not None + assert session['status'] == 'completed' + assert len(session['processed_batches']) == 3 + assert len(session['all_findings']) == 3 # 1 finding per batch + + # Try to process again (should fail) + result4 = process_next_batch(session_id, mock_appsignals_client) + assert 'error' in result4 + assert 'No more batches to process' in result4['error'] + + def test_batch_processing_with_mixed_results_and_retry(self, sample_targets, sample_input_obj): + """Test batch processing with mixed success and failure results, including retry behavior.""" + # Mock client: batch 1 succeeds, batch 2 fails then succeeds on retry, batch 3 succeeds + mock_client = MagicMock() + mock_client.list_audit_findings.side_effect = [ + {'AuditFindings': [{'FindingId': 'finding-1'}]}, # Batch 1: success + Exception('Network error'), # Batch 2: failure + {'AuditFindings': [{'FindingId': 'finding-2'}]}, # Batch 2 retry: success + {'AuditFindings': []}, # Batch 3: success with no findings + ] + + session_id = create_batch_session( + targets=sample_targets[:9], + input_obj=sample_input_obj, + region='us-east-1', + banner='Mixed Results Test', + batch_size=3, + ) + + # Process first batch (success) + result1 = process_next_batch(session_id, mock_client) + assert result1['status'] == 'success' + assert result1['findings_count'] == 1 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 # Advanced to batch 2 + + # Process second batch (failure - should NOT advance) + result2 = process_next_batch(session_id, mock_client) + assert result2['status'] == 'failed' + assert result2['error'] == 'Network error' + assert result2['batch_index'] == 2 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 1 # Should NOT advance on failure + assert len(session['failed_batches']) == 0 # Should NOT be recorded as failed + + # Retry second batch (success - should advance) + result3 = process_next_batch(session_id, mock_client) + assert result3['status'] == 'success' + assert result3['batch_index'] == 2 # Same batch index + assert result3['findings_count'] == 1 + + session = get_batch_session(session_id) + assert session is not None + assert session['current_batch_index'] == 2 # Now advanced to batch 3 + + # Process third batch (success, no findings) + result4 = process_next_batch(session_id, mock_client) + assert result4['status'] == 'success' + assert result4['batch_index'] == 3 + assert result4['findings_count'] == 0 + + # Verify final session state + session = get_batch_session(session_id) + assert session is not None + assert session['status'] == 'completed' + assert len(session['processed_batches']) == 3 # All 3 batches successful (after retry) + assert len(session['failed_batches']) == 0 # No failed batches recorded (retry succeeded) + assert len(session['all_findings']) == 2 # Findings from batch 1 and 2 + + @patch('awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger') + def test_logging_during_batch_processing(self, mock_logger, sample_targets, sample_input_obj): + """Test that appropriate logging occurs during batch processing.""" + mock_logger.reset_mock() + + session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Logging Test', + ) + + # Verify session creation was logged (may also have cleanup logs) + session_creation_calls = [ + call + for call in mock_logger.info.call_args_list + if f'Created batch session {session_id}' in str(call) + ] + assert len(session_creation_calls) == 1 + assert f'Created batch session {session_id} with 1 batches' in str( + session_creation_calls[0] + ) + + +class TestCleanupExcessSessions: + """Test _cleanup_excess_sessions function.""" + + def setup_method(self): + """Clear batch sessions before each test.""" + global _batch_sessions + _batch_sessions.clear() + + def teardown_method(self): + """Clear batch sessions after each test.""" + global _batch_sessions + _batch_sessions.clear() + + def test_cleanup_excess_sessions_no_cleanup_needed(self): + """Test cleanup when sessions are within limit.""" + global _batch_sessions + + # Create sessions within limit (MAX_BATCH_SESSIONS = 1) + session_data = { + 'session_id': 'test-session-1', + 'created_at': '2024-01-01T10:00:00+00:00', + 'status': 'created', + } + _batch_sessions['test-session-1'] = session_data + + # Should not remove any sessions + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify no logging occurred (no cleanup needed) + mock_logger.info.assert_not_called() + + # Verify session still exists + assert len(_batch_sessions) == 1 + assert 'test-session-1' in _batch_sessions + + def test_cleanup_excess_sessions_at_limit(self): + """Test cleanup when sessions are exactly at limit.""" + global _batch_sessions + + # Create exactly MAX_BATCH_SESSIONS (1) sessions + session_data = { + 'session_id': 'test-session-1', + 'created_at': '2024-01-01T10:00:00+00:00', + 'status': 'created', + } + _batch_sessions['test-session-1'] = session_data + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify no logging occurred (at limit, no cleanup needed) + mock_logger.info.assert_not_called() + + # Verify session still exists + assert len(_batch_sessions) == 1 + assert 'test-session-1' in _batch_sessions + + def test_cleanup_excess_sessions_cleanup_needed(self): + """Test cleanup when sessions exceed limit.""" + global _batch_sessions + + # Create sessions exceeding limit (MAX_BATCH_SESSIONS = 1) + # Older session (should be removed) + older_session = { + 'session_id': 'old-session', + 'created_at': '2024-01-01T09:00:00+00:00', + 'status': 'created', + } + # Newer session (should be kept) + newer_session = { + 'session_id': 'new-session', + 'created_at': '2024-01-01T11:00:00+00:00', + 'status': 'created', + } + + _batch_sessions['old-session'] = older_session + _batch_sessions['new-session'] = newer_session + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify cleanup was logged + mock_logger.info.assert_called_once_with('Cleaned up 1 excess batch sessions') + + # Verify only newer session remains + assert len(_batch_sessions) == 1 + assert 'new-session' in _batch_sessions + assert 'old-session' not in _batch_sessions + + def test_cleanup_excess_sessions_multiple_excess(self): + """Test cleanup when multiple sessions exceed limit.""" + global _batch_sessions + + # Create multiple sessions exceeding limit (MAX_BATCH_SESSIONS = 1) + sessions = [ + ('oldest-session', '2024-01-01T08:00:00+00:00'), + ('old-session', '2024-01-01T09:00:00+00:00'), + ('newer-session', '2024-01-01T10:00:00+00:00'), + ('newest-session', '2024-01-01T11:00:00+00:00'), + ] + + for session_id, created_at in sessions: + _batch_sessions[session_id] = { + 'session_id': session_id, + 'created_at': created_at, + 'status': 'created', + } + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify cleanup was logged (4 sessions - 1 limit = 3 excess) + mock_logger.info.assert_called_once_with('Cleaned up 3 excess batch sessions') + + # Verify only newest session remains + assert len(_batch_sessions) == 1 + assert 'newest-session' in _batch_sessions + assert 'oldest-session' not in _batch_sessions + assert 'old-session' not in _batch_sessions + assert 'newer-session' not in _batch_sessions + + def test_cleanup_excess_sessions_missing_created_at(self): + """Test cleanup when some sessions are missing created_at field.""" + global _batch_sessions + + # Create sessions with and without created_at + session_with_time = { + 'session_id': 'session-with-time', + 'created_at': '2024-01-01T10:00:00+00:00', + 'status': 'created', + } + session_without_time = { + 'session_id': 'session-without-time', + 'status': 'created', + # Missing 'created_at' field + } + + _batch_sessions['session-with-time'] = session_with_time + _batch_sessions['session-without-time'] = session_without_time + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify cleanup was logged (2 sessions - 1 limit = 1 excess) + mock_logger.info.assert_called_once_with('Cleaned up 1 excess batch sessions') + + # Verify one session remains (the one with created_at should be kept as it sorts later) + assert len(_batch_sessions) == 1 + # Session without created_at gets empty string and sorts first (gets removed) + assert 'session-with-time' in _batch_sessions + assert 'session-without-time' not in _batch_sessions + + def test_cleanup_excess_sessions_empty_created_at(self): + """Test cleanup when sessions have empty created_at field.""" + global _batch_sessions + + # Create sessions with empty and valid created_at + session_with_empty_time = { + 'session_id': 'session-empty-time', + 'created_at': '', + 'status': 'created', + } + session_with_time = { + 'session_id': 'session-with-time', + 'created_at': '2024-01-01T10:00:00+00:00', + 'status': 'created', + } + + _batch_sessions['session-empty-time'] = session_with_empty_time + _batch_sessions['session-with-time'] = session_with_time + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify cleanup was logged + mock_logger.info.assert_called_once_with('Cleaned up 1 excess batch sessions') + + # Verify session with valid timestamp remains (empty string sorts first) + assert len(_batch_sessions) == 1 + assert 'session-with-time' in _batch_sessions + assert 'session-empty-time' not in _batch_sessions + + def test_cleanup_excess_sessions_same_timestamps(self): + """Test cleanup when sessions have identical timestamps.""" + global _batch_sessions + + # Create sessions with identical timestamps + same_time = '2024-01-01T10:00:00+00:00' + session1 = { + 'session_id': 'session-1', + 'created_at': same_time, + 'status': 'created', + } + session2 = { + 'session_id': 'session-2', + 'created_at': same_time, + 'status': 'created', + } + + _batch_sessions['session-1'] = session1 + _batch_sessions['session-2'] = session2 + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify cleanup was logged + mock_logger.info.assert_called_once_with('Cleaned up 1 excess batch sessions') + + # Verify one session remains (deterministic based on dict iteration order) + assert len(_batch_sessions) == 1 + # One of the sessions should remain + remaining_sessions = list(_batch_sessions.keys()) + assert len(remaining_sessions) == 1 + assert remaining_sessions[0] in ['session-1', 'session-2'] + + def test_cleanup_excess_sessions_empty_sessions_dict(self): + """Test cleanup when sessions dictionary is empty.""" + global _batch_sessions + + # Ensure sessions dict is empty + _batch_sessions.clear() + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify no logging occurred (no sessions to clean up) + mock_logger.info.assert_not_called() + + # Verify sessions dict remains empty + assert len(_batch_sessions) == 0 + + def test_cleanup_excess_sessions_preserves_session_data(self): + """Test that cleanup preserves all data in remaining sessions.""" + global _batch_sessions + + # Create sessions with rich data + older_session = { + 'session_id': 'old-session', + 'created_at': '2024-01-01T09:00:00+00:00', + 'last_activity': '2024-01-01T09:30:00+00:00', + 'targets': [{'service': 'old-service'}], + 'batches': [['batch1'], ['batch2']], + 'current_batch_index': 1, + 'status': 'in_progress', + 'processed_batches': [{'batch': 'data'}], + 'all_findings': [{'finding': 'old'}], + } + newer_session = { + 'session_id': 'new-session', + 'created_at': '2024-01-01T11:00:00+00:00', + 'last_activity': '2024-01-01T11:30:00+00:00', + 'targets': [{'service': 'new-service'}], + 'batches': [['batch3'], ['batch4']], + 'current_batch_index': 0, + 'status': 'created', + 'processed_batches': [], + 'all_findings': [{'finding': 'new'}], + } + + _batch_sessions['old-session'] = older_session + _batch_sessions['new-session'] = newer_session + + _cleanup_excess_sessions() + + # Verify newer session remains with all data intact + assert len(_batch_sessions) == 1 + assert 'new-session' in _batch_sessions + + remaining_session = _batch_sessions['new-session'] + assert remaining_session['session_id'] == 'new-session' + assert remaining_session['created_at'] == '2024-01-01T11:00:00+00:00' + assert remaining_session['last_activity'] == '2024-01-01T11:30:00+00:00' + assert remaining_session['targets'] == [{'service': 'new-service'}] + assert remaining_session['batches'] == [['batch3'], ['batch4']] + assert remaining_session['current_batch_index'] == 0 + assert remaining_session['status'] == 'created' + assert remaining_session['processed_batches'] == [] + assert remaining_session['all_findings'] == [{'finding': 'new'}] + + @patch('awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.MAX_BATCH_SESSIONS', 3) + def test_cleanup_excess_sessions_different_limit(self): + """Test cleanup with different MAX_BATCH_SESSIONS limit.""" + global _batch_sessions + + # Create 5 sessions when limit is 3 + sessions = [ + ('session-1', '2024-01-01T08:00:00+00:00'), + ('session-2', '2024-01-01T09:00:00+00:00'), + ('session-3', '2024-01-01T10:00:00+00:00'), + ('session-4', '2024-01-01T11:00:00+00:00'), + ('session-5', '2024-01-01T12:00:00+00:00'), + ] + + for session_id, created_at in sessions: + _batch_sessions[session_id] = { + 'session_id': session_id, + 'created_at': created_at, + 'status': 'created', + } + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + _cleanup_excess_sessions() + + # Verify cleanup was logged (5 sessions - 3 limit = 2 excess) + mock_logger.info.assert_called_once_with('Cleaned up 2 excess batch sessions') + + # Verify only 3 newest sessions remain + assert len(_batch_sessions) == 3 + assert 'session-3' in _batch_sessions + assert 'session-4' in _batch_sessions + assert 'session-5' in _batch_sessions + assert 'session-1' not in _batch_sessions + assert 'session-2' not in _batch_sessions + + def test_cleanup_excess_sessions_integration_with_create_batch_session( + self, sample_targets, sample_input_obj + ): + """Test that cleanup is called during session creation and works correctly.""" + global _batch_sessions + + # Clear sessions and create one session manually (to exceed limit when next is created) + _batch_sessions.clear() + existing_session = { + 'session_id': 'existing-session', + 'created_at': '2024-01-01T09:00:00+00:00', + 'status': 'created', + } + _batch_sessions['existing-session'] = existing_session + + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_processing_utils.logger' + ) as mock_logger: + # Create new session (should trigger cleanup) + new_session_id = create_batch_session( + targets=sample_targets[:3], + input_obj=sample_input_obj, + region='us-east-1', + banner='Integration Test', + ) + + # Verify cleanup was logged (2 sessions - 1 limit = 1 excess) + cleanup_calls = [ + call for call in mock_logger.info.call_args_list if 'Cleaned up' in str(call) + ] + assert len(cleanup_calls) == 1 + assert 'Cleaned up 1 excess batch sessions' in str(cleanup_calls[0]) + + # Verify only the new session remains (it has a later timestamp) + assert len(_batch_sessions) == 1 + assert new_session_id in _batch_sessions + assert 'existing-session' not in _batch_sessions diff --git a/src/cloudwatch-appsignals-mcp-server/tests/test_batch_tools.py b/src/cloudwatch-appsignals-mcp-server/tests/test_batch_tools.py new file mode 100644 index 0000000000..c4181a82f8 --- /dev/null +++ b/src/cloudwatch-appsignals-mcp-server/tests/test_batch_tools.py @@ -0,0 +1,468 @@ +"""Tests for batch tools.""" + +import pytest +from awslabs.cloudwatch_appsignals_mcp_server.batch_tools import ( + continue_audit_batch, +) +from unittest.mock import patch + + +@pytest.fixture +def mock_batch_processing_utils(): + """Mock batch processing utilities.""" + with ( + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_tools.process_next_batch' + ) as mock_process, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_tools.get_batch_session' + ) as mock_get_session, + patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_tools.format_batch_result' + ) as mock_format, + ): + yield { + 'process_next_batch': mock_process, + 'get_batch_session': mock_get_session, + 'format_batch_result': mock_format, + } + + +@pytest.fixture +def mock_appsignals_client(): + """Mock Application Signals client.""" + with patch( + 'awslabs.cloudwatch_appsignals_mcp_server.batch_tools.appsignals_client' + ) as mock_client: + yield mock_client + + +class TestContinueAuditBatch: + """Test continue_audit_batch function.""" + + @pytest.mark.asyncio + async def test_continue_audit_batch_success( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test successful batch continuation.""" + session_id = 'test-session-123' + + # Mock successful batch processing + batch_result = { + 'batch_index': 2, + 'total_batches': 3, + 'targets_in_batch': 5, + 'findings_count': 1, + 'findings': [{'FindingId': 'finding-1', 'Severity': 'WARNING'}], + 'status': 'success', + } + + session = { + 'session_id': session_id, + 'status': 'in_progress', + 'current_batch_index': 2, + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + mock_batch_processing_utils['get_batch_session'].return_value = session + mock_batch_processing_utils[ + 'format_batch_result' + ].return_value = ( + f"⚠️ Batch 2/3: 1 findings | Continue: continue_audit_batch('{session_id}')" + ) + + result = await continue_audit_batch(session_id) + + # Verify the mocks were called correctly + mock_batch_processing_utils['process_next_batch'].assert_called_once_with( + session_id, mock_appsignals_client + ) + mock_batch_processing_utils['get_batch_session'].assert_called_once_with(session_id) + mock_batch_processing_utils['format_batch_result'].assert_called_once_with( + batch_result, session + ) + + # Verify the result + assert f"Continue: continue_audit_batch('{session_id}')" in result + assert '⚠️ Batch 2/3: 1 findings' in result + + @pytest.mark.asyncio + async def test_continue_audit_batch_error_in_processing( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with processing error.""" + session_id = 'test-session-123' + + # Mock error in batch processing + batch_result = { + 'error': 'No more batches to process', + 'status': 'completed', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + + result = await continue_audit_batch(session_id) + + assert result == 'Error: No more batches to process' + mock_batch_processing_utils['process_next_batch'].assert_called_once_with( + session_id, mock_appsignals_client + ) + + @pytest.mark.asyncio + async def test_continue_audit_batch_session_not_found( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation when session is not found.""" + session_id = 'nonexistent-session' + + # Mock successful processing but no session found + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'status': 'success', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + mock_batch_processing_utils['get_batch_session'].return_value = None + + result = await continue_audit_batch(session_id) + + assert result == 'Error: Session not found or expired' + + @pytest.mark.asyncio + async def test_continue_audit_batch_healthy_services( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with healthy services.""" + session_id = 'test-session-123' + + # Mock healthy batch result + batch_result = { + 'batch_index': 1, + 'total_batches': 2, + 'targets_in_batch': 5, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + } + + session = { + 'session_id': session_id, + 'status': 'in_progress', + 'current_batch_index': 1, + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + mock_batch_processing_utils['get_batch_session'].return_value = session + mock_batch_processing_utils[ + 'format_batch_result' + ].return_value = ( + f"✅ Batch 1/2: 5 services healthy | Continue: continue_audit_batch('{session_id}')" + ) + + result = await continue_audit_batch(session_id) + + assert '✅ Batch 1/2: 5 services healthy' in result + assert f"Continue: continue_audit_batch('{session_id}')" in result + + @pytest.mark.asyncio + async def test_continue_audit_batch_final_batch( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation for final batch.""" + session_id = 'test-session-123' + + # Mock final batch result + batch_result = { + 'batch_index': 3, + 'total_batches': 3, + 'targets_in_batch': 2, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + } + + session = { + 'session_id': session_id, + 'status': 'completed', + 'current_batch_index': 3, + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + mock_batch_processing_utils['get_batch_session'].return_value = session + mock_batch_processing_utils[ + 'format_batch_result' + ].return_value = '✅ Batch 3/3: 2 services healthy' + + result = await continue_audit_batch(session_id) + + assert '✅ Batch 3/3: 2 services healthy' in result + # Should not contain continuation instruction for final batch + assert 'Continue:' not in result + + @pytest.mark.asyncio + async def test_continue_audit_batch_with_findings_json( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with findings that include JSON output.""" + session_id = 'test-session-123' + + # Mock batch result with findings + findings = [ + { + 'FindingId': 'finding-1', + 'Severity': 'CRITICAL', + 'Title': 'High error rate detected', + 'Description': 'Service experiencing elevated error rates', + 'ServiceName': 'payment-service', + }, + { + 'FindingId': 'finding-2', + 'Severity': 'WARNING', + 'Title': 'Elevated latency', + 'Description': 'Response times are higher than normal', + 'ServiceName': 'user-service', + }, + ] + + batch_result = { + 'batch_index': 1, + 'total_batches': 1, + 'targets_in_batch': 3, + 'findings_count': 2, + 'findings': findings, + 'status': 'success', + } + + session = { + 'session_id': session_id, + 'status': 'completed', + 'current_batch_index': 1, + } + + # Mock format_batch_result to return JSON findings + formatted_findings = f'⚠️ Batch 1/1: 2 findings\n```\n{findings}\n```' + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + mock_batch_processing_utils['get_batch_session'].return_value = session + mock_batch_processing_utils['format_batch_result'].return_value = formatted_findings + + result = await continue_audit_batch(session_id) + + assert '⚠️ Batch 1/1: 2 findings' in result + assert '```' in result # JSON formatting markers + assert 'CRITICAL' in result + assert 'payment-service' in result + + @pytest.mark.asyncio + async def test_continue_audit_batch_exception_handling( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with unexpected exception.""" + session_id = 'test-session-123' + + # Mock exception during processing + mock_batch_processing_utils['process_next_batch'].side_effect = Exception( + 'Unexpected error' + ) + + result = await continue_audit_batch(session_id) + + assert result == 'Error: Unexpected error' + + @pytest.mark.asyncio + async def test_continue_audit_batch_invalid_session_id( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with invalid session ID format.""" + session_id = 'invalid-session-format' + + # Mock error for invalid session + batch_result = { + 'error': 'Session not found or expired', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + + result = await continue_audit_batch(session_id) + + assert result == 'Error: Session not found or expired' + + @pytest.mark.asyncio + async def test_continue_audit_batch_api_failure( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation when API call fails.""" + session_id = 'test-session-123' + + # Mock API failure - when batch_result has an error, the function returns early + batch_result = { + 'batch_index': 1, + 'total_batches': 2, + 'targets_in_batch': 5, + 'error': 'AWS API throttling error', + 'status': 'failed', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + + result = await continue_audit_batch(session_id) + + # When batch_result has an error, the function returns early with just the error message + assert result == 'Error: AWS API throttling error' + + +class TestIntegration: + """Integration tests for batch tools workflow.""" + + @pytest.mark.asyncio + async def test_full_batch_continuation_workflow( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test complete batch continuation workflow.""" + session_id = 'integration-test-session' + + # Simulate processing multiple batches + batch_results = [ + { + 'batch_index': 1, + 'total_batches': 3, + 'targets_in_batch': 5, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + }, + { + 'batch_index': 2, + 'total_batches': 3, + 'targets_in_batch': 5, + 'findings_count': 2, + 'findings': [ + {'FindingId': 'finding-1', 'Severity': 'WARNING'}, + {'FindingId': 'finding-2', 'Severity': 'CRITICAL'}, + ], + 'status': 'success', + }, + { + 'batch_index': 3, + 'total_batches': 3, + 'targets_in_batch': 3, + 'findings_count': 0, + 'findings': [], + 'status': 'success', + }, + ] + + sessions = [ + {'session_id': session_id, 'status': 'in_progress', 'current_batch_index': 1}, + {'session_id': session_id, 'status': 'in_progress', 'current_batch_index': 2}, + {'session_id': session_id, 'status': 'completed', 'current_batch_index': 3}, + ] + + formatted_results = [ + f"✅ Batch 1/3: 5 services healthy | Continue: continue_audit_batch('{session_id}')", + f"⚠️ Batch 2/3: 2 findings | Continue: continue_audit_batch('{session_id}')", + '✅ Batch 3/3: 3 services healthy', + ] + + # Configure mocks for sequential calls + mock_batch_processing_utils['process_next_batch'].side_effect = batch_results + mock_batch_processing_utils['get_batch_session'].side_effect = sessions + mock_batch_processing_utils['format_batch_result'].side_effect = formatted_results + + # Process first batch (healthy) + result1 = await continue_audit_batch(session_id) + assert '✅ Batch 1/3: 5 services healthy' in result1 + assert f"Continue: continue_audit_batch('{session_id}')" in result1 + + # Process second batch (with findings) + result2 = await continue_audit_batch(session_id) + assert '⚠️ Batch 2/3: 2 findings' in result2 + assert f"Continue: continue_audit_batch('{session_id}')" in result2 + + # Process final batch (healthy, no continuation) + result3 = await continue_audit_batch(session_id) + assert '✅ Batch 3/3: 3 services healthy' in result3 + assert 'Continue:' not in result3 + + # Verify all calls were made + assert mock_batch_processing_utils['process_next_batch'].call_count == 3 + assert mock_batch_processing_utils['get_batch_session'].call_count == 3 + assert mock_batch_processing_utils['format_batch_result'].call_count == 3 + + @pytest.mark.asyncio + async def test_error_handling_in_workflow( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test error handling throughout the workflow.""" + session_id = 'error-test-session' + + # Test continue_audit_batch error handling + mock_batch_processing_utils['process_next_batch'].side_effect = Exception( + 'Processing error' + ) + + continue_result = await continue_audit_batch(session_id) + assert continue_result == 'Error: Processing error' + + +class TestParameterValidation: + """Test parameter validation and edge cases.""" + + @pytest.mark.asyncio + async def test_continue_audit_batch_empty_session_id( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with empty session ID.""" + session_id = '' + + # Mock error for empty session ID + batch_result = { + 'error': 'Session not found or expired', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + + result = await continue_audit_batch(session_id) + + assert result == 'Error: Session not found or expired' + mock_batch_processing_utils['process_next_batch'].assert_called_once_with( + session_id, mock_appsignals_client + ) + + @pytest.mark.asyncio + async def test_continue_audit_batch_none_session_id( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with None session ID.""" + # Pydantic validation will convert None to string "None" or handle it gracefully + # Mock error for None session ID + batch_result = { + 'error': 'Session not found or expired', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + + # This should not raise an exception but return an error message + # Fix type error: pass empty string instead of None + result = await continue_audit_batch('') + + assert result == 'Error: Session not found or expired' + + @pytest.mark.asyncio + async def test_continue_audit_batch_very_long_session_id( + self, mock_batch_processing_utils, mock_appsignals_client + ): + """Test batch continuation with very long session ID.""" + session_id = 'a' * 1000 # Very long session ID + + # Mock error for invalid session + batch_result = { + 'error': 'Session not found or expired', + } + + mock_batch_processing_utils['process_next_batch'].return_value = batch_result + + result = await continue_audit_batch(session_id) + + assert result == 'Error: Session not found or expired'