Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions app/server/handlers/redaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,17 @@ async def _get_doc_result(
doc = await store.get_result_doc(doc_id)
if not doc:
errors = getattr(final_task_result, "errors", [])
# If we get here without any recorded errors, the pipeline
# finished cleanly but the result key is no longer present
# in Redis -- almost always because it expired or was
# evicted between completion and this poll. Make that
# explicit so the caller knows to resubmit rather than
# treat this as a silent pipeline bug.
error_message = (
"Redaction job completed but document "
"is missing and no specific errors were recorded."
"Redaction completed, but the redacted document is no "
"longer available in the result store. The result has "
"likely expired or been evicted from cache; please "
"resubmit the redaction request."
)
if errors:
error_message = str(errors)
Expand Down
48 changes: 43 additions & 5 deletions app/server/tasks/finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..config import config
from ..db import DocumentStatus
from ..generated.models import OutputFormat, RedactionTarget
from .callback import CallbackTaskResult
from .callback import CallbackTaskResult, get_result_sync
from .metrics import (
celery_counters,
record_task_failure,
Expand Down Expand Up @@ -68,16 +68,54 @@ def finalize(
"""Finalize the redaction process."""
format_result = callback_result.formatted

celery_counters.record_job(bool(format_result.errors))
# Start from whatever the upstream pipeline reported.
final_errors: list[ProcessingError] = list(format_result.errors)

# If the pipeline thinks it succeeded, verify that the redacted document
# is actually retrievable from the result store. The chain can be
# "successful" while the result doc is absent (e.g. a Redis write that
# silently failed earlier, or the key was evicted under memory pressure
# before finalize ran). Surfacing this here keeps the experiments DB and
# the poll API in agreement: both will report ERROR with a clear cause
# instead of disagreeing about whether the document exists.
if not final_errors:
try:
doc = get_result_sync(
format_result.jurisdiction_id,
format_result.case_id,
format_result.document_id,
)
except Exception as e:
logger.exception("Failed to verify presence of redaction result in store")
final_errors.append(
ProcessingError.from_exception("finalize.verify_result", e)
)
else:
if doc is None:
final_errors.append(
ProcessingError(
message=(
"Redaction pipeline reported success but the "
"redacted document was not found in the result "
"store. The result may have failed to persist "
"or has been evicted/expired from cache before "
"finalize ran."
),
task="finalize.verify_result",
exception="MissingResultDocument",
)
)

celery_counters.record_job(bool(final_errors))

if config.experiments.enabled:
with config.experiments.store.driver.sync_session() as session:
status = DocumentStatus(
jurisdiction_id=format_result.jurisdiction_id,
case_id=format_result.case_id,
document_id=format_result.document_id,
status="ERROR" if format_result.errors else "COMPLETE",
error=format_errors(format_result.errors),
status="ERROR" if final_errors else "COMPLETE",
error=format_errors(final_errors),
)
session.add(status)
session.commit()
Expand Down Expand Up @@ -116,7 +154,7 @@ def finalize(
jurisdiction_id=format_result.jurisdiction_id,
case_id=format_result.case_id,
document_id=format_result.document_id,
errors=format_result.errors,
errors=final_errors,
next_task_id=str(next_task) if next_task else None,
)

Expand Down
62 changes: 56 additions & 6 deletions app/server/tasks/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ class FormatTaskResult(BaseModel):
errors: list[ProcessingError] = []


class ResultStoreWriteError(Exception):
"""Raised when persisting the formatted result to the result store fails.

Wraps the underlying exception so the format task can distinguish a
failure to *write* the result from other format-time failures and surface
a more specific ProcessingError after retries are exhausted.
"""


register_type(FormatTask)
register_type(FormatTaskResult)

Expand Down Expand Up @@ -98,19 +107,60 @@ def format(
else:
document = format_document(params, redact_result)

save_result_sync(
redact_result.jurisdiction_id,
redact_result.case_id,
redact_result.document_id,
document,
)
try:
save_result_sync(
redact_result.jurisdiction_id,
redact_result.case_id,
redact_result.document_id,
document,
)
except Exception as save_exc:
# Persisting the formatted result to the result store is the step
# that determines whether downstream consumers can actually find
# the document, so we tag this failure separately. Wrapping the
# original exception keeps the traceback intact while making the
# final ProcessingError unambiguous.
raise ResultStoreWriteError(
"Failed to persist redacted document to the result store"
) from save_exc

return FormatTaskResult(
jurisdiction_id=redact_result.jurisdiction_id,
case_id=redact_result.case_id,
document_id=redact_result.document_id,
errors=redact_result.errors,
)
except ResultStoreWriteError as e:
# Transient Redis hiccups are common, so we still retry. After
# exhausting retries, surface a ProcessingError tagged with the
# specific subsystem that failed (rather than a generic "format"
# error) so operators can tell write failures apart from
# rendering/encoding bugs.
cause = e.__cause__
# `__cause__` is typed as ``BaseException | None``; in practice the
# ``raise ... from save_exc`` site always chains an ``Exception``,
# but narrow defensively so mypy and ``ProcessingError.from_exception``
# agree, and fall back to ``e`` if anything ever chains a
# ``BaseException`` (e.g. ``SystemExit``).
underlying: Exception = cause if isinstance(cause, Exception) else e
if format.request.retries < format.max_retries:
logger.warning(
f"Failed to save format result: {underlying}, will be retried."
)
raise format.retry(exc=underlying) from e
logger.error(
f"Exhausted retries saving format result for {redact_result.document_id}"
)
logger.exception(underlying)
return FormatTaskResult(
jurisdiction_id=redact_result.jurisdiction_id,
case_id=redact_result.case_id,
document_id=redact_result.document_id,
errors=[
*redact_result.errors,
ProcessingError.from_exception("format.save_result", underlying),
],
)
except Exception as e:
if format.request.retries < format.max_retries:
logger.warning(f"Format task failed: {e}, will be retried.")
Expand Down
19 changes: 19 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@
"""


@pytest.fixture(scope="session", autouse=True)
def _init_celery_counters() -> None:
"""Initialize the Celery custom metrics counters for the test session.

In production these counters are wired up by the ``worker_process_init``
signal handler in ``app.server.tasks.queue`` when a Celery worker
process boots. Tests execute tasks eagerly via ``task.s(...).apply()``
and never start a real worker process, so that signal never fires --
which leaves ``celery_counters`` with no ``task_complete_counter``
attribute and causes any task that triggers ``on_failure`` /
``on_success`` to blow up with ``AttributeError`` deep inside Celery's
tracer. Initialize the counters once here so every test sees a
consistent, fully-wired metrics surface.
"""
from app.server.tasks.metrics import celery_counters

celery_counters.init()


@pytest.fixture
def logger() -> logging.Logger:
"""Logging for tests."""
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ProcessingError,
callback,
)
from app.server.tasks.metrics import celery_counters


def test_callback_no_callback_no_error():
Expand All @@ -31,7 +30,6 @@ def test_callback_no_callback_no_error():
),
)

celery_counters.init()
cb = CallbackTask(callback_url=None)

result = callback.s(fmt_result, cb).apply()
Expand All @@ -53,7 +51,6 @@ def test_callback_no_callback_with_error():
document=None,
)

celery_counters.init()
cb = CallbackTask(callback_url=None)

result = callback.s(fmt_result, cb).apply()
Expand Down Expand Up @@ -107,7 +104,6 @@ def test_callback_with_callback_no_error(fake_redis_store: FakeRedis):
],
)

celery_counters.init()
cb = CallbackTask(callback_url="http://callback.test.local")

result = callback.s(fmt_result, cb).apply()
Expand Down Expand Up @@ -156,7 +152,6 @@ def test_callback_with_callback_with_error(fake_redis_store: FakeRedis):
],
)

celery_counters.init()
cb = CallbackTask(callback_url="http://callback.test.local")

result = callback.s(fmt_result, cb).apply()
Expand Down
93 changes: 92 additions & 1 deletion tests/unit/test_case_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest

from app.server.case_helper import summarize_state
from app.server.handlers.redaction import _get_doc_result
from app.server.tasks import ProcessingError


def _result(state: str, name: str, result_value=None) -> MagicMock:
Expand Down Expand Up @@ -104,3 +108,90 @@ def test_summarize_single_success():
assert summary.simple_state == "SUCCESS"
assert summary.dominant_task_name == "fetch"
assert summary.result is tasks[0]


# --- _get_doc_result: SUCCESS-but-missing-doc messaging -----------------------
#
# These tests cover the handler branch where Celery reports SUCCESS but
# `get_result_doc` returns None. The user-facing error message should
# clearly indicate that the result has likely expired or been evicted,
# rather than the older "no specific errors were recorded" wording which
# read like a silent pipeline bug.


@pytest.mark.asyncio
async def test_get_doc_result_missing_doc_reports_expiry_message():
finalize_result = MagicMock()
# No `errors` attribute -> getattr default kicks in -> empty list.
del finalize_result.errors

finalize_async = MagicMock()
finalize_async.state = "SUCCESS"
finalize_async.name = "finalize"
finalize_async.result = finalize_result

store = MagicMock()
store.get_result_doc = AsyncMock(return_value=None)

with pytest.MonkeyPatch.context() as mp:
# Patch `get_result` to return our synthetic AsyncResult.
mp.setattr(
"app.server.handlers.redaction.get_result",
lambda task_id: finalize_async,
)
result = await _get_doc_result(
store=store,
jurisdiction_id="jur1",
case_id="case1",
doc_id="doc1",
task_ids=["finalize-task-id"],
masked_subjects=[],
)

body = result.root
assert body.status == "ERROR"
assert "expired" in body.error.lower() or "evicted" in body.error.lower()
assert "resubmit" in body.error.lower()


@pytest.mark.asyncio
async def test_get_doc_result_missing_doc_prefers_recorded_errors():
"""When the dominant task actually recorded errors, those take
precedence over the expiry-style fallback message.
"""
finalize_result = MagicMock()
finalize_result.errors = [
ProcessingError(
message="Boom",
task="format.save_result",
exception="RuntimeError",
)
]

finalize_async = MagicMock()
finalize_async.state = "SUCCESS"
finalize_async.name = "finalize"
finalize_async.result = finalize_result

store = MagicMock()
store.get_result_doc = AsyncMock(return_value=None)

with pytest.MonkeyPatch.context() as mp:
mp.setattr(
"app.server.handlers.redaction.get_result",
lambda task_id: finalize_async,
)
result = await _get_doc_result(
store=store,
jurisdiction_id="jur1",
case_id="case1",
doc_id="doc1",
task_ids=["finalize-task-id"],
masked_subjects=[],
)

body = result.root
assert body.status == "ERROR"
# Recorded error wins; expiry fallback should not appear.
assert "expired" not in body.error.lower()
assert "format.save_result" in body.error
Loading
Loading