diff --git a/src/seer/automation/autofix/components/root_cause/models.py b/src/seer/automation/autofix/components/root_cause/models.py index 2bc5251c5..e191bdf2a 100644 --- a/src/seer/automation/autofix/components/root_cause/models.py +++ b/src/seer/automation/autofix/components/root_cause/models.py @@ -38,34 +38,13 @@ class RootCauseRelevantContext(BaseModel): description: str snippet: Optional[RootCauseRelevantCodeSnippet] - class RootCauseAnalysisRelevantContext(BaseModel): snippets: list[RootCauseRelevantContext] - -class UnitTestSnippetPrompt(BaseModel): - file_path: str - code_snippet: str - description: str - - @field_validator("code_snippet") - @classmethod - def clean_code_snippet(cls, v: str) -> str: - return remove_code_backticks(v) - - -class UnitTestSnippet(BaseModel): - file_path: str - snippet: str - description: str - - class RootCauseAnalysisItem(BaseModel): id: int = -1 title: str description: str - # unit_test: UnitTestSnippet | None = None - # reproduction: str | None = None code_context: Optional[list[RootCauseRelevantContext]] = None def to_markdown_string(self) -> str: @@ -91,8 +70,6 @@ def to_markdown_string(self) -> str: class RootCauseAnalysisItemPrompt(BaseModel): title: str description: str - # reproduction_instructions: str | None = None - # unit_test: UnitTestSnippetPrompt | None = None relevant_code: Optional[RootCauseAnalysisRelevantContext] @classmethod @@ -100,16 +77,6 @@ def from_model(cls, model: RootCauseAnalysisItem): return cls( title=model.title, description=model.description, - # reproduction_instructions=model.reproduction, - # unit_test=( - # UnitTestSnippetPrompt( - # file_path=model.unit_test.file_path, - # code_snippet=model.unit_test.snippet, - # description=model.unit_test.description, - # ) - # if model.unit_test - # else None - # ), relevant_code=( RootCauseAnalysisRelevantContext( snippets=[ @@ -131,16 +98,6 @@ def to_model(self): return RootCauseAnalysisItem.model_validate( { **self.model_dump(), - # "reproduction": self.reproduction_instructions, - # "unit_test": ( - # { - # "file_path": self.unit_test.file_path, - # "snippet": self.unit_test.code_snippet, - # "description": self.unit_test.description, - # } - # if self.unit_test - # else None - # ), "code_context": ( self.relevant_code.model_dump()["snippets"] if self.relevant_code else None ), diff --git a/tests/automation/autofix/components/test_root_cause_models.py b/tests/automation/autofix/components/test_root_cause_models.py new file mode 100644 index 000000000..76f56d805 --- /dev/null +++ b/tests/automation/autofix/components/test_root_cause_models.py @@ -0,0 +1,127 @@ +import pytest +from pydantic import ValidationError + +from seer.automation.autofix.components.root_cause.models import ( + RootCauseAnalysisItem, + RootCauseAnalysisItemPrompt, + RootCauseAnalysisRelevantContext, + RootCauseRelevantCodeSnippet, + RootCauseRelevantContext, +) + + +class TestRootCauseModels: + def test_basic_model_validation(self): + """Test that basic model validation works with minimal required fields.""" + item = RootCauseAnalysisItem( + title="Test Title", + description="Test Description", + ) + assert item.title == "Test Title" + assert item.description == "Test Description" + assert item.code_context is None + assert item.id == -1 # Default value + + def test_full_model_validation(self): + """Test that model validation works with all fields provided.""" + code_context = [ + RootCauseRelevantContext( + id=1, + title="Context Title", + description="Context Description", + snippet=RootCauseRelevantCodeSnippet( + file_path="test.py", + snippet="def test(): pass", + ), + ) + ] + + item = RootCauseAnalysisItem( + id=0, + title="Test Title", + description="Test Description", + code_context=code_context, + ) + + assert item.id == 0 + assert item.title == "Test Title" + assert item.description == "Test Description" + assert len(item.code_context) == 1 + assert item.code_context[0].id == 1 + assert item.code_context[0].title == "Context Title" + + def test_model_transformation(self): + """Test that transformation between Prompt and Item models works correctly.""" + # Create a prompt model + relevant_code = RootCauseAnalysisRelevantContext( + snippets=[ + RootCauseRelevantContext( + id=1, + title="Context Title", + description="Context Description", + snippet=RootCauseRelevantCodeSnippet( + file_path="test.py", + snippet="def test(): pass", + ), + ) + ] + ) + + prompt = RootCauseAnalysisItemPrompt( + title="Test Title", + description="Test Description", + relevant_code=relevant_code, + ) + + # Transform to item model + item = prompt.to_model() + + assert item.title == "Test Title" + assert item.description == "Test Description" + assert len(item.code_context) == 1 + assert item.code_context[0].id == 1 + assert item.code_context[0].title == "Context Title" + + def test_model_missing_required_fields(self): + """Test that model validation fails when required fields are missing.""" + with pytest.raises(ValidationError) as exc_info: + RootCauseAnalysisItem(title="Test Title") + assert "description" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + RootCauseAnalysisItem(description="Test Description") + assert "title" in str(exc_info.value) + + def test_to_markdown_string(self): + """Test that markdown string generation works correctly.""" + code_context = [ + RootCauseRelevantContext( + id=1, + title="Context Title", + description="Context Description", + snippet=RootCauseRelevantCodeSnippet( + file_path="test.py", + snippet="def test(): pass", + repo_name="test/repo", + ), + ) + ] + + item = RootCauseAnalysisItem( + id=0, + title="Test Title", + description="Test Description", + code_context=code_context, + ) + + markdown = item.to_markdown_string() + + assert "# Test Title" in markdown + assert "## Description" in markdown + assert "Test Description" in markdown + assert "## Relevant Code Context" in markdown + assert "### Context Title" in markdown + assert "Context Description" in markdown + assert "**File:** test.py" in markdown + assert "**Repository:** test/repo" in markdown + assert "def test(): pass" in markdown \ No newline at end of file