Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 88 additions & 45 deletions hindsight-api-slim/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8943,57 +8943,100 @@ async def submit_async_retain(
num_sub_batches=len(sub_batches),
)

# Persist the parent row and all child rows in a single transaction.
#
# The parent row is a status aggregator with NO task_payload (workers
# skip rows where task_payload IS NULL because they're not directly
# executable). Its lifecycle is driven by child completions: when all
# children reach a terminal state, the parent gets promoted by the
# aggregator.
#
# If the parent INSERT and child INSERTs are not transactionally
# coupled, any failure between them (connection drop, timeout, schema
# cache invalidation under concurrent load) leaves a parent row with
# zero children. Workers ignore it forever (no task_payload), the
# aggregator never fires (no children to complete), and the row sits
# pending indefinitely — visible in queue-depth metrics and growing
# without bound. Wrapping parent + children in one transaction makes
# the create-batch operation atomic: either all rows are visible to
# workers, or none are.
#
# submit_task() must run AFTER the transaction commits. SyncTaskBackend
# (used in tests) executes the task synchronously, which would not see
# the still-uncommitted child row. BrokerTaskBackend / WorkerTaskBackend
# are effectively no-ops for already-populated task_payload, but we
# defer them all uniformly for clarity.
deferred_child_payloads: list[dict[str, Any]] = []

async with acquire_with_retry(backend) as conn:
await conn.execute(
f"""
INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata, status)
VALUES ($1, $2, $3, $4, $5)
""",
parent_operation_id,
bank_id,
"batch_retain",
json.dumps(parent_metadata.to_dict()),
"pending", # Will be updated by status aggregation
)
async with conn.transaction():
await conn.execute(
f"""
INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata, status)
VALUES ($1, $2, $3, $4, $5)
""",
parent_operation_id,
bank_id,
"batch_retain",
json.dumps(parent_metadata.to_dict()),
"pending", # Will be updated by status aggregation
)

logger.info(f"Created parent operation {parent_operation_id} for {len(sub_batches)} sub-batch(es)")
for i, sub_batch in enumerate(sub_batches, 1):
if len(sub_batches) > 1:
sub_batch_tokens = sum(count_tokens(item.get("content", "")) for item in sub_batch)
logger.info(
f"Submitting sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_tokens:,} tokens"
)

# Submit child operations for each sub-batch
for i, sub_batch in enumerate(sub_batches, 1):
if len(sub_batches) > 1:
sub_batch_tokens = sum(count_tokens(item.get("content", "")) for item in sub_batch)
logger.info(
f"Submitting sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_tokens:,} tokens"
)
task_payload: dict[str, Any] = {"contents": sub_batch}
if document_tags:
task_payload["document_tags"] = document_tags
if strategy:
task_payload["strategy"] = strategy
# Pass tenant_id and api_key_id through task payload
if request_context.tenant_id:
task_payload["_tenant_id"] = request_context.tenant_id
if request_context.api_key_id:
task_payload["_api_key_id"] = request_context.api_key_id

child_metadata = BatchRetainChildMetadata(
items_count=len(sub_batch),
parent_operation_id=str(parent_operation_id),
sub_batch_index=i,
total_sub_batches=len(sub_batches),
)

task_payload: dict[str, Any] = {"contents": sub_batch}
if document_tags:
task_payload["document_tags"] = document_tags
if strategy:
task_payload["strategy"] = strategy
# Pass tenant_id and api_key_id through task payload
if request_context.tenant_id:
task_payload["_tenant_id"] = request_context.tenant_id
if request_context.api_key_id:
task_payload["_api_key_id"] = request_context.api_key_id
child_operation_id = uuid.uuid4()
full_payload = {
"type": "batch_retain",
"operation_id": str(child_operation_id),
"bank_id": bank_id,
**task_payload,
}

# Create typed metadata for child operation
child_metadata = BatchRetainChildMetadata(
items_count=len(sub_batch),
parent_operation_id=str(parent_operation_id),
sub_batch_index=i,
total_sub_batches=len(sub_batches),
)
await conn.execute(
f"""
INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata, status, task_payload)
VALUES ($1, $2, $3, $4, $5, $6::jsonb)
""",
child_operation_id,
bank_id,
"retain",
json.dumps(child_metadata.to_dict(), default=_json_default),
"pending",
json.dumps(full_payload, default=_json_default),
)
deferred_child_payloads.append(full_payload)

# Create child operation with reference to parent
await self._submit_async_operation(
bank_id=bank_id,
operation_type="retain",
task_type="batch_retain",
task_payload=task_payload,
result_metadata=child_metadata.to_dict(),
dedupe_by_bank=False,
)
logger.info(f"Created parent operation {parent_operation_id} with {len(sub_batches)} child sub-batch(es)")

# Notify the task backend after commit. For BrokerTaskBackend /
# WorkerTaskBackend in production this is a no-op because task_payload
# is already populated; for SyncTaskBackend in tests this kicks off
# synchronous execution against the now-committed rows.
for full_payload in deferred_child_payloads:
await self._task_backend.submit_task(full_payload)

return {
"operation_id": str(parent_operation_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ async def handle_document_tracking(
is_first_batch: bool,
retain_params: dict | None = None,
document_tags: list[str] | None = None,
ops=None,
) -> None:
"""
Handle document tracking in the database (full-replace mode).
Expand All @@ -284,6 +285,11 @@ async def handle_document_tracking(
is_first_batch: Whether this is the first batch (for chunked operations)
retain_params: Optional parameters passed during retain (context, event_date, etc.)
document_tags: Optional list of tags to associate with the document
ops: Backend-specific DataAccessOps. Required by the inner
``delete_stale_observations_for_memories`` call to choose the PG
(native array) vs Oracle (junction table) read path. Defaults to
None so older callers don't break, but the PG branch is only
taken when ops is non-None — pass ``pool.ops`` from the caller.
"""
import hashlib

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ async def _process_db_batch(
is_first_batch,
retain_params,
merged_tags,
ops=pool.ops,
)
doc_tracking_done[0] = True
log_buffer.append(f"[streaming] Document {effective_doc_id} tracked (0 facts in first batch)")
Expand Down Expand Up @@ -1205,6 +1206,7 @@ async def _run_mini_batch_db_work() -> None:
is_first_batch,
retain_params,
merged_tags,
ops=pool.ops,
)
log_buffer.append(f"[streaming] Document {effective_doc_id} tracked (full content)")
doc_tracking_done[0] = True
Expand Down Expand Up @@ -1361,6 +1363,7 @@ async def _run_mini_batch_db_work() -> None:
is_first_batch,
retain_params,
merged_tags,
ops=pool.ops,
)
doc_tracking_done[0] = True
log_buffer.append(f"[streaming] Document {effective_doc_id} tracked (no facts extracted)")
Expand Down
63 changes: 63 additions & 0 deletions hindsight-api-slim/tests/test_async_batch_retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,3 +829,66 @@ async def failing_submit_task(_task_dict):
assert payload["type"] == "batch_retain"
assert payload["bank_id"] == bank_id
assert payload["contents"] == [{"content": "hello", "document_id": "d1"}]


@pytest.mark.asyncio
async def test_submit_async_batch_retain_rolls_back_parent_on_child_failure(
memory_no_llm_verify, request_context, monkeypatch
):
"""Regression for orphaned-parent rows.

submit_async_batch_retain inserts a parent row (status='pending',
task_payload=NULL — it's a status aggregator, not directly executable) and
then loops to insert one child row per sub-batch. If the parent INSERT and
the child INSERTs were not transactionally coupled, any failure during the
child loop (connection drop, timeout, schema-cache invalidation under
concurrent load) would leave a parent with zero children. Workers ignore
such rows forever (task_payload IS NULL filter), the status aggregator
never fires (no children to complete), and the row sits pending
indefinitely — visible in queue-depth metrics and growing without bound.

This test simulates a child-step failure by raising on the second
BatchRetainChildMetadata construction. After the failure we expect zero
async_operations rows for the bank: the parent INSERT must roll back
together with the children.
"""
import hindsight_api.engine.memory_engine as me
from hindsight_api.engine.memory_engine import count_tokens

bank_id = f"test_parent_rollback_{uuid.uuid4().hex[:8]}"
pool = await memory_no_llm_verify._get_pool()
await _ensure_bank(pool, bank_id)

# Force at least 2 sub-batches so the child loop runs more than once
# (matches the existing large-batch fixture's sizing).
large_content = "The quick brown fox jumps over the lazy dog. " * 500
contents = [{"content": large_content + f" item {i}", "document_id": f"doc{i}"} for i in range(2)]
assert sum(count_tokens(item["content"]) for item in contents) > 10_000

real_class = me.BatchRetainChildMetadata
call_count = {"n": 0}

def failing_child_metadata(*args, **kwargs):
call_count["n"] += 1
if call_count["n"] == 2:
raise RuntimeError("Simulated child-step failure mid-batch")
return real_class(*args, **kwargs)

monkeypatch.setattr(me, "BatchRetainChildMetadata", failing_child_metadata)

with pytest.raises(RuntimeError, match="Simulated child-step failure"):
await memory_no_llm_verify.submit_async_retain(
bank_id=bank_id,
contents=contents,
request_context=request_context,
)

rows = await pool.fetch(
"SELECT operation_id, operation_type, status, task_payload FROM async_operations WHERE bank_id = $1",
bank_id,
)
assert rows == [], (
f"Expected zero rows for bank_id={bank_id} after rollback, got {len(rows)}: "
f"{[(r['operation_type'], r['status'], r['task_payload'] is not None) for r in rows]}. "
"The parent INSERT must be transactionally coupled to the child INSERTs."
)
66 changes: 50 additions & 16 deletions hindsight-api-slim/tests/test_async_retain_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,35 @@

@pytest.mark.asyncio
async def test_submit_async_retain_includes_document_tags_in_task_payload():
"""submit_async_retain should include document_tags in queued task payload."""
"""submit_async_retain should include document_tags in queued task payload.

submit_async_batch_retain inserts the parent + all children inline inside
a single transaction (not via _submit_async_operation), then notifies the
task backend after commit. The test verifies document_tags propagates
through to the task backend's submit_task call, which is the post-commit
notification that drives SyncTaskBackend in tests and is a no-op for
BrokerTaskBackend / WorkerTaskBackend in production.
"""
import json

engine = MemoryEngine.__new__(MemoryEngine)
engine._initialized = True
engine._authenticate_tenant = AsyncMock()
engine._operation_validator = None
engine._submit_async_operation = AsyncMock(return_value={"operation_id": "op-1"})
# Children are now inserted inline (no _submit_async_operation hop), and
# submit_task fires post-commit. Mock both so the inline path runs cleanly
# without the test needing real DB or task backend.
engine._task_backend = AsyncMock()
engine._task_backend.submit_task = AsyncMock()

# Mock the pool and connection for parent operation creation
# Mock the pool and connection for parent + child INSERTs in one transaction.
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.transaction = MagicMock()
mock_conn.transaction.return_value.__aenter__ = AsyncMock()
mock_conn.transaction.return_value.__aexit__ = AsyncMock()
# __aexit__ must return falsy or `async with` will swallow exceptions —
# AsyncMock's default return is a truthy MagicMock.
mock_conn.transaction.return_value.__aexit__ = AsyncMock(return_value=False)

mock_pool = AsyncMock()
mock_pool.acquire = AsyncMock(return_value=mock_conn)
Expand Down Expand Up @@ -62,18 +78,36 @@ async def test_submit_async_retain_includes_document_tags_in_task_payload():
# Verify authentication was called
engine._authenticate_tenant.assert_awaited_once_with(request_context)

# Verify child operation was submitted
engine._submit_async_operation.assert_awaited_once()

# Verify child operation payload contains document_tags
kwargs = engine._submit_async_operation.await_args.kwargs
assert kwargs["bank_id"] == "bank-1"
assert kwargs["operation_type"] == "retain"
assert kwargs["task_type"] == "batch_retain"
assert kwargs["task_payload"]["contents"] == contents
assert kwargs["task_payload"]["document_tags"] == document_tags
assert kwargs["task_payload"]["_tenant_id"] == "tenant-a"
assert kwargs["task_payload"]["_api_key_id"] == "key-a"
# The parent + child INSERTs both went through mock_conn.execute. There
# should be exactly two: one for the parent (no task_payload), one for the
# single child (with task_payload). The child INSERT serializes
# full_payload — which carries document_tags — to JSON.
assert mock_conn.execute.await_count == 2, (
f"Expected two INSERTs (parent + child), got {mock_conn.execute.await_count}"
)

# Verify the post-commit submit_task fires once with a payload containing
# the document_tags (this is what gets handed to SyncTaskBackend in tests,
# and what carries the work into the worker in production).
engine._task_backend.submit_task.assert_awaited_once()
full_payload = engine._task_backend.submit_task.await_args.args[0]
assert full_payload["type"] == "batch_retain"
assert full_payload["bank_id"] == "bank-1"
assert full_payload["contents"] == contents
assert full_payload["document_tags"] == document_tags
assert full_payload["_tenant_id"] == "tenant-a"
assert full_payload["_api_key_id"] == "key-a"

# Cross-check: the child INSERT's task_payload column also contains
# the document_tags (same JSON the worker poller would later read).
child_insert_args = mock_conn.execute.await_args_list[1].args
# Positional: (sql, operation_id, bank_id, operation_type, result_metadata,
# status, task_payload_json)
child_task_payload_json = child_insert_args[6]
child_task_payload = json.loads(child_task_payload_json)
assert child_task_payload["document_tags"] == document_tags
assert child_task_payload["_tenant_id"] == "tenant-a"
assert child_task_payload["_api_key_id"] == "key-a"


@pytest.mark.asyncio
Expand Down
6 changes: 6 additions & 0 deletions hindsight-api-slim/tests/test_observation_invalidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ async def test_upsert_document_removes_observations_from_outgoing_memories(

# Trigger the upsert path directly. ``handle_document_tracking`` is
# what the retain orchestrator calls on every document re-ingest.
# Pass ops=memory._backend.ops so the inner observation-cleanup query
# selects the PG (native array) read path instead of falling back to
# the Oracle junction-table path (which would query a non-existent
# public.observation_sources relation under PG). The orchestrator
# call sites in _streaming_retain_batch already do this via pool.ops.
async with pool.acquire() as conn:
async with conn.transaction():
await handle_document_tracking(
Expand All @@ -339,6 +344,7 @@ async def test_upsert_document_removes_observations_from_outgoing_memories(
is_first_batch=True,
retain_params=None,
document_tags=None,
ops=memory._backend.ops,
)

async with pool.acquire() as conn:
Expand Down
Loading