Skip to content

Commit c611354

Browse files
committed
feat: added langextract audit and guardrails
1 parent 49dca81 commit c611354

File tree

10 files changed

+1132
-96
lines changed

10 files changed

+1132
-96
lines changed

.env.example

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,25 @@ EXTRACTION_CACHE_BACKEND=redis
4141
TASK_TIME_LIMIT=3600
4242
TASK_SOFT_TIME_LIMIT=3300
4343
RESULT_EXPIRES=86400
44+
45+
# ── Audit logging (langextract-audit) ────────────────────────────────────────
46+
# Enable structured audit logging for every LLM inference call
47+
AUDIT_ENABLED=false
48+
# Sink type: logging (stdlib), jsonfile (NDJSON file), otel (OpenTelemetry)
49+
AUDIT_SINK=logging
50+
# Path for the NDJSON audit file (only used when AUDIT_SINK=jsonfile)
51+
AUDIT_LOG_PATH=audit.jsonl
52+
# Truncated sample length for prompt/response in audit records (unset = disabled)
53+
# AUDIT_SAMPLE_LENGTH=200
54+
55+
# ── Guardrails / output validation (langextract-guardrails) ──────────────────
56+
# Enable LLM output validation with automatic retry & corrective prompting
57+
GUARDRAILS_ENABLED=false
58+
# Maximum retry attempts when validation fails (default 3)
59+
GUARDRAILS_MAX_RETRIES=3
60+
# Include invalid output in correction prompt (set false to save tokens)
61+
GUARDRAILS_INCLUDE_OUTPUT_IN_CORRECTION=true
62+
# Truncate original prompt in correction prompts (unset = no limit)
63+
# GUARDRAILS_MAX_CORRECTION_PROMPT_LENGTH=2000
64+
# Truncate invalid output in correction prompts (unset = no limit)
65+
# GUARDRAILS_MAX_CORRECTION_OUTPUT_LENGTH=1000

app/core/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ class Settings(BaseSettings):
9797
EXTRACTION_CACHE_TTL: int = 86400 # seconds (24 h)
9898
EXTRACTION_CACHE_BACKEND: str = "redis" # redis | disk | none
9999

100+
# ── Audit logging ───────────────────────────────────────────────
101+
AUDIT_ENABLED: bool = False
102+
AUDIT_SINK: str = "logging" # logging | jsonfile | otel
103+
AUDIT_LOG_PATH: str = "audit.jsonl"
104+
AUDIT_SAMPLE_LENGTH: int | None = None
105+
106+
# ── Guardrails (output validation) ──────────────────────────────
107+
GUARDRAILS_ENABLED: bool = False
108+
GUARDRAILS_MAX_RETRIES: int = 3
109+
GUARDRAILS_MAX_CONCURRENCY: int | None = None
110+
GUARDRAILS_INCLUDE_OUTPUT_IN_CORRECTION: bool = True
111+
GUARDRAILS_MAX_CORRECTION_PROMPT_LENGTH: int | None = None
112+
GUARDRAILS_MAX_CORRECTION_OUTPUT_LENGTH: int | None = None
113+
100114
@field_validator("CORS_ORIGINS", mode="before")
101115
@classmethod
102116
def _parse_cors(cls, v: str | list[str]) -> list[str]:

app/schemas/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from app.schemas.enums import TaskState
1414
from app.schemas.health import CeleryHealthResponse, HealthResponse
1515
from app.schemas.requests import (
16+
AuditConfig,
1617
BatchExtractionRequest,
1718
ExtractionConfig,
1819
ExtractionRequest,
20+
GuardrailsConfig,
1921
Provider,
2022
)
2123
from app.schemas.responses import (
@@ -31,6 +33,7 @@
3133
)
3234

3335
__all__ = [
36+
"AuditConfig",
3437
"BatchExtractionRequest",
3538
"BatchTaskSubmitResponse",
3639
"CeleryHealthResponse",
@@ -39,6 +42,7 @@
3942
"ExtractionMetadata",
4043
"ExtractionRequest",
4144
"ExtractionResult",
45+
"GuardrailsConfig",
4246
"HealthResponse",
4347
"Provider",
4448
"TaskRevokeResponse",

app/schemas/requests.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,98 @@
7979
# ── Extraction configuration model ─────────────────────────
8080

8181

82+
class GuardrailsConfig(BaseModel):
83+
"""Per-request guardrails configuration.
84+
85+
Controls LLM output validation, retry, and corrective
86+
prompting via ``langextract-guardrails``. When provided
87+
in ``extraction_config``, overrides the global
88+
``GUARDRAILS_*`` settings for this request.
89+
"""
90+
91+
enabled: bool | None = Field(
92+
default=None,
93+
description=(
94+
"Enable output validation with retry. "
95+
"When ``None``, falls back to the global "
96+
"``GUARDRAILS_ENABLED`` setting."
97+
),
98+
)
99+
json_schema: dict[str, Any] | None = Field(
100+
default=None,
101+
description=(
102+
"JSON Schema dict to validate LLM output against. "
103+
"When set, a ``JsonSchemaValidator`` is created "
104+
"automatically."
105+
),
106+
)
107+
regex_pattern: str | None = Field(
108+
default=None,
109+
description=(
110+
"Regex pattern the LLM output must match. "
111+
"When set, a ``RegexValidator`` is created."
112+
),
113+
)
114+
regex_description: str | None = Field(
115+
default=None,
116+
description=(
117+
"Human-readable description of the regex pattern, "
118+
"used in corrective error messages."
119+
),
120+
)
121+
max_retries: int | None = Field(
122+
default=None,
123+
ge=0,
124+
le=10,
125+
description=(
126+
"Maximum retry attempts on validation failure. "
127+
"Overrides ``GUARDRAILS_MAX_RETRIES``."
128+
),
129+
)
130+
include_output_in_correction: bool | None = Field(
131+
default=None,
132+
description=(
133+
"Include the invalid output in the correction "
134+
"prompt. Set ``False`` for error-only mode to "
135+
"save tokens."
136+
),
137+
)
138+
json_schema_strict: bool = Field(
139+
default=True,
140+
description=(
141+
"When ``True``, additional properties not in "
142+
"the schema cause validation failure."
143+
),
144+
)
145+
146+
147+
class AuditConfig(BaseModel):
148+
"""Per-request audit configuration.
149+
150+
Controls structured audit logging via
151+
``langextract-audit``. When provided in
152+
``extraction_config``, overrides the global
153+
``AUDIT_*`` settings for this request.
154+
"""
155+
156+
enabled: bool | None = Field(
157+
default=None,
158+
description=(
159+
"Enable audit logging. When ``None``, falls back "
160+
"to the global ``AUDIT_ENABLED`` setting."
161+
),
162+
)
163+
sample_length: int | None = Field(
164+
default=None,
165+
ge=0,
166+
description=(
167+
"Store truncated prompt/response samples in audit "
168+
"records for debugging. Overrides "
169+
"``AUDIT_SAMPLE_LENGTH``."
170+
),
171+
)
172+
173+
82174
class ExtractionConfig(BaseModel):
83175
"""Typed extraction configuration overrides.
84176
@@ -158,14 +250,38 @@ class ExtractionConfig(BaseModel):
158250
"prompt-only extraction."
159251
),
160252
)
253+
guardrails: GuardrailsConfig | None = Field(
254+
default=None,
255+
description=(
256+
"Output validation and retry configuration via "
257+
"langextract-guardrails. When unset, falls back "
258+
"to the global ``GUARDRAILS_*`` settings."
259+
),
260+
)
261+
audit: AuditConfig | None = Field(
262+
default=None,
263+
description=(
264+
"Audit logging configuration via "
265+
"langextract-audit. When unset, falls back "
266+
"to the global ``AUDIT_*`` settings."
267+
),
268+
)
161269

162270
def to_flat_dict(self) -> dict[str, Any]:
163271
"""Return a dict with only non-None values.
164272
273+
Nested models (``guardrails``, ``audit``) are serialized
274+
to plain dicts so the result is JSON-serializable and
275+
safe for Celery task arguments.
276+
165277
Returns:
166278
Flat dict suitable for ``run_extraction``.
167279
"""
168-
return {k: v for k, v in self.model_dump().items() if v is not None}
280+
data: dict[str, Any] = {}
281+
for k, v in self.model_dump().items():
282+
if v is not None:
283+
data[k] = v
284+
return data
169285

170286

171287
# ── Request models ──────────────────────────────────────────

app/services/extractor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ExtractionCache,
4545
build_cache_key,
4646
)
47+
from app.services.model_wrappers import apply_model_wrappers
4748
from app.services.provider_manager import ProviderManager
4849
from app.services.providers import is_openai_model, resolve_api_key
4950
from app.services.structured_output import (
@@ -349,6 +350,13 @@ def run_extraction(
349350
response_format=response_format,
350351
)
351352

353+
# ── Step 3b: Apply guardrails & audit wrappers ──────────
354+
cached_model = apply_model_wrappers(
355+
cached_model,
356+
provider,
357+
extraction_config,
358+
)
359+
352360
extract_kwargs: dict[str, Any] = {
353361
"text_or_documents": text_input,
354362
"prompt_description": prompt_description,
@@ -625,6 +633,13 @@ async def async_run_extraction(
625633
response_format=response_format_async,
626634
)
627635

636+
# ── Step 3b: Apply guardrails & audit wrappers (async) ───
637+
cached_model = apply_model_wrappers(
638+
cached_model,
639+
provider,
640+
extraction_config,
641+
)
642+
628643
extract_kwargs: dict[str, Any] = {
629644
"text_or_documents": text_input,
630645
"prompt_description": prompt_description,

0 commit comments

Comments
 (0)