diff --git a/pyproject.toml b/pyproject.toml index 63fb1c08..6bcbfde2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ indent-width = 4 target-version = "py312" force-exclude = true +preview = true [tool.ruff.lint] select = [ @@ -96,8 +97,9 @@ select = [ "PLE", # ruff currently implements only a subset of pylint's rules "PLW", # pylint warning "PLR", # pylint refactor - "UP", # pyupgrade - "C", # Complexity (mccabe+) & comprehensions + "UP", # pyupgrade + "C", # Complexity (mccabe+) & comprehensions + "DOC201", # pydoclint: return value must be documented ] ignore = [ "UP006", # See https://github.com/bokeh/bokeh/issues/13143 @@ -105,6 +107,7 @@ ignore = [ "PLC0415", # Allow imports inside functions (useful for optional deps) "PLR2004", # Allow magic values in comparisons (array indices etc.) ] +explicit-preview-rules = true [tool.ruff.format] # Like Black, use double quotes for strings. diff --git a/src/sre_agent/cli/configuration/wizard.py b/src/sre_agent/cli/configuration/wizard.py index 66546319..e74757a3 100644 --- a/src/sre_agent/cli/configuration/wizard.py +++ b/src/sre_agent/cli/configuration/wizard.py @@ -189,7 +189,11 @@ def _configure_model_provider( updates: dict[str, str], allow_back: bool = False, ) -> str: - """Prompt for model provider and required credentials.""" + """Prompt for model provider and required credentials. + + Returns: + The selected model provider name. + """ model_provider = _prompt_choice( "Model provider:", config.integrations.model_provider, @@ -215,7 +219,11 @@ def _configure_notification_platform( updates: dict[str, str], allow_back: bool = False, ) -> tuple[str, str | None]: - """Prompt for notification platform and required credentials.""" + """Prompt for notification platform and required credentials. + + Returns: + A tuple of (notification_platform, slack_channel_id). + """ notification_platform = _prompt_choice( "Messaging/notification platform:", config.integrations.notification_platform, @@ -248,7 +256,11 @@ def _configure_code_repository_provider( updates: dict[str, str], allow_back: bool = False, ) -> tuple[str, str | None, str | None, str | None]: - """Prompt for code repository provider and required credentials.""" + """Prompt for code repository provider and required credentials. + + Returns: + A tuple of (code_repository_provider, github_owner, github_repo, github_ref). + """ code_repository_provider = _prompt_choice( "Remote code repository:", config.integrations.code_repository_provider, @@ -299,7 +311,11 @@ def _configure_deployment_platform( updates: dict[str, str], allow_back: bool = False, ) -> tuple[str, str]: - """Prompt for deployment platform, logging platform, and AWS credentials.""" + """Prompt for deployment platform, logging platform, and AWS credentials. + + Returns: + A tuple of (deployment_platform, logging_platform). + """ deployment_platform = _prompt_choice( "Which platform is your application deployed on?", config.integrations.deployment_platform, diff --git a/src/sre_agent/cli/presentation/banner.py b/src/sre_agent/cli/presentation/banner.py index 9c93f518..8a1970fb 100644 --- a/src/sre_agent/cli/presentation/banner.py +++ b/src/sre_agent/cli/presentation/banner.py @@ -39,7 +39,11 @@ def _print_animated_banner() -> None: def _build_banner(colour_offset: int) -> Panel: - """Build the banner panel with a shifted colour palette.""" + """Build the banner panel with a shifted colour palette. + + Returns: + A Rich Panel containing the styled ASCII art banner. + """ ascii_art = get_ascii_art().strip("\n") # spellchecker:ignore-next-line banner_text = Text(justify="center") diff --git a/src/sre_agent/cli/presentation/styles.py b/src/sre_agent/cli/presentation/styles.py index 53e24bdc..22731144 100644 --- a/src/sre_agent/cli/presentation/styles.py +++ b/src/sre_agent/cli/presentation/styles.py @@ -4,22 +4,20 @@ import questionary.constants as questionary_constants import questionary.styles as questionary_styles -QUESTIONARY_STYLE = questionary.Style( - [ - ("qmark", "fg:#7C3AED"), - ("question", "fg:#e0e0e0 bold"), - ("answer", "fg:#5EEAD4 bold"), - ("search_success", "noinherit fg:#00FF00 bold"), - ("search_none", "noinherit fg:#FF0000 bold"), - ("pointer", "fg:#e0e0e0"), - ("highlighted", "fg:#f2f2f2"), - ("selected", "fg:#e0e0e0"), - ("separator", "fg:#e0e0e0"), - ("instruction", "fg:#e0e0e0"), - ("text", "fg:#e0e0e0"), - ("disabled", "fg:#bdbdbd italic"), - ] -) +QUESTIONARY_STYLE = questionary.Style([ + ("qmark", "fg:#7C3AED"), + ("question", "fg:#e0e0e0 bold"), + ("answer", "fg:#5EEAD4 bold"), + ("search_success", "noinherit fg:#00FF00 bold"), + ("search_none", "noinherit fg:#FF0000 bold"), + ("pointer", "fg:#e0e0e0"), + ("highlighted", "fg:#f2f2f2"), + ("selected", "fg:#e0e0e0"), + ("separator", "fg:#e0e0e0"), + ("instruction", "fg:#e0e0e0"), + ("text", "fg:#e0e0e0"), + ("disabled", "fg:#bdbdbd italic"), +]) def apply_questionary_style() -> None: diff --git a/src/sre_agent/core/deployments/aws_ecs/cleanup.py b/src/sre_agent/core/deployments/aws_ecs/cleanup.py index 0f82683b..3ad3bfc2 100644 --- a/src/sre_agent/core/deployments/aws_ecs/cleanup.py +++ b/src/sre_agent/core/deployments/aws_ecs/cleanup.py @@ -278,7 +278,11 @@ def _wait_for_nat_gateways(ec2: Any, nat_ids: list[str], reporter: Callable[[str def _list_internet_gateways(ec2: Any, vpc_id: str) -> list[str]: - """List internet gateways attached to a VPC.""" + """List internet gateways attached to a VPC. + + Returns: + A list of internet gateway IDs attached to the VPC. + """ response = ec2.describe_internet_gateways( Filters=[{"Name": "attachment.vpc-id", "Values": [vpc_id]}] ) diff --git a/src/sre_agent/core/deployments/aws_ecs/ecr.py b/src/sre_agent/core/deployments/aws_ecs/ecr.py index b5bf1cdd..73d052c2 100644 --- a/src/sre_agent/core/deployments/aws_ecs/ecr.py +++ b/src/sre_agent/core/deployments/aws_ecs/ecr.py @@ -7,7 +7,11 @@ def ensure_repository(session: Session, name: str) -> str: - """Ensure an ECR repository exists and return its URI.""" + """Ensure an ECR repository exists and return its URI. + + Returns: + The URI of the ECR repository. + """ ecr = session.client("ecr") try: response = ecr.describe_repositories(repositoryNames=[name]) diff --git a/src/sre_agent/core/deployments/aws_ecs/ecs_tasks.py b/src/sre_agent/core/deployments/aws_ecs/ecs_tasks.py index 95dcf617..8f03185e 100644 --- a/src/sre_agent/core/deployments/aws_ecs/ecs_tasks.py +++ b/src/sre_agent/core/deployments/aws_ecs/ecs_tasks.py @@ -29,7 +29,11 @@ def register_task_definition( config: EcsDeploymentConfig, reporter: Callable[[str], None], ) -> str: - """Register the ECS task definition.""" + """Register the ECS task definition. + + Returns: + The ARN of the registered task definition. + """ cpu_architecture = _normalise_cpu_architecture(config.task_cpu_architecture) if not config.exec_role_arn or not config.task_role_arn: raise RuntimeError("Task roles must be created before registering the task definition.") @@ -134,7 +138,11 @@ def _normalise_cpu_architecture(value: str) -> str: def ensure_cluster(session: Session, cluster_name: str) -> str: - """Ensure an ECS cluster exists.""" + """Ensure an ECS cluster exists. + + Returns: + The ARN of the ECS cluster. + """ ecs = session.client("ecs") response = ecs.describe_clusters(clusters=[cluster_name]) clusters = response.get("clusters", []) @@ -159,7 +167,11 @@ def run_task( config: EcsDeploymentConfig, container_overrides: list[dict[str, Any]] | None = None, ) -> str: - """Run a one-off ECS task.""" + """Run a one-off ECS task. + + Returns: + The ARN of the launched ECS task. + """ if not config.task_definition_arn: raise RuntimeError("Task definition is missing. Register it before running tasks.") if not config.security_group_id or not config.private_subnet_ids: @@ -209,7 +221,11 @@ def wait_for_task_completion( timeout_seconds: int = 1800, poll_interval_seconds: int = 5, ) -> tuple[bool, str]: - """Wait for a task to stop and report container exit status.""" + """Wait for a task to stop and report container exit status. + + Returns: + A tuple of (success, message) indicating the task outcome. + """ ecs = session.client("ecs") deadline = time.time() + timeout_seconds @@ -232,7 +248,11 @@ def wait_for_task_completion( def _task_completion_result(task: dict[str, Any]) -> tuple[bool, str]: - """Convert ECS task details into a completion result.""" + """Convert ECS task details into a completion result. + + Returns: + A tuple of (success, message) indicating the task outcome. + """ target = _find_container(task.get("containers", []), SRE_AGENT_CONTAINER_NAME) if target is None: stopped_reason = str(task.get("stoppedReason", "task stopped")) diff --git a/src/sre_agent/core/deployments/aws_ecs/iam.py b/src/sre_agent/core/deployments/aws_ecs/iam.py index f44222ff..4d499905 100644 --- a/src/sre_agent/core/deployments/aws_ecs/iam.py +++ b/src/sre_agent/core/deployments/aws_ecs/iam.py @@ -15,7 +15,11 @@ def ensure_roles( secret_arns: list[str], reporter: Callable[[str], None], ) -> tuple[str, str]: - """Ensure execution and task roles exist.""" + """Ensure execution and task roles exist. + + Returns: + A tuple of (exec_role_arn, task_role_arn). + """ if not secret_arns: raise RuntimeError("Secret ARNs are required before creating roles.") @@ -71,7 +75,11 @@ def ensure_service_linked_role(session: Session, reporter: Callable[[str], None] def _ensure_role(iam: Any, role_name: str, trust_policy: dict[str, Any]) -> str: - """Create a role if needed and return its ARN.""" + """Create a role if needed and return its ARN. + + Returns: + The ARN of the IAM role. + """ try: response = iam.get_role(RoleName=role_name) return cast(str, response["Role"]["Arn"]) @@ -125,7 +133,11 @@ def _ecs_trust_policy() -> dict[str, Any]: def _secrets_policy(secret_arns: list[str]) -> dict[str, Any]: - """Allow read access to Secrets Manager.""" + """Allow read access to Secrets Manager. + + Returns: + An IAM policy document granting read access to the given secrets. + """ return { "Version": "2012-10-17", "Statement": [ @@ -139,7 +151,11 @@ def _secrets_policy(secret_arns: list[str]) -> dict[str, Any]: def _logs_policy(region: str, account_id: str) -> dict[str, Any]: - """Allow CloudWatch Logs queries.""" + """Allow CloudWatch Logs queries. + + Returns: + An IAM policy document granting CloudWatch Logs query access. + """ return { "Version": "2012-10-17", "Statement": [ diff --git a/src/sre_agent/core/deployments/aws_ecs/images.py b/src/sre_agent/core/deployments/aws_ecs/images.py index d9165f88..a5ff0894 100644 --- a/src/sre_agent/core/deployments/aws_ecs/images.py +++ b/src/sre_agent/core/deployments/aws_ecs/images.py @@ -28,7 +28,11 @@ def build_and_push_images( image_config: ImageBuildConfig, reporter: Callable[[str], None], ) -> str: - """Build and push container images to ECR.""" + """Build and push container images to ECR. + + Returns: + The ECS runtime CPU architecture used for the build. + """ _require_docker() reporter("Authenticating Docker with ECR") diff --git a/src/sre_agent/core/deployments/aws_ecs/network.py b/src/sre_agent/core/deployments/aws_ecs/network.py index 51146b14..74bbdda4 100644 --- a/src/sre_agent/core/deployments/aws_ecs/network.py +++ b/src/sre_agent/core/deployments/aws_ecs/network.py @@ -13,7 +13,11 @@ def create_basic_vpc( project_name: str, reporter: Callable[[str], None], ) -> NetworkSelection: - """Create a simple VPC with one public and one private subnet.""" + """Create a simple VPC with one public and one private subnet. + + Returns: + A NetworkSelection containing the VPC ID and private subnet IDs. + """ ec2 = session.client("ec2") reporter("Creating VPC (private networking foundation)") @@ -91,7 +95,11 @@ def _tag_resource(ec2: Any, resource_id: str, name: str) -> None: def _first_availability_zone(ec2: Any) -> str: - """Fetch the first availability zone.""" + """Fetch the first availability zone. + + Returns: + The name of the first available availability zone. + """ response = ec2.describe_availability_zones() zones = response.get("AvailabilityZones", []) if not zones: diff --git a/src/sre_agent/core/deployments/aws_ecs/secrets.py b/src/sre_agent/core/deployments/aws_ecs/secrets.py index c6dd29d1..261ea408 100644 --- a/src/sre_agent/core/deployments/aws_ecs/secrets.py +++ b/src/sre_agent/core/deployments/aws_ecs/secrets.py @@ -16,7 +16,11 @@ class SecretInfo: def get_secret_info(session: Session, name: str) -> SecretInfo | None: - """Fetch secret metadata by name.""" + """Fetch secret metadata by name. + + Returns: + A SecretInfo instance, or None if the secret does not exist. + """ client = session.client("secretsmanager") try: response = client.describe_secret(SecretId=name) @@ -34,7 +38,11 @@ def get_secret_info(session: Session, name: str) -> SecretInfo | None: def create_secret(session: Session, name: str, value: str) -> str: - """Create a secret and return its ARN.""" + """Create a secret and return its ARN. + + Returns: + The ARN of the created secret. + """ client = session.client("secretsmanager") try: response = client.create_secret(Name=name, SecretString=value) @@ -44,7 +52,11 @@ def create_secret(session: Session, name: str, value: str) -> str: def restore_secret(session: Session, name: str) -> str: - """Restore a secret that is scheduled for deletion.""" + """Restore a secret that is scheduled for deletion. + + Returns: + The ARN of the restored secret. + """ client = session.client("secretsmanager") try: response = client.restore_secret(SecretId=name) diff --git a/src/sre_agent/core/deployments/aws_ecs/security_groups.py b/src/sre_agent/core/deployments/aws_ecs/security_groups.py index 367cb90f..3850eb77 100644 --- a/src/sre_agent/core/deployments/aws_ecs/security_groups.py +++ b/src/sre_agent/core/deployments/aws_ecs/security_groups.py @@ -12,7 +12,11 @@ def create_security_group( name: str, description: str, ) -> SecurityGroupInfo: - """Create a security group with default outbound access.""" + """Create a security group with default outbound access. + + Returns: + A SecurityGroupInfo instance with the group ID, name, and description. + """ ec2 = session.client("ec2") try: response = ec2.create_security_group( diff --git a/src/sre_agent/core/deployments/aws_ecs/session.py b/src/sre_agent/core/deployments/aws_ecs/session.py index 50e1a600..4a4db1fe 100644 --- a/src/sre_agent/core/deployments/aws_ecs/session.py +++ b/src/sre_agent/core/deployments/aws_ecs/session.py @@ -7,7 +7,11 @@ def create_session(config: EcsDeploymentConfig) -> boto3.session.Session: - """Create a boto3 session.""" + """Create a boto3 session. + + Returns: + A configured boto3 Session. + """ if config.aws_profile: return boto3.session.Session( profile_name=config.aws_profile, @@ -18,7 +22,11 @@ def create_session(config: EcsDeploymentConfig) -> boto3.session.Session: def get_identity(session: boto3.session.Session) -> dict[str, str]: - """Fetch the current AWS identity.""" + """Fetch the current AWS identity. + + Returns: + A dict with keys Account, Arn, and UserId. + """ client = session.client("sts") try: response = client.get_caller_identity() diff --git a/src/sre_agent/core/deployments/aws_ecs/status.py b/src/sre_agent/core/deployments/aws_ecs/status.py index 79d0ed8c..09f32b85 100644 --- a/src/sre_agent/core/deployments/aws_ecs/status.py +++ b/src/sre_agent/core/deployments/aws_ecs/status.py @@ -7,7 +7,11 @@ def check_deployment(session: Session, config: EcsDeploymentConfig) -> dict[str, str]: - """Check whether deployment resources exist.""" + """Check whether deployment resources exist. + + Returns: + A dict mapping resource names to their status strings. + """ results: dict[str, str] = {} results["VPC"] = _check_vpc(session, config.vpc_id) diff --git a/src/sre_agent/core/prompts.py b/src/sre_agent/core/prompts.py index f5c8564b..3fca1145 100644 --- a/src/sre_agent/core/prompts.py +++ b/src/sre_agent/core/prompts.py @@ -8,7 +8,11 @@ def _load_prompt(filename: str) -> str: - """Load a prompt from a text file.""" + """Load a prompt from a text file. + + Returns: + The prompt text with leading and trailing whitespace stripped. + """ return (PROMPTS_DIR / filename).read_text(encoding="utf-8").strip() @@ -22,7 +26,11 @@ def build_diagnosis_prompt( service_name: str, time_range_minutes: int = 10, ) -> str: - """Build a diagnosis prompt for the agent.""" + """Build a diagnosis prompt for the agent. + + Returns: + The formatted diagnosis prompt string. + """ prompt = DIAGNOSIS_PROMPT_TEMPLATE.format( log_group=log_group, time_range_minutes=time_range_minutes, diff --git a/src/sre_agent/core/settings.py b/src/sre_agent/core/settings.py index f3f47c5a..d2d2cfc4 100644 --- a/src/sre_agent/core/settings.py +++ b/src/sre_agent/core/settings.py @@ -74,6 +74,9 @@ def get_settings() -> AgentSettings: The sub-configs are automatically populated from the environment thanks to pydantic-settings. + + Returns: + The populated AgentSettings instance. """ # We use type: ignore[call-arg] because mypy doesn't know BaseSettings # will populate these fields from the environment variables. diff --git a/src/sre_agent/core/tools/cloudwatch.py b/src/sre_agent/core/tools/cloudwatch.py index a0eef3aa..8cc8f27e 100644 --- a/src/sre_agent/core/tools/cloudwatch.py +++ b/src/sre_agent/core/tools/cloudwatch.py @@ -77,7 +77,11 @@ async def query_errors( raise RuntimeError(f"Unexpected error querying logs: {e}") from e def _parse_events(self, events: list[dict[str, Any]]) -> list[LogEntry]: - """Parse filter_log_events entries into LogEntry objects.""" + """Parse filter_log_events entries into LogEntry objects. + + Returns: + A list of LogEntry objects sorted by timestamp descending. + """ entries = [] for event in events: timestamp_ms = event.get("timestamp") @@ -97,7 +101,11 @@ def _parse_events(self, events: list[dict[str, Any]]) -> list[LogEntry]: def create_cloudwatch_toolset(config: AgentSettings) -> FunctionToolset: - """Create a FunctionToolset with CloudWatch tools for pydantic-ai.""" + """Create a FunctionToolset with CloudWatch tools for pydantic-ai. + + Returns: + A FunctionToolset containing the CloudWatch search tool. + """ toolset = FunctionToolset() cw_logging = CloudWatchLogging(region=config.aws.region) diff --git a/src/sre_agent/core/tools/github.py b/src/sre_agent/core/tools/github.py index 0c8c7a67..32d113e9 100644 --- a/src/sre_agent/core/tools/github.py +++ b/src/sre_agent/core/tools/github.py @@ -13,6 +13,9 @@ def create_github_mcp_toolset(config: AgentSettings) -> MCPServerStreamableHTTP: """Create GitHub MCP server toolset for pydantic-ai. Connects to an external GitHub MCP server via Streamable HTTP. + + Returns: + An MCPServerStreamableHTTP instance configured for GitHub. """ if not config.github.mcp_url: logger.warning("GITHUB_MCP_URL not set, GitHub tools will be unavailable") diff --git a/src/sre_agent/core/tools/slack.py b/src/sre_agent/core/tools/slack.py index 14abcf7f..4454598b 100644 --- a/src/sre_agent/core/tools/slack.py +++ b/src/sre_agent/core/tools/slack.py @@ -17,6 +17,9 @@ def create_slack_mcp_toolset(config: AgentSettings) -> FilteredToolset: """Create Slack MCP server toolset for pydantic-ai. Connects to an external Slack MCP server via SSE. + + Returns: + A FilteredToolset exposing only the allowed Slack tools. """ if not config.slack.mcp_url: logger.warning("SLACK_MCP_URL not set, Slack tools will be unavailable") diff --git a/src/sre_agent/eval/diagnosis_quality/mocks/cloudwatch.py b/src/sre_agent/eval/diagnosis_quality/mocks/cloudwatch.py index 8b3267d5..86a2b263 100644 --- a/src/sre_agent/eval/diagnosis_quality/mocks/cloudwatch.py +++ b/src/sre_agent/eval/diagnosis_quality/mocks/cloudwatch.py @@ -14,7 +14,11 @@ async def search_error_logs( service_name: str, time_range_minutes: int, ) -> LogQueryResult: - """Mock CloudWatch log lookup using case fixtures.""" + """Mock CloudWatch log lookup using case fixtures. + + Returns: + A LogQueryResult populated from case fixtures. + """ with opik.start_as_current_span( name="search_error_logs", type="tool", @@ -41,7 +45,11 @@ async def search_error_logs( def _normalise_messages(runtime: MockToolRuntime) -> list[str]: - """Convert multiline fixture entries into non-empty log messages.""" + """Convert multiline fixture entries into non-empty log messages. + + Returns: + A list of non-empty log message strings. + """ messages: list[str] = [] for entry in runtime.case.mock_cloudwatch_entries: message = "\n".join(line.rstrip("\n") for line in entry.message).strip() diff --git a/src/sre_agent/eval/diagnosis_quality/mocks/slack.py b/src/sre_agent/eval/diagnosis_quality/mocks/slack.py index 70da1a42..0c9aab73 100644 --- a/src/sre_agent/eval/diagnosis_quality/mocks/slack.py +++ b/src/sre_agent/eval/diagnosis_quality/mocks/slack.py @@ -12,7 +12,11 @@ async def conversations_add_message( payload: str, thread_ts: str | None, ) -> dict[str, Any]: - """Mock Slack conversations_add_message.""" + """Mock Slack conversations_add_message. + + Returns: + A mock Slack API response dict. + """ span_input: dict[str, Any] = {"channel_id": channel_id, "payload": payload} if thread_ts is not None: span_input["thread_ts"] = thread_ts diff --git a/src/sre_agent/eval/diagnosis_quality/mocks/toolset.py b/src/sre_agent/eval/diagnosis_quality/mocks/toolset.py index 014f76ec..21382627 100644 --- a/src/sre_agent/eval/diagnosis_quality/mocks/toolset.py +++ b/src/sre_agent/eval/diagnosis_quality/mocks/toolset.py @@ -11,7 +11,11 @@ def build_mock_toolset(runtime: MockToolRuntime) -> FunctionToolset: - """Build mocked Slack and CloudWatch toolset.""" + """Build mocked Slack and CloudWatch toolset. + + Returns: + A FunctionToolset with mocked Slack and CloudWatch tools. + """ toolset = FunctionToolset() @toolset.tool @@ -20,7 +24,11 @@ async def conversations_add_message( payload: str, thread_ts: str | None = None, ) -> dict[str, Any]: - """Mock Slack message posting.""" + """Mock Slack message posting. + + Returns: + A mock Slack API response dict. + """ return await slack_mocks.conversations_add_message( channel_id, payload, @@ -33,7 +41,11 @@ async def search_error_logs( service_name: str, time_range_minutes: int = 10, ) -> LogQueryResult: - """Mock CloudWatch error search.""" + """Mock CloudWatch error search. + + Returns: + A LogQueryResult populated from case fixtures. + """ return await cloudwatch_mocks.search_error_logs( runtime, log_group, diff --git a/src/sre_agent/eval/tool_call/mocks/cloudwatch.py b/src/sre_agent/eval/tool_call/mocks/cloudwatch.py index d5a78a7e..0706f2a0 100644 --- a/src/sre_agent/eval/tool_call/mocks/cloudwatch.py +++ b/src/sre_agent/eval/tool_call/mocks/cloudwatch.py @@ -14,7 +14,11 @@ async def search_error_logs( service_name: str, time_range_minutes: int, ) -> LogQueryResult: - """Mock CloudWatch log lookup using case fixtures.""" + """Mock CloudWatch log lookup using case fixtures. + + Returns: + A LogQueryResult populated from case fixtures. + """ with opik.start_as_current_span( name="search_error_logs", type="tool", @@ -41,7 +45,11 @@ async def search_error_logs( def _normalise_messages(runtime: MockToolRuntime) -> list[str]: - """Convert multiline fixture entries into non-empty log messages.""" + """Convert multiline fixture entries into non-empty log messages. + + Returns: + A list of non-empty log message strings. + """ messages: list[str] = [] for entry in runtime.case.mock_cloudwatch_entries: message = "\n".join(line.rstrip("\n") for line in entry.message).strip() diff --git a/src/sre_agent/eval/tool_call/mocks/slack.py b/src/sre_agent/eval/tool_call/mocks/slack.py index c967c414..25dcc73f 100644 --- a/src/sre_agent/eval/tool_call/mocks/slack.py +++ b/src/sre_agent/eval/tool_call/mocks/slack.py @@ -12,7 +12,11 @@ async def conversations_add_message( payload: str, thread_ts: str | None, ) -> dict[str, Any]: - """Mock Slack conversations_add_message.""" + """Mock Slack conversations_add_message. + + Returns: + A mock Slack API response dict. + """ span_input: dict[str, Any] = {"channel_id": channel_id, "payload": payload} if thread_ts is not None: span_input["thread_ts"] = thread_ts diff --git a/src/sre_agent/eval/tool_call/mocks/toolset.py b/src/sre_agent/eval/tool_call/mocks/toolset.py index 28642cc7..f0efd754 100644 --- a/src/sre_agent/eval/tool_call/mocks/toolset.py +++ b/src/sre_agent/eval/tool_call/mocks/toolset.py @@ -11,7 +11,11 @@ def build_mock_toolset(runtime: MockToolRuntime) -> FunctionToolset: - """Build mocked Slack and CloudWatch toolset.""" + """Build mocked Slack and CloudWatch toolset. + + Returns: + A FunctionToolset with mocked Slack and CloudWatch tools. + """ toolset = FunctionToolset() @toolset.tool @@ -20,7 +24,11 @@ async def conversations_add_message( payload: str, thread_ts: str | None = None, ) -> dict[str, Any]: - """Mock Slack message posting.""" + """Mock Slack message posting. + + Returns: + A mock Slack API response dict. + """ return await slack_mocks.conversations_add_message( channel_id, payload, @@ -33,7 +41,11 @@ async def search_error_logs( service_name: str, time_range_minutes: int = 10, ) -> LogQueryResult: - """Mock CloudWatch error search.""" + """Mock CloudWatch error search. + + Returns: + A LogQueryResult populated from case fixtures. + """ return await cloudwatch_mocks.search_error_logs( runtime, log_group, diff --git a/src/sre_agent/run.py b/src/sre_agent/run.py index 585da872..52508c0c 100644 --- a/src/sre_agent/run.py +++ b/src/sre_agent/run.py @@ -19,7 +19,11 @@ def _load_request_from_args_or_env() -> tuple[str, str, int]: - """Load diagnosis inputs from CLI args or environment.""" + """Load diagnosis inputs from CLI args or environment. + + Returns: + A tuple of (log_group, service_name, time_range_minutes). + """ if len(sys.argv) >= 3: log_group = sys.argv[1] service_name = sys.argv[2]