From a0ea38323d25555a2ec7ac93aae291841557fefd Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Sun, 9 Nov 2025 21:59:32 -0800 Subject: [PATCH 1/3] [ENH]: Execute task with no backfill or incremental --- chromadb/test/distributed/test_task_api.py | 40 +- go/pkg/sysdb/coordinator/coordinator.go | 10 + go/pkg/sysdb/coordinator/create_task_test.go | 19 +- .../heap_client_integration_test.go | 10 +- go/pkg/sysdb/coordinator/model/collection.go | 4 + go/pkg/sysdb/coordinator/table_catalog.go | 67 ++ go/pkg/sysdb/coordinator/task.go | 23 +- go/pkg/sysdb/grpc/collection_service.go | 127 +++ idl/chromadb/proto/coordinator.proto | 24 +- idl/chromadb/proto/heapservice.proto | 1 - .../src/impls/service_based_frontend.rs | 2 +- rust/segment/src/types.rs | 29 + rust/sysdb/src/bin/chroma-task-manager.rs | 11 +- rust/sysdb/src/sysdb.rs | 110 ++- rust/sysdb/src/test_sysdb.rs | 2 +- rust/types/src/flush.rs | 90 ++- rust/types/src/task.rs | 87 +- .../src/execution/functions/statistics.rs | 186 ++--- .../src/execution/operators/execute_task.rs | 105 +-- .../operators/finish_attached_function.rs | 148 ++++ .../operators/get_attached_function.rs | 151 ++++ .../operators/get_collection_and_segments.rs | 9 + rust/worker/src/execution/operators/mod.rs | 4 +- .../src/execution/operators/register.rs | 22 +- .../orchestration/apply_logs_orchestrator.rs | 35 +- .../attached_function_orchestrator.rs | 756 ++++++++++++++++++ .../src/execution/orchestration/compact.rs | 373 +++++++-- .../orchestration/log_fetch_orchestrator.rs | 18 +- .../worker/src/execution/orchestration/mod.rs | 1 + .../orchestration/register_orchestrator.rs | 200 ++++- 30 files changed, 2318 insertions(+), 346 deletions(-) create mode 100644 rust/worker/src/execution/operators/finish_attached_function.rs create mode 100644 rust/worker/src/execution/operators/get_attached_function.rs create mode 100644 rust/worker/src/execution/orchestration/attached_function_orchestrator.rs diff --git a/chromadb/test/distributed/test_task_api.py b/chromadb/test/distributed/test_task_api.py index e221a281953..91930d2fc3e 100644 --- a/chromadb/test/distributed/test_task_api.py +++ b/chromadb/test/distributed/test_task_api.py @@ -9,9 +9,13 @@ from chromadb.api.client import Client as ClientCreator from chromadb.config import System from chromadb.errors import ChromaError, NotFoundError +from chromadb.test.utils.wait_for_version_increase import ( + get_collection_version, + wait_for_version_increase, +) -def test_function_attach_and_detach(basic_http_client: System) -> None: +def test_count_function_attach_and_detach(basic_http_client: System) -> None: """Test creating and removing a function with the record_counter operator""" client = ClientCreator.from_system(basic_http_client) client.reset() @@ -22,21 +26,6 @@ def test_function_attach_and_detach(basic_http_client: System) -> None: metadata={"description": "Sample documents for task processing"}, ) - # Add initial documents - collection.add( - ids=["doc1", "doc2", "doc3"], - documents=[ - "The quick brown fox jumps over the lazy dog", - "Machine learning is a subset of artificial intelligence", - "Python is a popular programming language", - ], - metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}], - ) - - # Verify collection has documents - assert collection.count() == 3 - # TODO(tanujnay112): Verify the output collection has the correct count - # Create a task that counts records in the collection attached_fn = collection.attach_function( name="count_my_docs", @@ -47,19 +36,22 @@ def test_function_attach_and_detach(basic_http_client: System) -> None: # Verify task creation succeeded assert attached_fn is not None + initial_version = get_collection_version(client, collection.name) - # Add more documents + # Add documents collection.add( - ids=["doc4", "doc5"], - documents=[ - "Chroma is a vector database", - "Tasks automate data processing", - ], + ids=["doc_{}".format(i) for i in range(0, 300)], + documents=["test document"] * 300, ) # Verify documents were added - assert collection.count() == 5 - # TODO(tanujnay112): Verify the output collection has the correct count + assert collection.count() == 300 + + wait_for_version_increase(client, collection.name, initial_version) + + result = client.get_collection("my_documents_counts").get("function_output") + assert result["metadatas"] is not None + assert result["metadatas"][0]["total_count"] == 300 # Remove the task success = attached_fn.detach( diff --git a/go/pkg/sysdb/coordinator/coordinator.go b/go/pkg/sysdb/coordinator/coordinator.go index ccb8f3f9a2a..f343cfe38f9 100644 --- a/go/pkg/sysdb/coordinator/coordinator.go +++ b/go/pkg/sysdb/coordinator/coordinator.go @@ -13,6 +13,7 @@ import ( "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" s3metastore "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/s3" "github.com/chroma-core/chroma/go/pkg/types" + "github.com/google/uuid" "github.com/pingcap/log" "go.uber.org/zap" ) @@ -286,6 +287,15 @@ func (s *Coordinator) FlushCollectionCompaction(ctx context.Context, flushCollec return s.catalog.FlushCollectionCompaction(ctx, flushCollectionCompaction) } +func (s *Coordinator) FlushCollectionCompactionsAndAttachedFunction( + ctx context.Context, + collectionCompactions []*model.FlushCollectionCompaction, + attachedFunctionID uuid.UUID, + completionOffset int64, +) (*model.ExtendedFlushCollectionInfo, error) { + return s.catalog.FlushCollectionCompactionsAndAttachedFunction(ctx, collectionCompactions, attachedFunctionID, completionOffset) +} + func (s *Coordinator) ListCollectionsToGc(ctx context.Context, cutoffTimeSecs *uint64, limit *uint64, tenantID *string, minVersionsIfAlive *uint64) ([]*model.CollectionToGc, error) { return s.catalog.ListCollectionsToGc(ctx, cutoffTimeSecs, limit, tenantID, minVersionsIfAlive) } diff --git a/go/pkg/sysdb/coordinator/create_task_test.go b/go/pkg/sysdb/coordinator/create_task_test.go index 58d601b0ce5..7f99e70b75c 100644 --- a/go/pkg/sysdb/coordinator/create_task_test.go +++ b/go/pkg/sysdb/coordinator/create_task_test.go @@ -191,12 +191,6 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() { []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() - // Check output collection doesn't exist - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() - suite.mockCollectionDb.On("GetCollections", - []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). - Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() - // Insert attached function with lowest_live_nonce = NULL suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() suite.mockAttachedFunctionDb.On("Insert", mock.MatchedBy(func(attachedFunction *dbmodel.AttachedFunction) bool { @@ -225,7 +219,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() { suite.NoError(err) suite.NotNil(response) - suite.NotEmpty(response.Id) + suite.NotEmpty(response.AttachedFunction.Id) // Verify all mocks were called as expected suite.mockMetaDomain.AssertExpectations(suite.T()) @@ -317,7 +311,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea // Assertions suite.NoError(err) suite.NotNil(response) - suite.Equal(existingAttachedFunctionID.String(), response.Id) + suite.Equal(existingAttachedFunctionID.String(), response.AttachedFunction.Id) // Verify no writes occurred (no Insert, no heap Push) // Note: Transaction IS called for idempotency check, but no writes happen inside it @@ -390,11 +384,6 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() - suite.mockCollectionDb.On("GetCollections", - []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). - Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() - suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() suite.mockAttachedFunctionDb.On("Insert", mock.Anything).Return(nil).Once() @@ -408,7 +397,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { response1, err1 := suite.coordinator.AttachFunction(ctx, request) suite.NoError(err1) suite.NotNil(response1) - suite.NotEmpty(response1.Id) + suite.NotEmpty(response1.AttachedFunction.Id) // ========== GetAttachedFunctionByName: Should Return ErrAttachedFunctionNotReady ========== @@ -453,7 +442,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { response2, err2 := suite.coordinator.AttachFunction(ctx, request) suite.NoError(err2) suite.NotNil(response2) - suite.Equal(incompleteAttachedFunctionID.String(), response2.Id) + suite.Equal(incompleteAttachedFunctionID.String(), response2.AttachedFunction.Id) // Verify transaction was called in both attempts (idempotency check happens in transaction) suite.mockTxImpl.AssertNumberOfCalls(suite.T(), "Transaction", 2) // First attempt + recovery attempt diff --git a/go/pkg/sysdb/coordinator/heap_client_integration_test.go b/go/pkg/sysdb/coordinator/heap_client_integration_test.go index 5cd8c6ef23b..d656f7be10a 100644 --- a/go/pkg/sysdb/coordinator/heap_client_integration_test.go +++ b/go/pkg/sysdb/coordinator/heap_client_integration_test.go @@ -176,7 +176,7 @@ func (suite *HeapClientIntegrationTestSuite) TestAttachFunctionPushesScheduleToH }) suite.NoError(err, "Should attached function successfully") suite.NotNil(response) - suite.NotEmpty(response.Id, "Attached function ID should be returned") + suite.NotEmpty(response.AttachedFunction.Id, "Attached function ID should be returned") // Get updated heap summary updatedSummary, err := suite.heapClient.Summary(ctx, &coordinatorpb.HeapSummaryRequest{}) @@ -376,12 +376,12 @@ func (suite *HeapClientIntegrationTestSuite) TestPartialTaskCleanup_ThenRecreate }) suite.NoError(err, "Task should still exist after cleanup") suite.NotNil(getResp) - suite.Equal(taskResp.Id, getResp.AttachedFunction.Id) + suite.Equal(taskResp.AttachedFunction.Id, getResp.AttachedFunction.Id) suite.T().Logf("Task still exists after cleanup: %s", getResp.AttachedFunction.Id) // STEP 4: Delete the task _, err = suite.sysdbClient.DetachFunction(ctx, &coordinatorpb.DetachFunctionRequest{ - AttachedFunctionId: taskResp.Id, + AttachedFunctionId: taskResp.AttachedFunction.Id, DeleteOutput: true, }) suite.NoError(err, "Should delete task") @@ -398,8 +398,8 @@ func (suite *HeapClientIntegrationTestSuite) TestPartialTaskCleanup_ThenRecreate }) suite.NoError(err, "Should be able to recreate task after deletion") suite.NotNil(taskResp2) - suite.NotEqual(taskResp.Id, taskResp2.Id, "New task should have different ID") - suite.T().Logf("Successfully recreated task: %s", taskResp2.Id) + suite.NotEqual(taskResp.AttachedFunction.Id, taskResp2.AttachedFunction.Id, "New task should have different ID") + suite.T().Logf("Successfully recreated task: %s", taskResp2.AttachedFunction.Id) } func TestHeapClientIntegrationSuite(t *testing.T) { diff --git a/go/pkg/sysdb/coordinator/model/collection.go b/go/pkg/sysdb/coordinator/model/collection.go index 16730104762..ac5d3a81e4e 100644 --- a/go/pkg/sysdb/coordinator/model/collection.go +++ b/go/pkg/sysdb/coordinator/model/collection.go @@ -102,6 +102,10 @@ type FlushCollectionInfo struct { AttachedFunctionCompletionOffset *int64 } +type ExtendedFlushCollectionInfo struct { + Collections []*FlushCollectionInfo +} + func FilterCollection(collection *Collection, collectionID types.UniqueID, collectionName *string) bool { if collectionID != types.NilUniqueID() && collectionID != collection.ID { return false diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index 6e48096e003..ab714d0c2d9 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -1736,6 +1736,73 @@ func (tc *Catalog) FlushCollectionCompaction(ctx context.Context, flushCollectio return flushCollectionInfo, nil } +// FlushCollectionCompactionsAndAttachedFunction atomically updates multiple collection compaction data +// and attached function completion offset in a single transaction. +func (tc *Catalog) FlushCollectionCompactionsAndAttachedFunction( + ctx context.Context, + collectionCompactions []*model.FlushCollectionCompaction, + attachedFunctionID uuid.UUID, + completionOffset int64, +) (*model.ExtendedFlushCollectionInfo, error) { + if !tc.versionFileEnabled { + // Attached-function-based compactions are only supported with versioned collections + log.Error("FlushCollectionCompactionsAndAttachedFunction is only supported for versioned collections") + return nil, errors.New("attached-function-based compaction requires versioned collections") + } + + if len(collectionCompactions) == 0 { + return nil, errors.New("at least one collection compaction is required") + } + + flushInfos := make([]*model.FlushCollectionInfo, 0, len(collectionCompactions)) + + err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error { + var err error + // Get the transaction from context to pass to FlushCollectionCompactionForVersionedCollection + tx := dbcore.GetDB(txCtx) + + // Handle all collection compactions + for _, collectionCompaction := range collectionCompactions { + log.Info("FlushCollectionCompactionsAndAttachedFunction", zap.String("collection_id", collectionCompaction.ID.String())) + flushInfo, err := tc.FlushCollectionCompactionForVersionedCollection(txCtx, collectionCompaction, tx) + if err != nil { + return err + } + flushInfos = append(flushInfos, flushInfo) + } + + err = tc.metaDomain.AttachedFunctionDb(txCtx).Update(&dbmodel.AttachedFunction{ + ID: attachedFunctionID, + CompletionOffset: completionOffset, + }) + if err != nil { + return err + } + + return nil + }) + + if err != nil { + return nil, err + } + + // Populate attached function fields with authoritative values from database + for _, flushInfo := range flushInfos { + flushInfo.AttachedFunctionCompletionOffset = &completionOffset + } + + // Log with first collection ID (typically the output collection) + log.Info("FlushCollectionCompactionsAndAttachedFunction", + zap.String("first_collection_id", collectionCompactions[0].ID.String()), + zap.Int("collection_count", len(collectionCompactions)), + zap.String("attached_function_id", attachedFunctionID.String()), + zap.Int64("completion_offset", completionOffset)) + + return &model.ExtendedFlushCollectionInfo{ + Collections: flushInfos, + }, nil +} + func (tc *Catalog) validateVersionFile(versionFile *coordinatorpb.CollectionVersionFile, collectionID string, version int64) error { if versionFile.GetCollectionInfoImmutable().GetCollectionId() != collectionID { log.Error("collection id mismatch", zap.String("collection_id", collectionID), zap.String("version_file_collection_id", versionFile.GetCollectionInfoImmutable().GetCollectionId())) diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index f992830a4f9..1bd63a4af19 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -76,7 +76,7 @@ func (s *Coordinator) validateAttachedFunctionMatchesRequest(ctx context.Context return nil } -// AttachFunction creates a new attached function in the database +// AttachFunction creates an output collection and attached function in a single transaction func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.AttachFunctionRequest) (*coordinatorpb.AttachFunctionResponse, error) { log := log.With(zap.String("method", "AttachFunction")) @@ -143,18 +143,6 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att return common.ErrCollectionNotFound } - // Check if output collection already exists - outputCollectionName := req.OutputCollectionName - existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections(nil, &outputCollectionName, req.TenantId, req.Database, nil, nil, false) - if err != nil { - log.Error("AttachFunction: failed to check output collection", zap.Error(err)) - return err - } - if len(existingOutputCollections) > 0 { - log.Error("AttachFunction: output collection already exists") - return common.ErrCollectionUniqueConstraintViolation - } - // Serialize params var paramsJSON string if req.Params != nil { @@ -168,6 +156,7 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att paramsJSON = "{}" } + // Create attached function now := time.Now() attachedFunction := &dbmodel.AttachedFunction{ ID: attachedFunctionID, @@ -176,6 +165,7 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att DatabaseID: databases[0].ID, InputCollectionID: req.InputCollectionId, OutputCollectionName: req.OutputCollectionName, + OutputCollectionID: nil, FunctionID: function.ID, FunctionParams: paramsJSON, CompletionOffset: 0, @@ -196,6 +186,7 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att log.Debug("AttachFunction: attached function created with is_ready=false", zap.String("attached_function_id", attachedFunctionID.String()), + zap.String("output_collection_name", req.OutputCollectionName), zap.String("name", req.Name)) return nil }) @@ -205,7 +196,9 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att } return &coordinatorpb.AttachFunctionResponse{ - Id: attachedFunctionID.String(), + AttachedFunction: &coordinatorpb.AttachedFunction{ + Id: attachedFunctionID.String(), + }, }, nil } @@ -581,7 +574,7 @@ func (s *Coordinator) FinishCreateAttachedFunction(ctx context.Context, req *coo _, _, err = s.catalog.CreateCollectionAndSegments(txCtx, collection, segments, 0) if err != nil { - log.Error("FinishCreateAttachedFunction: failed to create collection", zap.Error(err)) + log.Error("FinishCreateAttachedFunction: failed to create output collection", zap.Error(err)) return err } diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index e3017b0d5e1..11dade5ec13 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -3,8 +3,10 @@ package grpc import ( "context" "encoding/json" + "math" "github.com/chroma-core/chroma/go/pkg/grpcutils" + "github.com/google/uuid" "github.com/chroma-core/chroma/go/pkg/common" "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" @@ -570,6 +572,131 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator return res, nil } +func (s *Server) FlushCollectionCompactionAndAttachedFunction(ctx context.Context, req *coordinatorpb.FlushCollectionCompactionAndAttachedFunctionRequest) (*coordinatorpb.FlushCollectionCompactionAndAttachedFunctionResponse, error) { + // Parse the repeated flush compaction requests + flushReqs := req.GetFlushCompactions() + if len(flushReqs) == 0 { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. flush_compactions is empty") + return nil, grpcutils.BuildInternalGrpcError("at least one flush_compaction is required") + } + if len(flushReqs) > 2 { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. too many flush_compactions", zap.Int("count", len(flushReqs))) + return nil, grpcutils.BuildInternalGrpcError("expected 1 or 2 flush_compactions") + } + + // Parse attached function update info + attachedFunctionUpdate := req.GetAttachedFunctionUpdate() + if attachedFunctionUpdate == nil { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. attached_function_update is nil") + return nil, grpcutils.BuildInternalGrpcError("attached_function_update is required") + } + + attachedFunctionID, err := uuid.Parse(attachedFunctionUpdate.Id) + if err != nil { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing attached_function_id", zap.Error(err), zap.String("attached_function_id", attachedFunctionUpdate.Id)) + return nil, grpcutils.BuildInternalGrpcError("invalid attached_function_id: " + err.Error()) + } + + // Validate completion_offset fits in int64 before storing in database + if attachedFunctionUpdate.CompletionOffset > uint64(math.MaxInt64) { + log.Error("FlushCollectionCompactionAndAttachedFunction: completion_offset too large", + zap.Uint64("completion_offset", attachedFunctionUpdate.CompletionOffset)) + return nil, grpcutils.BuildInternalGrpcError("completion_offset too large") + } + completionOffsetSigned := int64(attachedFunctionUpdate.CompletionOffset) + + // Parse all flush requests into a slice + collectionCompactions := make([]*model.FlushCollectionCompaction, 0, len(flushReqs)) + + for _, flushReq := range flushReqs { + collectionID, err := types.ToUniqueID(&flushReq.CollectionId) + err = grpcutils.BuildErrorForUUID(collectionID, "collection", err) + if err != nil { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing collection id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + + segmentCompactionInfo := make([]*model.FlushSegmentCompaction, 0, len(flushReq.SegmentCompactionInfo)) + for _, flushSegmentCompaction := range flushReq.SegmentCompactionInfo { + segmentID, err := types.ToUniqueID(&flushSegmentCompaction.SegmentId) + err = grpcutils.BuildErrorForUUID(segmentID, "segment", err) + if err != nil { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing segment id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + filePaths := make(map[string][]string) + for key, filePath := range flushSegmentCompaction.FilePaths { + filePaths[key] = filePath.Paths + } + segmentCompactionInfo = append(segmentCompactionInfo, &model.FlushSegmentCompaction{ + ID: segmentID, + FilePaths: filePaths, + }) + } + + collectionCompactions = append(collectionCompactions, &model.FlushCollectionCompaction{ + ID: collectionID, + TenantID: flushReq.TenantId, + LogPosition: flushReq.LogPosition, + CurrentCollectionVersion: flushReq.CollectionVersion, + FlushSegmentCompactions: segmentCompactionInfo, + TotalRecordsPostCompaction: flushReq.TotalRecordsPostCompaction, + SizeBytesPostCompaction: flushReq.SizeBytesPostCompaction, + }) + } + + // Call the Extended coordinator function to handle all collections + extendedFlushInfo, err := s.coordinator.FlushCollectionCompactionsAndAttachedFunction( + ctx, + collectionCompactions, + attachedFunctionID, + completionOffsetSigned, + ) + if err != nil { + log.Error("FlushCollectionCompactionAndAttachedFunction failed", zap.Error(err), zap.String("attached_function_id", attachedFunctionUpdate.Id)) + if err == common.ErrCollectionSoftDeleted { + return nil, grpcutils.BuildFailedPreconditionGrpcError(err.Error()) + } + if err == common.ErrAttachedFunctionNotFound { + return nil, grpcutils.BuildNotFoundGrpcError(err.Error()) + } + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + + // Build response with repeated collections + res := &coordinatorpb.FlushCollectionCompactionAndAttachedFunctionResponse{ + Collections: make([]*coordinatorpb.CollectionCompactionInfo, 0, len(extendedFlushInfo.Collections)), + } + + for _, flushInfo := range extendedFlushInfo.Collections { + res.Collections = append(res.Collections, &coordinatorpb.CollectionCompactionInfo{ + CollectionId: flushInfo.ID, + CollectionVersion: flushInfo.CollectionVersion, + LastCompactionTime: flushInfo.TenantLastCompactionTime, + }) + } + + // Populate attached function state with authoritative values from database (use first collection) + if len(extendedFlushInfo.Collections) > 0 { + firstFlushInfo := extendedFlushInfo.Collections[0] + attachedFunctionState := &coordinatorpb.AttachedFunctionState{} + + if firstFlushInfo.AttachedFunctionCompletionOffset != nil { + // Validate completion_offset is non-negative before converting to uint64 + if *firstFlushInfo.AttachedFunctionCompletionOffset < 0 { + log.Error("FlushCollectionCompactionAndAttachedFunction: invalid completion_offset", + zap.Int64("completion_offset", *firstFlushInfo.AttachedFunctionCompletionOffset)) + return nil, grpcutils.BuildInternalGrpcError("attached function has invalid completion_offset") + } + attachedFunctionState.CompletionOffset = uint64(*firstFlushInfo.AttachedFunctionCompletionOffset) + } + + res.AttachedFunctionState = attachedFunctionState + } + + return res, nil +} + func (s *Server) ListCollectionsToGc(ctx context.Context, req *coordinatorpb.ListCollectionsToGcRequest) (*coordinatorpb.ListCollectionsToGcResponse, error) { absoluteCutoffTimeSecs := (*uint64)(nil) if req.CutoffTime != nil { diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index bc5d8098e78..768df8221e1 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -327,6 +327,27 @@ message AttachedFunctionUpdateInfo { uint64 completion_offset = 2; } +// Combined request to flush collection compaction and update attached function atomically in a single transaction +message FlushCollectionCompactionAndAttachedFunctionRequest { + repeated FlushCollectionCompactionRequest flush_compactions = 1; + AttachedFunctionUpdateInfo attached_function_update = 2; +} + +message CollectionCompactionInfo { + string collection_id = 1; + int32 collection_version = 2; + int64 last_compaction_time = 3; +} + +message AttachedFunctionState { + uint64 completion_offset = 1; +} + +message FlushCollectionCompactionAndAttachedFunctionResponse { + repeated CollectionCompactionInfo collections = 1; + AttachedFunctionState attached_function_state = 2; +} + // Used for serializing contents in collection version history file. message CollectionVersionFile { CollectionInfoImmutable collection_info_immutable = 1; @@ -544,7 +565,7 @@ message AttachFunctionRequest { } message AttachFunctionResponse { - string id = 1; + AttachedFunction attached_function = 1; } message GetAttachedFunctionByNameRequest { @@ -692,4 +713,5 @@ service SysDB { rpc GetFunctions(GetFunctionsRequest) returns (GetFunctionsResponse) {} rpc GetSoftDeletedAttachedFunctions(GetSoftDeletedAttachedFunctionsRequest) returns (GetSoftDeletedAttachedFunctionsResponse) {} rpc FinishAttachedFunctionDeletion(FinishAttachedFunctionDeletionRequest) returns (FinishAttachedFunctionDeletionResponse) {} + rpc FlushCollectionCompactionAndAttachedFunction(FlushCollectionCompactionAndAttachedFunctionRequest) returns (FlushCollectionCompactionAndAttachedFunctionResponse) {} } diff --git a/idl/chromadb/proto/heapservice.proto b/idl/chromadb/proto/heapservice.proto index 5802f34d03a..cd3d7f163f2 100644 --- a/idl/chromadb/proto/heapservice.proto +++ b/idl/chromadb/proto/heapservice.proto @@ -4,7 +4,6 @@ package chroma; option go_package = "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"; -import "chromadb/proto/chroma.proto"; import "google/protobuf/timestamp.proto"; // A task that can be scheduled and triggered in the heap. diff --git a/rust/frontend/src/impls/service_based_frontend.rs b/rust/frontend/src/impls/service_based_frontend.rs index 7ac351b251d..c2d4a4b60fe 100644 --- a/rust/frontend/src/impls/service_based_frontend.rs +++ b/rust/frontend/src/impls/service_based_frontend.rs @@ -1964,7 +1964,7 @@ impl ServiceBasedFrontend { // Stub method for backfill - will be implemented later async fn start_backfill(&self, _attached_function_id: chroma_types::AttachedFunctionUuid) { tracing::info!("start_backfill stub called - not yet implemented"); - // TODO: Implement backfill logic + // TODO(tanujnay112): Implement backfill logic } pub async fn detach_function( diff --git a/rust/segment/src/types.rs b/rust/segment/src/types.rs index 04e684e44e3..8979a884e63 100644 --- a/rust/segment/src/types.rs +++ b/rust/segment/src/types.rs @@ -497,6 +497,35 @@ impl MaterializeLogsResult { index: 0, } } + + /// Create a MaterializeLogsResult from log records for testing purposes. + /// Each log record is treated as a new insertion with a unique offset_id. + /// + /// # Note + /// This is primarily intended for testing and should not be used in production code. + /// Use the `materialize_logs` function instead for proper log materialization. + #[doc(hidden)] + pub fn from_logs_for_test(logs: Chunk) -> Result { + let mut materialized = Vec::new(); + for (index, (log_record, _)) in logs.iter().enumerate() { + let offset_id = (index + 1) as u32; + let mut mat_record = + MaterializedLogRecord::from_log_record(offset_id, index, log_record)?; + + // Override the operation for delete records + if log_record.record.operation == Operation::Delete { + mat_record.final_operation = MaterializedLogOperation::DeleteExisting; + mat_record.final_document_at_log_index = None; + mat_record.final_embedding_at_log_index = None; + } + + materialized.push(mat_record); + } + Ok(Self { + logs, + materialized: Chunk::new(materialized.into()), + }) + } } // IntoIterator is implemented for &'a MaterializeLogsResult rather than MaterializeLogsResult because the iterator needs to hand out values with a lifetime of 'a. diff --git a/rust/sysdb/src/bin/chroma-task-manager.rs b/rust/sysdb/src/bin/chroma-task-manager.rs index 62a3987f160..1f40a91e283 100644 --- a/rust/sysdb/src/bin/chroma-task-manager.rs +++ b/rust/sysdb/src/bin/chroma-task-manager.rs @@ -126,7 +126,11 @@ async fn main() -> Result<(), Box> { }; let response = client.attach_function(request).await?; - println!("Attached Function created: {}", response.into_inner().id); + let attached_function = response + .into_inner() + .attached_function + .ok_or("Server did not return attached function")?; + println!("Attached Function created: {}", attached_function.id); } Command::GetAttachedFunction { input_collection_id, @@ -138,7 +142,10 @@ async fn main() -> Result<(), Box> { }; let response = client.get_attached_function_by_name(request).await?; - let attached_function = response.into_inner().attached_function.unwrap(); + let attached_function = response + .into_inner() + .attached_function + .ok_or("Server did not return attached function")?; println!("Attached Function ID: {:?}", attached_function.id); println!("Name: {:?}", attached_function.name); diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 0e80dad80f8..d0ef29bb696 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -8,7 +8,7 @@ use chroma_error::{ChromaError, ErrorCodes, TonicError, TonicMissingFieldError}; use chroma_types::chroma_proto::sys_db_client::SysDbClient; use chroma_types::chroma_proto::VersionListForCollection; use chroma_types::{ - chroma_proto, chroma_proto::CollectionVersionInfo, CollectionAndSegments, + chroma_proto, chroma_proto::CollectionVersionInfo, CollectionAndSegments, CollectionFlushInfo, CollectionMetadataUpdate, CountCollectionsError, CreateCollectionError, CreateDatabaseError, CreateDatabaseResponse, CreateTenantError, CreateTenantResponse, Database, DeleteCollectionError, DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionByCrnError, @@ -21,11 +21,12 @@ use chroma_types::{ UpdateTenantResponse, }; use chroma_types::{ - AttachedFunctionUuid, BatchGetCollectionSoftDeleteStatusError, + AttachedFunctionUpdateInfo, AttachedFunctionUuid, BatchGetCollectionSoftDeleteStatusError, BatchGetCollectionVersionFilePathsError, Collection, CollectionConversionError, CollectionUuid, CountForksError, DatabaseUuid, FinishCreateAttachedFunctionError, FinishDatabaseDeletionError, - FlushCompactionResponse, FlushCompactionResponseConversionError, ForkCollectionError, Schema, - SchemaError, Segment, SegmentConversionError, SegmentScope, Tenant, + FlushCompactionAndAttachedFunctionResponse, FlushCompactionResponse, + FlushCompactionResponseConversionError, ForkCollectionError, Schema, SchemaError, Segment, + SegmentConversionError, SegmentScope, Tenant, }; use prost_types; use std::collections::HashMap; @@ -627,6 +628,22 @@ impl SysDb { } } + #[allow(clippy::too_many_arguments)] + pub async fn flush_compaction_and_attached_function( + &mut self, + collections: Vec, + attached_function_update: AttachedFunctionUpdateInfo, + ) -> Result { + match self { + SysDb::Grpc(grpc) => { + grpc.flush_compaction_and_attached_function(collections, attached_function_update) + .await + } + SysDb::Sqlite(_) => todo!(), + SysDb::Test(_) => todo!(), + } + } + pub async fn list_collection_versions( &mut self, collection_id: CollectionUuid, @@ -1635,6 +1652,81 @@ impl GrpcSysDb { } } + async fn flush_compaction_and_attached_function( + &mut self, + collections: Vec, + attached_function_update: AttachedFunctionUpdateInfo, + ) -> Result { + // Process all collections into flush compaction requests + let mut flush_compactions = Vec::with_capacity(collections.len()); + + for collection in collections { + let segment_compaction_info = collection + .segment_flush_info + .iter() + .map(|segment_flush_info| segment_flush_info.try_into()) + .collect::, + SegmentFlushInfoConversionError, + >>()?; + + let schema_str = collection.schema.and_then(|s| { + serde_json::to_string(&s).ok().or_else(|| { + tracing::error!( + "Failed to serialize schema for flush_compaction_and_attached_function" + ); + None + }) + }); + + flush_compactions.push(chroma_proto::FlushCollectionCompactionRequest { + tenant_id: collection.tenant_id, + collection_id: collection.collection_id.0.to_string(), + log_position: collection.log_position, + collection_version: collection.collection_version, + segment_compaction_info, + total_records_post_compaction: collection.total_records_post_compaction, + size_bytes_post_compaction: collection.size_bytes_post_compaction, + schema_str, + }); + } + + let attached_function_update_proto = Some(chroma_proto::AttachedFunctionUpdateInfo { + id: attached_function_update.attached_function_id.0.to_string(), + completion_offset: attached_function_update.completion_offset, + }); + + let req = chroma_proto::FlushCollectionCompactionAndAttachedFunctionRequest { + flush_compactions, + attached_function_update: attached_function_update_proto, + }; + + let res = self + .client + .flush_collection_compaction_and_attached_function(req) + .await; + match res { + Ok(res) => { + let res = res.into_inner(); + let res = match res.try_into() { + Ok(res) => res, + Err(e) => { + return Err( + FlushCompactionError::FlushCompactionResponseConversionError(e), + ); + } + }; + Ok(res) + } + Err(e) => { + if e.code() == Code::FailedPrecondition { + return Err(FlushCompactionError::FailedToFlushCompaction(e)); + } + Err(FlushCompactionError::FailedToFlushCompaction(e)) + } + } + } + async fn mark_version_for_deletion( &mut self, epoch_id: i64, @@ -1736,10 +1828,15 @@ impl GrpcSysDb { let response = self.client.attach_function(req).await?.into_inner(); // Parse the returned attached_function_id - this should always succeed since the server generated it // If this fails, it indicates a serious server bug or protocol corruption + let attached_function = response.attached_function.ok_or_else(|| { + tracing::error!("Server did not return attached function in response"); + AttachFunctionError::ServerReturnedInvalidData + })?; + let attached_function_id = chroma_types::AttachedFunctionUuid( - uuid::Uuid::parse_str(&response.id).map_err(|e| { + uuid::Uuid::parse_str(&attached_function.id).map_err(|e| { tracing::error!( - attached_function_id = %response.id, + attached_function_id = %attached_function.id, error = %e, "Server returned invalid attached_function_id UUID - attached function was created but response is corrupt" ); @@ -1832,7 +1929,6 @@ impl GrpcSysDb { + std::time::Duration::from_micros(attached_function.created_at), updated_at: std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_micros(attached_function.updated_at), - is_ready: attached_function.is_ready, }) } diff --git a/rust/sysdb/src/test_sysdb.rs b/rust/sysdb/src/test_sysdb.rs index 0b0202a34d6..8f9cb12a7f5 100644 --- a/rust/sysdb/src/test_sysdb.rs +++ b/rust/sysdb/src/test_sysdb.rs @@ -716,7 +716,7 @@ fn attached_function_to_proto( created_at: system_time_to_micros(attached_function.created_at), updated_at: system_time_to_micros(attached_function.updated_at), function_id: attached_function.function_id.to_string(), - is_ready: attached_function.is_ready, + is_ready: false, // Default value since Rust struct doesn't track this field } } diff --git a/rust/types/src/flush.rs b/rust/types/src/flush.rs index af191e11060..0020823c599 100644 --- a/rust/types/src/flush.rs +++ b/rust/types/src/flush.rs @@ -1,11 +1,14 @@ -use super::{AttachedFunctionUuid, CollectionUuid, ConversionError}; +use super::{AttachedFunctionUuid, CollectionUuid, ConversionError, Schema}; use crate::{ - chroma_proto::{FilePaths, FlushSegmentCompactionInfo}, + chroma_proto::{ + FilePaths, FlushCollectionCompactionAndAttachedFunctionResponse, FlushSegmentCompactionInfo, + }, SegmentUuid, }; use chroma_error::{ChromaError, ErrorCodes}; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use thiserror::Error; +use uuid::Uuid; #[derive(Debug, Clone)] pub struct SegmentFlushInfo { @@ -13,6 +16,18 @@ pub struct SegmentFlushInfo { pub file_paths: HashMap>, } +#[derive(Debug, Clone)] +pub struct CollectionFlushInfo { + pub tenant_id: String, + pub collection_id: CollectionUuid, + pub log_position: i64, + pub collection_version: i32, + pub segment_flush_info: Arc<[SegmentFlushInfo]>, + pub total_records_post_compaction: u64, + pub size_bytes_post_compaction: u64, + pub schema: Option, +} + #[derive(Debug, Clone)] pub struct AttachedFunctionUpdateInfo { pub attached_function_id: AttachedFunctionUuid, @@ -122,6 +137,13 @@ pub struct FlushCompactionResponse { pub last_compaction_time: i64, } +#[derive(Debug)] +pub struct FlushCompactionAndAttachedFunctionResponse { + pub collections: Vec, + // Completion offset updated during register + pub completion_offset: u64, +} + impl FlushCompactionResponse { pub fn new( collection_id: CollectionUuid, @@ -136,6 +158,63 @@ impl FlushCompactionResponse { } } +impl TryFrom for FlushCompactionResponse { + type Error = FlushCompactionResponseConversionError; + + fn try_from( + value: FlushCollectionCompactionAndAttachedFunctionResponse, + ) -> Result { + // Use first collection for backward compatibility + let first_collection = value + .collections + .first() + .ok_or(FlushCompactionResponseConversionError::MissingCollections)?; + let id = Uuid::parse_str(&first_collection.collection_id) + .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; + Ok(FlushCompactionResponse { + collection_id: CollectionUuid(id), + collection_version: first_collection.collection_version, + last_compaction_time: first_collection.last_compaction_time, + }) + } +} + +impl TryFrom + for FlushCompactionAndAttachedFunctionResponse +{ + type Error = FlushCompactionResponseConversionError; + + fn try_from( + value: FlushCollectionCompactionAndAttachedFunctionResponse, + ) -> Result { + // Parse all collections from the repeated field + let mut collections = Vec::with_capacity(value.collections.len()); + for collection in value.collections { + let id = Uuid::parse_str(&collection.collection_id) + .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; + collections.push(FlushCompactionResponse { + collection_id: CollectionUuid(id), + collection_version: collection.collection_version, + last_compaction_time: collection.last_compaction_time, + }); + } + + // Extract completion_offset from attached_function_state + // Note: next_nonce and next_run are no longer used by the client + // They were already set by PrepareAttachedFunction via advance_attached_function() + let completion_offset = value + .attached_function_state + .as_ref() + .map(|state| state.completion_offset) + .unwrap_or(0); + + Ok(FlushCompactionAndAttachedFunctionResponse { + collections, + completion_offset, + }) + } +} + #[derive(Error, Debug)] pub enum FlushCompactionResponseConversionError { #[error(transparent)] @@ -146,6 +225,8 @@ pub enum FlushCompactionResponseConversionError { InvalidAttachedFunctionNonce, #[error("Invalid timestamp format")] InvalidTimestamp, + #[error("Missing collections in response")] + MissingCollections, } impl ChromaError for FlushCompactionResponseConversionError { @@ -156,6 +237,9 @@ impl ChromaError for FlushCompactionResponseConversionError { ErrorCodes::InvalidArgument } FlushCompactionResponseConversionError::InvalidTimestamp => ErrorCodes::InvalidArgument, + FlushCompactionResponseConversionError::MissingCollections => { + ErrorCodes::InvalidArgument + } FlushCompactionResponseConversionError::DecodeError(e) => e.code(), } } diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs index 7bb64525580..2422e495844 100644 --- a/rust/types/src/task.rs +++ b/rust/types/src/task.rs @@ -72,6 +72,89 @@ pub struct AttachedFunction { /// Timestamp when the attached function was last updated #[serde(default = "default_systemtime")] pub updated_at: SystemTime, - /// Whether the attached function is ready (has completed initialization/backfill) - pub is_ready: bool, + // is_ready is a column in the database, but not in the struct because + // it is not meant to be used in rust code. If it is false, rust code + // should never even see it. +} + +#[derive(Debug, thiserror::Error)] +pub enum AttachedFunctionConversionError { + #[error("Invalid UUID: {0}")] + InvalidUuid(String), + #[error("Attached function params aren't supported yet")] + ParamsNotSupported, +} + +impl TryFrom for AttachedFunction { + type Error = AttachedFunctionConversionError; + + fn try_from( + attached_function: crate::chroma_proto::AttachedFunction, + ) -> Result { + // Parse attached_function_id + let attached_function_id = attached_function + .id + .parse::() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("attached_function_id".to_string()) + })?; + + // Parse function_id + let function_id = attached_function + .function_id + .parse::() + .map_err(|_| AttachedFunctionConversionError::InvalidUuid("function_id".to_string()))?; + + // Parse input_collection_id + let input_collection_id = attached_function + .input_collection_id + .parse::() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("input_collection_id".to_string()) + })?; + + // Parse output_collection_id if available + let output_collection_id = attached_function + .output_collection_id + .map(|id| id.parse::()) + .transpose() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("output_collection_id".to_string()) + })?; + + // Parse params if available - only allow empty JSON "{}" or empty struct for now. + // TODO(tanujnay112): Process params when we allow them + let params = if let Some(params_struct) = &attached_function.params { + if !params_struct.fields.is_empty() { + return Err(AttachedFunctionConversionError::ParamsNotSupported); + } + Some("{}".to_string()) + } else { + None + }; + + // Parse timestamps + let created_at = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_micros(attached_function.created_at); + let updated_at = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_micros(attached_function.updated_at); + + Ok(AttachedFunction { + id: attached_function_id, + name: attached_function.name, + function_id, + input_collection_id, + output_collection_name: attached_function.output_collection_name, + output_collection_id, + params, + tenant_id: attached_function.tenant_id, + database_id: attached_function.database_id, + last_run: None, // Not available in proto + completion_offset: attached_function.completion_offset, + min_records_for_invocation: attached_function.min_records_for_invocation, + is_deleted: false, // Not available in proto, would need to be fetched separately + created_at, + updated_at, + }) + } } diff --git a/rust/worker/src/execution/functions/statistics.rs b/rust/worker/src/execution/functions/statistics.rs index 5d5a2ea4bdb..2a306e35fdb 100644 --- a/rust/worker/src/execution/functions/statistics.rs +++ b/rust/worker/src/execution/functions/statistics.rs @@ -11,8 +11,10 @@ use std::hash::{Hash, Hasher}; use async_trait::async_trait; use chroma_error::ChromaError; use chroma_segment::blockfile_record::RecordSegmentReader; +use chroma_segment::types::HydratedMaterializedLogRecord; use chroma_types::{ - Chunk, LogRecord, MetadataValue, Operation, OperationRecord, UpdateMetadataValue, + Chunk, LogRecord, MaterializedLogOperation, MetadataValue, Operation, OperationRecord, + UpdateMetadataValue, }; use futures::StreamExt; @@ -25,7 +27,7 @@ pub trait StatisticsFunctionFactory: std::fmt::Debug + Send + Sync { /// Accumulate statistics. Must be an associative and commutative over a sequence of `observe` calls. pub trait StatisticsFunction: std::fmt::Debug + Send { - fn observe(&mut self, log_record: &LogRecord); + fn observe(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>); fn output(&self) -> UpdateMetadataValue; } @@ -44,7 +46,7 @@ pub struct CounterFunction { } impl StatisticsFunction for CounterFunction { - fn observe(&mut self, _: &LogRecord) { + fn observe(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) { self.acc = self.acc.saturating_add(1); } @@ -173,28 +175,26 @@ pub struct StatisticsFunctionExecutor(pub Box); impl AttachedFunctionExecutor for StatisticsFunctionExecutor { async fn execute( &self, - input_records: Chunk, + input_records: Chunk>, output_reader: Option<&RecordSegmentReader<'_>>, ) -> Result, Box> { let mut counts: HashMap>> = HashMap::default(); - for (log_record, _) in input_records.iter() { - if matches!(log_record.record.operation, Operation::Delete) { + for (hydrated_record, _index) in input_records.iter() { + // Skip delete operations - they should not be counted in statistics + if hydrated_record.get_operation() == MaterializedLogOperation::DeleteExisting { continue; } - if let Some(update_metadata) = log_record.record.metadata.as_ref() { - for (key, update_value) in update_metadata.iter() { - let value: Option = update_value.try_into().ok(); - if let Some(value) = value { - let inner_map = counts.entry(key.clone()).or_default(); - for stats_value in StatisticsValue::from_metadata_value(&value) { - inner_map - .entry(stats_value) - .or_insert_with(|| self.0.create()) - .observe(log_record); - } - } + // Use merged_metadata to get the metadata from the hydrated record + let metadata = hydrated_record.merged_metadata(); + for (key, value) in metadata.iter() { + let inner_map = counts.entry(key.clone()).or_default(); + for stats_value in StatisticsValue::from_metadata_value(value) { + inner_map + .entry(stats_value) + .or_insert_with(|| self.0.create()) + .observe(hydrated_record); } } } @@ -261,12 +261,17 @@ impl AttachedFunctionExecutor for StatisticsFunctionExecutor { mod tests { use std::collections::HashMap; - use chroma_segment::{blockfile_record::RecordSegmentReader, test::TestDistributedSegment}; + use chroma_segment::{ + blockfile_record::RecordSegmentReader, test::TestDistributedSegment, + types::MaterializeLogsResult, + }; use chroma_types::{ Chunk, LogRecord, Operation, OperationRecord, SparseVector, UpdateMetadata, UpdateMetadataValue, }; + use crate::execution::orchestration::compact; + use super::*; fn build_record(id: &str, metadata: HashMap) -> LogRecord { @@ -282,7 +287,7 @@ mod tests { log_offset: 0, record: OperationRecord { id: id.to_string(), - embedding: None, + embedding: Some(vec![0.0]), encoding: None, metadata: Some(metadata), document: None, @@ -291,6 +296,21 @@ mod tests { } } + async fn hydrate_records<'a>( + materialized: &'a MaterializeLogsResult, + record_reader: Option<&'a RecordSegmentReader<'a>>, + ) -> Vec> { + let mut hydrated_records = Vec::new(); + for borrowed_record in materialized.iter() { + let hydrated = borrowed_record + .hydrate(record_reader) + .await + .expect("hydration should succeed"); + hydrated_records.push(hydrated); + } + hydrated_records + } + fn extract_metadata_tuple(metadata: &UpdateMetadata) -> (i64, String, String, String) { let count = match metadata.get("count") { Some(UpdateMetadataValue::Int(value)) => *value, @@ -460,7 +480,11 @@ mod tests { ]), ); - let input = Chunk::new(vec![record_one, record_two].into()); + let logs = Chunk::new(vec![record_one, record_two].into()); + let materialized = MaterializeLogsResult::from_logs_for_test(logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, None).await; + let input = Chunk::new(std::sync::Arc::from(hydrated)); let output = executor .execute(input, None) @@ -558,7 +582,12 @@ mod tests { )]), ); - let input = Chunk::new(vec![record_one, record_two].into()); + let logs = Chunk::new(vec![record_one, record_two].into()); + let materialized = MaterializeLogsResult::from_logs_for_test(logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, None).await; + let input = Chunk::new(std::sync::Arc::from(hydrated)); + let output = executor .execute(input, None) .await @@ -597,7 +626,12 @@ mod tests { HashMap::from([("bool_key".to_string(), UpdateMetadataValue::Bool(false))]), ); - let input = Chunk::new(vec![upsert_record, delete_record].into()); + let logs = Chunk::new(vec![upsert_record, delete_record].into()); + let materialized = MaterializeLogsResult::from_logs_for_test(logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, None).await; + let input = Chunk::new(std::sync::Arc::from(hydrated)); + let output = executor .execute(input, None) .await @@ -634,7 +668,12 @@ mod tests { )]), ); - let input = Chunk::new(vec![record].into()); + let logs = Chunk::new(vec![record].into()); + let materialized = MaterializeLogsResult::from_logs_for_test(logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, None).await; + let input = Chunk::new(std::sync::Arc::from(hydrated)); + let output = executor .execute(input, None) .await @@ -652,7 +691,11 @@ mod tests { HashMap::from([("skip".to_string(), UpdateMetadataValue::None)]), ); - let input = Chunk::new(vec![record].into()); + let logs = Chunk::new(vec![record].into()); + let materialized = MaterializeLogsResult::from_logs_for_test(logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, None).await; + let input = Chunk::new(std::sync::Arc::from(hydrated)); let output = executor .execute(input, None) @@ -685,13 +728,17 @@ mod tests { .await .expect("record segment reader creation succeeds"); - let input = Chunk::new( + let logs = Chunk::new( vec![build_record( "input-1", HashMap::from([("fresh_key".to_string(), UpdateMetadataValue::Int(1))]), )] .into(), ); + let materialized = MaterializeLogsResult::from_logs_for_test(logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, Some(&record_reader)).await; + let input = Chunk::new(std::sync::Arc::from(hydrated)); let output = executor .execute(input, Some(&record_reader)) @@ -736,7 +783,11 @@ mod tests { .await .expect("record segment reader creation succeeds"); - let empty_input: Chunk = Chunk::new(Vec::::new().into()); + let empty_logs: Chunk = Chunk::new(Vec::::new().into()); + let materialized = MaterializeLogsResult::from_logs_for_test(empty_logs) + .expect("materialization should succeed"); + let hydrated = hydrate_records(&materialized, Some(&record_reader)).await; + let empty_input = Chunk::new(std::sync::Arc::from(hydrated)); let output = executor .execute(empty_input, Some(&record_reader)) @@ -749,19 +800,16 @@ mod tests { } // TODO(tanujnay112): Reenable this after function compaction is brought back - /* #[tokio::test] async fn test_k8s_integration_statistics_function() { use crate::config::RootConfig; - use crate::execution::orchestration::CompactOrchestrator; use chroma_config::{registry::Registry, Configurable}; use chroma_log::in_memory_log::{InMemoryLog, InternalLogRecord}; use chroma_log::Log; use chroma_segment::test::TestDistributedSegment; use chroma_sysdb::SysDb; - use chroma_system::{Dispatcher, Orchestrator, System}; + use chroma_system::{Dispatcher, System}; use chroma_types::{CollectionUuid, Operation, OperationRecord, UpdateMetadataValue}; - use s3heap_service::client::{GrpcHeapService, GrpcHeapServiceConfig}; use std::collections::HashMap; // Setup test environment @@ -788,14 +836,6 @@ mod tests { .expect("Should connect to grpc sysdb"); let mut sysdb = SysDb::Grpc(grpc_sysdb); - // Connect to Grpc Heap Service (requires Tilt running) - let heap_service = GrpcHeapService::try_from_config( - &(GrpcHeapServiceConfig::default(), system.clone()), - ®istry, - ) - .await - .expect("Should connect to grpc heap service"); - let test_segments = TestDistributedSegment::new().await; let mut in_memory_log = InMemoryLog::new(); @@ -886,8 +926,9 @@ mod tests { } let log = Log::InMemory(in_memory_log); - let attached_function_name = "test_statistics"; - let output_collection_name = format!("test_stats_output_{}", uuid::Uuid::new_v4()); + let test_run_id = uuid::Uuid::new_v4(); + let attached_function_name = format!("test_statistics_{}", test_run_id); + let output_collection_name = format!("test_stats_output_{}", test_run_id); // Create statistics attached function via sysdb let attached_function_id = sysdb @@ -895,7 +936,7 @@ mod tests { attached_function_name.to_string(), "statistics".to_string(), collection_id, - output_collection_name, + output_collection_name.clone(), serde_json::Value::Null, tenant.clone(), db.clone(), @@ -903,39 +944,13 @@ mod tests { ) .await .expect("Attached function creation should succeed"); - - // Initial compaction - let compact_orchestrator = CompactOrchestrator::new( - collection_id, - false, - 50, - 1000, - 50, - log.clone(), - sysdb.clone(), - test_segments.blockfile_provider.clone(), - test_segments.hnsw_provider.clone(), - test_segments.spann_provider.clone(), - dispatcher_handle.clone(), - None, - ); - - let result = compact_orchestrator.run(system.clone()).await; - assert!( - result.is_ok(), - "Initial compaction should succeed: {:?}", - result.err() - ); - - // Get nonce for attached function run - let attached_function = sysdb - .get_attached_function_by_name(collection_id, attached_function_name.to_string()) + sysdb + .finish_create_attached_function(attached_function_id) .await - .expect("Attached function should be found"); - let execution_nonce = attached_function.lowest_live_nonce.unwrap(); + .expect("Attached function creation finish should succeed"); - // Run statistics function - let compact_orchestrator = CompactOrchestrator::new_for_attached_function( + Box::pin(compact::compact( + system.clone(), collection_id, false, 50, @@ -943,33 +958,25 @@ mod tests { 50, log.clone(), sysdb.clone(), - heap_service, test_segments.blockfile_provider.clone(), test_segments.hnsw_provider.clone(), test_segments.spann_provider.clone(), - dispatcher_handle, + dispatcher_handle.clone(), None, - attached_function_id, - execution_nonce, - ); - - let result = compact_orchestrator.run(system).await; - assert!( - result.is_ok(), - "Statistics function execution should succeed: {:?}", - result.err() - ); + )) + .await + .expect("Compaction should succeed"); // Verify statistics were generated let updated_attached_function = sysdb - .get_attached_function_by_name(collection_id, attached_function_name.to_string()) + .get_attached_function_by_name(collection_id, attached_function_name.clone()) .await .expect("Attached function should be found"); - // Note: completion_offset is 13, but all 15 records (0-14) were processed + // Note: completion_offset is 14, all 15 records (0-14) were processed assert_eq!( - updated_attached_function.completion_offset, 13, - "Completion offset should be 13" + updated_attached_function.completion_offset, 14, + "Completion offset should be 14 (last processed record)" ); let output_collection_id = updated_attached_function.output_collection_id.unwrap(); @@ -1073,5 +1080,4 @@ mod tests { stats_by_key_value.len() ); } - */ } diff --git a/rust/worker/src/execution/operators/execute_task.rs b/rust/worker/src/execution/operators/execute_task.rs index c4a5171ef91..d5d2955a3a7 100644 --- a/rust/worker/src/execution/operators/execute_task.rs +++ b/rust/worker/src/execution/operators/execute_task.rs @@ -3,6 +3,7 @@ use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; use chroma_log::Log; use chroma_segment::blockfile_record::{RecordSegmentReader, RecordSegmentReaderCreationError}; +use chroma_segment::types::HydratedMaterializedLogRecord; use chroma_system::{Operator, OperatorType}; use chroma_types::{ Chunk, CollectionUuid, LogRecord, Operation, OperationRecord, Segment, UpdateMetadataValue, @@ -10,8 +11,10 @@ use chroma_types::{ }; use std::sync::Arc; use thiserror::Error; +use uuid::Uuid; use crate::execution::functions::{CounterFunctionFactory, StatisticsFunctionExecutor}; +use crate::execution::operators::materialize_logs::MaterializeLogOutput; /// Trait for attached function executors that process input records and produce output records. /// Implementors can read from the output collection to maintain state across executions. @@ -20,14 +23,14 @@ pub trait AttachedFunctionExecutor: Send + Sync + std::fmt::Debug { /// Execute the attached function logic on input records. /// /// # Arguments - /// * `input_records` - The log records to process + /// * `input_records` - The hydrated materialized log records to process /// * `output_reader` - Optional reader for the output collection's compacted data /// /// # Returns /// The output records to be written to the output collection async fn execute( &self, - input_records: Chunk, + input_records: Chunk>, output_reader: Option<&RecordSegmentReader<'_>>, ) -> Result, Box>; } @@ -41,13 +44,14 @@ pub struct CountAttachedFunction; impl AttachedFunctionExecutor for CountAttachedFunction { async fn execute( &self, - input_records: Chunk, + input_records: Chunk>, _output_reader: Option<&RecordSegmentReader<'_>>, ) -> Result, Box> { let records_count = input_records.len() as i64; - let new_total_count = records_count; + println!("new_total_count is {}", new_total_count); + // Create output record with updated count let mut metadata = std::collections::HashMap::new(); metadata.insert( @@ -55,21 +59,19 @@ impl AttachedFunctionExecutor for CountAttachedFunction { UpdateMetadataValue::Int(new_total_count), ); - let operation_record = OperationRecord { - id: "attached_function_result".to_string(), - embedding: Some(vec![0.0]), - encoding: None, - metadata: Some(metadata), - document: None, - operation: Operation::Upsert, - }; - - let log_record = LogRecord { - log_offset: 0, // Will be set by caller - record: operation_record, + let output_record = LogRecord { + log_offset: 0, + record: OperationRecord { + id: "function_output".to_string(), + embedding: Some(vec![0.0]), + encoding: None, + metadata: Some(metadata), + document: Some(format!("Processed {} records", records_count)), + operation: Operation::Upsert, + }, }; - Ok(Chunk::new(Arc::new([log_record]))) + Ok(Chunk::new(std::sync::Arc::from(vec![output_record]))) } } @@ -84,12 +86,11 @@ pub struct ExecuteAttachedFunctionOperator { impl ExecuteAttachedFunctionOperator { /// Create a new ExecuteAttachedFunctionOperator from an AttachedFunction. /// The executor is selected based on the function_id in the attached function. - #[allow(dead_code)] pub(crate) fn from_attached_function( - attached_function: &chroma_types::AttachedFunction, + function_id: Uuid, log_client: Log, ) -> Result { - let executor: Arc = match attached_function.function_id { + let executor: Arc = match function_id { // For the record counter, use CountAttachedFunction FUNCTION_RECORD_COUNTER_ID => Arc::new(CountAttachedFunction), // For statistics, use StatisticsFunctionExecutor with CounterFunctionFactory @@ -97,13 +98,10 @@ impl ExecuteAttachedFunctionOperator { Arc::new(StatisticsFunctionExecutor(Box::new(CounterFunctionFactory))) } _ => { - tracing::error!( - "Unknown function_id UUID: {}", - attached_function.function_id - ); + tracing::error!("Unknown function_id UUID: {}", function_id); return Err(ExecuteAttachedFunctionError::InvalidUuid(format!( "Unknown function_id UUID: {}", - attached_function.function_id + function_id ))); } }; @@ -118,8 +116,8 @@ impl ExecuteAttachedFunctionOperator { /// Input for the ExecuteAttachedFunction operator #[derive(Debug)] pub struct ExecuteAttachedFunctionInput { - /// The fetched log records to process - pub log_records: Chunk, + /// The materialized log outputs to process + pub materialized_logs: Arc>, /// The tenant ID pub tenant_id: String, /// The output collection ID where results are written @@ -188,13 +186,11 @@ impl Operator input: &ExecuteAttachedFunctionInput, ) -> Result { tracing::info!( - "[ExecuteAttachedFunction]: Processing {} records for output collection {}", - input.log_records.len(), + "[ExecuteAttachedFunction]: Processing {} materialized log outputs for output collection {}", + input.materialized_logs.len(), input.output_collection_id ); - let records_count = input.log_records.len() as u64; - // Create record segment reader from the output collection's record segment let record_segment_reader = match Box::pin(RecordSegmentReader::from_segment( &input.output_record_segment, @@ -211,33 +207,40 @@ impl Operator Err(e) => return Err((*e).into()), }; + // Process all materialized logs and hydrate the records + let mut all_hydrated_records = Vec::new(); + let mut total_records_processed = 0u64; + + for materialized_log in input.materialized_logs.iter() { + // Use the iterator to process each materialized record + for borrowed_record in materialized_log.result.iter() { + // Hydrate the record using the same pattern as materialize_logs operator + let hydrated_record = borrowed_record + .hydrate(record_segment_reader.as_ref()) + .await + .map_err(|e| ExecuteAttachedFunctionError::SegmentRead(Box::new(e)))?; + + all_hydrated_records.push(hydrated_record); + } + + total_records_processed += materialized_log.result.len() as u64; + } + // Execute the attached function using the provided executor let output_records = self .attached_function_executor - .execute(input.log_records.clone(), record_segment_reader.as_ref()) + .execute( + Chunk::new(std::sync::Arc::from(all_hydrated_records)), + record_segment_reader.as_ref(), + ) .await .map_err(ExecuteAttachedFunctionError::SegmentRead)?; - // Update log offsets for output records - // Convert u64 completion_offset to i64 for LogRecord (which uses i64) - let base_offset: i64 = input.completion_offset.try_into().map_err(|_| { - ExecuteAttachedFunctionError::LogOffsetOverflowUnsignedToSigned( - input.completion_offset, - 0, - ) - })?; - let output_records_with_offsets: Vec = output_records .iter() - .enumerate() - .map(|(i, (log_record, _))| { - let i_i64 = i64::try_from(i) - .map_err(|_| ExecuteAttachedFunctionError::LogOffsetOverflow(base_offset, i))?; - let offset = base_offset.checked_add(i_i64).ok_or_else(|| { - ExecuteAttachedFunctionError::LogOffsetOverflow(base_offset, i) - })?; + .map(|(log_record, _)| { Ok(LogRecord { - log_offset: offset, + log_offset: -1, // Nobody should be using these anyway. record: log_record.record.clone(), }) }) @@ -250,8 +253,8 @@ impl Operator // Return the output records to be partitioned Ok(ExecuteAttachedFunctionOutput { - records_processed: records_count, - output_records: Chunk::new(Arc::from(output_records_with_offsets)), + records_processed: total_records_processed, + output_records: Chunk::new(std::sync::Arc::from(output_records_with_offsets)), }) } } diff --git a/rust/worker/src/execution/operators/finish_attached_function.rs b/rust/worker/src/execution/operators/finish_attached_function.rs new file mode 100644 index 00000000000..bdb2ce5cea7 --- /dev/null +++ b/rust/worker/src/execution/operators/finish_attached_function.rs @@ -0,0 +1,148 @@ +use async_trait::async_trait; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_log::Log; +use chroma_sysdb::SysDb; +use chroma_system::Operator; +use chroma_types::{AttachedFunctionUpdateInfo, AttachedFunctionUuid, CollectionFlushInfo}; +use thiserror::Error; +use tonic; + +/// The finish attached function operator is responsible for: +/// 1. Registering collection compaction results for all collections +/// 2. Updating attached function completion offset in the same transaction +#[derive(Debug)] +pub struct FinishAttachedFunctionOperator {} + +impl FinishAttachedFunctionOperator { + /// Create a new finish attached function operator. + pub fn new() -> Box { + Box::new(FinishAttachedFunctionOperator {}) + } +} + +#[derive(Debug)] +/// The input for the finish attached function operator. +/// This input is used to complete the attached function workflow by: +/// - Flushing collection compaction data to sysdb for all collections +/// - Updating attached function completion offset in the same transaction +pub struct FinishAttachedFunctionInput { + pub collections: Vec, + pub attached_function_id: AttachedFunctionUuid, + pub completion_offset: u64, + + pub sysdb: SysDb, + pub log: Log, +} + +impl FinishAttachedFunctionInput { + /// Create a new finish attached function input. + pub fn new( + collections: Vec, + attached_function_id: AttachedFunctionUuid, + completion_offset: u64, + + sysdb: SysDb, + log: Log, + ) -> Self { + FinishAttachedFunctionInput { + collections, + attached_function_id, + completion_offset, + sysdb, + log, + } + } +} + +#[derive(Debug)] +pub struct FinishAttachedFunctionOutput { + pub collection_flush_results: Vec, + pub completion_offset: u64, +} + +#[derive(Error, Debug)] +pub enum FinishAttachedFunctionError { + #[error("Failed to flush collection compaction: {0}")] + FlushFailed(#[from] chroma_sysdb::FlushCompactionError), + #[error("Invalid attached function ID: {0}")] + InvalidFunctionId(String), +} + +impl ChromaError for FinishAttachedFunctionError { + fn code(&self) -> ErrorCodes { + match self { + FinishAttachedFunctionError::FlushFailed(e) => e.code(), + FinishAttachedFunctionError::InvalidFunctionId(_) => ErrorCodes::InvalidArgument, + } + } +} + +#[async_trait] +impl Operator + for FinishAttachedFunctionOperator +{ + type Error = FinishAttachedFunctionError; + + fn get_name(&self) -> &'static str { + "FinishAttachedFunctionOperator" + } + + async fn run( + &self, + input: &FinishAttachedFunctionInput, + ) -> Result { + let mut sysdb = input.sysdb.clone(); + + // Create the attached function update info + let attached_function_update = AttachedFunctionUpdateInfo { + attached_function_id: input.attached_function_id, + completion_offset: input.completion_offset, + }; + + // Flush all collection compaction results and update attached function in one RPC + let flush_result = sysdb + .flush_compaction_and_attached_function( + input.collections.clone(), + attached_function_update, + ) + .await + .map_err(FinishAttachedFunctionError::FlushFailed)?; + + // Convert the collection results to FlushCompactionResponse + let collection_flush_results: Vec = flush_result + .collections + .into_iter() + .map(|collection| chroma_types::FlushCompactionResponse { + collection_id: collection.collection_id, + collection_version: collection.collection_version, + last_compaction_time: collection.last_compaction_time, + }) + .collect(); + + // TODO(tanujnay112): Can optimize the below to not happen on the output collection. + + // Update log offsets for all collections to ensure consistency + // This must be done after the flush to ensure the log position in sysdb is always >= log service + let mut log = input.log.clone(); + for collection in &input.collections { + log.update_collection_log_offset( + &collection.tenant_id, + collection.collection_id, + collection.log_position, + ) + .await + .map_err(|e| { + FinishAttachedFunctionError::FlushFailed( + chroma_sysdb::FlushCompactionError::FailedToFlushCompaction( + tonic::Status::internal(format!("Failed to update log offset: {}", e)), + ), + ) + })?; + } + + Ok(FinishAttachedFunctionOutput { + collection_flush_results, + completion_offset: flush_result.completion_offset, + }) + } +} diff --git a/rust/worker/src/execution/operators/get_attached_function.rs b/rust/worker/src/execution/operators/get_attached_function.rs new file mode 100644 index 00000000000..e9bf1930c64 --- /dev/null +++ b/rust/worker/src/execution/operators/get_attached_function.rs @@ -0,0 +1,151 @@ +use async_trait::async_trait; +use chroma_error::ChromaError; +use chroma_sysdb::sysdb::SysDb; +use chroma_system::{Operator, OperatorType}; +use chroma_types::{ + AttachedFunction, AttachedFunctionConversionError, CollectionUuid, ListAttachedFunctionsError, +}; +use thiserror::Error; + +/// The `GetAttachedFunctionOperator` lists attached functions for a collection and selects the first one. +/// If no functions are found, it returns an empty result (not an error) to allow the orchestrator +/// to handle the case gracefully. +#[derive(Clone, Debug)] +pub struct GetAttachedFunctionOperator { + pub sysdb: SysDb, + pub collection_id: CollectionUuid, +} + +impl GetAttachedFunctionOperator { + pub fn new(sysdb: SysDb, collection_id: CollectionUuid) -> Self { + Self { + sysdb, + collection_id, + } + } +} + +#[derive(Debug)] +pub struct GetAttachedFunctionInput { + pub collection_id: CollectionUuid, +} + +#[derive(Debug)] +pub struct GetAttachedFunctionOutput { + pub attached_function: Option, +} + +#[derive(Debug, Error)] +pub enum GetAttachedFunctionOperatorError { + #[error("Failed to list attached functions: {0}")] + ListFunctions(#[from] ListAttachedFunctionsError), + #[error("Failed to convert attached function proto")] + ConversionError(#[from] AttachedFunctionConversionError), + #[error("No attached function found")] + NoAttachedFunctionFound, +} + +#[derive(Debug, Error)] +pub enum GetAttachedFunctionError { + #[error("Failed to list attached functions: {0}")] + ListFunctions(#[from] ListAttachedFunctionsError), + #[error("Failed to convert attached function proto")] + ConversionError(#[from] AttachedFunctionConversionError), +} + +impl ChromaError for GetAttachedFunctionError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + GetAttachedFunctionError::ListFunctions(e) => e.code(), + GetAttachedFunctionError::ConversionError(_) => chroma_error::ErrorCodes::Internal, + } + } + + fn should_trace_error(&self) -> bool { + match self { + GetAttachedFunctionError::ListFunctions(e) => e.should_trace_error(), + GetAttachedFunctionError::ConversionError(_) => true, + } + } +} + +impl ChromaError for GetAttachedFunctionOperatorError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + GetAttachedFunctionOperatorError::ListFunctions(e) => e.code(), + GetAttachedFunctionOperatorError::ConversionError(_) => { + chroma_error::ErrorCodes::Internal + } + GetAttachedFunctionOperatorError::NoAttachedFunctionFound => { + chroma_error::ErrorCodes::NotFound + } + } + } + + fn should_trace_error(&self) -> bool { + match self { + GetAttachedFunctionOperatorError::ListFunctions(e) => e.should_trace_error(), + GetAttachedFunctionOperatorError::ConversionError(_) => true, + GetAttachedFunctionOperatorError::NoAttachedFunctionFound => false, + } + } +} + +#[async_trait] +impl Operator for GetAttachedFunctionOperator { + type Error = GetAttachedFunctionOperatorError; + + fn get_type(&self) -> OperatorType { + OperatorType::IO + } + + async fn run( + &self, + input: &GetAttachedFunctionInput, + ) -> Result { + tracing::trace!( + "[{}]: Collection ID {}", + self.get_name(), + input.collection_id.0 + ); + + let attached_functions = self + .sysdb + .clone() + .list_attached_functions(input.collection_id) + .await?; + + if attached_functions.is_empty() { + tracing::info!( + "[{}]: No attached functions found for collection {}", + self.get_name(), + input.collection_id.0 + ); + return Ok(GetAttachedFunctionOutput { + attached_function: None, + }); + } + + // Take the first attached function from the list + let attached_function_proto = attached_functions + .into_iter() + .next() + .ok_or(GetAttachedFunctionOperatorError::NoAttachedFunctionFound)?; + + // Convert proto to AttachedFunction type using TryFrom from task.rs + let attached_function: AttachedFunction = attached_function_proto + .try_into() + .map_err(GetAttachedFunctionOperatorError::ConversionError)?; + + tracing::info!( + "[{}]: Found attached function '{}' for collection {}", + self.get_name(), + attached_function.name, + input.collection_id.0 + ); + + Ok(GetAttachedFunctionOutput { + attached_function: Some(attached_function), + }) + } +} diff --git a/rust/worker/src/execution/operators/get_collection_and_segments.rs b/rust/worker/src/execution/operators/get_collection_and_segments.rs index 8a93a4adec6..2efb46696ae 100644 --- a/rust/worker/src/execution/operators/get_collection_and_segments.rs +++ b/rust/worker/src/execution/operators/get_collection_and_segments.rs @@ -22,6 +22,15 @@ pub struct GetCollectionAndSegmentsOperator { pub collection_id: CollectionUuid, } +impl GetCollectionAndSegmentsOperator { + pub fn new(sysdb: SysDb, collection_id: CollectionUuid) -> Self { + Self { + sysdb, + collection_id, + } + } +} + type GetCollectionAndSegmentsInput = (); pub type GetCollectionAndSegmentsOutput = CollectionAndSegments; diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index e6464e1d3bf..c829d03e356 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -3,7 +3,10 @@ pub mod commit_segment_writer; pub mod count_records; pub mod execute_task; pub mod fetch_log; +pub mod finish_attached_function; pub mod flush_segment_writer; +pub mod get_attached_function; +pub mod get_collection_and_segments; pub mod materialize_logs; pub(super) mod register; pub mod spann_bf_pl; @@ -11,7 +14,6 @@ pub(super) mod spann_centers_search; pub(super) mod spann_fetch_pl; pub mod filter; -pub mod get_collection_and_segments; pub mod idf; pub mod knn_hnsw; pub mod knn_log; diff --git a/rust/worker/src/execution/operators/register.rs b/rust/worker/src/execution/operators/register.rs index 8ca41ab01b5..d6a6febb333 100644 --- a/rust/worker/src/execution/operators/register.rs +++ b/rust/worker/src/execution/operators/register.rs @@ -9,6 +9,8 @@ use chroma_types::{CollectionUuid, FlushCompactionResponse, SegmentFlushInfo}; use std::sync::Arc; use thiserror::Error; +use crate::execution::operators::finish_attached_function::FinishAttachedFunctionError; + /// The register operator is responsible for flushing compaction data to the sysdb /// as well as updating the log offset in the log service. #[derive(Debug)] @@ -91,23 +93,27 @@ pub struct RegisterOutput { #[derive(Error, Debug)] pub enum RegisterError { #[error("Flush compaction error: {0}")] - FlushCompactionError(#[from] FlushCompactionError), + FlushCompaction(#[from] FlushCompactionError), #[error("Update log offset error: {0}")] - UpdateLogOffsetError(#[from] Box), + UpdateLogOffset(#[from] Box), + #[error("Finish attached function error: {0}")] + FinishAttachedFunction(#[from] FinishAttachedFunctionError), } impl ChromaError for RegisterError { fn code(&self) -> ErrorCodes { match self { - RegisterError::FlushCompactionError(e) => e.code(), - RegisterError::UpdateLogOffsetError(e) => e.code(), + RegisterError::FlushCompaction(e) => e.code(), + RegisterError::UpdateLogOffset(e) => e.code(), + RegisterError::FinishAttachedFunction(e) => e.code(), } } fn should_trace_error(&self) -> bool { match self { - RegisterError::FlushCompactionError(e) => e.should_trace_error(), - RegisterError::UpdateLogOffsetError(e) => e.should_trace_error(), + RegisterError::FlushCompaction(e) => e.should_trace_error(), + RegisterError::UpdateLogOffset(e) => e.should_trace_error(), + RegisterError::FinishAttachedFunction(e) => e.should_trace_error(), } } } @@ -141,7 +147,7 @@ impl Operator for RegisterOperator { // the we may lose data in compaction. let sysdb_registration_result = match result { Ok(response) => response, - Err(error) => return Err(RegisterError::FlushCompactionError(error)), + Err(error) => return Err(RegisterError::FlushCompaction(error)), }; let result = log @@ -152,7 +158,7 @@ impl Operator for RegisterOperator { Ok(_) => Ok(RegisterOutput { _sysdb_registration_result: sysdb_registration_result, }), - Err(error) => Err(RegisterError::UpdateLogOffsetError(error)), + Err(error) => Err(RegisterError::UpdateLogOffset(error)), } } } diff --git a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs index 2bcc13a50a8..39c40659035 100644 --- a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; @@ -60,7 +60,7 @@ pub struct ApplyLogsOrchestrator { segment_spans: HashMap, // Store the materialized outputs from LogFetchOrchestrator - materialized_log_data: Option>, + materialized_log_data: Option>>, metrics: CompactionMetrics, } @@ -181,7 +181,7 @@ impl ApplyLogsOrchestratorResponse { impl ApplyLogsOrchestrator { pub fn new( context: &CompactionContext, - materialized_log_data: Option>, + materialized_log_data: Option>>, ) -> Self { ApplyLogsOrchestrator { context: context.clone(), @@ -206,7 +206,7 @@ impl ApplyLogsOrchestrator { let mut tasks_to_run = Vec::new(); self.num_materialized_logs += materialized_logs.len() as u64; - let writers = self.context.get_segment_writers()?; + let writers = self.context.get_output_segment_writers()?; { self.num_uncompleted_tasks_by_segment @@ -255,7 +255,7 @@ impl ApplyLogsOrchestrator { materialized_logs.clone(), writers.record_reader.clone(), self.context - .get_collection_info()? + .get_output_collection_info()? .collection .schema .clone(), @@ -356,7 +356,7 @@ impl ApplyLogsOrchestrator { .add(self.num_materialized_logs, &[]); self.state = ExecutionState::Register; - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_output_collection_info() { Ok(collection_info) => collection_info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -461,7 +461,7 @@ impl Orchestrator for ApplyLogsOrchestrator { } }; - for materialized_output in materialized_outputs { + for materialized_output in materialized_outputs.iter() { if materialized_output.result.is_empty() { self.terminate_with_result( Err(ApplyLogsOrchestratorError::InvariantViolation( @@ -477,7 +477,7 @@ impl Orchestrator for ApplyLogsOrchestrator { // Create tasks for each materialized output let result = self - .create_apply_log_to_segment_writer_tasks(materialized_output.result, ctx) + .create_apply_log_to_segment_writer_tasks(materialized_output.result.clone(), ctx) .await; let mut new_tasks = match result { @@ -525,15 +525,10 @@ impl Handler collection_info, - None => { - let err = ApplyLogsOrchestratorError::InvariantViolation( - "Collection info should have been set", - ); - self.terminate_with_result(Err(err), ctx).await; - return; + let collection_info = match self.context.get_output_collection_info_mut() { + Ok(info) => info, + Err(err) => { + return self.terminate_with_result(Err(err.into()), ctx).await; } }; @@ -587,7 +582,9 @@ impl Handler writer, None => return, @@ -617,7 +614,7 @@ impl Handler info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; diff --git a/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs new file mode 100644 index 00000000000..a51547b5876 --- /dev/null +++ b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs @@ -0,0 +1,756 @@ +use std::{ + cell::OnceCell, + sync::{atomic::AtomicU32, Arc}, +}; + +use async_trait::async_trait; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_segment::{ + blockfile_metadata::{MetadataSegmentError, MetadataSegmentWriter}, + blockfile_record::{RecordSegmentWriter, RecordSegmentWriterCreationError}, + distributed_hnsw::{DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentWriter}, + distributed_spann::SpannSegmentWriterError, + types::VectorSegmentWriter, +}; +use chroma_system::{ + wrap, ChannelError, ComponentContext, ComponentHandle, Dispatcher, Handler, Orchestrator, + OrchestratorContext, PanicError, TaskError, TaskMessage, TaskResult, +}; +use chroma_types::{ + AttachedFunctionUuid, Chunk, CollectionAndSegments, CollectionUuid, JobId, LogRecord, + SegmentType, +}; +use thiserror::Error; +use tokio::sync::oneshot::{error::RecvError, Sender}; +use tracing::Span; +use uuid::Uuid; + +use crate::execution::{ + operators::{ + execute_task::{ + ExecuteAttachedFunctionError, ExecuteAttachedFunctionInput, + ExecuteAttachedFunctionOperator, ExecuteAttachedFunctionOutput, + }, + get_attached_function::{ + GetAttachedFunctionInput, GetAttachedFunctionOperator, + GetAttachedFunctionOperatorError, GetAttachedFunctionOutput, + }, + get_collection_and_segments::{ + GetCollectionAndSegmentsError, GetCollectionAndSegmentsOperator, + }, + materialize_logs::{ + MaterializeLogInput, MaterializeLogOperator, MaterializeLogOperatorError, + MaterializeLogOutput, + }, + }, + orchestration::compact::{CompactionContextError, ExecutionState}, +}; + +use super::compact::{CollectionCompactInfo, CompactWriters, CompactionContext}; +use chroma_types::AdvanceAttachedFunctionError; + +#[derive(Debug, Clone)] +pub struct FunctionContext { + pub attached_function_id: AttachedFunctionUuid, + pub function_id: Uuid, + pub updated_completion_offset: u64, +} + +#[derive(Debug)] +pub struct AttachedFunctionOrchestrator { + input_collection_info: CollectionCompactInfo, + output_context: CompactionContext, + result_channel: Option< + Sender>, + >, + + // Store the materialized outputs from DataFetchOrchestrator + materialized_log_data: Arc>, + + // Function context + function_context: OnceCell, + + // Execution state + state: ExecutionState, + + orchestrator_context: OrchestratorContext, + + dispatcher: ComponentHandle, +} + +#[derive(Error, Debug)] +pub enum AttachedFunctionOrchestratorError { + #[error("Operation aborted because resources exhausted")] + Aborted, + #[error("Failed to get attached function: {0}")] + GetAttachedFunction(#[from] GetAttachedFunctionOperatorError), + #[error("Failed to get collection and segments: {0}")] + GetCollectionAndSegments(#[from] GetCollectionAndSegmentsError), + #[error("No attached function found")] + NoAttachedFunction, + #[error("Failed to execute attached function: {0}")] + ExecuteAttachedFunction(#[from] ExecuteAttachedFunctionError), + #[error("Failed to advance attached function: {0}")] + AdvanceAttachedFunction(#[from] AdvanceAttachedFunctionError), + #[error("Function context not set")] + FunctionContextNotSet, + #[error("Invariant violation: {0}")] + InvariantViolation(String), + #[error("Failed to materialize log: {0}")] + MaterializeLog(#[from] MaterializeLogOperatorError), + #[error("Compaction context error: {0}")] + CompactionContext(#[from] CompactionContextError), + #[error("Output collection ID not set")] + OutputCollectionIdNotSet, + #[error("Channel error: {0}")] + Channel(#[from] ChannelError), + #[error("Could not count current segment: {0}")] + CountError(Box), + #[error("Receiver error: {0}")] + RecvError(#[from] RecvError), + #[error("Panic error: {0}")] + PanicError(#[from] PanicError), + #[error("Error creating metadata writer: {0}")] + MetadataSegment(#[from] MetadataSegmentError), + #[error("Error creating record segment writer: {0}")] + RecordSegmentWriter(#[from] RecordSegmentWriterCreationError), + #[error("Error creating hnsw writer: {0}")] + HnswSegment(#[from] DistributedHNSWSegmentFromSegmentError), + #[error("Error creating spann writer: {0}")] + SpannSegment(#[from] SpannSegmentWriterError), +} + +impl ChromaError for AttachedFunctionOrchestratorError { + fn code(&self) -> ErrorCodes { + match self { + AttachedFunctionOrchestratorError::Aborted => ErrorCodes::Aborted, + AttachedFunctionOrchestratorError::GetAttachedFunction(e) => e.code(), + AttachedFunctionOrchestratorError::GetCollectionAndSegments(e) => e.code(), + AttachedFunctionOrchestratorError::NoAttachedFunction => ErrorCodes::NotFound, + AttachedFunctionOrchestratorError::ExecuteAttachedFunction(e) => e.code(), + AttachedFunctionOrchestratorError::AdvanceAttachedFunction(e) => e.code(), + AttachedFunctionOrchestratorError::MaterializeLog(e) => e.code(), + AttachedFunctionOrchestratorError::FunctionContextNotSet => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::InvariantViolation(_) => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::CompactionContext(e) => e.code(), + AttachedFunctionOrchestratorError::OutputCollectionIdNotSet => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::Channel(e) => e.code(), + AttachedFunctionOrchestratorError::RecvError(_) => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::CountError(e) => e.code(), + AttachedFunctionOrchestratorError::PanicError(e) => e.code(), + AttachedFunctionOrchestratorError::MetadataSegment(e) => e.code(), + AttachedFunctionOrchestratorError::RecordSegmentWriter(e) => e.code(), + AttachedFunctionOrchestratorError::HnswSegment(e) => e.code(), + AttachedFunctionOrchestratorError::SpannSegment(e) => e.code(), + } + } + + fn should_trace_error(&self) -> bool { + match self { + AttachedFunctionOrchestratorError::Aborted => true, + AttachedFunctionOrchestratorError::GetAttachedFunction(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::GetCollectionAndSegments(e) => { + e.should_trace_error() + } + AttachedFunctionOrchestratorError::NoAttachedFunction => false, + AttachedFunctionOrchestratorError::ExecuteAttachedFunction(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::AdvanceAttachedFunction(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::MaterializeLog(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::FunctionContextNotSet => true, + AttachedFunctionOrchestratorError::InvariantViolation(_) => true, + AttachedFunctionOrchestratorError::CompactionContext(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::OutputCollectionIdNotSet => true, + AttachedFunctionOrchestratorError::Channel(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::RecvError(_) => true, + AttachedFunctionOrchestratorError::CountError(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::PanicError(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::MetadataSegment(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::RecordSegmentWriter(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::HnswSegment(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::SpannSegment(e) => e.should_trace_error(), + } + } +} + +impl From> for AttachedFunctionOrchestratorError +where + E: Into, +{ + fn from(value: TaskError) -> Self { + match value { + TaskError::Aborted => AttachedFunctionOrchestratorError::Aborted, + TaskError::Panic(e) => e.into(), + TaskError::TaskFailed(e) => e.into(), + } + } +} + +#[derive(Debug)] +pub enum AttachedFunctionOrchestratorResponse { + /// No attached function was found, so nothing was executed + NoAttachedFunction { job_id: JobId }, + /// Success - attached function was executed successfully + Success { + job_id: JobId, + materialized_output: Vec, + output_collection_info: CollectionCompactInfo, + attached_function_id: AttachedFunctionUuid, + completion_offset: u64, + }, +} + +impl AttachedFunctionOrchestrator { + pub fn new( + input_collection_info: CollectionCompactInfo, + output_context: CompactionContext, + dispatcher: ComponentHandle, + data_fetch_records: Arc>, + ) -> Self { + let orchestrator_context = OrchestratorContext::new(dispatcher.clone()); + + AttachedFunctionOrchestrator { + input_collection_info, + output_context, + result_channel: None, + materialized_log_data: data_fetch_records, + function_context: OnceCell::new(), + state: ExecutionState::MaterializeApplyCommitFlush, + orchestrator_context, + dispatcher, + } + } + + /// Get the input collection info, following the same pattern as CompactionContext + pub fn get_input_collection_info(&self) -> &CollectionCompactInfo { + &self.input_collection_info + } + + /// Get the output collection info if it has been set + pub fn get_output_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, AttachedFunctionOrchestratorError> { + self.output_context + .get_output_collection_info() + .map_err(AttachedFunctionOrchestratorError::CompactionContext) + } + + /// Get the output collection ID if it has been set + pub fn get_output_collection_id( + &self, + ) -> Result { + self.output_context + .get_output_collection_info() + .map(|info| info.collection_id) + .map_err(AttachedFunctionOrchestratorError::CompactionContext) + } + + /// Set the output collection info + pub fn set_output_collection_info( + &mut self, + collection_info: CollectionCompactInfo, + ) -> Result<(), CollectionCompactInfo> { + self.output_context + .output_collection_info + .set(collection_info) + } + + /// Get the function context if it has been set + pub fn get_function_context(&self) -> Option<&FunctionContext> { + self.function_context.get() + } + + /// Set the function context + pub fn set_function_context( + &self, + function_context: FunctionContext, + ) -> Result<(), FunctionContext> { + self.function_context.set(function_context) + } + + async fn finish_no_attached_function(&mut self, ctx: &ComponentContext) { + let collection_info = self.get_input_collection_info(); + let job_id = collection_info.collection_id.into(); + self.terminate_with_result( + Ok(AttachedFunctionOrchestratorResponse::NoAttachedFunction { job_id }), + ctx, + ) + .await; + } + + async fn finish_success( + &mut self, + materialized_output: Vec, + ctx: &ComponentContext, + ) { + let collection_info = self.get_input_collection_info(); + + // Get output collection info - should always exist in success case + let output_collection_info = match self.get_output_collection_info() { + Ok(info) => info.clone(), + Err(e) => { + self.terminate_with_result(Err(e), ctx).await; + return; + } + }; + + // Get attached function ID - should always exist in success case + let attached_function = match self.get_function_context() { + Some(func) => func, + None => { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::FunctionContextNotSet), + ctx, + ) + .await; + return; + } + }; + let attached_function_id = attached_function.attached_function_id; + + // Get the completion offset from the input collection's pulled log offset + let completion_offset = collection_info.pulled_log_offset as u64; + + println!( + "Attached function finished successfully with {} records", + materialized_output.len() + ); + + let job_id = collection_info.collection_id.into(); + self.terminate_with_result( + Ok(AttachedFunctionOrchestratorResponse::Success { + job_id, + materialized_output, + output_collection_info, + attached_function_id, + completion_offset, + }), + ctx, + ) + .await; + } + + async fn materialize_log( + &mut self, + partitions: Vec>, + ctx: &ComponentContext, + ) { + self.state = ExecutionState::MaterializeApplyCommitFlush; + + // NOTE: We allow writers to be uninitialized for the case when the materialized logs are empty + let record_reader = self + .output_context + .get_output_segment_writers() + .ok() + .and_then(|writers| writers.record_reader); + + let next_max_offset_id = Arc::new( + record_reader + .as_ref() + .map(|reader| AtomicU32::new(reader.get_max_offset_id() + 1)) + .unwrap_or_default(), + ); + + if let Some(rr) = record_reader.as_ref() { + let count = match rr.count().await { + Ok(count) => count as u64, + Err(err) => { + return self + .terminate_with_result( + Err(AttachedFunctionOrchestratorError::CountError(err)), + ctx, + ) + .await; + } + }; + + let collection_info = match self.output_context.get_output_collection_info_mut() { + Ok(info) => info, + Err(err) => { + return self.terminate_with_result(Err(err.into()), ctx).await; + } + }; + collection_info.collection.total_records_post_compaction = count; + } + + for partition in partitions.iter() { + let operator = MaterializeLogOperator::new(); + let input = MaterializeLogInput::new( + partition.clone(), + record_reader.clone(), + next_max_offset_id.clone(), + ); + let task = wrap( + operator, + input, + ctx.receiver(), + self.output_context + .orchestrator_context + .task_cancellation_token + .clone(), + ); + self.send(task, ctx, Some(Span::current())).await; + } + } +} + +#[async_trait] +impl Orchestrator for AttachedFunctionOrchestrator { + type Output = AttachedFunctionOrchestratorResponse; + type Error = AttachedFunctionOrchestratorError; + + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() + } + + fn context(&self) -> &OrchestratorContext { + &self.orchestrator_context + } + + async fn initial_tasks( + &mut self, + ctx: &ComponentContext, + ) -> Vec<(TaskMessage, Option)> { + // Start by getting the attached function for this collection + let collection_info = self.get_input_collection_info(); + let operator = Box::new(GetAttachedFunctionOperator::new( + self.output_context.sysdb.clone(), + collection_info.collection_id, + )); + let input = GetAttachedFunctionInput { + collection_id: collection_info.collection_id, + }; + let task = wrap( + operator, + input, + ctx.receiver(), + self.context().task_cancellation_token.clone(), + ); + vec![(task, Some(Span::current()))] + } + + fn set_result_channel( + &mut self, + sender: Sender< + Result, + >, + ) { + self.result_channel = Some(sender) + } + + fn take_result_channel( + &mut self, + ) -> Option< + Sender>, + > { + self.result_channel.take() + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + self.finish_success(vec![message], ctx).await; + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + match message.attached_function { + Some(attached_function) => { + tracing::info!( + "[AttachedFunctionOrchestrator]: Found attached function '{}' for collection", + attached_function.name + ); + + // TODO(tanujnay112): Handle error + let _ = self.function_context.set(FunctionContext { + attached_function_id: attached_function.id, + function_id: attached_function.function_id, + updated_completion_offset: attached_function.completion_offset, + }); + + // Get the output collection ID from the attached function + let output_collection_id = match attached_function.output_collection_id { + Some(id) => id, + None => { + tracing::error!( + "[AttachedFunctionOrchestrator]: Output collection ID not set for attached function '{}'", + attached_function.name + ); + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::OutputCollectionIdNotSet), + ctx, + ) + .await; + return; + } + }; + + // Next step: get the output collection segments using the existing GetCollectionAndSegmentsOperator + let operator = Box::new(GetCollectionAndSegmentsOperator::new( + self.output_context.sysdb.clone(), + output_collection_id, + )); + let input = (); + let task = wrap( + operator, + input, + ctx.receiver(), + self.context().task_cancellation_token.clone(), + ); + let res = self.dispatcher().send(task, None).await; + self.ok_or_terminate(res, ctx).await; + } + None => { + tracing::info!("[AttachedFunctionOrchestrator]: No attached function found"); + self.finish_no_attached_function(ctx).await; + } + } + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + tracing::debug!( + "[AttachedFunctionOrchestrator]: Found output collection segments - metadata: {:?}, record: {:?}, vector: {:?}", + message.metadata_segment.id, + message.record_segment.id, + message.vector_segment.id + ); + + // Create segment writers for the output collection + let collection = &message.collection; + let dimension = match collection.dimension { + Some(dim) => dim as usize, + None => { + // Output collection is not initialized, cannot create writers + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::InvariantViolation( + "Output collection dimension is not set".to_string(), + )), + ctx, + ) + .await; + return; + } + }; + + let record_writer = match self + .ok_or_terminate( + RecordSegmentWriter::from_segment( + &collection.tenant, + &collection.database_id, + &message.record_segment, + &self.output_context.blockfile_provider, + ) + .await, + ctx, + ) + .await + { + Some(writer) => writer, + None => return, + }; + + let metadata_writer = match self + .ok_or_terminate( + MetadataSegmentWriter::from_segment( + &collection.tenant, + &collection.database_id, + &message.metadata_segment, + &self.output_context.blockfile_provider, + ) + .await, + ctx, + ) + .await + { + Some(writer) => writer, + None => return, + }; + + let (hnsw_index_uuid, vector_writer) = match message.vector_segment.r#type { + SegmentType::Spann => match self + .ok_or_terminate( + self.output_context + .spann_provider + .write(collection, &message.vector_segment, dimension) + .await, + ctx, + ) + .await + { + Some(writer) => (writer.hnsw_index_uuid(), VectorSegmentWriter::Spann(writer)), + None => return, + }, + _ => match self + .ok_or_terminate( + DistributedHNSWSegmentWriter::from_segment( + collection, + &message.vector_segment, + dimension, + self.output_context.hnsw_provider.clone(), + ) + .await + .map_err(|err| *err), + ctx, + ) + .await + { + Some(writer) => (writer.index_uuid(), VectorSegmentWriter::Hnsw(writer)), + None => return, + }, + }; + + let writers = CompactWriters { + record_reader: None, // Output collection doesn't need a reader + metadata_writer, + record_writer, + vector_writer, + }; + + // Store the output collection info with writers + let output_collection_info = CollectionCompactInfo { + collection_id: message.collection.collection_id, + collection: message.collection.clone(), + writers: Some(writers), + pulled_log_offset: message.collection.log_position, + hnsw_index_uuid: Some(hnsw_index_uuid), + schema: message.collection.schema.clone(), + }; + + if self + .set_output_collection_info(output_collection_info) + .is_err() + { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::InvariantViolation( + "Failed to set output collection info".to_string(), + )), + ctx, + ) + .await; + return; + } + + let function_context = self.function_context.get(); + + let attached_function = match function_context { + Some(func) => func, + None => { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::NoAttachedFunction), + ctx, + ) + .await; + return; + } + }; + + let function_id = attached_function.function_id; + // Execute the attached function + let operator = match ExecuteAttachedFunctionOperator::from_attached_function( + function_id, + self.output_context.log.clone(), + ) { + Ok(op) => Box::new(op), + Err(e) => { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::ExecuteAttachedFunction( + e, + )), + ctx, + ) + .await; + return; + } + }; + + // Get the input collection info to access pulled_log_offset + let collection_info = self.get_input_collection_info(); + + let input = ExecuteAttachedFunctionInput { + materialized_logs: Arc::clone(&self.materialized_log_data), // Use the actual materialized logs from data fetch + tenant_id: "default".to_string(), // TODO: Get actual tenant ID + output_collection_id: message.collection.collection_id, + completion_offset: collection_info.pulled_log_offset as u64, // Use the completion offset from input collection + output_record_segment: message.record_segment.clone(), + blockfile_provider: self.output_context.blockfile_provider.clone(), + }; + + let task = wrap( + operator, + input, + ctx.receiver(), + self.context().task_cancellation_token.clone(), + ); + let res = self.dispatcher().send(task, None).await; + self.ok_or_terminate(res, ctx).await; + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + tracing::info!( + "[AttachedFunctionOrchestrator]: Attached function executed successfully, processed {} records", + message.records_processed + ); + self.materialize_log(vec![message.output_records], ctx) + .await; + } +} diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index 3cdb64b07b0..83af95f09a3 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -1,4 +1,4 @@ -use std::cell::OnceCell; +use std::{cell::OnceCell, sync::Arc}; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; @@ -14,20 +14,25 @@ use chroma_sysdb::SysDb; use chroma_system::{ ComponentHandle, Dispatcher, Orchestrator, OrchestratorContext, PanicError, System, TaskError, }; -use chroma_types::{Collection, CollectionUuid, JobId, Schema, SegmentFlushInfo, SegmentUuid}; +use chroma_types::{Collection, CollectionUuid, JobId, Schema, SegmentUuid}; use opentelemetry::metrics::Counter; use thiserror::Error; use super::apply_logs_orchestrator::{ApplyLogsOrchestrator, ApplyLogsOrchestratorError}; +use super::attached_function_orchestrator::{ + AttachedFunctionOrchestrator, AttachedFunctionOrchestratorError, + AttachedFunctionOrchestratorResponse, +}; use super::log_fetch_orchestrator::{ LogFetchOrchestrator, LogFetchOrchestratorResponse, RequireCompactionOffsetRepair, Success, }; -use super::register_orchestrator::RegisterOrchestrator; +use super::register_orchestrator::{CollectionRegisterInfo, RegisterOrchestrator}; use crate::execution::{ operators::materialize_logs::MaterializeLogOutput, orchestration::{ apply_logs_orchestrator::ApplyLogsOrchestratorResponse, + attached_function_orchestrator::FunctionContext, log_fetch_orchestrator::LogFetchOrchestratorError, register_orchestrator::{RegisterOrchestratorError, RegisterOrchestratorResponse}, }, @@ -99,7 +104,9 @@ pub struct CollectionCompactInfo { #[derive(Debug)] pub struct CompactionContext { - pub collection_info: OnceCell, + pub input_collection_info: OnceCell, + pub output_collection_info: OnceCell, + pub attached_function_context: OnceCell, pub log: Log, pub sysdb: SysDb, pub blockfile_provider: BlockfileProvider, @@ -119,7 +126,9 @@ impl Clone for CompactionContext { fn clone(&self) -> Self { let orchestrator_context = OrchestratorContext::new(self.dispatcher.clone()); Self { - collection_info: self.collection_info.clone(), + input_collection_info: self.input_collection_info.clone(), + output_collection_info: self.output_collection_info.clone(), + attached_function_context: self.attached_function_context.clone(), log: self.log.clone(), sysdb: self.sysdb.clone(), blockfile_provider: self.blockfile_provider.clone(), @@ -143,6 +152,8 @@ pub enum CompactionError { Aborted, #[error("Error applying data to segment writers: {0}")] ApplyDataError(#[from] ApplyLogsOrchestratorError), + #[error("Error executing attached function: {0}")] + AttachedFunction(#[from] AttachedFunctionOrchestratorError), #[error("Error fetching compaction context: {0}")] CompactionContextError(#[from] CompactionContextError), #[error("Error fetching logs: {0}")] @@ -172,7 +183,13 @@ impl ChromaError for CompactionError { fn code(&self) -> ErrorCodes { match self { CompactionError::Aborted => ErrorCodes::Aborted, - _ => ErrorCodes::Internal, + CompactionError::ApplyDataError(e) => e.code(), + CompactionError::AttachedFunction(e) => e.code(), + CompactionError::CompactionContextError(e) => e.code(), + CompactionError::DataFetchError(e) => e.code(), + CompactionError::RegisterError(e) => e.code(), + CompactionError::PanicError(e) => e.code(), + CompactionError::InvariantViolation(_) => ErrorCodes::Internal, } } @@ -180,6 +197,7 @@ impl ChromaError for CompactionError { match self { Self::Aborted => true, Self::ApplyDataError(e) => e.should_trace_error(), + Self::AttachedFunction(e) => e.should_trace_error(), Self::CompactionContextError(e) => e.should_trace_error(), Self::DataFetchError(e) => e.should_trace_error(), Self::PanicError(e) => e.should_trace_error(), @@ -225,7 +243,9 @@ impl CompactionContext { ) -> Self { let orchestrator_context = OrchestratorContext::new(dispatcher.clone()); CompactionContext { - collection_info: OnceCell::new(), + input_collection_info: OnceCell::new(), + output_collection_info: OnceCell::new(), + attached_function_context: OnceCell::new(), is_rebuild, fetch_log_batch_size, max_compaction_size, @@ -247,35 +267,67 @@ impl CompactionContext { self.poison_offset = Some(offset); } - pub fn get_segment_writers(&self) -> Result { - self.get_collection_info()?.writers.clone().ok_or( - CompactionContextError::InvariantViolation("Segment writers should have been set"), + pub fn get_output_segment_writers(&self) -> Result { + self.get_output_collection_info()?.writers.clone().ok_or( + CompactionContextError::InvariantViolation( + "Output segment writers should have been set", + ), + ) + } + + pub fn get_input_segment_writers(&self) -> Result { + self.get_input_collection_info()?.writers.clone().ok_or( + CompactionContextError::InvariantViolation( + "Input segment writers should have been set", + ), ) } - pub fn get_collection_info(&self) -> Result<&CollectionCompactInfo, CompactionContextError> { - self.collection_info + pub fn get_input_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, CompactionContextError> { + self.input_collection_info .get() .ok_or(CompactionContextError::InvariantViolation( "Collection info should have been set", )) } - pub fn get_collection_info_mut( + pub fn get_output_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, CompactionContextError> { + self.output_collection_info + .get() + .ok_or(CompactionContextError::InvariantViolation( + "Collection info should have been set", + )) + } + + pub fn get_input_collection_info_mut( &mut self, ) -> Result<&mut CollectionCompactInfo, CompactionContextError> { - self.collection_info + self.input_collection_info .get_mut() .ok_or(CompactionContextError::InvariantViolation( "Collection info mut should have been set", )) } - pub fn get_segment_writer_by_id( + pub fn get_output_collection_info_mut( + &mut self, + ) -> Result<&mut CollectionCompactInfo, CompactionContextError> { + self.output_collection_info + .get_mut() + .ok_or(CompactionContextError::InvariantViolation( + "Collection info mut should have been set", + )) + } + + pub fn get_output_segment_writer_by_id( &self, segment_id: SegmentUuid, ) -> Result, CompactionContextError> { - let writers = self.get_segment_writers()?; + let writers = self.get_output_segment_writers()?; if writers.metadata_writer.id == segment_id { return Ok(ChromaSegmentWriter::MetadataSegment( @@ -330,7 +382,7 @@ impl CompactionContext { let materialized = success.materialized; let collection_info = success.collection_info; - self.collection_info + self.input_collection_info .set(collection_info.clone()) .map_err(|_| { CompactionContextError::InvariantViolation("Collection info already set") @@ -347,10 +399,10 @@ impl CompactionContext { pub(crate) async fn run_apply_logs( &mut self, - log_fetch_records: Vec, + log_fetch_records: Arc>, system: System, ) -> Result { - let collection_info = self.get_collection_info()?; + let collection_info = self.get_input_collection_info()?; if log_fetch_records.is_empty() { return Ok(ApplyLogsOrchestratorResponse::new_with_empty_results( collection_info.collection_id.into(), @@ -358,8 +410,14 @@ impl CompactionContext { )); } + if self.get_output_collection_info().is_err() { + return Err(ApplyLogsOrchestratorError::InvariantViolation( + "Output collection info should have been set before running apply logs", + )); + } + // INVARIANT: Every element of log_fetch_records should be non-empty - for mat_logs in &log_fetch_records { + for mat_logs in log_fetch_records.iter() { if mat_logs.result.is_empty() { return Err(ApplyLogsOrchestratorError::InvariantViolation( "Every element of log_fetch_records should be non-empty", @@ -379,9 +437,7 @@ impl CompactionContext { } }; - let collection_info = self.collection_info.get_mut().ok_or( - ApplyLogsOrchestratorError::InvariantViolation("Collection info should have been set"), - )?; + let collection_info = self.get_output_collection_info_mut()?; collection_info.schema = apply_logs_response.schema.clone(); collection_info.collection.total_records_post_compaction = apply_logs_response.total_records_post_compaction; @@ -389,18 +445,122 @@ impl CompactionContext { Ok(apply_logs_response) } + // Should be invoked on output collection context + pub(crate) async fn run_attached_function( + &mut self, + data_fetch_records: Arc>, + system: System, + ) -> Result { + let input_collection_info = self.get_input_collection_info()?.clone(); + let input_collection_info_clone = input_collection_info.clone(); + let attached_function_orchestrator = AttachedFunctionOrchestrator::new( + input_collection_info, + self.clone(), + self.dispatcher.clone(), + data_fetch_records, + ); + + let attached_function_response = + match Box::pin(attached_function_orchestrator.run(system)).await { + Ok(response) => response, + Err(e) => { + if e.should_trace_error() { + tracing::error!("Attached function phase failed: {e}"); + } + return Err(e); + } + }; + + // Set the output collection info based on the response + match &attached_function_response { + AttachedFunctionOrchestratorResponse::NoAttachedFunction { .. } => { + self.output_collection_info + .set(input_collection_info_clone) + .map_err(|_| { + AttachedFunctionOrchestratorError::InvariantViolation( + "Collection info should not have been already set".to_string(), + ) + })?; + } + AttachedFunctionOrchestratorResponse::Success { + output_collection_info, + .. + } => { + self.output_collection_info + .set(output_collection_info.clone()) + .map_err(|_| { + AttachedFunctionOrchestratorError::InvariantViolation( + "Collection info should not have been already set".to_string(), + ) + })?; + } + } + + Ok(attached_function_response) + } + + async fn run_attached_function_workflow( + &mut self, + log_fetch_records: Arc>, + system: System, + ) -> Result, CompactionError> { + let attached_function_result = + Box::pin(self.run_attached_function(log_fetch_records, system.clone())).await?; + + match attached_function_result { + AttachedFunctionOrchestratorResponse::NoAttachedFunction { .. } => Ok(None), + AttachedFunctionOrchestratorResponse::Success { + job_id: _, + output_collection_info, + materialized_output, + attached_function_id, + completion_offset, + } => { + // Update self to use the output collection for apply_logs + self.output_collection_info = OnceCell::from(output_collection_info.clone()); + + // Apply materialized output to output collection + let apply_logs_response = self + .run_apply_logs(Arc::new(materialized_output), system.clone()) + .await?; + + let function_context = FunctionContext { + attached_function_id, + function_id: attached_function_id.0, + updated_completion_offset: completion_offset, + }; + + let collection_register_info = CollectionRegisterInfo { + collection_info: output_collection_info, + flush_results: apply_logs_response.flush_results, + collection_logical_size_bytes: apply_logs_response + .collection_logical_size_bytes, + }; + + Ok(Some((function_context, collection_register_info))) + } + } + } + pub(crate) async fn run_register( &mut self, - flush_results: Vec, - collection_logical_size_bytes: u64, + collection_register_infos: Vec, + function_register_info: Option, system: System, ) -> Result { let dispatcher = self.dispatcher.clone(); + + if collection_register_infos.is_empty() || collection_register_infos.len() > 2 { + return Err(RegisterOrchestratorError::InvariantViolation( + "Invalid number of collection register infos", + )); + } + let register_orchestrator = RegisterOrchestrator::new( self, dispatcher, - flush_results, - collection_logical_size_bytes, + collection_register_infos, + function_register_info, ); match register_orchestrator.run(system).await { @@ -433,15 +593,90 @@ impl CompactionContext { } }; - let apply_logs_response = self - .run_apply_logs(log_fetch_records, system.clone()) - .await?; + // Wrap in Arc to avoid cloning large MaterializeLogOutput data + let log_fetch_records = Arc::new(log_fetch_records); + let log_fetch_records_clone = log_fetch_records.clone(); + + let input_collection_info = + self.input_collection_info + .get() + .ok_or(CompactionError::InvariantViolation( + "Input collection info should not be None", + ))?; + + // Clone first - both clones will have empty output_collection_info + let mut self_clone_fn = self.clone(); + let mut self_clone_compact = self.clone(); + let system_clone_fn = system.clone(); + let system_clone_compact = system.clone(); + + // Set output_collection_info on self and the apply_logs clone + // The attached function orchestrator will set its own separate output_collection_info + self.output_collection_info + .set(input_collection_info.clone()) + .map_err(|_| { + CompactionError::InvariantViolation( + "Collection info should not have been already set", + ) + })?; + + // 1. Attached function execution + apply output to output collection + // 2. Apply input logs to input collection + // Box the futures to avoid stack overflow with large state machines + let fn_future = async move { + Box::pin( + self_clone_fn + .run_attached_function_workflow(log_fetch_records_clone, system_clone_fn), + ) + .await + }; + + let compact_future = Box::pin(async move { + self_clone_compact + .output_collection_info + .set(input_collection_info.clone()) + .map_err(|_| { + CompactionError::InvariantViolation( + "Collection info should not have been already set", + ) + })?; + let apply_logs_response = self_clone_compact + .run_apply_logs(log_fetch_records, system_clone_compact) + .await?; + + // Build CollectionRegisterInfo from the updated context + let collection_info = self_clone_compact + .get_output_collection_info() + .map_err(CompactionError::CompactionContextError)? + .clone(); + Ok::(CollectionRegisterInfo { + collection_info, + flush_results: apply_logs_response.flush_results, + collection_logical_size_bytes: apply_logs_response.collection_logical_size_bytes, + }) + }); + + let (fn_result, compact_result) = tokio::join!(fn_future, compact_future); + + let fn_result = fn_result?; + let compact_result = compact_result?; + + // Collect results + let mut attached_function_context = None; + let mut results: Vec = Vec::new(); + + if let Some((function_context, collection_register_info)) = fn_result { + attached_function_context = Some(function_context); + results.push(collection_register_info); + } + // Otherwise there was no attached function + // Process input collection result // Invariant: flush_results is empty => collection_logical_size_bytes == collection_info.collection.size_bytes_post_compaction - if apply_logs_response.flush_results.is_empty() - && apply_logs_response.collection_logical_size_bytes - != self - .get_collection_info()? + if compact_result.flush_results.is_empty() + && compact_result.collection_logical_size_bytes + != compact_result + .collection_info .collection .size_bytes_post_compaction { @@ -450,12 +685,10 @@ impl CompactionContext { )); } - let _ = Box::pin(self.run_register( - apply_logs_response.flush_results, - apply_logs_response.collection_logical_size_bytes, - system.clone(), - )) - .await?; + results.push(compact_result); + + let _ = + Box::pin(self.run_register(results, attached_function_context, system.clone())).await?; Ok(CompactionResponse::Success { job_id: collection_id.into(), @@ -463,7 +696,17 @@ impl CompactionContext { } pub(crate) async fn cleanup(self) { - if let Some(collection_info) = self.collection_info.get() { + if let Some(collection_info) = self.input_collection_info.get() { + if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { + let _ = HnswIndexProvider::purge_one_id( + self.hnsw_provider.temporary_storage_path.as_path(), + hnsw_index_uuid, + ) + .await; + } + } + + if let Some(collection_info) = self.output_collection_info.get() { if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { let _ = HnswIndexProvider::purge_one_id( self.hnsw_provider.temporary_storage_path.as_path(), @@ -521,9 +764,7 @@ pub async fn compact( compaction_context.set_poison_offset(poison_offset); } - let result = compaction_context - .run_compaction(collection_id, system) - .await; + let result = Box::pin(compaction_context.run_compaction(collection_id, system)).await; Box::pin(compaction_context.cleanup()).await; result } @@ -535,6 +776,7 @@ mod tests { }; use std::collections::HashMap; use std::path::{Path, PathBuf}; + use std::sync::Arc; use tokio::fs; use chroma_blockstore::arrow::config::{BlockManagerConfig, TEST_MAX_BLOCK_SIZE_BYTES}; @@ -567,6 +809,7 @@ mod tests { }; use super::{compact, CompactionContext, CompactionResponse, LogFetchOrchestratorResponse}; + use crate::execution::orchestration::register_orchestrator::CollectionRegisterInfo; async fn get_all_records( system: &System, @@ -2013,8 +2256,23 @@ mod tests { assert_eq!(old_records, new_records); } - #[tokio::test(flavor = "multi_thread")] - async fn test_concurrent_compactions() { + #[test] + fn test_concurrent_compactions() { + // Deep async call chains create large state machines that exceed default 2MB stack + // Use larger stack to accommodate the nested futures + std::thread::Builder::new() + .stack_size(8 * 1024 * 1024) // 8 MB stack + .spawn(|| { + tokio::runtime::Runtime::new() + .unwrap() + .block_on(test_concurrent_compactions_impl()) + }) + .unwrap() + .join() + .unwrap(); + } + + async fn test_concurrent_compactions_impl() { // This test simulates the scenario where: // 1. Compaction 1 starts its log_fetch_orchestrator // 2. Compaction 2 starts and finishes everything @@ -2263,17 +2521,24 @@ mod tests { compaction_1_log_records.len() ); let compaction_1_apply_response = compaction_context_1 - .run_apply_logs(compaction_1_log_records, system.clone()) + .run_apply_logs(Arc::new(compaction_1_log_records), system.clone()) .await .expect("Apply should have succeeded."); - let _register_result = Box::pin(compaction_context_1.run_register( - compaction_1_apply_response.flush_results, - compaction_1_apply_response.collection_logical_size_bytes, - system.clone(), - )) - .await - .expect_err("Register should have failed."); + let register_info = vec![CollectionRegisterInfo { + collection_info: compaction_context_1 + .get_input_collection_info() + .unwrap() + .clone(), + flush_results: compaction_1_apply_response.flush_results, + collection_logical_size_bytes: compaction_1_apply_response + .collection_logical_size_bytes, + }]; + + let _register_result = + Box::pin(compaction_context_1.run_register(register_info, None, system.clone())) + .await + .expect_err("Register should have failed."); // Verify that the collection was successfully compacted (by whichever succeeded) let collection_after_compaction = sysdb diff --git a/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs b/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs index 306180410b7..888c4471db3 100644 --- a/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs @@ -313,7 +313,7 @@ impl LogFetchOrchestrator { // NOTE: We allow writers to be uninitialized for the case when the materialized logs are empty let record_reader = self .context - .get_segment_writers() + .get_input_segment_writers() .ok() .and_then(|writers| writers.record_reader); @@ -334,7 +334,7 @@ impl LogFetchOrchestrator { } }; - let collection_info = match self.context.get_collection_info_mut() { + let collection_info = match self.context.get_input_collection_info_mut() { Ok(info) => info, Err(err) => { return self.terminate_with_result(Err(err.into()), ctx).await; @@ -447,7 +447,7 @@ impl Handler info, None => { self.terminate_with_result( @@ -644,7 +644,7 @@ impl Handler> for LogFetchOrchestrator tracing::info!("Pulled Records: {}", output.len()); match output.iter().last() { Some((rec, _)) => { - let collection_info = match self.context.get_collection_info_mut() { + let collection_info = match self.context.get_input_collection_info_mut() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -659,7 +659,7 @@ impl Handler> for LogFetchOrchestrator } None => { tracing::warn!("No logs were pulled from the log service, this can happen when the log compaction offset is behing the sysdb."); - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_input_collection_info() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -699,7 +699,7 @@ impl Handler> }; tracing::info!("Sourced Records: {}", output.len()); // Each record should corresond to a log - let collection_info = match self.context.get_collection_info_mut() { + let collection_info = match self.context.get_input_collection_info_mut() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -708,7 +708,7 @@ impl Handler> }; collection_info.collection.total_records_post_compaction = output.len() as u64; - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_input_collection_info() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -766,7 +766,7 @@ impl Handler> } self.num_uncompleted_materialization_tasks -= 1; if self.num_uncompleted_materialization_tasks == 0 { - let collection_info = match self.context.collection_info.take() { + let collection_info = match self.context.input_collection_info.take() { Some(info) => info, None => { self.terminate_with_result( diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index f00e8b7105a..6db9bebd9f3 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,4 +1,5 @@ pub mod apply_logs_orchestrator; +pub mod attached_function_orchestrator; pub(crate) mod compact; pub(crate) mod count; pub mod get; diff --git a/rust/worker/src/execution/orchestration/register_orchestrator.rs b/rust/worker/src/execution/orchestration/register_orchestrator.rs index 324c7943ecc..c18234836c9 100644 --- a/rust/worker/src/execution/orchestration/register_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/register_orchestrator.rs @@ -11,9 +11,15 @@ use tokio::sync::oneshot::error::RecvError; use tokio::sync::oneshot::Sender; use tracing::Span; +use crate::execution::operators::finish_attached_function::{ + FinishAttachedFunctionError, FinishAttachedFunctionInput, FinishAttachedFunctionOperator, + FinishAttachedFunctionOutput, +}; use crate::execution::operators::register::{ RegisterError, RegisterInput, RegisterOperator, RegisterOutput, }; +use crate::execution::orchestration::attached_function_orchestrator::FunctionContext; +use crate::execution::orchestration::compact::CollectionCompactInfo; use crate::execution::orchestration::compact::CompactionContextError; use super::compact::{CompactionContext, ExecutionState}; @@ -24,8 +30,34 @@ pub struct RegisterOrchestrator { dispatcher: ComponentHandle, result_channel: Option>>, _state: ExecutionState, - flush_results: Vec, - collection_logical_size_bytes: u64, + // Attached function fields + collection_register_infos: Vec, + function_context: Option, +} + +#[derive(Debug)] +pub struct CollectionRegisterInfo { + pub collection_info: CollectionCompactInfo, + pub flush_results: Vec, + pub collection_logical_size_bytes: u64, +} + +impl From<&CollectionRegisterInfo> for chroma_types::CollectionFlushInfo { + fn from(info: &CollectionRegisterInfo) -> Self { + chroma_types::CollectionFlushInfo { + tenant_id: info.collection_info.collection.tenant.clone(), + collection_id: info.collection_info.collection_id, + log_position: info.collection_info.pulled_log_offset, + collection_version: info.collection_info.collection.version, + segment_flush_info: info.flush_results.clone().into(), + total_records_post_compaction: info + .collection_info + .collection + .total_records_post_compaction, + size_bytes_post_compaction: info.collection_logical_size_bytes, + schema: info.collection_info.schema.clone(), + } + } } #[derive(Debug)] @@ -91,20 +123,26 @@ where } } +impl From for RegisterOrchestratorError { + fn from(value: FinishAttachedFunctionError) -> Self { + RegisterOrchestratorError::Register(value.into()) + } +} + impl RegisterOrchestrator { pub fn new( context: &CompactionContext, dispatcher: ComponentHandle, - flush_results: Vec, - collection_logical_size_bytes: u64, + collection_register_infos: Vec, + function_context: Option, ) -> Self { RegisterOrchestrator { context: context.clone(), dispatcher, result_channel: None, _state: ExecutionState::Register, - flush_results, - collection_logical_size_bytes, + collection_register_infos, + function_context, } } } @@ -134,38 +172,89 @@ impl Orchestrator for RegisterOrchestrator { &mut self, ctx: &ComponentContext, ) -> Vec<(TaskMessage, Option)> { - // Check if collection is set before proceeding - let collection_info = match self.context.get_collection_info() { - Ok(collection_info) => collection_info, - Err(e) => { - self.terminate_with_result(Err(e.into()), ctx).await; - return vec![]; - } - }; + // Check if we have attached function context + let collection_flush_infos = self + .collection_register_infos + .iter() + .map(|info| info.into()) + .collect(); + if let Some(function_context) = &self.function_context { + vec![( + wrap( + FinishAttachedFunctionOperator::new(), + FinishAttachedFunctionInput::new( + collection_flush_infos, + function_context.attached_function_id, + function_context.updated_completion_offset, + self.context.sysdb.clone(), + self.context.log.clone(), + ), + ctx.receiver(), + self.context + .orchestrator_context + .task_cancellation_token + .clone(), + ), + Some(Span::current()), + )] + } else { + // Use regular RegisterOperator for normal compaction + // INVARIANT: We should have exactly one collection register info + let output_collection_register_info = match self.collection_register_infos.first() { + Some(info) => info, + None => { + self.terminate_with_result( + Err(RegisterOrchestratorError::InvariantViolation( + "No collection register info found", + )), + ctx, + ) + .await; + return vec![]; + } + }; - vec![( - wrap( - RegisterOperator::new(), - RegisterInput::new( - collection_info.collection.tenant.clone(), - collection_info.collection_id, - collection_info.pulled_log_offset, - collection_info.collection.version, - self.flush_results.clone().into(), - collection_info.collection.total_records_post_compaction, - self.collection_logical_size_bytes, - self.context.sysdb.clone(), - self.context.log.clone(), - collection_info.schema.clone(), + vec![( + wrap( + RegisterOperator::new(), + RegisterInput::new( + output_collection_register_info + .collection_info + .collection + .tenant + .clone(), + output_collection_register_info + .collection_info + .collection_id, + output_collection_register_info + .collection_info + .pulled_log_offset, + output_collection_register_info + .collection_info + .collection + .version, + output_collection_register_info.flush_results.clone().into(), + output_collection_register_info + .collection_info + .collection + .total_records_post_compaction, + output_collection_register_info.collection_logical_size_bytes, + self.context.sysdb.clone(), + self.context.log.clone(), + output_collection_register_info + .collection_info + .schema + .clone(), + ), + ctx.receiver(), + self.context + .orchestrator_context + .task_cancellation_token + .clone(), ), - ctx.receiver(), - self.context - .orchestrator_context - .task_cancellation_token - .clone(), - ), - Some(Span::current()), - )] + Some(Span::current()), + )] + } } } @@ -178,7 +267,7 @@ impl Handler> for RegisterOrchestrator message: TaskResult, ctx: &ComponentContext, ) { - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_input_collection_info() { Ok(collection_info) => collection_info, Err(e) => { self.terminate_with_result(Err(e.into()), ctx).await; @@ -196,3 +285,40 @@ impl Handler> for RegisterOrchestrator .await; } } + +#[async_trait] +impl Handler> + for RegisterOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let collection_info = match self.context.get_input_collection_info() { + Ok(collection_info) => collection_info, + Err(e) => { + self.terminate_with_result(Err(e.into()), ctx).await; + return; + } + }; + + self.terminate_with_result( + message + .into_inner() + .map_err(|e| match e { + TaskError::TaskFailed(inner_error) => { + RegisterOrchestratorError::Register(inner_error.into()) + } + other_error => other_error.into(), + }) + .map(|_| RegisterOrchestratorResponse { + job_id: collection_info.collection_id.into(), + }), + ctx, + ) + .await; + } +} From da946db163ab03169b7394c4b2400caca4c8d0b9 Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Thu, 20 Nov 2025 16:58:53 -0800 Subject: [PATCH 2/3] sicheng comments --- rust/sysdb/src/sysdb.rs | 1 - rust/types/src/flush.rs | 21 -- .../src/execution/functions/statistics.rs | 3 +- .../src/execution/operators/execute_task.rs | 2 - .../orchestration/apply_logs_orchestrator.rs | 14 +- .../attached_function_orchestrator.rs | 14 +- .../src/execution/orchestration/compact.rs | 233 +++++++----------- .../orchestration/log_fetch_orchestrator.rs | 30 +-- .../orchestration/register_orchestrator.rs | 4 +- 9 files changed, 119 insertions(+), 203 deletions(-) diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index d0ef29bb696..64c368ac4af 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -628,7 +628,6 @@ impl SysDb { } } - #[allow(clippy::too_many_arguments)] pub async fn flush_compaction_and_attached_function( &mut self, collections: Vec, diff --git a/rust/types/src/flush.rs b/rust/types/src/flush.rs index 0020823c599..13e21ccc5c8 100644 --- a/rust/types/src/flush.rs +++ b/rust/types/src/flush.rs @@ -158,27 +158,6 @@ impl FlushCompactionResponse { } } -impl TryFrom for FlushCompactionResponse { - type Error = FlushCompactionResponseConversionError; - - fn try_from( - value: FlushCollectionCompactionAndAttachedFunctionResponse, - ) -> Result { - // Use first collection for backward compatibility - let first_collection = value - .collections - .first() - .ok_or(FlushCompactionResponseConversionError::MissingCollections)?; - let id = Uuid::parse_str(&first_collection.collection_id) - .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; - Ok(FlushCompactionResponse { - collection_id: CollectionUuid(id), - collection_version: first_collection.collection_version, - last_compaction_time: first_collection.last_compaction_time, - }) - } -} - impl TryFrom for FlushCompactionAndAttachedFunctionResponse { diff --git a/rust/worker/src/execution/functions/statistics.rs b/rust/worker/src/execution/functions/statistics.rs index 2a306e35fdb..da88b82477d 100644 --- a/rust/worker/src/execution/functions/statistics.rs +++ b/rust/worker/src/execution/functions/statistics.rs @@ -181,7 +181,8 @@ impl AttachedFunctionExecutor for StatisticsFunctionExecutor { let mut counts: HashMap>> = HashMap::default(); for (hydrated_record, _index) in input_records.iter() { - // Skip delete operations - they should not be counted in statistics + // This is only applicable for non-incremental statistics. + // TODO(tanujnay112): Change this when we make incremental statistics work. if hydrated_record.get_operation() == MaterializedLogOperation::DeleteExisting { continue; } diff --git a/rust/worker/src/execution/operators/execute_task.rs b/rust/worker/src/execution/operators/execute_task.rs index d5d2955a3a7..4458e3d6fef 100644 --- a/rust/worker/src/execution/operators/execute_task.rs +++ b/rust/worker/src/execution/operators/execute_task.rs @@ -50,8 +50,6 @@ impl AttachedFunctionExecutor for CountAttachedFunction { let records_count = input_records.len() as i64; let new_total_count = records_count; - println!("new_total_count is {}", new_total_count); - // Create output record with updated count let mut metadata = std::collections::HashMap::new(); metadata.insert( diff --git a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs index 39c40659035..27f2c8b23d2 100644 --- a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs @@ -206,7 +206,7 @@ impl ApplyLogsOrchestrator { let mut tasks_to_run = Vec::new(); self.num_materialized_logs += materialized_logs.len() as u64; - let writers = self.context.get_output_segment_writers()?; + let writers = self.context.get_segment_writers()?; { self.num_uncompleted_tasks_by_segment @@ -255,7 +255,7 @@ impl ApplyLogsOrchestrator { materialized_logs.clone(), writers.record_reader.clone(), self.context - .get_output_collection_info()? + .get_collection_info()? .collection .schema .clone(), @@ -356,7 +356,7 @@ impl ApplyLogsOrchestrator { .add(self.num_materialized_logs, &[]); self.state = ExecutionState::Register; - let collection_info = match self.context.get_output_collection_info() { + let collection_info = match self.context.get_collection_info() { Ok(collection_info) => collection_info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -525,7 +525,7 @@ impl Handler info, Err(err) => { return self.terminate_with_result(Err(err.into()), ctx).await; @@ -582,9 +582,7 @@ impl Handler writer, None => return, @@ -614,7 +612,7 @@ impl Handler info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; diff --git a/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs index a51547b5876..ac5ea337845 100644 --- a/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs @@ -230,7 +230,7 @@ impl AttachedFunctionOrchestrator { &self, ) -> Result<&CollectionCompactInfo, AttachedFunctionOrchestratorError> { self.output_context - .get_output_collection_info() + .get_collection_info() .map_err(AttachedFunctionOrchestratorError::CompactionContext) } @@ -239,7 +239,7 @@ impl AttachedFunctionOrchestrator { &self, ) -> Result { self.output_context - .get_output_collection_info() + .get_collection_info() .map(|info| info.collection_id) .map_err(AttachedFunctionOrchestratorError::CompactionContext) } @@ -249,9 +249,7 @@ impl AttachedFunctionOrchestrator { &mut self, collection_info: CollectionCompactInfo, ) -> Result<(), CollectionCompactInfo> { - self.output_context - .output_collection_info - .set(collection_info) + self.output_context.collection_info.set(collection_info) } /// Get the function context if it has been set @@ -339,7 +337,7 @@ impl AttachedFunctionOrchestrator { // NOTE: We allow writers to be uninitialized for the case when the materialized logs are empty let record_reader = self .output_context - .get_output_segment_writers() + .get_segment_writers() .ok() .and_then(|writers| writers.record_reader); @@ -363,7 +361,7 @@ impl AttachedFunctionOrchestrator { } }; - let collection_info = match self.output_context.get_output_collection_info_mut() { + let collection_info = match self.output_context.get_collection_info_mut() { Ok(info) => info, Err(err) => { return self.terminate_with_result(Err(err.into()), ctx).await; @@ -490,7 +488,7 @@ impl Handler, - pub output_collection_info: OnceCell, - pub attached_function_context: OnceCell, + pub collection_info: OnceCell, pub log: Log, pub sysdb: SysDb, pub blockfile_provider: BlockfileProvider, @@ -118,6 +116,7 @@ pub struct CompactionContext { pub fetch_log_batch_size: u32, pub max_compaction_size: usize, pub max_partition_size: usize, + pub hnsw_index_uuids: HashSet, // TODO(tanujnay112): Remove after direct hnsw is solidified #[cfg(test)] pub poison_offset: Option, } @@ -126,9 +125,7 @@ impl Clone for CompactionContext { fn clone(&self) -> Self { let orchestrator_context = OrchestratorContext::new(self.dispatcher.clone()); Self { - input_collection_info: self.input_collection_info.clone(), - output_collection_info: self.output_collection_info.clone(), - attached_function_context: self.attached_function_context.clone(), + collection_info: self.collection_info.clone(), log: self.log.clone(), sysdb: self.sysdb.clone(), blockfile_provider: self.blockfile_provider.clone(), @@ -140,6 +137,32 @@ impl Clone for CompactionContext { fetch_log_batch_size: self.fetch_log_batch_size, max_compaction_size: self.max_compaction_size, max_partition_size: self.max_partition_size, + hnsw_index_uuids: self.hnsw_index_uuids.clone(), + #[cfg(test)] + poison_offset: self.poison_offset, + } + } +} + +impl CompactionContext { + /// Create an empty output context for attached function orchestrator + /// This creates a new context with an empty collection_info OnceCell + fn clone_for_new_collection(&self) -> Self { + let orchestrator_context = OrchestratorContext::new(self.dispatcher.clone()); + Self { + collection_info: OnceCell::new(), // Start empty for output context + log: self.log.clone(), + sysdb: self.sysdb.clone(), + blockfile_provider: self.blockfile_provider.clone(), + hnsw_provider: self.hnsw_provider.clone(), + spann_provider: self.spann_provider.clone(), + dispatcher: self.dispatcher.clone(), + orchestrator_context, + is_rebuild: self.is_rebuild, + fetch_log_batch_size: self.fetch_log_batch_size, + max_compaction_size: self.max_compaction_size, + max_partition_size: self.max_partition_size, + hnsw_index_uuids: self.hnsw_index_uuids.clone(), #[cfg(test)] poison_offset: self.poison_offset, } @@ -243,9 +266,7 @@ impl CompactionContext { ) -> Self { let orchestrator_context = OrchestratorContext::new(dispatcher.clone()); CompactionContext { - input_collection_info: OnceCell::new(), - output_collection_info: OnceCell::new(), - attached_function_context: OnceCell::new(), + collection_info: OnceCell::new(), is_rebuild, fetch_log_batch_size, max_compaction_size, @@ -257,6 +278,7 @@ impl CompactionContext { spann_provider, dispatcher, orchestrator_context, + hnsw_index_uuids: HashSet::new(), #[cfg(test)] poison_offset: None, } @@ -267,67 +289,35 @@ impl CompactionContext { self.poison_offset = Some(offset); } - pub fn get_output_segment_writers(&self) -> Result { - self.get_output_collection_info()?.writers.clone().ok_or( - CompactionContextError::InvariantViolation( - "Output segment writers should have been set", - ), - ) - } - - pub fn get_input_segment_writers(&self) -> Result { - self.get_input_collection_info()?.writers.clone().ok_or( - CompactionContextError::InvariantViolation( - "Input segment writers should have been set", - ), - ) - } - - pub fn get_input_collection_info( - &self, - ) -> Result<&CollectionCompactInfo, CompactionContextError> { - self.input_collection_info + pub fn get_collection_info(&self) -> Result<&CollectionCompactInfo, CompactionContextError> { + self.collection_info .get() .ok_or(CompactionContextError::InvariantViolation( "Collection info should have been set", )) } - pub fn get_output_collection_info( - &self, - ) -> Result<&CollectionCompactInfo, CompactionContextError> { - self.output_collection_info - .get() - .ok_or(CompactionContextError::InvariantViolation( - "Collection info should have been set", - )) - } - - pub fn get_input_collection_info_mut( - &mut self, - ) -> Result<&mut CollectionCompactInfo, CompactionContextError> { - self.input_collection_info - .get_mut() - .ok_or(CompactionContextError::InvariantViolation( - "Collection info mut should have been set", - )) + pub fn get_segment_writers(&self) -> Result { + self.get_collection_info()?.writers.clone().ok_or( + CompactionContextError::InvariantViolation("Segment writers should have been set"), + ) } - pub fn get_output_collection_info_mut( + pub fn get_collection_info_mut( &mut self, ) -> Result<&mut CollectionCompactInfo, CompactionContextError> { - self.output_collection_info + self.collection_info .get_mut() .ok_or(CompactionContextError::InvariantViolation( "Collection info mut should have been set", )) } - pub fn get_output_segment_writer_by_id( + pub fn get_segment_writer_by_id( &self, segment_id: SegmentUuid, ) -> Result, CompactionContextError> { - let writers = self.get_output_segment_writers()?; + let writers = self.get_segment_writers()?; if writers.metadata_writer.id == segment_id { return Ok(ChromaSegmentWriter::MetadataSegment( @@ -382,12 +372,16 @@ impl CompactionContext { let materialized = success.materialized; let collection_info = success.collection_info; - self.input_collection_info + self.collection_info .set(collection_info.clone()) .map_err(|_| { CompactionContextError::InvariantViolation("Collection info already set") })?; + if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { + self.hnsw_index_uuids.insert(hnsw_index_uuid); + } + Ok(Success::new(materialized, collection_info.clone()).into()) } LogFetchOrchestratorResponse::RequireCompactionOffsetRepair(repair) => Ok( @@ -402,7 +396,7 @@ impl CompactionContext { log_fetch_records: Arc>, system: System, ) -> Result { - let collection_info = self.get_input_collection_info()?; + let collection_info = self.get_collection_info()?; if log_fetch_records.is_empty() { return Ok(ApplyLogsOrchestratorResponse::new_with_empty_results( collection_info.collection_id.into(), @@ -410,7 +404,7 @@ impl CompactionContext { )); } - if self.get_output_collection_info().is_err() { + if self.get_collection_info().is_err() { return Err(ApplyLogsOrchestratorError::InvariantViolation( "Output collection info should have been set before running apply logs", )); @@ -437,7 +431,7 @@ impl CompactionContext { } }; - let collection_info = self.get_output_collection_info_mut()?; + let collection_info = self.get_collection_info_mut()?; collection_info.schema = apply_logs_response.schema.clone(); collection_info.collection.total_records_post_compaction = apply_logs_response.total_records_post_compaction; @@ -451,11 +445,10 @@ impl CompactionContext { data_fetch_records: Arc>, system: System, ) -> Result { - let input_collection_info = self.get_input_collection_info()?.clone(); - let input_collection_info_clone = input_collection_info.clone(); + let collection_info = self.get_collection_info()?.clone(); let attached_function_orchestrator = AttachedFunctionOrchestrator::new( - input_collection_info, - self.clone(), + collection_info, + self.clone_for_new_collection(), self.dispatcher.clone(), data_fetch_records, ); @@ -473,32 +466,43 @@ impl CompactionContext { // Set the output collection info based on the response match &attached_function_response { - AttachedFunctionOrchestratorResponse::NoAttachedFunction { .. } => { - self.output_collection_info - .set(input_collection_info_clone) - .map_err(|_| { - AttachedFunctionOrchestratorError::InvariantViolation( - "Collection info should not have been already set".to_string(), - ) - })?; - } + AttachedFunctionOrchestratorResponse::NoAttachedFunction { .. } => {} AttachedFunctionOrchestratorResponse::Success { output_collection_info, .. } => { - self.output_collection_info - .set(output_collection_info.clone()) - .map_err(|_| { - AttachedFunctionOrchestratorError::InvariantViolation( - "Collection info should not have been already set".to_string(), - ) - })?; + // We are replacing the output collection info with the attached function output + self.collection_info = OnceCell::from(output_collection_info.clone()); + + if let Some(hnsw_index_uuid) = output_collection_info.hnsw_index_uuid { + self.hnsw_index_uuids.insert(hnsw_index_uuid); + } } } Ok(attached_function_response) } + async fn run_regular_compaction_workflow( + &mut self, + log_fetch_records: Arc>, + system: System, + ) -> Result { + let apply_logs_response = self.run_apply_logs(log_fetch_records, system).await?; + + // Build CollectionRegisterInfo from the updated context + let collection_info = self + .get_collection_info() + .map_err(CompactionError::CompactionContextError)? + .clone(); + + Ok(CollectionRegisterInfo { + collection_info, + flush_results: apply_logs_response.flush_results, + collection_logical_size_bytes: apply_logs_response.collection_logical_size_bytes, + }) + } + async fn run_attached_function_workflow( &mut self, log_fetch_records: Arc>, @@ -517,7 +521,7 @@ impl CompactionContext { completion_offset, } => { // Update self to use the output collection for apply_logs - self.output_collection_info = OnceCell::from(output_collection_info.clone()); + self.collection_info = OnceCell::from(output_collection_info.clone()); // Apply materialized output to output collection let apply_logs_response = self @@ -597,29 +601,11 @@ impl CompactionContext { let log_fetch_records = Arc::new(log_fetch_records); let log_fetch_records_clone = log_fetch_records.clone(); - let input_collection_info = - self.input_collection_info - .get() - .ok_or(CompactionError::InvariantViolation( - "Input collection info should not be None", - ))?; - - // Clone first - both clones will have empty output_collection_info let mut self_clone_fn = self.clone(); let mut self_clone_compact = self.clone(); let system_clone_fn = system.clone(); let system_clone_compact = system.clone(); - // Set output_collection_info on self and the apply_logs clone - // The attached function orchestrator will set its own separate output_collection_info - self.output_collection_info - .set(input_collection_info.clone()) - .map_err(|_| { - CompactionError::InvariantViolation( - "Collection info should not have been already set", - ) - })?; - // 1. Attached function execution + apply output to output collection // 2. Apply input logs to input collection // Box the futures to avoid stack overflow with large state machines @@ -633,33 +619,11 @@ impl CompactionContext { let compact_future = Box::pin(async move { self_clone_compact - .output_collection_info - .set(input_collection_info.clone()) - .map_err(|_| { - CompactionError::InvariantViolation( - "Collection info should not have been already set", - ) - })?; - let apply_logs_response = self_clone_compact - .run_apply_logs(log_fetch_records, system_clone_compact) - .await?; - - // Build CollectionRegisterInfo from the updated context - let collection_info = self_clone_compact - .get_output_collection_info() - .map_err(CompactionError::CompactionContextError)? - .clone(); - Ok::(CollectionRegisterInfo { - collection_info, - flush_results: apply_logs_response.flush_results, - collection_logical_size_bytes: apply_logs_response.collection_logical_size_bytes, - }) + .run_regular_compaction_workflow(log_fetch_records, system_clone_compact) + .await }); - let (fn_result, compact_result) = tokio::join!(fn_future, compact_future); - - let fn_result = fn_result?; - let compact_result = compact_result?; + let (fn_result, compact_result) = tokio::try_join!(fn_future, compact_future)?; // Collect results let mut attached_function_context = None; @@ -696,24 +660,12 @@ impl CompactionContext { } pub(crate) async fn cleanup(self) { - if let Some(collection_info) = self.input_collection_info.get() { - if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { - let _ = HnswIndexProvider::purge_one_id( - self.hnsw_provider.temporary_storage_path.as_path(), - hnsw_index_uuid, - ) - .await; - } - } - - if let Some(collection_info) = self.output_collection_info.get() { - if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { - let _ = HnswIndexProvider::purge_one_id( - self.hnsw_provider.temporary_storage_path.as_path(), - hnsw_index_uuid, - ) - .await; - } + for hnsw_index_uuid in self.hnsw_index_uuids { + let _ = HnswIndexProvider::purge_one_id( + self.hnsw_provider.temporary_storage_path.as_path(), + hnsw_index_uuid, + ) + .await; } } } @@ -2526,10 +2478,7 @@ mod tests { .expect("Apply should have succeeded."); let register_info = vec![CollectionRegisterInfo { - collection_info: compaction_context_1 - .get_input_collection_info() - .unwrap() - .clone(), + collection_info: compaction_context_1.get_collection_info().unwrap().clone(), flush_results: compaction_1_apply_response.flush_results, collection_logical_size_bytes: compaction_1_apply_response .collection_logical_size_bytes, diff --git a/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs b/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs index 888c4471db3..4dbde1de58b 100644 --- a/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs @@ -313,7 +313,7 @@ impl LogFetchOrchestrator { // NOTE: We allow writers to be uninitialized for the case when the materialized logs are empty let record_reader = self .context - .get_input_segment_writers() + .get_segment_writers() .ok() .and_then(|writers| writers.record_reader); @@ -334,7 +334,7 @@ impl LogFetchOrchestrator { } }; - let collection_info = match self.context.get_input_collection_info_mut() { + let collection_info = match self.context.get_collection_info_mut() { Ok(info) => info, Err(err) => { return self.terminate_with_result(Err(err.into()), ctx).await; @@ -447,7 +447,7 @@ impl Handler info, - None => { - self.terminate_with_result( - Err(LogFetchOrchestratorError::InvariantViolation( - "Collection info should have been initialized", - )), - ctx, - ) - .await; + let collection_info = match self.context.get_collection_info_mut() { + Ok(info) => info, + Err(err) => { + self.terminate_with_result(Err(err.into()), ctx).await; return; } }; @@ -644,7 +638,7 @@ impl Handler> for LogFetchOrchestrator tracing::info!("Pulled Records: {}", output.len()); match output.iter().last() { Some((rec, _)) => { - let collection_info = match self.context.get_input_collection_info_mut() { + let collection_info = match self.context.get_collection_info_mut() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -659,7 +653,7 @@ impl Handler> for LogFetchOrchestrator } None => { tracing::warn!("No logs were pulled from the log service, this can happen when the log compaction offset is behing the sysdb."); - let collection_info = match self.context.get_input_collection_info() { + let collection_info = match self.context.get_collection_info() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -699,7 +693,7 @@ impl Handler> }; tracing::info!("Sourced Records: {}", output.len()); // Each record should corresond to a log - let collection_info = match self.context.get_input_collection_info_mut() { + let collection_info = match self.context.get_collection_info_mut() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -708,7 +702,7 @@ impl Handler> }; collection_info.collection.total_records_post_compaction = output.len() as u64; - let collection_info = match self.context.get_input_collection_info() { + let collection_info = match self.context.get_collection_info() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -766,7 +760,7 @@ impl Handler> } self.num_uncompleted_materialization_tasks -= 1; if self.num_uncompleted_materialization_tasks == 0 { - let collection_info = match self.context.input_collection_info.take() { + let collection_info = match self.context.collection_info.take() { Some(info) => info, None => { self.terminate_with_result( diff --git a/rust/worker/src/execution/orchestration/register_orchestrator.rs b/rust/worker/src/execution/orchestration/register_orchestrator.rs index c18234836c9..3d3dae85bc8 100644 --- a/rust/worker/src/execution/orchestration/register_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/register_orchestrator.rs @@ -267,7 +267,7 @@ impl Handler> for RegisterOrchestrator message: TaskResult, ctx: &ComponentContext, ) { - let collection_info = match self.context.get_input_collection_info() { + let collection_info = match self.context.get_collection_info() { Ok(collection_info) => collection_info, Err(e) => { self.terminate_with_result(Err(e.into()), ctx).await; @@ -297,7 +297,7 @@ impl Handler, ctx: &ComponentContext, ) { - let collection_info = match self.context.get_input_collection_info() { + let collection_info = match self.context.get_collection_info() { Ok(collection_info) => collection_info, Err(e) => { self.terminate_with_result(Err(e.into()), ctx).await; From 8659f3dcbce77dab7a436bc886b19fc3283771ba Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Fri, 21 Nov 2025 14:19:14 -0800 Subject: [PATCH 3/3] address hammad + sicheng comments --- .../list_attached_functions_test.go | 4 ++ go/pkg/sysdb/coordinator/task.go | 5 +- go/pkg/sysdb/grpc/collection_service.go | 3 + idl/chromadb/proto/coordinator.proto | 1 - rust/segment/src/types.rs | 1 - rust/sysdb/src/sysdb.rs | 59 ++++++------------- rust/sysdb/src/test_sysdb.rs | 1 - rust/types/src/flush.rs | 47 +++++++++++++-- .../src/execution/operators/execute_task.rs | 4 +- .../execution/operators/materialize_logs.rs | 2 +- .../orchestration/apply_logs_orchestrator.rs | 6 +- .../attached_function_orchestrator.rs | 38 ++++++++---- .../src/execution/orchestration/compact.rs | 22 +++---- 13 files changed, 115 insertions(+), 78 deletions(-) diff --git a/go/pkg/sysdb/coordinator/list_attached_functions_test.go b/go/pkg/sysdb/coordinator/list_attached_functions_test.go index b62f7142132..516cd93d558 100644 --- a/go/pkg/sysdb/coordinator/list_attached_functions_test.go +++ b/go/pkg/sysdb/coordinator/list_attached_functions_test.go @@ -66,6 +66,7 @@ func (suite *ListAttachedFunctionsTestSuite) TestListAttachedFunctions_Success() MinRecordsForInvocation: 5, CreatedAt: now, UpdatedAt: now, + IsReady: true, }, { ID: uuid.New(), @@ -80,6 +81,7 @@ func (suite *ListAttachedFunctionsTestSuite) TestListAttachedFunctions_Success() MinRecordsForInvocation: 15, CreatedAt: now, UpdatedAt: now, + IsReady: true, }, } @@ -157,6 +159,7 @@ func (suite *ListAttachedFunctionsTestSuite) TestListAttachedFunctions_FunctionD MinRecordsForInvocation: 1, CreatedAt: now, UpdatedAt: now, + IsReady: true, } suite.mockMetaDomain.On("AttachedFunctionDb", ctx).Return(suite.mockAttachedFunctionDb).Once() @@ -191,6 +194,7 @@ func (suite *ListAttachedFunctionsTestSuite) TestListAttachedFunctions_InvalidPa MinRecordsForInvocation: 1, CreatedAt: now, UpdatedAt: now, + IsReady: true, } suite.mockMetaDomain.On("AttachedFunctionDb", ctx).Return(suite.mockAttachedFunctionDb).Once() diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index 1bd63a4af19..83eba59cd60 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -222,6 +222,10 @@ func attachedFunctionToProto(attachedFunction *dbmodel.AttachedFunction, functio return nil, status.Errorf(codes.Internal, "attached function has invalid completion_offset: %d", attachedFunction.CompletionOffset) } + if !attachedFunction.IsReady { + return nil, status.Errorf(codes.Internal, "serialized attached function is not ready") + } + attachedFunctionProto := &coordinatorpb.AttachedFunction{ Id: attachedFunction.ID.String(), Name: attachedFunction.Name, @@ -236,7 +240,6 @@ func attachedFunctionToProto(attachedFunction *dbmodel.AttachedFunction, functio DatabaseId: attachedFunction.DatabaseID, CreatedAt: uint64(attachedFunction.CreatedAt.UnixMicro()), UpdatedAt: uint64(attachedFunction.UpdatedAt.UnixMicro()), - IsReady: attachedFunction.IsReady, } if attachedFunction.OutputCollectionID != nil { attachedFunctionProto.OutputCollectionId = attachedFunction.OutputCollectionID diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index 11dade5ec13..c97f10a5e97 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -579,6 +579,9 @@ func (s *Server) FlushCollectionCompactionAndAttachedFunction(ctx context.Contex log.Error("FlushCollectionCompactionAndAttachedFunction failed. flush_compactions is empty") return nil, grpcutils.BuildInternalGrpcError("at least one flush_compaction is required") } + + // Currently we only expect 1 or 2 flush_compactions. We expect the former in the case of backfills + // and the latter in the case of normal compactions with an attached function. if len(flushReqs) > 2 { log.Error("FlushCollectionCompactionAndAttachedFunction failed. too many flush_compactions", zap.Int("count", len(flushReqs))) return nil, grpcutils.BuildInternalGrpcError("expected 1 or 2 flush_compactions") diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 768df8221e1..3bd5c9b1cdd 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -588,7 +588,6 @@ message AttachedFunction { uint64 created_at = 15; uint64 updated_at = 16; string function_id = 17; - bool is_ready = 18; } message GetAttachedFunctionByNameResponse { diff --git a/rust/segment/src/types.rs b/rust/segment/src/types.rs index 8979a884e63..406f50dffa2 100644 --- a/rust/segment/src/types.rs +++ b/rust/segment/src/types.rs @@ -504,7 +504,6 @@ impl MaterializeLogsResult { /// # Note /// This is primarily intended for testing and should not be used in production code. /// Use the `materialize_logs` function instead for proper log materialization. - #[doc(hidden)] pub fn from_logs_for_test(logs: Chunk) -> Result { let mut materialized = Vec::new(); for (index, (log_record, _)) in logs.iter().enumerate() { diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 64c368ac4af..4d60ff36120 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -9,16 +9,16 @@ use chroma_types::chroma_proto::sys_db_client::SysDbClient; use chroma_types::chroma_proto::VersionListForCollection; use chroma_types::{ chroma_proto, chroma_proto::CollectionVersionInfo, CollectionAndSegments, CollectionFlushInfo, - CollectionMetadataUpdate, CountCollectionsError, CreateCollectionError, CreateDatabaseError, - CreateDatabaseResponse, CreateTenantError, CreateTenantResponse, Database, - DeleteCollectionError, DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionByCrnError, - GetCollectionSizeError, GetCollectionWithSegmentsError, GetCollectionsError, GetDatabaseError, - GetDatabaseResponse, GetSegmentsError, GetTenantError, GetTenantResponse, - InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, - ListAttachedFunctionsError, ListCollectionVersionsError, ListDatabasesError, - ListDatabasesResponse, Metadata, ResetError, ResetResponse, SegmentFlushInfo, - SegmentFlushInfoConversionError, SegmentUuid, UpdateCollectionError, UpdateTenantError, - UpdateTenantResponse, + CollectionFlushInfoConversionError, CollectionMetadataUpdate, CountCollectionsError, + CreateCollectionError, CreateDatabaseError, CreateDatabaseResponse, CreateTenantError, + CreateTenantResponse, Database, DeleteCollectionError, DeleteDatabaseError, + DeleteDatabaseResponse, GetCollectionByCrnError, GetCollectionSizeError, + GetCollectionWithSegmentsError, GetCollectionsError, GetDatabaseError, GetDatabaseResponse, + GetSegmentsError, GetTenantError, GetTenantResponse, InternalCollectionConfiguration, + InternalUpdateCollectionConfiguration, ListAttachedFunctionsError, ListCollectionVersionsError, + ListDatabasesError, ListDatabasesResponse, Metadata, ResetError, ResetResponse, + SegmentFlushInfo, SegmentFlushInfoConversionError, SegmentUuid, UpdateCollectionError, + UpdateTenantError, UpdateTenantResponse, }; use chroma_types::{ AttachedFunctionUpdateInfo, AttachedFunctionUuid, BatchGetCollectionSoftDeleteStatusError, @@ -1657,38 +1657,10 @@ impl GrpcSysDb { attached_function_update: AttachedFunctionUpdateInfo, ) -> Result { // Process all collections into flush compaction requests - let mut flush_compactions = Vec::with_capacity(collections.len()); - - for collection in collections { - let segment_compaction_info = collection - .segment_flush_info - .iter() - .map(|segment_flush_info| segment_flush_info.try_into()) - .collect::, - SegmentFlushInfoConversionError, - >>()?; - - let schema_str = collection.schema.and_then(|s| { - serde_json::to_string(&s).ok().or_else(|| { - tracing::error!( - "Failed to serialize schema for flush_compaction_and_attached_function" - ); - None - }) - }); - - flush_compactions.push(chroma_proto::FlushCollectionCompactionRequest { - tenant_id: collection.tenant_id, - collection_id: collection.collection_id.0.to_string(), - log_position: collection.log_position, - collection_version: collection.collection_version, - segment_compaction_info, - total_records_post_compaction: collection.total_records_post_compaction, - size_bytes_post_compaction: collection.size_bytes_post_compaction, - schema_str, - }); - } + let flush_compactions = collections + .into_iter() + .map(|collection| collection.try_into()) + .collect::, _>>()?; let attached_function_update_proto = Some(chroma_proto::AttachedFunctionUpdateInfo { id: attached_function_update.attached_function_id.0.to_string(), @@ -2137,6 +2109,8 @@ pub enum FlushCompactionError { FailedToFlushCompaction(#[from] tonic::Status), #[error("Failed to convert segment flush info")] SegmentFlushInfoConversionError(#[from] SegmentFlushInfoConversionError), + #[error("Failed to convert collection flush info")] + CollectionFlushInfoConversionError(#[from] CollectionFlushInfoConversionError), #[error("Failed to convert flush compaction response")] FlushCompactionResponseConversionError(#[from] FlushCompactionResponseConversionError), #[error("Collection not found in sysdb")] @@ -2158,6 +2132,7 @@ impl ChromaError for FlushCompactionError { } } FlushCompactionError::SegmentFlushInfoConversionError(_) => ErrorCodes::Internal, + FlushCompactionError::CollectionFlushInfoConversionError(_) => ErrorCodes::Internal, FlushCompactionError::FlushCompactionResponseConversionError(_) => ErrorCodes::Internal, FlushCompactionError::CollectionNotFound => ErrorCodes::Internal, FlushCompactionError::SegmentNotFound => ErrorCodes::Internal, diff --git a/rust/sysdb/src/test_sysdb.rs b/rust/sysdb/src/test_sysdb.rs index 8f9cb12a7f5..3997d13e888 100644 --- a/rust/sysdb/src/test_sysdb.rs +++ b/rust/sysdb/src/test_sysdb.rs @@ -716,7 +716,6 @@ fn attached_function_to_proto( created_at: system_time_to_micros(attached_function.created_at), updated_at: system_time_to_micros(attached_function.updated_at), function_id: attached_function.function_id.to_string(), - is_ready: false, // Default value since Rust struct doesn't track this field } } diff --git a/rust/types/src/flush.rs b/rust/types/src/flush.rs index 13e21ccc5c8..4b879d0bb05 100644 --- a/rust/types/src/flush.rs +++ b/rust/types/src/flush.rs @@ -1,8 +1,6 @@ use super::{AttachedFunctionUuid, CollectionUuid, ConversionError, Schema}; use crate::{ - chroma_proto::{ - FilePaths, FlushCollectionCompactionAndAttachedFunctionResponse, FlushSegmentCompactionInfo, - }, + chroma_proto::{self, FilePaths, FlushSegmentCompactionInfo}, SegmentUuid, }; use chroma_error::{ChromaError, ErrorCodes}; @@ -130,6 +128,45 @@ pub enum SegmentFlushInfoConversionError { DecodeError(#[from] ConversionError), } +#[derive(Error, Debug)] +pub enum CollectionFlushInfoConversionError { + #[error("Failed to convert segment flush info: {0}")] + SegmentConversionError(#[from] SegmentFlushInfoConversionError), + #[error("Failed to serialize schema")] + SchemaSerializationError, +} + +impl TryFrom for chroma_proto::FlushCollectionCompactionRequest { + type Error = CollectionFlushInfoConversionError; + + fn try_from(collection: CollectionFlushInfo) -> Result { + let segment_compaction_info = collection + .segment_flush_info + .iter() + .map(|segment_flush_info| segment_flush_info.try_into()) + .collect::, _>>()?; + + let schema_str = collection + .schema + .map(|s| { + serde_json::to_string(&s) + .map_err(|_| CollectionFlushInfoConversionError::SchemaSerializationError) + }) + .transpose()?; + + Ok(crate::chroma_proto::FlushCollectionCompactionRequest { + tenant_id: collection.tenant_id, + collection_id: collection.collection_id.0.to_string(), + log_position: collection.log_position, + collection_version: collection.collection_version, + segment_compaction_info, + total_records_post_compaction: collection.total_records_post_compaction, + size_bytes_post_compaction: collection.size_bytes_post_compaction, + schema_str, + }) + } +} + #[derive(Debug)] pub struct FlushCompactionResponse { pub collection_id: CollectionUuid, @@ -158,13 +195,13 @@ impl FlushCompactionResponse { } } -impl TryFrom +impl TryFrom for FlushCompactionAndAttachedFunctionResponse { type Error = FlushCompactionResponseConversionError; fn try_from( - value: FlushCollectionCompactionAndAttachedFunctionResponse, + value: chroma_proto::FlushCollectionCompactionAndAttachedFunctionResponse, ) -> Result { // Parse all collections from the repeated field let mut collections = Vec::with_capacity(value.collections.len()); diff --git a/rust/worker/src/execution/operators/execute_task.rs b/rust/worker/src/execution/operators/execute_task.rs index 4458e3d6fef..995586da26a 100644 --- a/rust/worker/src/execution/operators/execute_task.rs +++ b/rust/worker/src/execution/operators/execute_task.rs @@ -69,7 +69,7 @@ impl AttachedFunctionExecutor for CountAttachedFunction { }, }; - Ok(Chunk::new(std::sync::Arc::from(vec![output_record]))) + Ok(Chunk::new(Arc::from(vec![output_record]))) } } @@ -115,7 +115,7 @@ impl ExecuteAttachedFunctionOperator { #[derive(Debug)] pub struct ExecuteAttachedFunctionInput { /// The materialized log outputs to process - pub materialized_logs: Arc>, + pub materialized_logs: Vec, /// The tenant ID pub tenant_id: String, /// The output collection ID where results are written diff --git a/rust/worker/src/execution/operators/materialize_logs.rs b/rust/worker/src/execution/operators/materialize_logs.rs index cdc752326b7..8f0daa88c1b 100644 --- a/rust/worker/src/execution/operators/materialize_logs.rs +++ b/rust/worker/src/execution/operators/materialize_logs.rs @@ -56,7 +56,7 @@ impl MaterializeLogInput { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MaterializeLogOutput { pub result: MaterializeLogsResult, pub collection_logical_size_delta: i64, diff --git a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs index 27f2c8b23d2..7533fadc428 100644 --- a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; @@ -60,7 +60,7 @@ pub struct ApplyLogsOrchestrator { segment_spans: HashMap, // Store the materialized outputs from LogFetchOrchestrator - materialized_log_data: Option>>, + materialized_log_data: Option>, metrics: CompactionMetrics, } @@ -181,7 +181,7 @@ impl ApplyLogsOrchestratorResponse { impl ApplyLogsOrchestrator { pub fn new( context: &CompactionContext, - materialized_log_data: Option>>, + materialized_log_data: Option>, ) -> Self { ApplyLogsOrchestrator { context: context.clone(), diff --git a/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs index ac5ea337845..201a4249b70 100644 --- a/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs @@ -65,7 +65,7 @@ pub struct AttachedFunctionOrchestrator { >, // Store the materialized outputs from DataFetchOrchestrator - materialized_log_data: Arc>, + materialized_log_data: Vec, // Function context function_context: OnceCell, @@ -204,7 +204,7 @@ impl AttachedFunctionOrchestrator { input_collection_info: CollectionCompactInfo, output_context: CompactionContext, dispatcher: ComponentHandle, - data_fetch_records: Arc>, + data_fetch_records: Vec, ) -> Self { let orchestrator_context = OrchestratorContext::new(dispatcher.clone()); @@ -308,7 +308,12 @@ impl AttachedFunctionOrchestrator { // Get the completion offset from the input collection's pulled log offset let completion_offset = collection_info.pulled_log_offset as u64; - println!( + let materialized_output = materialized_output + .into_iter() + .filter(|output| !output.result.is_empty()) + .collect::>(); + + tracing::info!( "Attached function finished successfully with {} records", materialized_output.len() ); @@ -487,12 +492,23 @@ impl Handler> let collection_info = self.get_input_collection_info(); let input = ExecuteAttachedFunctionInput { - materialized_logs: Arc::clone(&self.materialized_log_data), // Use the actual materialized logs from data fetch - tenant_id: "default".to_string(), // TODO: Get actual tenant ID + materialized_logs: self.materialized_log_data.clone(), // Use the actual materialized logs from data fetch + tenant_id: "default".to_string(), // TODO: Get actual tenant ID output_collection_id: message.collection.collection_id, completion_offset: collection_info.pulled_log_offset as u64, // Use the completion offset from input collection output_record_segment: message.record_segment.clone(), diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index 492c280ad6b..9f6a35327c0 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -1,4 +1,4 @@ -use std::{cell::OnceCell, collections::HashSet, sync::Arc}; +use std::{cell::OnceCell, collections::HashSet}; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; @@ -393,7 +393,7 @@ impl CompactionContext { pub(crate) async fn run_apply_logs( &mut self, - log_fetch_records: Arc>, + log_fetch_records: Vec, system: System, ) -> Result { let collection_info = self.get_collection_info()?; @@ -442,7 +442,7 @@ impl CompactionContext { // Should be invoked on output collection context pub(crate) async fn run_attached_function( &mut self, - data_fetch_records: Arc>, + data_fetch_records: Vec, system: System, ) -> Result { let collection_info = self.get_collection_info()?.clone(); @@ -485,7 +485,7 @@ impl CompactionContext { async fn run_regular_compaction_workflow( &mut self, - log_fetch_records: Arc>, + log_fetch_records: Vec, system: System, ) -> Result { let apply_logs_response = self.run_apply_logs(log_fetch_records, system).await?; @@ -505,7 +505,7 @@ impl CompactionContext { async fn run_attached_function_workflow( &mut self, - log_fetch_records: Arc>, + log_fetch_records: Vec, system: System, ) -> Result, CompactionError> { let attached_function_result = @@ -525,7 +525,7 @@ impl CompactionContext { // Apply materialized output to output collection let apply_logs_response = self - .run_apply_logs(Arc::new(materialized_output), system.clone()) + .run_apply_logs(materialized_output, system.clone()) .await?; let function_context = FunctionContext { @@ -534,8 +534,11 @@ impl CompactionContext { updated_completion_offset: completion_offset, }; + // Get updated collection info after running apply logs. + let output_collection_info = self.get_collection_info()?; + let collection_register_info = CollectionRegisterInfo { - collection_info: output_collection_info, + collection_info: output_collection_info.clone(), flush_results: apply_logs_response.flush_results, collection_logical_size_bytes: apply_logs_response .collection_logical_size_bytes, @@ -598,10 +601,10 @@ impl CompactionContext { }; // Wrap in Arc to avoid cloning large MaterializeLogOutput data - let log_fetch_records = Arc::new(log_fetch_records); let log_fetch_records_clone = log_fetch_records.clone(); let mut self_clone_fn = self.clone(); + // TODO(tanujnay112): Think about a better way to pass mutable state to these futures let mut self_clone_compact = self.clone(); let system_clone_fn = system.clone(); let system_clone_compact = system.clone(); @@ -728,7 +731,6 @@ mod tests { }; use std::collections::HashMap; use std::path::{Path, PathBuf}; - use std::sync::Arc; use tokio::fs; use chroma_blockstore::arrow::config::{BlockManagerConfig, TEST_MAX_BLOCK_SIZE_BYTES}; @@ -2473,7 +2475,7 @@ mod tests { compaction_1_log_records.len() ); let compaction_1_apply_response = compaction_context_1 - .run_apply_logs(Arc::new(compaction_1_log_records), system.clone()) + .run_apply_logs(compaction_1_log_records, system.clone()) .await .expect("Apply should have succeeded.");