Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ indent-width = 4

target-version = "py312"
force-exclude = true
preview = true
explicit-preview-rules = true

[tool.ruff.lint]
select = [
Expand All @@ -96,8 +98,9 @@ select = [
"PLE", # ruff currently implements only a subset of pylint's rules
"PLW", # pylint warning
"PLR", # pylint refactor
"UP", # pyupgrade
"C", # Complexity (mccabe+) & comprehensions
"UP", # pyupgrade
"C", # Complexity (mccabe+) & comprehensions
"DOC201", # pydoclint: return value must be documented
]
ignore = [
"UP006", # See https://github.com/bokeh/bokeh/issues/13143
Expand Down
24 changes: 20 additions & 4 deletions src/sre_agent/cli/configuration/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _configure_model_provider(
updates: dict[str, str],
allow_back: bool = False,
) -> str:
"""Prompt for model provider and required credentials."""
"""Prompt for model provider and required credentials.

Returns:
The selected model provider name.
"""
model_provider = _prompt_choice(
"Model provider:",
config.integrations.model_provider,
Expand All @@ -215,7 +219,11 @@ def _configure_notification_platform(
updates: dict[str, str],
allow_back: bool = False,
) -> tuple[str, str | None]:
"""Prompt for notification platform and required credentials."""
"""Prompt for notification platform and required credentials.

Returns:
A tuple of (notification_platform, slack_channel_id).
"""
notification_platform = _prompt_choice(
"Messaging/notification platform:",
config.integrations.notification_platform,
Expand Down Expand Up @@ -248,7 +256,11 @@ def _configure_code_repository_provider(
updates: dict[str, str],
allow_back: bool = False,
) -> tuple[str, str | None, str | None, str | None]:
"""Prompt for code repository provider and required credentials."""
"""Prompt for code repository provider and required credentials.

Returns:
A tuple of (code_repository_provider, github_owner, github_repo, github_ref).
"""
code_repository_provider = _prompt_choice(
"Remote code repository:",
config.integrations.code_repository_provider,
Expand Down Expand Up @@ -299,7 +311,11 @@ def _configure_deployment_platform(
updates: dict[str, str],
allow_back: bool = False,
) -> tuple[str, str]:
"""Prompt for deployment platform, logging platform, and AWS credentials."""
"""Prompt for deployment platform, logging platform, and AWS credentials.

Returns:
A tuple of (deployment_platform, logging_platform).
"""
deployment_platform = _prompt_choice(
"Which platform is your application deployed on?",
config.integrations.deployment_platform,
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/cli/presentation/banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def _print_animated_banner() -> None:


def _build_banner(colour_offset: int) -> Panel:
"""Build the banner panel with a shifted colour palette."""
"""Build the banner panel with a shifted colour palette.

Returns:
A Rich Panel containing the styled ASCII art banner.
"""
ascii_art = get_ascii_art().strip("\n")
# spellchecker:ignore-next-line
banner_text = Text(justify="center")
Expand Down
30 changes: 14 additions & 16 deletions src/sre_agent/cli/presentation/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@
import questionary.constants as questionary_constants
import questionary.styles as questionary_styles

QUESTIONARY_STYLE = questionary.Style(
[
("qmark", "fg:#7C3AED"),
("question", "fg:#e0e0e0 bold"),
("answer", "fg:#5EEAD4 bold"),
("search_success", "noinherit fg:#00FF00 bold"),
("search_none", "noinherit fg:#FF0000 bold"),
("pointer", "fg:#e0e0e0"),
("highlighted", "fg:#f2f2f2"),
("selected", "fg:#e0e0e0"),
("separator", "fg:#e0e0e0"),
("instruction", "fg:#e0e0e0"),
("text", "fg:#e0e0e0"),
("disabled", "fg:#bdbdbd italic"),
]
)
QUESTIONARY_STYLE = questionary.Style([
("qmark", "fg:#7C3AED"),
("question", "fg:#e0e0e0 bold"),
("answer", "fg:#5EEAD4 bold"),
("search_success", "noinherit fg:#00FF00 bold"),
("search_none", "noinherit fg:#FF0000 bold"),
("pointer", "fg:#e0e0e0"),
("highlighted", "fg:#f2f2f2"),
("selected", "fg:#e0e0e0"),
("separator", "fg:#e0e0e0"),
("instruction", "fg:#e0e0e0"),
("text", "fg:#e0e0e0"),
("disabled", "fg:#bdbdbd italic"),
])


def apply_questionary_style() -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,11 @@ def _wait_for_nat_gateways(ec2: Any, nat_ids: list[str], reporter: Callable[[str


def _list_internet_gateways(ec2: Any, vpc_id: str) -> list[str]:
"""List internet gateways attached to a VPC."""
"""List internet gateways attached to a VPC.

Returns:
A list of internet gateway IDs attached to the VPC.
"""
response = ec2.describe_internet_gateways(
Filters=[{"Name": "attachment.vpc-id", "Values": [vpc_id]}]
)
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/ecr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@


def ensure_repository(session: Session, name: str) -> str:
"""Ensure an ECR repository exists and return its URI."""
"""Ensure an ECR repository exists and return its URI.

Returns:
The URI of the ECR repository.
"""
ecr = session.client("ecr")
try:
response = ecr.describe_repositories(repositoryNames=[name])
Expand Down
30 changes: 25 additions & 5 deletions src/sre_agent/core/deployments/aws_ecs/ecs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def register_task_definition(
config: EcsDeploymentConfig,
reporter: Callable[[str], None],
) -> str:
"""Register the ECS task definition."""
"""Register the ECS task definition.

Returns:
The ARN of the registered task definition.
"""
cpu_architecture = _normalise_cpu_architecture(config.task_cpu_architecture)
if not config.exec_role_arn or not config.task_role_arn:
raise RuntimeError("Task roles must be created before registering the task definition.")
Expand Down Expand Up @@ -134,7 +138,11 @@ def _normalise_cpu_architecture(value: str) -> str:


def ensure_cluster(session: Session, cluster_name: str) -> str:
"""Ensure an ECS cluster exists."""
"""Ensure an ECS cluster exists.

Returns:
The ARN of the ECS cluster.
"""
ecs = session.client("ecs")
response = ecs.describe_clusters(clusters=[cluster_name])
clusters = response.get("clusters", [])
Expand All @@ -159,7 +167,11 @@ def run_task(
config: EcsDeploymentConfig,
container_overrides: list[dict[str, Any]] | None = None,
) -> str:
"""Run a one-off ECS task."""
"""Run a one-off ECS task.

Returns:
The ARN of the launched ECS task.
"""
if not config.task_definition_arn:
raise RuntimeError("Task definition is missing. Register it before running tasks.")
if not config.security_group_id or not config.private_subnet_ids:
Expand Down Expand Up @@ -209,7 +221,11 @@ def wait_for_task_completion(
timeout_seconds: int = 1800,
poll_interval_seconds: int = 5,
) -> tuple[bool, str]:
"""Wait for a task to stop and report container exit status."""
"""Wait for a task to stop and report container exit status.

Returns:
A tuple of (success, message) indicating the task outcome.
"""
ecs = session.client("ecs")
deadline = time.time() + timeout_seconds

Expand All @@ -232,7 +248,11 @@ def wait_for_task_completion(


def _task_completion_result(task: dict[str, Any]) -> tuple[bool, str]:
"""Convert ECS task details into a completion result."""
"""Convert ECS task details into a completion result.

Returns:
A tuple of (success, message) indicating the task outcome.
"""
target = _find_container(task.get("containers", []), SRE_AGENT_CONTAINER_NAME)
if target is None:
stopped_reason = str(task.get("stoppedReason", "task stopped"))
Expand Down
24 changes: 20 additions & 4 deletions src/sre_agent/core/deployments/aws_ecs/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def ensure_roles(
secret_arns: list[str],
reporter: Callable[[str], None],
) -> tuple[str, str]:
"""Ensure execution and task roles exist."""
"""Ensure execution and task roles exist.

Returns:
A tuple of (exec_role_arn, task_role_arn).
"""
if not secret_arns:
raise RuntimeError("Secret ARNs are required before creating roles.")

Expand Down Expand Up @@ -71,7 +75,11 @@ def ensure_service_linked_role(session: Session, reporter: Callable[[str], None]


def _ensure_role(iam: Any, role_name: str, trust_policy: dict[str, Any]) -> str:
"""Create a role if needed and return its ARN."""
"""Create a role if needed and return its ARN.

Returns:
The ARN of the IAM role.
"""
try:
response = iam.get_role(RoleName=role_name)
return cast(str, response["Role"]["Arn"])
Expand Down Expand Up @@ -125,7 +133,11 @@ def _ecs_trust_policy() -> dict[str, Any]:


def _secrets_policy(secret_arns: list[str]) -> dict[str, Any]:
"""Allow read access to Secrets Manager."""
"""Allow read access to Secrets Manager.

Returns:
An IAM policy document granting read access to the given secrets.
"""
return {
"Version": "2012-10-17",
"Statement": [
Expand All @@ -139,7 +151,11 @@ def _secrets_policy(secret_arns: list[str]) -> dict[str, Any]:


def _logs_policy(region: str, account_id: str) -> dict[str, Any]:
"""Allow CloudWatch Logs queries."""
"""Allow CloudWatch Logs queries.

Returns:
An IAM policy document granting CloudWatch Logs query access.
"""
return {
"Version": "2012-10-17",
"Statement": [
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def build_and_push_images(
image_config: ImageBuildConfig,
reporter: Callable[[str], None],
) -> str:
"""Build and push container images to ECR."""
"""Build and push container images to ECR.

Returns:
The ECS runtime CPU architecture used for the build.
"""
_require_docker()

reporter("Authenticating Docker with ECR")
Expand Down
12 changes: 10 additions & 2 deletions src/sre_agent/core/deployments/aws_ecs/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def create_basic_vpc(
project_name: str,
reporter: Callable[[str], None],
) -> NetworkSelection:
"""Create a simple VPC with one public and one private subnet."""
"""Create a simple VPC with one public and one private subnet.

Returns:
A NetworkSelection containing the VPC ID and private subnet IDs.
"""
ec2 = session.client("ec2")

reporter("Creating VPC (private networking foundation)")
Expand Down Expand Up @@ -91,7 +95,11 @@ def _tag_resource(ec2: Any, resource_id: str, name: str) -> None:


def _first_availability_zone(ec2: Any) -> str:
"""Fetch the first availability zone."""
"""Fetch the first availability zone.

Returns:
The name of the first available availability zone.
"""
response = ec2.describe_availability_zones()
zones = response.get("AvailabilityZones", [])
if not zones:
Expand Down
18 changes: 15 additions & 3 deletions src/sre_agent/core/deployments/aws_ecs/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ class SecretInfo:


def get_secret_info(session: Session, name: str) -> SecretInfo | None:
"""Fetch secret metadata by name."""
"""Fetch secret metadata by name.

Returns:
A SecretInfo instance, or None if the secret does not exist.
"""
client = session.client("secretsmanager")
try:
response = client.describe_secret(SecretId=name)
Expand All @@ -34,7 +38,11 @@ def get_secret_info(session: Session, name: str) -> SecretInfo | None:


def create_secret(session: Session, name: str, value: str) -> str:
"""Create a secret and return its ARN."""
"""Create a secret and return its ARN.

Returns:
The ARN of the created secret.
"""
client = session.client("secretsmanager")
try:
response = client.create_secret(Name=name, SecretString=value)
Expand All @@ -44,7 +52,11 @@ def create_secret(session: Session, name: str, value: str) -> str:


def restore_secret(session: Session, name: str) -> str:
"""Restore a secret that is scheduled for deletion."""
"""Restore a secret that is scheduled for deletion.

Returns:
The ARN of the restored secret.
"""
client = session.client("secretsmanager")
try:
response = client.restore_secret(SecretId=name)
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/security_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def create_security_group(
name: str,
description: str,
) -> SecurityGroupInfo:
"""Create a security group with default outbound access."""
"""Create a security group with default outbound access.

Returns:
A SecurityGroupInfo instance with the group ID, name, and description.
"""
ec2 = session.client("ec2")
try:
response = ec2.create_security_group(
Expand Down
12 changes: 10 additions & 2 deletions src/sre_agent/core/deployments/aws_ecs/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@


def create_session(config: EcsDeploymentConfig) -> boto3.session.Session:
"""Create a boto3 session."""
"""Create a boto3 session.

Returns:
A configured boto3 Session.
"""
if config.aws_profile:
return boto3.session.Session(
profile_name=config.aws_profile,
Expand All @@ -18,7 +22,11 @@ def create_session(config: EcsDeploymentConfig) -> boto3.session.Session:


def get_identity(session: boto3.session.Session) -> dict[str, str]:
"""Fetch the current AWS identity."""
"""Fetch the current AWS identity.

Returns:
A dict with keys Account, Arn, and UserId.
"""
client = session.client("sts")
try:
response = client.get_caller_identity()
Expand Down
6 changes: 5 additions & 1 deletion src/sre_agent/core/deployments/aws_ecs/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@


def check_deployment(session: Session, config: EcsDeploymentConfig) -> dict[str, str]:
"""Check whether deployment resources exist."""
"""Check whether deployment resources exist.

Returns:
A dict mapping resource names to their status strings.
"""
results: dict[str, str] = {}

results["VPC"] = _check_vpc(session, config.vpc_id)
Expand Down
Loading
Loading