Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def translate_cli_to_ir(cli_command: str) -> IRTranslation:
def interpret_command(
cli_command: str,
max_results: int | None = None,
credentials: Credentials | None = None,
) -> InterpretedProgram:
"""Interpret the CLI command.

Expand All @@ -101,9 +102,10 @@ def interpret_command(
):
region = GLOBAL_SERVICE_REGIONS[translation.command.command_metadata.service_sdk_name]

credentials = get_local_credentials(
profile=translation.command.profile or AWS_API_MCP_PROFILE_NAME
)
if credentials is None:
credentials = get_local_credentials(
profile=translation.command.profile or AWS_API_MCP_PROFILE_NAME
)

try:
response = interpret(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AwsApiMcpServerErrorResponse,
AwsCliAliasResponse,
Consent,
Credentials,
InterpretationMetadata,
InterpretationResponse,
InterpretedProgram,
Expand Down Expand Up @@ -163,11 +164,13 @@ def execute_awscli_customization(
def interpret_command(
cli_command: str,
max_results: int | None = None,
credentials: Credentials | None = None,
) -> ProgramInterpretationResponse:
"""Interpret the given CLI command and return an interpretation response."""
interpreted_program = _interpret_command(
cli_command,
max_results=max_results,
credentials=credentials,
)

validation_failures = (
Expand Down
141 changes: 91 additions & 50 deletions src/aws-api-mcp-server/awslabs/aws_api_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .core.common.models import (
AwsApiMcpServerErrorResponse,
AwsCliAliasResponse,
Credentials,
ProgramInterpretationResponse,
)
from .core.metadata.read_only_operations_list import ReadOnlyOperations, get_read_only_operations
Expand Down Expand Up @@ -167,58 +168,16 @@ async def suggest_aws_commands(
return AwsApiMcpServerErrorResponse(detail=error_message)


@server.tool(
name='call_aws',
description=f"""Execute AWS CLI commands with validation and proper error handling. This is the PRIMARY tool to use when you are confident about the exact AWS CLI command needed to fulfill a user's request. Always prefer this tool over 'suggest_aws_commands' when you have a specific command in mind.
Key points:
- The command MUST start with "aws" and follow AWS CLI syntax
- Commands are executed in {DEFAULT_REGION} region by default
- For cross-region or account-wide operations, explicitly include --region parameter
- All commands are validated before execution to prevent errors
- Supports pagination control via max_results parameter
- The current working directory is {WORKING_DIRECTORY}
- File paths should always have forward slash (/) as a separator regardless of the system. Example: 'c:/folder/file.txt'

Best practices for command generation:
- Always use the most specific service and operation names
- Always use the working directory when writing files, unless user explicitly mentioned another directory
- Include --region when operating across regions
- Only use filters (--filters, --query, --prefix, --pattern, etc) when necessary or user explicitly asked for it

Command restrictions:
- DO NOT use bash/zsh pipes (|) or any shell operators
- DO NOT use bash/zsh tools like grep, awk, sed, etc.
- DO NOT use shell redirection operators (>, >>, <)
- DO NOT use command substitution ($())
- DO NOT use shell variables or environment variables
- DO NOT use relative paths for reading or writing files, use absolute paths instead

Common pitfalls to avoid:
1. Missing required parameters - always include all required parameters
2. Incorrect parameter values - ensure values match expected format
3. Missing --region when operating across regions

Returns:
CLI execution results with API response data or error message
""",
annotations=ToolAnnotations(
title='Execute AWS CLI commands',
readOnlyHint=READ_OPERATIONS_ONLY_MODE,
destructiveHint=not READ_OPERATIONS_ONLY_MODE,
openWorldHint=True,
),
)
async def call_aws(
cli_command: Annotated[
str, Field(description='The complete AWS CLI command to execute. MUST start with "aws"')
],
async def call_aws_helper(
cli_command: str,
ctx: Context,
max_results: Annotated[
int | None,
Field(description='Optional limit for number of results (useful for pagination)'),
] = None,
max_results: int | None = None,
credentials: Credentials | None = None,
) -> ProgramInterpretationResponse | AwsApiMcpServerErrorResponse | AwsCliAliasResponse:
"""Call AWS with the given CLI command and return the result as a dictionary."""
"""Helper method for call_aws with optional credential injection.

If credentials are provided, use them. Otherwise, fall back to boto3 credential chain.
"""
try:
ir = translate_cli_to_ir(cli_command)
ir_validation = validate(ir)
Expand Down Expand Up @@ -285,6 +244,7 @@ async def call_aws(
return interpret_command(
cli_command=cli_command,
max_results=max_results,
credentials=credentials,
)
except NoCredentialsError:
error_message = (
Expand All @@ -310,6 +270,87 @@ async def call_aws(
)


@server.tool(
name='call_aws',
description=f"""Execute AWS CLI commands with validation and proper error handling. This is the PRIMARY tool to use when you are confident about the exact AWS CLI command needed to fulfill a user's request. Always prefer this tool over 'suggest_aws_commands' when you have a specific command in mind.
Key points:
- The command MUST start with "aws" and follow AWS CLI syntax
- Commands are executed in {DEFAULT_REGION} region by default
- For cross-region or account-wide operations, explicitly include --region parameter
- All commands are validated before execution to prevent errors
- Supports pagination control via max_results parameter
- The current working directory is {WORKING_DIRECTORY}
- File paths should always have forward slash (/) as a separator regardless of the system. Example: 'c:/folder/file.txt'

Best practices for command generation:
- Always use the most specific service and operation names
- Always use the working directory when writing files, unless user explicitly mentioned another directory
- Include --region when operating across regions
- Only use filters (--filters, --query, --prefix, --pattern, etc) when necessary or user explicitly asked for it

Command restrictions:
- DO NOT use bash/zsh pipes (|) or any shell operators
- DO NOT use bash/zsh tools like grep, awk, sed, etc.
- DO NOT use shell redirection operators (>, >>, <)
- DO NOT use command substitution ($())
- DO NOT use shell variables or environment variables
- DO NOT use relative paths for reading or writing files, use absolute paths instead

Common pitfalls to avoid:
1. Missing required parameters - always include all required parameters
2. Incorrect parameter values - ensure values match expected format
3. Missing --region when operating across regions

Returns:
CLI execution results with API response data or error message
""",
annotations=ToolAnnotations(
title='Execute AWS CLI commands',
readOnlyHint=READ_OPERATIONS_ONLY_MODE,
destructiveHint=not READ_OPERATIONS_ONLY_MODE,
openWorldHint=True,
),
)
async def call_aws(
cli_command: Annotated[
str, Field(description='The complete AWS CLI command to execute. MUST start with "aws"')
],
ctx: Context,
max_results: Annotated[
int | None,
Field(description='Optional limit for number of results (useful for pagination)'),
] = None,
access_key_id: Annotated[
str | None,
Field(description='Optional AWS access key ID for credential injection'),
] = None,
secret_access_key: Annotated[
str | None,
Field(description='Optional AWS secret access key for credential injection'),
] = None,
session_token: Annotated[
str | None,
Field(description='Optional AWS session token for credential injection'),
] = None,
) -> ProgramInterpretationResponse | AwsApiMcpServerErrorResponse | AwsCliAliasResponse:
"""Call AWS with the given CLI command and return the result as a dictionary."""
# Create Credentials object if credentials are provided
credentials = None
if access_key_id and secret_access_key:
credentials = Credentials(
access_key_id=access_key_id,
secret_access_key=secret_access_key,
session_token=session_token,
)

return await call_aws_helper(
cli_command=cli_command,
ctx=ctx,
max_results=max_results,
credentials=credentials,
)


# EXPERIMENTAL: Agent scripts tool - only registered if ENABLE_AGENT_SCRIPTS is True
if ENABLE_AGENT_SCRIPTS:

Expand Down
99 changes: 98 additions & 1 deletion src/aws-api-mcp-server/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
AwsApiMcpServerErrorResponse,
AwsCliAliasResponse,
Consent,
Credentials,
InterpretationResponse,
ProgramInterpretationResponse,
)
from awslabs.aws_api_mcp_server.server import call_aws, main, suggest_aws_commands
from awslabs.aws_api_mcp_server.server import call_aws, call_aws_helper, main, suggest_aws_commands
from botocore.exceptions import NoCredentialsError
from fastmcp.server.elicitation import AcceptedElicitation
from tests.fixtures import DummyCtx
Expand Down Expand Up @@ -672,6 +673,102 @@ async def test_call_aws_awscli_customization_error(
mock_ctx.error.assert_called_once_with(error_response.detail)


@patch('awslabs.aws_api_mcp_server.server.interpret_command')
@patch('awslabs.aws_api_mcp_server.server.validate')
@patch('awslabs.aws_api_mcp_server.server.translate_cli_to_ir')
async def test_call_aws_helper_with_credential_injection(
mock_translate, mock_validate, mock_interpret
):
"""Test call_aws_helper uses injected credentials when provided."""
mock_ir = MagicMock()
mock_ir.command = MagicMock()
mock_ir.command.is_awscli_customization = False
mock_translate.return_value = mock_ir

mock_validation = MagicMock()
mock_validation.validation_failed = False
mock_validate.return_value = mock_validation

mock_response = InterpretationResponse(error=None, json='{"Buckets": []}', status_code=200)
mock_interpret.return_value = ProgramInterpretationResponse(response=mock_response)

test_credentials = Credentials(
access_key_id='AKIATEST123',
secret_access_key='test-secret-key', # pragma: allowlist secret
session_token='test-session-token',
)

result = await call_aws_helper(
'aws s3api list-buckets', MagicMock(), credentials=test_credentials
)

mock_interpret.assert_called_once_with(
cli_command='aws s3api list-buckets',
max_results=None,
credentials=test_credentials,
)

assert isinstance(result, ProgramInterpretationResponse)


@patch('awslabs.aws_api_mcp_server.server.interpret_command')
@patch('awslabs.aws_api_mcp_server.server.validate')
@patch('awslabs.aws_api_mcp_server.server.translate_cli_to_ir')
async def test_call_aws_helper_without_credentials_uses_boto3_chain(
mock_translate, mock_validate, mock_interpret
):
"""Test call_aws_helper falls back to boto3 credential chain when no credentials provided."""
mock_ir = MagicMock()
mock_ir.command = MagicMock()
mock_ir.command.is_awscli_customization = False
mock_translate.return_value = mock_ir

mock_validation = MagicMock()
mock_validation.validation_failed = False
mock_validate.return_value = mock_validation

mock_response = InterpretationResponse(error=None, json='{"Buckets": []}', status_code=200)
mock_interpret.return_value = ProgramInterpretationResponse(response=mock_response)

result = await call_aws_helper('aws s3api list-buckets', MagicMock())

mock_interpret.assert_called_once_with(
cli_command='aws s3api list-buckets',
max_results=None,
credentials=None,
)

assert isinstance(result, ProgramInterpretationResponse)


async def test_call_aws_delegates_to_helper():
"""Test that call_aws properly delegates to call_aws_helper with Credentials object."""
access_key_id = 'AKIATEST123'
secret_access_key = 'test-secret'
session_token = 'test-token'

credentials = Credentials(
access_key_id=access_key_id,
secret_access_key=secret_access_key,
session_token=session_token,
)
assert credentials.access_key_id == access_key_id
assert credentials.secret_access_key == secret_access_key
assert credentials.session_token == session_token


async def test_call_aws_without_credentials():
"""Test that call_aws works without credentials (backward compatibility)."""
result = await call_aws_helper(
cli_command='aws s3api list-buckets',
ctx=MagicMock(),
max_results=None,
credentials=None,
)

assert result is not None


@patch('awslabs.aws_api_mcp_server.server.DEFAULT_REGION', None)
@patch('awslabs.aws_api_mcp_server.server.WORKING_DIRECTORY', '/tmp')
def test_main_missing_aws_region():
Expand Down
Loading