Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/sentry/seer/autofix/autofix_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.utils import timezone
from pydantic import BaseModel
from rest_framework.exceptions import PermissionDenied
from scm.types import GetBranchProtocol, GetRepositoryProtocol

from sentry import analytics, features, quotas
from sentry.analytics.events.autofix_events import (
Expand Down Expand Up @@ -331,6 +332,47 @@ def _code_review_enabled(organization: Organization) -> bool:
return features.has("organizations:seer-autofix-code-review", organization)


def _build_base_shas_metadata(group: Group, referrer: AutofixReferrer) -> str | None:
preference = read_preference_from_sentry_db(group.project)
# Imported lazily to avoid a circular import: sentry.scm pulls in the
# github/slack integrations, which import notifications templates that
# import back into sentry.seer.autofix.
from sentry.scm import factory as scm_factory

base_shas: dict[str, dict[str, str]] = {}
for repo in preference.repositories:
if repo.repository_id is None:
continue

full_name = f"{repo.owner}/{repo.name}"
try:
scm = scm_factory.new(group.organization.id, repo.repository_id, referrer.value)
if repo.branch_name:
base_branch: str | None = repo.branch_name
elif isinstance(scm, GetRepositoryProtocol):
base_branch = scm.get_repository()["data"]["default_branch"]
else:
continue
if not base_branch:
continue
if not isinstance(scm, GetBranchProtocol):
continue
base_sha = scm.get_branch(base_branch)["data"]["sha"]
except Exception:
logger.exception(
"autofix.base_shas.resolve_failed",
extra={"repo": full_name, "group_id": group.id},
)
continue

if base_sha:
base_shas[full_name] = {"base_sha": base_sha, "base_branch": base_branch}

if not base_shas:
return None
return json.dumps(base_shas)


def trigger_autofix_agent(
group: Group,
step: AutofixStep,
Expand Down Expand Up @@ -420,6 +462,12 @@ def trigger_autofix_agent(
)
if iteration_index is not None:
prompt_metadata["iteration_index"] = str(iteration_index)

if step == AutofixStep.CODE_CHANGES and pr_iteration_enabled:
base_shas = _build_base_shas_metadata(group, referrer)
if base_shas:
prompt_metadata["base_shas"] = base_shas

artifact_key = step.value if config.artifact_schema else None
artifact_schema = config.artifact_schema

Expand Down
218 changes: 218 additions & 0 deletions tests/sentry/seer/autofix/test_autofix_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AutofixStep,
NoSeerQuotaException,
PrIterationNoPullRequestException,
_build_base_shas_metadata,
build_step_prompt,
generate_autofix_handoff_prompt,
get_iteration_for_insert_index,
Expand All @@ -28,6 +29,23 @@
from sentry.seer.models import SeerPermissionError
from sentry.sentry_apps.utils.webhooks import SeerActionType
from sentry.testutils.cases import TestCase
from sentry.utils import json


def _make_scm_mock(*, get_repository=None, get_branch=None):
"""Build an SCM mock that satisfies the runtime_checkable Get*Protocol checks.

MagicMock attributes are invisible to ``inspect.getattr_static``, which Python 3.12's
``runtime_checkable`` ``isinstance()`` uses, so the methods must be real class attributes.
"""
return type(
"FakeSCM",
(),
{
"get_repository": MagicMock(return_value=get_repository),
"get_branch": MagicMock(return_value=get_branch),
},
)()


class TestGenerateAutofixHandoffPrompt(TestCase):
Expand Down Expand Up @@ -703,6 +721,206 @@ def test_code_review_enabled_on_non_coding_step_with_flag(

assert mock_client_class.call_args.kwargs["code_review_enabled"] is True

def _make_repo_and_projectrepo(
self,
*,
owner: str = "owner",
name: str = "repo",
external_id: str = "123",
branch_name: str | None = None,
) -> None:
repository = self.create_repo(
project=self.project,
provider="integrations:github",
external_id=external_id,
name=f"{owner}/{name}",
)
self.create_seer_project_repository(
project=self.project,
repository=repository,
branch_name=branch_name,
)

@patch("sentry.scm.factory.new")
@patch("sentry.quotas.backend.record_seer_run")
@patch("sentry.quotas.backend.check_seer_quota", return_value=True)
@patch("sentry.seer.autofix.autofix_agent.broadcast_webhooks_for_organization.delay")
@patch("sentry.seer.autofix.autofix_agent.SeerAgentClient")
def test_code_changes_includes_base_shas_when_pr_iteration_enabled(
self, mock_client_class, mock_broadcast, mock_check_quota, mock_record_run, mock_scm_new
):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.start_run.return_value = MagicMock(seer_run_state_id=123)
self._make_repo_and_projectrepo()

mock_scm = _make_scm_mock(
get_repository={"data": {"default_branch": "main"}},
get_branch={"data": {"sha": "abc123"}},
)
mock_scm_new.return_value = mock_scm

with self.feature("organizations:autofix-pr-iteration"):
trigger_autofix_agent(
group=self.group,
step=AutofixStep.CODE_CHANGES,
referrer=AutofixReferrer.UNKNOWN,
run_id=None,
)

prompt_metadata = mock_client.start_run.call_args.kwargs["prompt_metadata"]
assert json.loads(prompt_metadata["base_shas"]) == {
"owner/repo": {"base_sha": "abc123", "base_branch": "main"}
}

@patch("sentry.scm.factory.new")
@patch("sentry.quotas.backend.record_seer_run")
@patch("sentry.quotas.backend.check_seer_quota", return_value=True)
@patch("sentry.seer.autofix.autofix_agent.broadcast_webhooks_for_organization.delay")
@patch("sentry.seer.autofix.autofix_agent.SeerAgentClient")
def test_code_changes_omits_base_shas_when_pr_iteration_disabled(
self, mock_client_class, mock_broadcast, mock_check_quota, mock_record_run, mock_scm_new
):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.start_run.return_value = MagicMock(seer_run_state_id=123)
self._make_repo_and_projectrepo()

trigger_autofix_agent(
group=self.group,
step=AutofixStep.CODE_CHANGES,
referrer=AutofixReferrer.UNKNOWN,
run_id=None,
)

prompt_metadata = mock_client.start_run.call_args.kwargs["prompt_metadata"]
assert "base_shas" not in prompt_metadata
mock_scm_new.assert_not_called()

@patch("sentry.scm.factory.new")
@patch("sentry.quotas.backend.record_seer_run")
@patch("sentry.quotas.backend.check_seer_quota", return_value=True)
@patch("sentry.seer.autofix.autofix_agent.broadcast_webhooks_for_organization.delay")
@patch("sentry.seer.autofix.autofix_agent.SeerAgentClient")
def test_non_code_changes_step_omits_base_shas(
self, mock_client_class, mock_broadcast, mock_check_quota, mock_record_run, mock_scm_new
):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.start_run.return_value = MagicMock(seer_run_state_id=123)
self._make_repo_and_projectrepo()

with self.feature("organizations:autofix-pr-iteration"):
trigger_autofix_agent(
group=self.group,
step=AutofixStep.ROOT_CAUSE,
referrer=AutofixReferrer.UNKNOWN,
run_id=None,
)

prompt_metadata = mock_client.start_run.call_args.kwargs["prompt_metadata"]
assert "base_shas" not in prompt_metadata
mock_scm_new.assert_not_called()


class TestBuildBaseShasMetadata(TestCase):
def setUp(self) -> None:
super().setUp()
self.group = self.create_group(project=self.project)

def _make_repo_and_projectrepo(
self,
*,
owner: str = "owner",
name: str = "repo",
external_id: str = "123",
branch_name: str | None = None,
) -> None:
repository = self.create_repo(
project=self.project,
provider="integrations:github",
external_id=external_id,
name=f"{owner}/{name}",
)
self.create_seer_project_repository(
project=self.project,
repository=repository,
branch_name=branch_name,
)

def test_returns_none_without_repos(self) -> None:
assert _build_base_shas_metadata(self.group, AutofixReferrer.UNKNOWN) is None

@patch("sentry.scm.factory.new")
def test_builds_base_shas_using_default_branch(self, mock_scm_new):
self._make_repo_and_projectrepo()
mock_scm = _make_scm_mock(
get_repository={"data": {"default_branch": "main"}},
get_branch={"data": {"sha": "deadbeef"}},
)
mock_scm_new.return_value = mock_scm

result = _build_base_shas_metadata(self.group, AutofixReferrer.UNKNOWN)

assert result is not None
assert json.loads(result) == {"owner/repo": {"base_sha": "deadbeef", "base_branch": "main"}}
mock_scm.get_branch.assert_called_once_with("main")

@patch("sentry.scm.factory.new")
def test_uses_branch_name_override(self, mock_scm_new):
self._make_repo_and_projectrepo(branch_name="release/v2")
mock_scm = _make_scm_mock(get_branch={"data": {"sha": "abc"}})
mock_scm_new.return_value = mock_scm

result = _build_base_shas_metadata(self.group, AutofixReferrer.UNKNOWN)

assert result is not None
assert json.loads(result) == {
"owner/repo": {"base_sha": "abc", "base_branch": "release/v2"}
}
mock_scm.get_repository.assert_not_called()
mock_scm.get_branch.assert_called_once_with("release/v2")

@patch("sentry.seer.autofix.autofix_agent.logger")
@patch("sentry.scm.factory.new")
def test_skips_repo_when_scm_raises(self, mock_scm_new, mock_logger):
self._make_repo_and_projectrepo()
mock_scm_new.side_effect = Exception("boom")

assert _build_base_shas_metadata(self.group, AutofixReferrer.UNKNOWN) is None
mock_logger.exception.assert_called_once()

@patch("sentry.scm.factory.new")
def test_skips_repo_without_resolvable_branch(self, mock_scm_new):
self._make_repo_and_projectrepo()
mock_scm = _make_scm_mock(get_repository={"data": {"default_branch": None}})
mock_scm_new.return_value = mock_scm

assert _build_base_shas_metadata(self.group, AutofixReferrer.UNKNOWN) is None
mock_scm.get_branch.assert_not_called()

@patch("sentry.scm.factory.new")
def test_includes_only_repos_with_resolved_sha(self, mock_scm_new):
self._make_repo_and_projectrepo(name="repo-ok", external_id="1")
self._make_repo_and_projectrepo(name="repo-bad", external_id="2")

ok_scm = _make_scm_mock(
get_repository={"data": {"default_branch": "main"}},
get_branch={"data": {"sha": "sha-ok"}},
)
bad_scm = _make_scm_mock(
get_repository={"data": {"default_branch": "main"}},
get_branch={"data": {"sha": ""}},
)
mock_scm_new.side_effect = [ok_scm, bad_scm]

result = _build_base_shas_metadata(self.group, AutofixReferrer.UNKNOWN)

assert result is not None
assert json.loads(result) == {
"owner/repo-ok": {"base_sha": "sha-ok", "base_branch": "main"}
}


class TestTriggerCodingAgentHandoff(TestCase):
"""Tests for trigger_coding_agent_handoff function."""
Expand Down
Loading