Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 852fe04

Browse files
committedMar 5, 2025
estuary-cdk: Support multi-subtask capture with independent states
Enhance the capture common module to support running multiple subtasks for incremental and backfill captures with independent states. This allows more flexible capture configurations where different subtasks can track their own progress and state within a single binding.
1 parent 8a47397 commit 852fe04

File tree

1 file changed

+132
-44
lines changed

1 file changed

+132
-44
lines changed
 

‎estuary-cdk/estuary_cdk/capture/common.py

+132-44
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import asyncio
3+
import functools
34
from enum import Enum
45
from dataclasses import dataclass
56
from datetime import UTC, datetime, timedelta
@@ -54,7 +55,7 @@
5455

5556

5657
class Triggers(Enum):
57-
BACKFILL = 'BACKFILL'
58+
BACKFILL = "BACKFILL"
5859

5960

6061
class BaseDocument(BaseModel):
@@ -114,13 +115,14 @@ def path(self) -> list[str]:
114115
_ResourceConfig = TypeVar("_ResourceConfig", bound=ResourceConfig)
115116

116117

117-
CRON_REGEX = (r"^"
118+
CRON_REGEX = (
119+
r"^"
118120
r"((?:[0-5]?\d(?:-[0-5]?\d)?|\*(?:/[0-5]?\d)?)(?:,(?:[0-5]?\d(?:-[0-5]?\d)?|\*(?:/[0-5]?\d)?))*)\s+" # minute
119121
r"((?:[01]?\d|2[0-3]|(?:[01]?\d|2[0-3])-(?:[01]?\d|2[0-3])|\*(?:/[01]?\d|/2[0-3])?)(?:,(?:[01]?\d|2[0-3]|(?:[01]?\d|2[0-3])-(?:[01]?\d|2[0-3])|\*(?:/[01]?\d|/2[0-3])?))*)\s+" # hour
120122
r"((?:0?[1-9]|[12]\d|3[01]|(?:0?[1-9]|[12]\d|3[01])-(?:0?[1-9]|[12]\d|3[01])|\*(?:/[0-9]|/1[0-9]|/2[0-9]|/3[01])?)(?:,(?:0?[1-9]|[12]\d|3[01]|(?:0?[1-9]|[12]\d|3[01])-(?:0?[1-9]|[12]\d|3[01])|\*(?:/[0-9]|/1[0-9]|/2[0-9]|/3[01])?))*)\s+" # day of month
121123
r"((?:[1-9]|1[0-2]|(?:[1-9]|1[0-2])-(?:[1-9]|1[0-2])|\*(?:/[1-9]|/1[0-2])?)(?:,(?:[1-9]|1[0-2]|(?:[1-9]|1[0-2])-(?:[1-9]|1[0-2])|\*(?:/[1-9]|/1[0-2])?))*)\s+" # month
122124
r"((?:[0-6]|(?:[0-6])-(?:[0-6])|\*(?:/[0-6])?)(?:,(?:[0-6]|(?:[0-6])-(?:[0-6])|\*(?:/[0-6])?))*)" # day of week
123-
r"$|^$" # Empty string to signify no schedule
125+
r"$|^$" # Empty string to signify no schedule
124126
)
125127

126128

@@ -129,7 +131,7 @@ class ResourceConfigWithSchedule(ResourceConfig):
129131
default="",
130132
title="Schedule",
131133
description="Schedule to automatically rebackfill this binding. Accepts a cron expression.",
132-
pattern=CRON_REGEX
134+
pattern=CRON_REGEX,
133135
)
134136

135137

@@ -164,8 +166,7 @@ class Backfill(BaseModel, extra="forbid"):
164166
description="LogCursor at which incremental replication began"
165167
)
166168
next_page: PageCursor = Field(
167-
description="PageCursor of the next page to fetch",
168-
default=None
169+
description="PageCursor of the next page to fetch", default=None
169170
)
170171

171172
class Snapshot(BaseModel, extra="forbid"):
@@ -179,18 +180,20 @@ class Snapshot(BaseModel, extra="forbid"):
179180
description="The xxh3_128 hex digest of documents of this resource in the last snapshot"
180181
)
181182

182-
inc: Incremental | None = Field(
183+
inc: Incremental | dict[str, Incremental | None] | None = Field(
183184
default=None, description="Incremental capture progress"
184185
)
185186

186-
backfill: Backfill | None = Field(
187+
backfill: Backfill | dict[str, Backfill | None] | None = Field(
187188
default=None,
188189
description="Backfill progress, or None if no backfill is occurring",
189190
)
190191

191192
snapshot: Snapshot | None = Field(default=None, description="Snapshot progress")
192193

193-
last_initialized: datetime | None = Field(default=None, description="The last time this state was initialized.")
194+
last_initialized: datetime | None = Field(
195+
default=None, description="The last time this state was initialized."
196+
)
194197

195198

196199
_ResourceState = TypeVar("_ResourceState", bound=ResourceState)
@@ -213,6 +216,7 @@ class AssociatedDocument(Generic[_BaseDocument]):
213216
You might use this if your data model requires you to load "child" documents when capturing a "parent" document,
214217
instead of independently loading the child data stream.
215218
"""
219+
216220
doc: _BaseDocument
217221
binding: int
218222

@@ -317,7 +321,7 @@ class FixedSchema:
317321
CaptureBinding[_ResourceConfig],
318322
"Resource[_BaseDocument, _ResourceConfig, _ResourceState]",
319323
]
320-
]
324+
],
321325
],
322326
None,
323327
]
@@ -363,7 +367,6 @@ def resolve_bindings(
363367
resources: list[Resource[Any, _BaseResourceConfig, Any]],
364368
resource_term="Resource",
365369
) -> list[tuple[_ResolvableBinding, Resource[Any, _BaseResourceConfig, Any]]]:
366-
367370
resolved: list[
368371
tuple[_ResolvableBinding, Resource[Any, _BaseResourceConfig, Any]]
369372
] = []
@@ -397,7 +400,6 @@ def validated(
397400
]
398401
],
399402
) -> response.Validated:
400-
401403
return response.Validated(
402404
bindings=[
403405
response.ValidatedBinding(resourcePath=b[0].resourceConfig.path())
@@ -415,7 +417,6 @@ def open(
415417
]
416418
],
417419
) -> tuple[response.Opened, Callable[[Task], Awaitable[None]]]:
418-
419420
async def _run(task: Task):
420421
backfill_requests = []
421422
if open.state.backfillRequests is not None:
@@ -445,17 +446,20 @@ async def _run(task: Task):
445446
if state.last_initialized is None:
446447
state.last_initialized = datetime.now(tz=UTC)
447448
task.checkpoint(
448-
ConnectorState(
449-
bindingStateV1={binding.stateKey: state}
450-
)
449+
ConnectorState(bindingStateV1={binding.stateKey: state})
451450
)
452451

453452
if isinstance(binding.resourceConfig, ResourceConfigWithSchedule):
454453
cron_schedule = binding.resourceConfig.schedule
455-
next_scheduled_initialization = next_fire(cron_schedule, state.last_initialized)
454+
next_scheduled_initialization = next_fire(
455+
cron_schedule, state.last_initialized
456+
)
456457

457-
if next_scheduled_initialization and next_scheduled_initialization < datetime.now(tz=UTC):
458-
# Re-initialize the binding if we missed a scheduled re-initialization.
458+
if (
459+
next_scheduled_initialization
460+
and next_scheduled_initialization < datetime.now(tz=UTC)
461+
):
462+
# Re-initialize the binding if we missed a scheduled re-initialization.
459463
should_initialize = True
460464
if state.backfill:
461465
task.log.warning(
@@ -464,12 +468,22 @@ async def _run(task: Task):
464468
" complete before the next scheduled backfill starts."
465469
)
466470

467-
next_scheduled_initialization = next_fire(cron_schedule, datetime.now(tz=UTC))
471+
next_scheduled_initialization = next_fire(
472+
cron_schedule, datetime.now(tz=UTC)
473+
)
468474

469-
if next_scheduled_initialization and soonest_future_scheduled_initialization:
470-
soonest_future_scheduled_initialization = min(soonest_future_scheduled_initialization, next_scheduled_initialization)
475+
if (
476+
next_scheduled_initialization
477+
and soonest_future_scheduled_initialization
478+
):
479+
soonest_future_scheduled_initialization = min(
480+
soonest_future_scheduled_initialization,
481+
next_scheduled_initialization,
482+
)
471483
elif next_scheduled_initialization:
472-
soonest_future_scheduled_initialization = next_scheduled_initialization
484+
soonest_future_scheduled_initialization = (
485+
next_scheduled_initialization
486+
)
473487

474488
if should_initialize:
475489
# Checkpoint the binding's initialized state prior to any processing.
@@ -478,7 +492,7 @@ async def _run(task: Task):
478492

479493
task.checkpoint(
480494
ConnectorState(
481-
bindingStateV1={binding.stateKey: state}
495+
bindingStateV1={binding.stateKey: state},
482496
)
483497
)
484498

@@ -487,7 +501,7 @@ async def _run(task: Task):
487501
index,
488502
state,
489503
task,
490-
resolved_bindings
504+
resolved_bindings,
491505
)
492506

493507
async def scheduled_stop(future_dt: datetime | None) -> None:
@@ -510,8 +524,12 @@ def open_binding(
510524
binding_index: int,
511525
state: _ResourceState,
512526
task: Task,
513-
fetch_changes: FetchChangesFn[_BaseDocument] | None = None,
514-
fetch_page: FetchPageFn[_BaseDocument] | None = None,
527+
fetch_changes: FetchChangesFn[_BaseDocument]
528+
| dict[str, FetchChangesFn[_BaseDocument]]
529+
| None = None,
530+
fetch_page: FetchPageFn[_BaseDocument]
531+
| dict[str, FetchPageFn[_BaseDocument]]
532+
| None = None,
515533
fetch_snapshot: FetchSnapshotFn[_BaseDocument] | None = None,
516534
tombstone: _BaseDocument | None = None,
517535
):
@@ -520,30 +538,92 @@ def open_binding(
520538
521539
It does 'heavy lifting' to actually capture a binding.
522540
523-
TODO(johnny): Separate into snapshot vs incremental tasks?
541+
When fetch_changes, fetch_page, or fetch_snapshot are provided as dictionaries,
542+
each function will be run as a separate subtask with its own independent state.
543+
The dictionary keys are used as subtask IDs and are used to store and retrieve
544+
the state for each subtask in state.inc, state.backfill, or state.snapshot.
524545
"""
525546

526547
prefix = ".".join(binding.resourceConfig.path())
527548

528549
if fetch_changes:
529550

530-
async def closure(task: Task):
531-
assert state.inc
551+
async def closure(
552+
task: Task,
553+
fetch_changes: FetchChangesFn[_BaseDocument],
554+
state: ResourceState.Incremental,
555+
):
556+
assert state and not isinstance(state, dict)
532557
await _binding_incremental_task(
533-
binding, binding_index, fetch_changes, state.inc, task,
558+
binding,
559+
binding_index,
560+
fetch_changes,
561+
state,
562+
task,
534563
)
535564

536-
task.spawn_child(f"{prefix}.incremental", closure)
565+
if isinstance(fetch_changes, dict):
566+
for subtask_id, subtask_fetch_changes in fetch_changes.items():
567+
inc_state = state.inc.get(subtask_id)
568+
assert inc_state
569+
570+
task.spawn_child(
571+
f"{prefix}.incremental.{subtask_id}",
572+
functools.partial(
573+
closure,
574+
fetch_changes=subtask_fetch_changes,
575+
state=inc_state,
576+
),
577+
)
578+
else:
579+
task.spawn_child(
580+
f"{prefix}.incremental",
581+
functools.partial(
582+
closure,
583+
fetch_changes=fetch_changes,
584+
state=state.inc,
585+
),
586+
)
537587

538588
if fetch_page and state.backfill:
539589

540-
async def closure(task: Task):
541-
assert state.backfill
590+
async def closure(
591+
task: Task,
592+
fetch_page: FetchPageFn[_BaseDocument],
593+
state: ResourceState.Backfill,
594+
):
595+
assert state and not isinstance(state, dict)
542596
await _binding_backfill_task(
543-
binding, binding_index, fetch_page, state.backfill, task,
597+
binding,
598+
binding_index,
599+
fetch_page,
600+
state,
601+
task,
544602
)
545603

546-
task.spawn_child(f"{prefix}.backfill", closure)
604+
if isinstance(fetch_page, dict):
605+
for subtask_id, subtask_fetch_page in fetch_page.items():
606+
backfill_state = state.backfill.get(subtask_id)
607+
assert backfill_state
608+
609+
task.spawn_child(
610+
f"{prefix}.backfill.{subtask_id}",
611+
functools.partial(
612+
closure,
613+
fetch_page=subtask_fetch_page,
614+
state=backfill_state,
615+
),
616+
)
617+
618+
else:
619+
task.spawn_child(
620+
f"{prefix}.backfill",
621+
functools.partial(
622+
closure,
623+
fetch_page=fetch_page,
624+
state=state.backfill,
625+
),
626+
)
547627

548628
if fetch_snapshot:
549629

@@ -612,7 +692,7 @@ async def _binding_snapshot_task(
612692
if isinstance(doc, dict):
613693
doc["meta_"] = {
614694
"op": "u" if count < state.last_count else "c",
615-
"row_id": count
695+
"row_id": count,
616696
}
617697
else:
618698
doc.meta_ = BaseDocument.Meta(
@@ -719,7 +799,10 @@ async def _binding_incremental_task(
719799

720800
if lag < binding.resourceConfig.interval:
721801
sleep_for = binding.resourceConfig.interval - lag
722-
task.log.info("incremental task ran recently, sleeping until `interval` has fully elapsed", {"sleep_for": sleep_for, "interval": binding.resourceConfig.interval})
802+
task.log.info(
803+
"incremental task ran recently, sleeping until `interval` has fully elapsed",
804+
{"sleep_for": sleep_for, "interval": binding.resourceConfig.interval},
805+
)
723806

724807
while True:
725808
try:
@@ -747,9 +830,7 @@ async def _binding_incremental_task(
747830
task.log.info("incremental task triggered backfill")
748831
task.stopping.event.set()
749832
task.checkpoint(
750-
ConnectorState(
751-
backfillRequests={binding.stateKey: True}
752-
)
833+
ConnectorState(backfillRequests={binding.stateKey: True})
753834
)
754835
return
755836
else:
@@ -759,7 +840,12 @@ async def _binding_incremental_task(
759840
is_larger = item > state.cursor
760841
elif isinstance(item, datetime) and isinstance(state.cursor, datetime):
761842
is_larger = item > state.cursor
762-
elif isinstance(item, tuple) and isinstance(state.cursor, tuple) and isinstance(item[0], str) and isinstance(state.cursor[0], str):
843+
elif (
844+
isinstance(item, tuple)
845+
and isinstance(state.cursor, tuple)
846+
and isinstance(item[0], str)
847+
and isinstance(state.cursor[0], str)
848+
):
763849
is_larger = item[0] > state.cursor[0]
764850
else:
765851
raise RuntimeError(
@@ -786,7 +872,7 @@ async def _binding_incremental_task(
786872
sleep_for = binding.resourceConfig.interval
787873

788874
elif isinstance(state.cursor, datetime):
789-
lag = (datetime.now(tz=UTC) - state.cursor)
875+
lag = datetime.now(tz=UTC) - state.cursor
790876

791877
if lag > binding.resourceConfig.interval:
792878
# We're not idle. Attempt to fetch the next changes.
@@ -800,4 +886,6 @@ async def _binding_incremental_task(
800886
sleep_for = timedelta()
801887
continue
802888

803-
task.log.debug("incremental task is idle", {"sleep_for": sleep_for, "cursor": state.cursor})
889+
task.log.debug(
890+
"incremental task is idle", {"sleep_for": sleep_for, "cursor": state.cursor}
891+
)

0 commit comments

Comments
 (0)
Please sign in to comment.