diff --git a/hindsight-api-slim/hindsight_api/engine/memory_engine.py b/hindsight-api-slim/hindsight_api/engine/memory_engine.py index 3c414787a..649fef103 100644 --- a/hindsight-api-slim/hindsight_api/engine/memory_engine.py +++ b/hindsight-api-slim/hindsight_api/engine/memory_engine.py @@ -8943,57 +8943,100 @@ async def submit_async_retain( num_sub_batches=len(sub_batches), ) + # Persist the parent row and all child rows in a single transaction. + # + # The parent row is a status aggregator with NO task_payload (workers + # skip rows where task_payload IS NULL because they're not directly + # executable). Its lifecycle is driven by child completions: when all + # children reach a terminal state, the parent gets promoted by the + # aggregator. + # + # If the parent INSERT and child INSERTs are not transactionally + # coupled, any failure between them (connection drop, timeout, schema + # cache invalidation under concurrent load) leaves a parent row with + # zero children. Workers ignore it forever (no task_payload), the + # aggregator never fires (no children to complete), and the row sits + # pending indefinitely — visible in queue-depth metrics and growing + # without bound. Wrapping parent + children in one transaction makes + # the create-batch operation atomic: either all rows are visible to + # workers, or none are. + # + # submit_task() must run AFTER the transaction commits. SyncTaskBackend + # (used in tests) executes the task synchronously, which would not see + # the still-uncommitted child row. BrokerTaskBackend / WorkerTaskBackend + # are effectively no-ops for already-populated task_payload, but we + # defer them all uniformly for clarity. + deferred_child_payloads: list[dict[str, Any]] = [] + async with acquire_with_retry(backend) as conn: - await conn.execute( - f""" - INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata, status) - VALUES ($1, $2, $3, $4, $5) - """, - parent_operation_id, - bank_id, - "batch_retain", - json.dumps(parent_metadata.to_dict()), - "pending", # Will be updated by status aggregation - ) + async with conn.transaction(): + await conn.execute( + f""" + INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata, status) + VALUES ($1, $2, $3, $4, $5) + """, + parent_operation_id, + bank_id, + "batch_retain", + json.dumps(parent_metadata.to_dict()), + "pending", # Will be updated by status aggregation + ) - logger.info(f"Created parent operation {parent_operation_id} for {len(sub_batches)} sub-batch(es)") + for i, sub_batch in enumerate(sub_batches, 1): + if len(sub_batches) > 1: + sub_batch_tokens = sum(count_tokens(item.get("content", "")) for item in sub_batch) + logger.info( + f"Submitting sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_tokens:,} tokens" + ) - # Submit child operations for each sub-batch - for i, sub_batch in enumerate(sub_batches, 1): - if len(sub_batches) > 1: - sub_batch_tokens = sum(count_tokens(item.get("content", "")) for item in sub_batch) - logger.info( - f"Submitting sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_tokens:,} tokens" - ) + task_payload: dict[str, Any] = {"contents": sub_batch} + if document_tags: + task_payload["document_tags"] = document_tags + if strategy: + task_payload["strategy"] = strategy + # Pass tenant_id and api_key_id through task payload + if request_context.tenant_id: + task_payload["_tenant_id"] = request_context.tenant_id + if request_context.api_key_id: + task_payload["_api_key_id"] = request_context.api_key_id + + child_metadata = BatchRetainChildMetadata( + items_count=len(sub_batch), + parent_operation_id=str(parent_operation_id), + sub_batch_index=i, + total_sub_batches=len(sub_batches), + ) - task_payload: dict[str, Any] = {"contents": sub_batch} - if document_tags: - task_payload["document_tags"] = document_tags - if strategy: - task_payload["strategy"] = strategy - # Pass tenant_id and api_key_id through task payload - if request_context.tenant_id: - task_payload["_tenant_id"] = request_context.tenant_id - if request_context.api_key_id: - task_payload["_api_key_id"] = request_context.api_key_id + child_operation_id = uuid.uuid4() + full_payload = { + "type": "batch_retain", + "operation_id": str(child_operation_id), + "bank_id": bank_id, + **task_payload, + } - # Create typed metadata for child operation - child_metadata = BatchRetainChildMetadata( - items_count=len(sub_batch), - parent_operation_id=str(parent_operation_id), - sub_batch_index=i, - total_sub_batches=len(sub_batches), - ) + await conn.execute( + f""" + INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata, status, task_payload) + VALUES ($1, $2, $3, $4, $5, $6::jsonb) + """, + child_operation_id, + bank_id, + "retain", + json.dumps(child_metadata.to_dict(), default=_json_default), + "pending", + json.dumps(full_payload, default=_json_default), + ) + deferred_child_payloads.append(full_payload) - # Create child operation with reference to parent - await self._submit_async_operation( - bank_id=bank_id, - operation_type="retain", - task_type="batch_retain", - task_payload=task_payload, - result_metadata=child_metadata.to_dict(), - dedupe_by_bank=False, - ) + logger.info(f"Created parent operation {parent_operation_id} with {len(sub_batches)} child sub-batch(es)") + + # Notify the task backend after commit. For BrokerTaskBackend / + # WorkerTaskBackend in production this is a no-op because task_payload + # is already populated; for SyncTaskBackend in tests this kicks off + # synchronous execution against the now-committed rows. + for full_payload in deferred_child_payloads: + await self._task_backend.submit_task(full_payload) return { "operation_id": str(parent_operation_id), diff --git a/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py b/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py index 251410101..dd4661c8c 100644 --- a/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py +++ b/hindsight-api-slim/hindsight_api/engine/retain/fact_storage.py @@ -269,6 +269,7 @@ async def handle_document_tracking( is_first_batch: bool, retain_params: dict | None = None, document_tags: list[str] | None = None, + ops=None, ) -> None: """ Handle document tracking in the database (full-replace mode). @@ -284,6 +285,11 @@ async def handle_document_tracking( is_first_batch: Whether this is the first batch (for chunked operations) retain_params: Optional parameters passed during retain (context, event_date, etc.) document_tags: Optional list of tags to associate with the document + ops: Backend-specific DataAccessOps. Required by the inner + ``delete_stale_observations_for_memories`` call to choose the PG + (native array) vs Oracle (junction table) read path. Defaults to + None so older callers don't break, but the PG branch is only + taken when ops is non-None — pass ``pool.ops`` from the caller. """ import hashlib diff --git a/hindsight-api-slim/hindsight_api/engine/retain/orchestrator.py b/hindsight-api-slim/hindsight_api/engine/retain/orchestrator.py index 512b22ce3..a60b7694d 100644 --- a/hindsight-api-slim/hindsight_api/engine/retain/orchestrator.py +++ b/hindsight-api-slim/hindsight_api/engine/retain/orchestrator.py @@ -1103,6 +1103,7 @@ async def _process_db_batch( is_first_batch, retain_params, merged_tags, + ops=pool.ops, ) doc_tracking_done[0] = True log_buffer.append(f"[streaming] Document {effective_doc_id} tracked (0 facts in first batch)") @@ -1205,6 +1206,7 @@ async def _run_mini_batch_db_work() -> None: is_first_batch, retain_params, merged_tags, + ops=pool.ops, ) log_buffer.append(f"[streaming] Document {effective_doc_id} tracked (full content)") doc_tracking_done[0] = True @@ -1361,6 +1363,7 @@ async def _run_mini_batch_db_work() -> None: is_first_batch, retain_params, merged_tags, + ops=pool.ops, ) doc_tracking_done[0] = True log_buffer.append(f"[streaming] Document {effective_doc_id} tracked (no facts extracted)") diff --git a/hindsight-api-slim/tests/test_async_batch_retain.py b/hindsight-api-slim/tests/test_async_batch_retain.py index b5cabd8f3..5442acff9 100644 --- a/hindsight-api-slim/tests/test_async_batch_retain.py +++ b/hindsight-api-slim/tests/test_async_batch_retain.py @@ -829,3 +829,66 @@ async def failing_submit_task(_task_dict): assert payload["type"] == "batch_retain" assert payload["bank_id"] == bank_id assert payload["contents"] == [{"content": "hello", "document_id": "d1"}] + + +@pytest.mark.asyncio +async def test_submit_async_batch_retain_rolls_back_parent_on_child_failure( + memory_no_llm_verify, request_context, monkeypatch +): + """Regression for orphaned-parent rows. + + submit_async_batch_retain inserts a parent row (status='pending', + task_payload=NULL — it's a status aggregator, not directly executable) and + then loops to insert one child row per sub-batch. If the parent INSERT and + the child INSERTs were not transactionally coupled, any failure during the + child loop (connection drop, timeout, schema-cache invalidation under + concurrent load) would leave a parent with zero children. Workers ignore + such rows forever (task_payload IS NULL filter), the status aggregator + never fires (no children to complete), and the row sits pending + indefinitely — visible in queue-depth metrics and growing without bound. + + This test simulates a child-step failure by raising on the second + BatchRetainChildMetadata construction. After the failure we expect zero + async_operations rows for the bank: the parent INSERT must roll back + together with the children. + """ + import hindsight_api.engine.memory_engine as me + from hindsight_api.engine.memory_engine import count_tokens + + bank_id = f"test_parent_rollback_{uuid.uuid4().hex[:8]}" + pool = await memory_no_llm_verify._get_pool() + await _ensure_bank(pool, bank_id) + + # Force at least 2 sub-batches so the child loop runs more than once + # (matches the existing large-batch fixture's sizing). + large_content = "The quick brown fox jumps over the lazy dog. " * 500 + contents = [{"content": large_content + f" item {i}", "document_id": f"doc{i}"} for i in range(2)] + assert sum(count_tokens(item["content"]) for item in contents) > 10_000 + + real_class = me.BatchRetainChildMetadata + call_count = {"n": 0} + + def failing_child_metadata(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 2: + raise RuntimeError("Simulated child-step failure mid-batch") + return real_class(*args, **kwargs) + + monkeypatch.setattr(me, "BatchRetainChildMetadata", failing_child_metadata) + + with pytest.raises(RuntimeError, match="Simulated child-step failure"): + await memory_no_llm_verify.submit_async_retain( + bank_id=bank_id, + contents=contents, + request_context=request_context, + ) + + rows = await pool.fetch( + "SELECT operation_id, operation_type, status, task_payload FROM async_operations WHERE bank_id = $1", + bank_id, + ) + assert rows == [], ( + f"Expected zero rows for bank_id={bank_id} after rollback, got {len(rows)}: " + f"{[(r['operation_type'], r['status'], r['task_payload'] is not None) for r in rows]}. " + "The parent INSERT must be transactionally coupled to the child INSERTs." + ) diff --git a/hindsight-api-slim/tests/test_async_retain_tags.py b/hindsight-api-slim/tests/test_async_retain_tags.py index 2f3b1a1bd..df82d4f45 100644 --- a/hindsight-api-slim/tests/test_async_retain_tags.py +++ b/hindsight-api-slim/tests/test_async_retain_tags.py @@ -10,19 +10,35 @@ @pytest.mark.asyncio async def test_submit_async_retain_includes_document_tags_in_task_payload(): - """submit_async_retain should include document_tags in queued task payload.""" + """submit_async_retain should include document_tags in queued task payload. + + submit_async_batch_retain inserts the parent + all children inline inside + a single transaction (not via _submit_async_operation), then notifies the + task backend after commit. The test verifies document_tags propagates + through to the task backend's submit_task call, which is the post-commit + notification that drives SyncTaskBackend in tests and is a no-op for + BrokerTaskBackend / WorkerTaskBackend in production. + """ + import json + engine = MemoryEngine.__new__(MemoryEngine) engine._initialized = True engine._authenticate_tenant = AsyncMock() engine._operation_validator = None - engine._submit_async_operation = AsyncMock(return_value={"operation_id": "op-1"}) + # Children are now inserted inline (no _submit_async_operation hop), and + # submit_task fires post-commit. Mock both so the inline path runs cleanly + # without the test needing real DB or task backend. + engine._task_backend = AsyncMock() + engine._task_backend.submit_task = AsyncMock() - # Mock the pool and connection for parent operation creation + # Mock the pool and connection for parent + child INSERTs in one transaction. mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.transaction = MagicMock() mock_conn.transaction.return_value.__aenter__ = AsyncMock() - mock_conn.transaction.return_value.__aexit__ = AsyncMock() + # __aexit__ must return falsy or `async with` will swallow exceptions — + # AsyncMock's default return is a truthy MagicMock. + mock_conn.transaction.return_value.__aexit__ = AsyncMock(return_value=False) mock_pool = AsyncMock() mock_pool.acquire = AsyncMock(return_value=mock_conn) @@ -62,18 +78,36 @@ async def test_submit_async_retain_includes_document_tags_in_task_payload(): # Verify authentication was called engine._authenticate_tenant.assert_awaited_once_with(request_context) - # Verify child operation was submitted - engine._submit_async_operation.assert_awaited_once() - - # Verify child operation payload contains document_tags - kwargs = engine._submit_async_operation.await_args.kwargs - assert kwargs["bank_id"] == "bank-1" - assert kwargs["operation_type"] == "retain" - assert kwargs["task_type"] == "batch_retain" - assert kwargs["task_payload"]["contents"] == contents - assert kwargs["task_payload"]["document_tags"] == document_tags - assert kwargs["task_payload"]["_tenant_id"] == "tenant-a" - assert kwargs["task_payload"]["_api_key_id"] == "key-a" + # The parent + child INSERTs both went through mock_conn.execute. There + # should be exactly two: one for the parent (no task_payload), one for the + # single child (with task_payload). The child INSERT serializes + # full_payload — which carries document_tags — to JSON. + assert mock_conn.execute.await_count == 2, ( + f"Expected two INSERTs (parent + child), got {mock_conn.execute.await_count}" + ) + + # Verify the post-commit submit_task fires once with a payload containing + # the document_tags (this is what gets handed to SyncTaskBackend in tests, + # and what carries the work into the worker in production). + engine._task_backend.submit_task.assert_awaited_once() + full_payload = engine._task_backend.submit_task.await_args.args[0] + assert full_payload["type"] == "batch_retain" + assert full_payload["bank_id"] == "bank-1" + assert full_payload["contents"] == contents + assert full_payload["document_tags"] == document_tags + assert full_payload["_tenant_id"] == "tenant-a" + assert full_payload["_api_key_id"] == "key-a" + + # Cross-check: the child INSERT's task_payload column also contains + # the document_tags (same JSON the worker poller would later read). + child_insert_args = mock_conn.execute.await_args_list[1].args + # Positional: (sql, operation_id, bank_id, operation_type, result_metadata, + # status, task_payload_json) + child_task_payload_json = child_insert_args[6] + child_task_payload = json.loads(child_task_payload_json) + assert child_task_payload["document_tags"] == document_tags + assert child_task_payload["_tenant_id"] == "tenant-a" + assert child_task_payload["_api_key_id"] == "key-a" @pytest.mark.asyncio diff --git a/hindsight-api-slim/tests/test_observation_invalidation.py b/hindsight-api-slim/tests/test_observation_invalidation.py index 41efb2052..68a8ac25a 100644 --- a/hindsight-api-slim/tests/test_observation_invalidation.py +++ b/hindsight-api-slim/tests/test_observation_invalidation.py @@ -329,6 +329,11 @@ async def test_upsert_document_removes_observations_from_outgoing_memories( # Trigger the upsert path directly. ``handle_document_tracking`` is # what the retain orchestrator calls on every document re-ingest. + # Pass ops=memory._backend.ops so the inner observation-cleanup query + # selects the PG (native array) read path instead of falling back to + # the Oracle junction-table path (which would query a non-existent + # public.observation_sources relation under PG). The orchestrator + # call sites in _streaming_retain_batch already do this via pool.ops. async with pool.acquire() as conn: async with conn.transaction(): await handle_document_tracking( @@ -339,6 +344,7 @@ async def test_upsert_document_removes_observations_from_outgoing_memories( is_first_batch=True, retain_params=None, document_tags=None, + ops=memory._backend.ops, ) async with pool.acquire() as conn: