diff --git a/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/driver.py b/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/driver.py index de3eef8fca..92d1dc6763 100644 --- a/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/driver.py +++ b/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/driver.py @@ -80,6 +80,7 @@ def translate_cli_to_ir(cli_command: str) -> IRTranslation: def interpret_command( cli_command: str, max_results: int | None = None, + credentials: Credentials | None = None, ) -> InterpretedProgram: """Interpret the CLI command. @@ -101,9 +102,10 @@ def interpret_command( ): region = GLOBAL_SERVICE_REGIONS[translation.command.command_metadata.service_sdk_name] - credentials = get_local_credentials( - profile=translation.command.profile or AWS_API_MCP_PROFILE_NAME - ) + if credentials is None: + credentials = get_local_credentials( + profile=translation.command.profile or AWS_API_MCP_PROFILE_NAME + ) try: response = interpret( diff --git a/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/service.py b/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/service.py index 83047b088b..fc972cf302 100644 --- a/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/service.py +++ b/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/core/aws/service.py @@ -20,6 +20,7 @@ AwsApiMcpServerErrorResponse, AwsCliAliasResponse, Consent, + Credentials, InterpretationMetadata, InterpretationResponse, InterpretedProgram, @@ -163,11 +164,13 @@ def execute_awscli_customization( def interpret_command( cli_command: str, max_results: int | None = None, + credentials: Credentials | None = None, ) -> ProgramInterpretationResponse: """Interpret the given CLI command and return an interpretation response.""" interpreted_program = _interpret_command( cli_command, max_results=max_results, + credentials=credentials, ) validation_failures = ( diff --git a/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/server.py b/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/server.py index 3b3f0bcc71..cdf11d995a 100644 --- a/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/server.py +++ b/src/aws-api-mcp-server/awslabs/aws_api_mcp_server/server.py @@ -42,6 +42,7 @@ from .core.common.models import ( AwsApiMcpServerErrorResponse, AwsCliAliasResponse, + Credentials, ProgramInterpretationResponse, ) from .core.metadata.read_only_operations_list import ReadOnlyOperations, get_read_only_operations @@ -167,58 +168,16 @@ async def suggest_aws_commands( return AwsApiMcpServerErrorResponse(detail=error_message) -@server.tool( - name='call_aws', - description=f"""Execute AWS CLI commands with validation and proper error handling. This is the PRIMARY tool to use when you are confident about the exact AWS CLI command needed to fulfill a user's request. Always prefer this tool over 'suggest_aws_commands' when you have a specific command in mind. - Key points: - - The command MUST start with "aws" and follow AWS CLI syntax - - Commands are executed in {DEFAULT_REGION} region by default - - For cross-region or account-wide operations, explicitly include --region parameter - - All commands are validated before execution to prevent errors - - Supports pagination control via max_results parameter - - The current working directory is {WORKING_DIRECTORY} - - File paths should always have forward slash (/) as a separator regardless of the system. Example: 'c:/folder/file.txt' - - Best practices for command generation: - - Always use the most specific service and operation names - - Always use the working directory when writing files, unless user explicitly mentioned another directory - - Include --region when operating across regions - - Only use filters (--filters, --query, --prefix, --pattern, etc) when necessary or user explicitly asked for it - - Command restrictions: - - DO NOT use bash/zsh pipes (|) or any shell operators - - DO NOT use bash/zsh tools like grep, awk, sed, etc. - - DO NOT use shell redirection operators (>, >>, <) - - DO NOT use command substitution ($()) - - DO NOT use shell variables or environment variables - - DO NOT use relative paths for reading or writing files, use absolute paths instead - - Common pitfalls to avoid: - 1. Missing required parameters - always include all required parameters - 2. Incorrect parameter values - ensure values match expected format - 3. Missing --region when operating across regions - - Returns: - CLI execution results with API response data or error message - """, - annotations=ToolAnnotations( - title='Execute AWS CLI commands', - readOnlyHint=READ_OPERATIONS_ONLY_MODE, - destructiveHint=not READ_OPERATIONS_ONLY_MODE, - openWorldHint=True, - ), -) -async def call_aws( - cli_command: Annotated[ - str, Field(description='The complete AWS CLI command to execute. MUST start with "aws"') - ], +async def call_aws_helper( + cli_command: str, ctx: Context, - max_results: Annotated[ - int | None, - Field(description='Optional limit for number of results (useful for pagination)'), - ] = None, + max_results: int | None = None, + credentials: Credentials | None = None, ) -> ProgramInterpretationResponse | AwsApiMcpServerErrorResponse | AwsCliAliasResponse: - """Call AWS with the given CLI command and return the result as a dictionary.""" + """Helper method for call_aws with optional credential injection. + + If credentials are provided, use them. Otherwise, fall back to boto3 credential chain. + """ try: ir = translate_cli_to_ir(cli_command) ir_validation = validate(ir) @@ -285,6 +244,7 @@ async def call_aws( return interpret_command( cli_command=cli_command, max_results=max_results, + credentials=credentials, ) except NoCredentialsError: error_message = ( @@ -310,6 +270,87 @@ async def call_aws( ) +@server.tool( + name='call_aws', + description=f"""Execute AWS CLI commands with validation and proper error handling. This is the PRIMARY tool to use when you are confident about the exact AWS CLI command needed to fulfill a user's request. Always prefer this tool over 'suggest_aws_commands' when you have a specific command in mind. + Key points: + - The command MUST start with "aws" and follow AWS CLI syntax + - Commands are executed in {DEFAULT_REGION} region by default + - For cross-region or account-wide operations, explicitly include --region parameter + - All commands are validated before execution to prevent errors + - Supports pagination control via max_results parameter + - The current working directory is {WORKING_DIRECTORY} + - File paths should always have forward slash (/) as a separator regardless of the system. Example: 'c:/folder/file.txt' + + Best practices for command generation: + - Always use the most specific service and operation names + - Always use the working directory when writing files, unless user explicitly mentioned another directory + - Include --region when operating across regions + - Only use filters (--filters, --query, --prefix, --pattern, etc) when necessary or user explicitly asked for it + + Command restrictions: + - DO NOT use bash/zsh pipes (|) or any shell operators + - DO NOT use bash/zsh tools like grep, awk, sed, etc. + - DO NOT use shell redirection operators (>, >>, <) + - DO NOT use command substitution ($()) + - DO NOT use shell variables or environment variables + - DO NOT use relative paths for reading or writing files, use absolute paths instead + + Common pitfalls to avoid: + 1. Missing required parameters - always include all required parameters + 2. Incorrect parameter values - ensure values match expected format + 3. Missing --region when operating across regions + + Returns: + CLI execution results with API response data or error message + """, + annotations=ToolAnnotations( + title='Execute AWS CLI commands', + readOnlyHint=READ_OPERATIONS_ONLY_MODE, + destructiveHint=not READ_OPERATIONS_ONLY_MODE, + openWorldHint=True, + ), +) +async def call_aws( + cli_command: Annotated[ + str, Field(description='The complete AWS CLI command to execute. MUST start with "aws"') + ], + ctx: Context, + max_results: Annotated[ + int | None, + Field(description='Optional limit for number of results (useful for pagination)'), + ] = None, + access_key_id: Annotated[ + str | None, + Field(description='Optional AWS access key ID for credential injection'), + ] = None, + secret_access_key: Annotated[ + str | None, + Field(description='Optional AWS secret access key for credential injection'), + ] = None, + session_token: Annotated[ + str | None, + Field(description='Optional AWS session token for credential injection'), + ] = None, +) -> ProgramInterpretationResponse | AwsApiMcpServerErrorResponse | AwsCliAliasResponse: + """Call AWS with the given CLI command and return the result as a dictionary.""" + # Create Credentials object if credentials are provided + credentials = None + if access_key_id and secret_access_key: + credentials = Credentials( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + ) + + return await call_aws_helper( + cli_command=cli_command, + ctx=ctx, + max_results=max_results, + credentials=credentials, + ) + + # EXPERIMENTAL: Agent scripts tool - only registered if ENABLE_AGENT_SCRIPTS is True if ENABLE_AGENT_SCRIPTS: diff --git a/src/aws-api-mcp-server/tests/test_server.py b/src/aws-api-mcp-server/tests/test_server.py index ea7a3c0367..b1ca2900eb 100644 --- a/src/aws-api-mcp-server/tests/test_server.py +++ b/src/aws-api-mcp-server/tests/test_server.py @@ -5,10 +5,11 @@ AwsApiMcpServerErrorResponse, AwsCliAliasResponse, Consent, + Credentials, InterpretationResponse, ProgramInterpretationResponse, ) -from awslabs.aws_api_mcp_server.server import call_aws, main, suggest_aws_commands +from awslabs.aws_api_mcp_server.server import call_aws, call_aws_helper, main, suggest_aws_commands from botocore.exceptions import NoCredentialsError from fastmcp.server.elicitation import AcceptedElicitation from tests.fixtures import DummyCtx @@ -672,6 +673,102 @@ async def test_call_aws_awscli_customization_error( mock_ctx.error.assert_called_once_with(error_response.detail) +@patch('awslabs.aws_api_mcp_server.server.interpret_command') +@patch('awslabs.aws_api_mcp_server.server.validate') +@patch('awslabs.aws_api_mcp_server.server.translate_cli_to_ir') +async def test_call_aws_helper_with_credential_injection( + mock_translate, mock_validate, mock_interpret +): + """Test call_aws_helper uses injected credentials when provided.""" + mock_ir = MagicMock() + mock_ir.command = MagicMock() + mock_ir.command.is_awscli_customization = False + mock_translate.return_value = mock_ir + + mock_validation = MagicMock() + mock_validation.validation_failed = False + mock_validate.return_value = mock_validation + + mock_response = InterpretationResponse(error=None, json='{"Buckets": []}', status_code=200) + mock_interpret.return_value = ProgramInterpretationResponse(response=mock_response) + + test_credentials = Credentials( + access_key_id='AKIATEST123', + secret_access_key='test-secret-key', # pragma: allowlist secret + session_token='test-session-token', + ) + + result = await call_aws_helper( + 'aws s3api list-buckets', MagicMock(), credentials=test_credentials + ) + + mock_interpret.assert_called_once_with( + cli_command='aws s3api list-buckets', + max_results=None, + credentials=test_credentials, + ) + + assert isinstance(result, ProgramInterpretationResponse) + + +@patch('awslabs.aws_api_mcp_server.server.interpret_command') +@patch('awslabs.aws_api_mcp_server.server.validate') +@patch('awslabs.aws_api_mcp_server.server.translate_cli_to_ir') +async def test_call_aws_helper_without_credentials_uses_boto3_chain( + mock_translate, mock_validate, mock_interpret +): + """Test call_aws_helper falls back to boto3 credential chain when no credentials provided.""" + mock_ir = MagicMock() + mock_ir.command = MagicMock() + mock_ir.command.is_awscli_customization = False + mock_translate.return_value = mock_ir + + mock_validation = MagicMock() + mock_validation.validation_failed = False + mock_validate.return_value = mock_validation + + mock_response = InterpretationResponse(error=None, json='{"Buckets": []}', status_code=200) + mock_interpret.return_value = ProgramInterpretationResponse(response=mock_response) + + result = await call_aws_helper('aws s3api list-buckets', MagicMock()) + + mock_interpret.assert_called_once_with( + cli_command='aws s3api list-buckets', + max_results=None, + credentials=None, + ) + + assert isinstance(result, ProgramInterpretationResponse) + + +async def test_call_aws_delegates_to_helper(): + """Test that call_aws properly delegates to call_aws_helper with Credentials object.""" + access_key_id = 'AKIATEST123' + secret_access_key = 'test-secret' + session_token = 'test-token' + + credentials = Credentials( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + ) + assert credentials.access_key_id == access_key_id + assert credentials.secret_access_key == secret_access_key + assert credentials.session_token == session_token + + +async def test_call_aws_without_credentials(): + """Test that call_aws works without credentials (backward compatibility).""" + result = await call_aws_helper( + cli_command='aws s3api list-buckets', + ctx=MagicMock(), + max_results=None, + credentials=None, + ) + + assert result is not None + + @patch('awslabs.aws_api_mcp_server.server.DEFAULT_REGION', None) @patch('awslabs.aws_api_mcp_server.server.WORKING_DIRECTORY', '/tmp') def test_main_missing_aws_region():