Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions langextract/prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from langextract.core import tokenizer as tokenizer_lib

__all__ = [
"IssueKind",
"PromptValidationLevel",
"ValidationIssue",
"ValidationReport",
Expand All @@ -49,8 +50,8 @@ class PromptValidationLevel(enum.Enum):
ERROR = "error"


class _IssueKind(enum.Enum):
"""Internal categorization of alignment issues."""
class IssueKind(enum.Enum):
"""Categorization of alignment issues."""

FAILED = "failed" # alignment_status is None
NON_EXACT = "non_exact" # MATCH_FUZZY or MATCH_LESSER
Expand All @@ -65,7 +66,7 @@ class ValidationIssue:
extraction_class: str
extraction_text_preview: str
alignment_status: data.AlignmentStatus | None
issue_kind: _IssueKind
issue_kind: IssueKind
char_interval: tuple[int, int] | None = None
token_interval: tuple[int, int] | None = None

Expand All @@ -92,12 +93,12 @@ class ValidationReport:
@property
def has_failed(self) -> bool:
"""Returns True if any extraction failed to align."""
return any(i.issue_kind is _IssueKind.FAILED for i in self.issues)
return any(i.issue_kind is IssueKind.FAILED for i in self.issues)

@property
def has_non_exact(self) -> bool:
"""Returns True if any extraction has non-exact alignment."""
return any(i.issue_kind is _IssueKind.NON_EXACT for i in self.issues)
return any(i.issue_kind is IssueKind.NON_EXACT for i in self.issues)


class PromptAlignmentError(RuntimeError):
Expand Down Expand Up @@ -174,7 +175,7 @@ def validate_prompt_alignment(
extraction_class=klass,
extraction_text_preview=_preview(text),
alignment_status=None,
issue_kind=_IssueKind.FAILED,
issue_kind=IssueKind.FAILED,
char_interval=None,
token_interval=None,
)
Expand All @@ -200,7 +201,7 @@ def validate_prompt_alignment(
extraction_class=klass,
extraction_text_preview=_preview(text),
alignment_status=status,
issue_kind=_IssueKind.NON_EXACT,
issue_kind=IssueKind.NON_EXACT,
char_interval=char_interval_tuple,
token_interval=token_interval_tuple,
)
Expand Down Expand Up @@ -229,7 +230,7 @@ def handle_alignment_report(
return

for issue in report.issues:
if issue.issue_kind is _IssueKind.NON_EXACT:
if issue.issue_kind is IssueKind.NON_EXACT:
logging.warning(
"Prompt alignment: non-exact match: %s", issue.short_msg()
)
Expand All @@ -239,9 +240,9 @@ def handle_alignment_report(
)

if level is PromptValidationLevel.ERROR:
failed = [i for i in report.issues if i.issue_kind is _IssueKind.FAILED]
failed = [i for i in report.issues if i.issue_kind is IssueKind.FAILED]
non_exact = [
i for i in report.issues if i.issue_kind is _IssueKind.NON_EXACT
i for i in report.issues if i.issue_kind is IssueKind.NON_EXACT
]

if failed:
Expand Down
Loading