Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ indent-width = 4

target-version = "py312"
force-exclude = true
preview = true

[tool.ruff.lint]
select = [
Expand All @@ -96,15 +97,17 @@ select = [
"PLE", # ruff currently implements only a subset of pylint's rules
"PLW", # pylint warning
"PLR", # pylint refactor
"UP", # pyupgrade
"C", # Complexity (mccabe+) & comprehensions
"UP", # pyupgrade
"C", # Complexity (mccabe+) & comprehensions
"DOC201", # pydoclint: return value must be documented
]
ignore = [
"UP006", # See https://github.com/bokeh/bokeh/issues/13143
"UP007", # See https://github.com/bokeh/bokeh/pull/13144
"PLC0415", # Allow imports inside functions (useful for optional deps)
"PLR2004", # Allow magic values in comparisons (array indices etc.)
]
explicit-preview-rules = true

[tool.ruff.format]
# Like Black, use double quotes for strings.
Expand Down
24 changes: 20 additions & 4 deletions src/sre_agent/cli/configuration/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _configure_model_provider(
updates: dict[str, str],
allow_back: bool = False,
) -> str:
"""Prompt for model provider and required credentials."""
"""Prompt for model provider and required credentials.

Returns:
The selected model provider name.
"""
model_provider = _prompt_choice(
"Model provider:",
config.integrations.model_provider,
Expand All @@ -215,7 +219,11 @@ def _configure_notification_platform(
updates: dict[str, str],
allow_back: bool = False,
) -> tuple[str, str | None]:
"""Prompt for notification platform and required credentials."""
"""Prompt for notification platform and required credentials.

Returns:
A tuple of (notification_platform, slack_channel_id).
"""
notification_platform = _prompt_choice(
"Messaging/notification platform:",
config.integrations.notification_platform,
Expand Down Expand Up @@ -248,7 +256,11 @@ def _configure_code_repository_provider(
updates: dict[str, str],
allow_back: bool = False,
) -> tuple[str, str | None, str | None, str | None]:
"""Prompt for code repository provider and required credentials."""
"""Prompt for code repository provider and required credentials.

Returns:
A tuple of (code_repository_provider, github_owner, github_repo, github_ref).
"""
code_repository_provider = _prompt_choice(
"Remote code repository:",
config.integrations.code_repository_provider,
Expand Down Expand Up @@ -299,7 +311,11 @@ def _configure_deployment_platform(
updates: dict[str, str],
allow_back: bool = False,
) -> tuple[str, str]:
"""Prompt for deployment platform, logging platform, and AWS credentials."""
"""Prompt for deployment platform, logging platform, and AWS credentials.

Returns:
A tuple of (deployment_platform, logging_platform).
"""
deployment_platform = _prompt_choice(
"Which platform is your application deployed on?",
config.integrations.deployment_platform,
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/cli/presentation/banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def _print_animated_banner() -> None:


def _build_banner(colour_offset: int) -> Panel:
"""Build the banner panel with a shifted colour palette."""
"""Build the banner panel with a shifted colour palette.

Returns:
A Rich Panel containing the styled ASCII art banner.
"""
ascii_art = get_ascii_art().strip("\n")
# spellchecker:ignore-next-line
banner_text = Text(justify="center")
Expand Down
30 changes: 14 additions & 16 deletions src/sre_agent/cli/presentation/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@
import questionary.constants as questionary_constants
import questionary.styles as questionary_styles

QUESTIONARY_STYLE = questionary.Style(
[
("qmark", "fg:#7C3AED"),
("question", "fg:#e0e0e0 bold"),
("answer", "fg:#5EEAD4 bold"),
("search_success", "noinherit fg:#00FF00 bold"),
("search_none", "noinherit fg:#FF0000 bold"),
("pointer", "fg:#e0e0e0"),
("highlighted", "fg:#f2f2f2"),
("selected", "fg:#e0e0e0"),
("separator", "fg:#e0e0e0"),
("instruction", "fg:#e0e0e0"),
("text", "fg:#e0e0e0"),
("disabled", "fg:#bdbdbd italic"),
]
)
QUESTIONARY_STYLE = questionary.Style([
("qmark", "fg:#7C3AED"),
("question", "fg:#e0e0e0 bold"),
("answer", "fg:#5EEAD4 bold"),
("search_success", "noinherit fg:#00FF00 bold"),
("search_none", "noinherit fg:#FF0000 bold"),
("pointer", "fg:#e0e0e0"),
("highlighted", "fg:#f2f2f2"),
("selected", "fg:#e0e0e0"),
("separator", "fg:#e0e0e0"),
("instruction", "fg:#e0e0e0"),
("text", "fg:#e0e0e0"),
("disabled", "fg:#bdbdbd italic"),
])


def apply_questionary_style() -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,11 @@ def _wait_for_nat_gateways(ec2: Any, nat_ids: list[str], reporter: Callable[[str


def _list_internet_gateways(ec2: Any, vpc_id: str) -> list[str]:
"""List internet gateways attached to a VPC."""
"""List internet gateways attached to a VPC.

Returns:
A list of internet gateway IDs attached to the VPC.
"""
response = ec2.describe_internet_gateways(
Filters=[{"Name": "attachment.vpc-id", "Values": [vpc_id]}]
)
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/ecr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@


def ensure_repository(session: Session, name: str) -> str:
"""Ensure an ECR repository exists and return its URI."""
"""Ensure an ECR repository exists and return its URI.

Returns:
The URI of the ECR repository.
"""
ecr = session.client("ecr")
try:
response = ecr.describe_repositories(repositoryNames=[name])
Expand Down
30 changes: 25 additions & 5 deletions src/sre_agent/core/deployments/aws_ecs/ecs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def register_task_definition(
config: EcsDeploymentConfig,
reporter: Callable[[str], None],
) -> str:
"""Register the ECS task definition."""
"""Register the ECS task definition.

Returns:
The ARN of the registered task definition.
"""
cpu_architecture = _normalise_cpu_architecture(config.task_cpu_architecture)
if not config.exec_role_arn or not config.task_role_arn:
raise RuntimeError("Task roles must be created before registering the task definition.")
Expand Down Expand Up @@ -134,7 +138,11 @@ def _normalise_cpu_architecture(value: str) -> str:


def ensure_cluster(session: Session, cluster_name: str) -> str:
"""Ensure an ECS cluster exists."""
"""Ensure an ECS cluster exists.

Returns:
The ARN of the ECS cluster.
"""
ecs = session.client("ecs")
response = ecs.describe_clusters(clusters=[cluster_name])
clusters = response.get("clusters", [])
Expand All @@ -159,7 +167,11 @@ def run_task(
config: EcsDeploymentConfig,
container_overrides: list[dict[str, Any]] | None = None,
) -> str:
"""Run a one-off ECS task."""
"""Run a one-off ECS task.

Returns:
The ARN of the launched ECS task.
"""
if not config.task_definition_arn:
raise RuntimeError("Task definition is missing. Register it before running tasks.")
if not config.security_group_id or not config.private_subnet_ids:
Expand Down Expand Up @@ -209,7 +221,11 @@ def wait_for_task_completion(
timeout_seconds: int = 1800,
poll_interval_seconds: int = 5,
) -> tuple[bool, str]:
"""Wait for a task to stop and report container exit status."""
"""Wait for a task to stop and report container exit status.

Returns:
A tuple of (success, message) indicating the task outcome.
"""
ecs = session.client("ecs")
deadline = time.time() + timeout_seconds

Expand All @@ -232,7 +248,11 @@ def wait_for_task_completion(


def _task_completion_result(task: dict[str, Any]) -> tuple[bool, str]:
"""Convert ECS task details into a completion result."""
"""Convert ECS task details into a completion result.

Returns:
A tuple of (success, message) indicating the task outcome.
"""
target = _find_container(task.get("containers", []), SRE_AGENT_CONTAINER_NAME)
if target is None:
stopped_reason = str(task.get("stoppedReason", "task stopped"))
Expand Down
24 changes: 20 additions & 4 deletions src/sre_agent/core/deployments/aws_ecs/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def ensure_roles(
secret_arns: list[str],
reporter: Callable[[str], None],
) -> tuple[str, str]:
"""Ensure execution and task roles exist."""
"""Ensure execution and task roles exist.

Returns:
A tuple of (exec_role_arn, task_role_arn).
"""
if not secret_arns:
raise RuntimeError("Secret ARNs are required before creating roles.")

Expand Down Expand Up @@ -71,7 +75,11 @@ def ensure_service_linked_role(session: Session, reporter: Callable[[str], None]


def _ensure_role(iam: Any, role_name: str, trust_policy: dict[str, Any]) -> str:
"""Create a role if needed and return its ARN."""
"""Create a role if needed and return its ARN.

Returns:
The ARN of the IAM role.
"""
try:
response = iam.get_role(RoleName=role_name)
return cast(str, response["Role"]["Arn"])
Expand Down Expand Up @@ -125,7 +133,11 @@ def _ecs_trust_policy() -> dict[str, Any]:


def _secrets_policy(secret_arns: list[str]) -> dict[str, Any]:
"""Allow read access to Secrets Manager."""
"""Allow read access to Secrets Manager.

Returns:
An IAM policy document granting read access to the given secrets.
"""
return {
"Version": "2012-10-17",
"Statement": [
Expand All @@ -139,7 +151,11 @@ def _secrets_policy(secret_arns: list[str]) -> dict[str, Any]:


def _logs_policy(region: str, account_id: str) -> dict[str, Any]:
"""Allow CloudWatch Logs queries."""
"""Allow CloudWatch Logs queries.

Returns:
An IAM policy document granting CloudWatch Logs query access.
"""
return {
"Version": "2012-10-17",
"Statement": [
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def build_and_push_images(
image_config: ImageBuildConfig,
reporter: Callable[[str], None],
) -> str:
"""Build and push container images to ECR."""
"""Build and push container images to ECR.

Returns:
The ECS runtime CPU architecture used for the build.
"""
_require_docker()

reporter("Authenticating Docker with ECR")
Expand Down
12 changes: 10 additions & 2 deletions src/sre_agent/core/deployments/aws_ecs/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def create_basic_vpc(
project_name: str,
reporter: Callable[[str], None],
) -> NetworkSelection:
"""Create a simple VPC with one public and one private subnet."""
"""Create a simple VPC with one public and one private subnet.

Returns:
A NetworkSelection containing the VPC ID and private subnet IDs.
"""
ec2 = session.client("ec2")

reporter("Creating VPC (private networking foundation)")
Expand Down Expand Up @@ -91,7 +95,11 @@ def _tag_resource(ec2: Any, resource_id: str, name: str) -> None:


def _first_availability_zone(ec2: Any) -> str:
"""Fetch the first availability zone."""
"""Fetch the first availability zone.

Returns:
The name of the first available availability zone.
"""
response = ec2.describe_availability_zones()
zones = response.get("AvailabilityZones", [])
if not zones:
Expand Down
18 changes: 15 additions & 3 deletions src/sre_agent/core/deployments/aws_ecs/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ class SecretInfo:


def get_secret_info(session: Session, name: str) -> SecretInfo | None:
"""Fetch secret metadata by name."""
"""Fetch secret metadata by name.

Returns:
A SecretInfo instance, or None if the secret does not exist.
"""
client = session.client("secretsmanager")
try:
response = client.describe_secret(SecretId=name)
Expand All @@ -34,7 +38,11 @@ def get_secret_info(session: Session, name: str) -> SecretInfo | None:


def create_secret(session: Session, name: str, value: str) -> str:
"""Create a secret and return its ARN."""
"""Create a secret and return its ARN.

Returns:
The ARN of the created secret.
"""
client = session.client("secretsmanager")
try:
response = client.create_secret(Name=name, SecretString=value)
Expand All @@ -44,7 +52,11 @@ def create_secret(session: Session, name: str, value: str) -> str:


def restore_secret(session: Session, name: str) -> str:
"""Restore a secret that is scheduled for deletion."""
"""Restore a secret that is scheduled for deletion.

Returns:
The ARN of the restored secret.
"""
client = session.client("secretsmanager")
try:
response = client.restore_secret(SecretId=name)
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/security_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def create_security_group(
name: str,
description: str,
) -> SecurityGroupInfo:
"""Create a security group with default outbound access."""
"""Create a security group with default outbound access.

Returns:
A SecurityGroupInfo instance with the group ID, name, and description.
"""
ec2 = session.client("ec2")
try:
response = ec2.create_security_group(
Expand Down
12 changes: 10 additions & 2 deletions src/sre_agent/core/deployments/aws_ecs/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@


def create_session(config: EcsDeploymentConfig) -> boto3.session.Session:
"""Create a boto3 session."""
"""Create a boto3 session.

Returns:
A configured boto3 Session.
"""
if config.aws_profile:
return boto3.session.Session(
profile_name=config.aws_profile,
Expand All @@ -18,7 +22,11 @@ def create_session(config: EcsDeploymentConfig) -> boto3.session.Session:


def get_identity(session: boto3.session.Session) -> dict[str, str]:
"""Fetch the current AWS identity."""
"""Fetch the current AWS identity.

Returns:
A dict with keys Account, Arn, and UserId.
"""
client = session.client("sts")
try:
response = client.get_caller_identity()
Expand Down
Loading
Loading