Skip to content

Commit 503dbdb

Browse files
committed
estuary-cdk: Support multi-subtask capture with independent states
Enhance the capture common module to support running multiple subtasks for incremental and backfill captures with independent states. This allows more flexible capture configurations where different subtasks can track their own progress and state within a single binding.
1 parent 8a47397 commit 503dbdb

File tree

2 files changed

+152
-47
lines changed

2 files changed

+152
-47
lines changed

estuary-cdk/estuary_cdk/capture/common.py

+136-44
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import asyncio
3+
import functools
34
from enum import Enum
45
from dataclasses import dataclass
56
from datetime import UTC, datetime, timedelta
@@ -54,7 +55,7 @@
5455

5556

5657
class Triggers(Enum):
57-
BACKFILL = 'BACKFILL'
58+
BACKFILL = "BACKFILL"
5859

5960

6061
class BaseDocument(BaseModel):
@@ -114,13 +115,14 @@ def path(self) -> list[str]:
114115
_ResourceConfig = TypeVar("_ResourceConfig", bound=ResourceConfig)
115116

116117

117-
CRON_REGEX = (r"^"
118+
CRON_REGEX = (
119+
r"^"
118120
r"((?:[0-5]?\d(?:-[0-5]?\d)?|\*(?:/[0-5]?\d)?)(?:,(?:[0-5]?\d(?:-[0-5]?\d)?|\*(?:/[0-5]?\d)?))*)\s+" # minute
119121
r"((?:[01]?\d|2[0-3]|(?:[01]?\d|2[0-3])-(?:[01]?\d|2[0-3])|\*(?:/[01]?\d|/2[0-3])?)(?:,(?:[01]?\d|2[0-3]|(?:[01]?\d|2[0-3])-(?:[01]?\d|2[0-3])|\*(?:/[01]?\d|/2[0-3])?))*)\s+" # hour
120122
r"((?:0?[1-9]|[12]\d|3[01]|(?:0?[1-9]|[12]\d|3[01])-(?:0?[1-9]|[12]\d|3[01])|\*(?:/[0-9]|/1[0-9]|/2[0-9]|/3[01])?)(?:,(?:0?[1-9]|[12]\d|3[01]|(?:0?[1-9]|[12]\d|3[01])-(?:0?[1-9]|[12]\d|3[01])|\*(?:/[0-9]|/1[0-9]|/2[0-9]|/3[01])?))*)\s+" # day of month
121123
r"((?:[1-9]|1[0-2]|(?:[1-9]|1[0-2])-(?:[1-9]|1[0-2])|\*(?:/[1-9]|/1[0-2])?)(?:,(?:[1-9]|1[0-2]|(?:[1-9]|1[0-2])-(?:[1-9]|1[0-2])|\*(?:/[1-9]|/1[0-2])?))*)\s+" # month
122124
r"((?:[0-6]|(?:[0-6])-(?:[0-6])|\*(?:/[0-6])?)(?:,(?:[0-6]|(?:[0-6])-(?:[0-6])|\*(?:/[0-6])?))*)" # day of week
123-
r"$|^$" # Empty string to signify no schedule
125+
r"$|^$" # Empty string to signify no schedule
124126
)
125127

126128

@@ -129,7 +131,7 @@ class ResourceConfigWithSchedule(ResourceConfig):
129131
default="",
130132
title="Schedule",
131133
description="Schedule to automatically rebackfill this binding. Accepts a cron expression.",
132-
pattern=CRON_REGEX
134+
pattern=CRON_REGEX,
133135
)
134136

135137

@@ -164,8 +166,7 @@ class Backfill(BaseModel, extra="forbid"):
164166
description="LogCursor at which incremental replication began"
165167
)
166168
next_page: PageCursor = Field(
167-
description="PageCursor of the next page to fetch",
168-
default=None
169+
description="PageCursor of the next page to fetch", default=None
169170
)
170171

171172
class Snapshot(BaseModel, extra="forbid"):
@@ -179,18 +180,20 @@ class Snapshot(BaseModel, extra="forbid"):
179180
description="The xxh3_128 hex digest of documents of this resource in the last snapshot"
180181
)
181182

182-
inc: Incremental | None = Field(
183+
inc: Incremental | dict[str, Incremental | None] | None = Field(
183184
default=None, description="Incremental capture progress"
184185
)
185186

186-
backfill: Backfill | None = Field(
187+
backfill: Backfill | dict[str, Backfill | None] | None = Field(
187188
default=None,
188189
description="Backfill progress, or None if no backfill is occurring",
189190
)
190191

191192
snapshot: Snapshot | None = Field(default=None, description="Snapshot progress")
192193

193-
last_initialized: datetime | None = Field(default=None, description="The last time this state was initialized.")
194+
last_initialized: datetime | None = Field(
195+
default=None, description="The last time this state was initialized."
196+
)
194197

195198

196199
_ResourceState = TypeVar("_ResourceState", bound=ResourceState)
@@ -213,6 +216,7 @@ class AssociatedDocument(Generic[_BaseDocument]):
213216
You might use this if your data model requires you to load "child" documents when capturing a "parent" document,
214217
instead of independently loading the child data stream.
215218
"""
219+
216220
doc: _BaseDocument
217221
binding: int
218222

@@ -317,7 +321,7 @@ class FixedSchema:
317321
CaptureBinding[_ResourceConfig],
318322
"Resource[_BaseDocument, _ResourceConfig, _ResourceState]",
319323
]
320-
]
324+
],
321325
],
322326
None,
323327
]
@@ -363,7 +367,6 @@ def resolve_bindings(
363367
resources: list[Resource[Any, _BaseResourceConfig, Any]],
364368
resource_term="Resource",
365369
) -> list[tuple[_ResolvableBinding, Resource[Any, _BaseResourceConfig, Any]]]:
366-
367370
resolved: list[
368371
tuple[_ResolvableBinding, Resource[Any, _BaseResourceConfig, Any]]
369372
] = []
@@ -397,7 +400,6 @@ def validated(
397400
]
398401
],
399402
) -> response.Validated:
400-
401403
return response.Validated(
402404
bindings=[
403405
response.ValidatedBinding(resourcePath=b[0].resourceConfig.path())
@@ -415,7 +417,6 @@ def open(
415417
]
416418
],
417419
) -> tuple[response.Opened, Callable[[Task], Awaitable[None]]]:
418-
419420
async def _run(task: Task):
420421
backfill_requests = []
421422
if open.state.backfillRequests is not None:
@@ -445,17 +446,20 @@ async def _run(task: Task):
445446
if state.last_initialized is None:
446447
state.last_initialized = datetime.now(tz=UTC)
447448
task.checkpoint(
448-
ConnectorState(
449-
bindingStateV1={binding.stateKey: state}
450-
)
449+
ConnectorState(bindingStateV1={binding.stateKey: state})
451450
)
452451

453452
if isinstance(binding.resourceConfig, ResourceConfigWithSchedule):
454453
cron_schedule = binding.resourceConfig.schedule
455-
next_scheduled_initialization = next_fire(cron_schedule, state.last_initialized)
454+
next_scheduled_initialization = next_fire(
455+
cron_schedule, state.last_initialized
456+
)
456457

457-
if next_scheduled_initialization and next_scheduled_initialization < datetime.now(tz=UTC):
458-
# Re-initialize the binding if we missed a scheduled re-initialization.
458+
if (
459+
next_scheduled_initialization
460+
and next_scheduled_initialization < datetime.now(tz=UTC)
461+
):
462+
# Re-initialize the binding if we missed a scheduled re-initialization.
459463
should_initialize = True
460464
if state.backfill:
461465
task.log.warning(
@@ -464,12 +468,22 @@ async def _run(task: Task):
464468
" complete before the next scheduled backfill starts."
465469
)
466470

467-
next_scheduled_initialization = next_fire(cron_schedule, datetime.now(tz=UTC))
471+
next_scheduled_initialization = next_fire(
472+
cron_schedule, datetime.now(tz=UTC)
473+
)
468474

469-
if next_scheduled_initialization and soonest_future_scheduled_initialization:
470-
soonest_future_scheduled_initialization = min(soonest_future_scheduled_initialization, next_scheduled_initialization)
475+
if (
476+
next_scheduled_initialization
477+
and soonest_future_scheduled_initialization
478+
):
479+
soonest_future_scheduled_initialization = min(
480+
soonest_future_scheduled_initialization,
481+
next_scheduled_initialization,
482+
)
471483
elif next_scheduled_initialization:
472-
soonest_future_scheduled_initialization = next_scheduled_initialization
484+
soonest_future_scheduled_initialization = (
485+
next_scheduled_initialization
486+
)
473487

474488
if should_initialize:
475489
# Checkpoint the binding's initialized state prior to any processing.
@@ -478,7 +492,7 @@ async def _run(task: Task):
478492

479493
task.checkpoint(
480494
ConnectorState(
481-
bindingStateV1={binding.stateKey: state}
495+
bindingStateV1={binding.stateKey: state},
482496
)
483497
)
484498

@@ -487,7 +501,7 @@ async def _run(task: Task):
487501
index,
488502
state,
489503
task,
490-
resolved_bindings
504+
resolved_bindings,
491505
)
492506

493507
async def scheduled_stop(future_dt: datetime | None) -> None:
@@ -510,8 +524,12 @@ def open_binding(
510524
binding_index: int,
511525
state: _ResourceState,
512526
task: Task,
513-
fetch_changes: FetchChangesFn[_BaseDocument] | None = None,
514-
fetch_page: FetchPageFn[_BaseDocument] | None = None,
527+
fetch_changes: FetchChangesFn[_BaseDocument]
528+
| dict[str, FetchChangesFn[_BaseDocument]]
529+
| None = None,
530+
fetch_page: FetchPageFn[_BaseDocument]
531+
| dict[str, FetchPageFn[_BaseDocument]]
532+
| None = None,
515533
fetch_snapshot: FetchSnapshotFn[_BaseDocument] | None = None,
516534
tombstone: _BaseDocument | None = None,
517535
):
@@ -520,30 +538,96 @@ def open_binding(
520538
521539
It does 'heavy lifting' to actually capture a binding.
522540
523-
TODO(johnny): Separate into snapshot vs incremental tasks?
541+
When fetch_changes, fetch_page, or fetch_snapshot are provided as dictionaries,
542+
each function will be run as a separate subtask with its own independent state.
543+
The dictionary keys are used as subtask IDs and are used to store and retrieve
544+
the state for each subtask in state.inc, state.backfill, or state.snapshot.
524545
"""
525546

526547
prefix = ".".join(binding.resourceConfig.path())
527548

528549
if fetch_changes:
529550

530-
async def closure(task: Task):
531-
assert state.inc
551+
async def incremental_closure(
552+
task: Task,
553+
fetch_changes: FetchChangesFn[_BaseDocument],
554+
state: ResourceState.Incremental,
555+
):
556+
assert state and not isinstance(state, dict)
532557
await _binding_incremental_task(
533-
binding, binding_index, fetch_changes, state.inc, task,
558+
binding,
559+
binding_index,
560+
fetch_changes,
561+
state,
562+
task,
534563
)
535564

536-
task.spawn_child(f"{prefix}.incremental", closure)
565+
if isinstance(fetch_changes, dict):
566+
assert state.inc and isinstance(state.inc, dict)
567+
for subtask_id, subtask_fetch_changes in fetch_changes.items():
568+
inc_state = state.inc.get(subtask_id)
569+
assert inc_state
570+
571+
task.spawn_child(
572+
f"{prefix}.incremental.{subtask_id}",
573+
functools.partial(
574+
incremental_closure,
575+
fetch_changes=subtask_fetch_changes,
576+
state=inc_state,
577+
),
578+
)
579+
else:
580+
assert state.inc and not isinstance(state.inc, dict)
581+
task.spawn_child(
582+
f"{prefix}.incremental",
583+
functools.partial(
584+
incremental_closure,
585+
fetch_changes=fetch_changes,
586+
state=state.inc,
587+
),
588+
)
537589

538590
if fetch_page and state.backfill:
539591

540-
async def closure(task: Task):
541-
assert state.backfill
592+
async def backfill_closure(
593+
task: Task,
594+
fetch_page: FetchPageFn[_BaseDocument],
595+
state: ResourceState.Backfill,
596+
):
597+
assert state and not isinstance(state, dict)
542598
await _binding_backfill_task(
543-
binding, binding_index, fetch_page, state.backfill, task,
599+
binding,
600+
binding_index,
601+
fetch_page,
602+
state,
603+
task,
544604
)
545605

546-
task.spawn_child(f"{prefix}.backfill", closure)
606+
if isinstance(fetch_page, dict):
607+
assert state.backfill and isinstance(state.backfill, dict)
608+
for subtask_id, subtask_fetch_page in fetch_page.items():
609+
backfill_state = state.backfill.get(subtask_id)
610+
assert backfill_state
611+
612+
task.spawn_child(
613+
f"{prefix}.backfill.{subtask_id}",
614+
functools.partial(
615+
backfill_closure,
616+
fetch_page=subtask_fetch_page,
617+
state=backfill_state,
618+
),
619+
)
620+
621+
else:
622+
assert state.backfill and not isinstance(state.backfill, dict)
623+
task.spawn_child(
624+
f"{prefix}.backfill",
625+
functools.partial(
626+
backfill_closure,
627+
fetch_page=fetch_page,
628+
state=state.backfill,
629+
),
630+
)
547631

548632
if fetch_snapshot:
549633

@@ -612,7 +696,7 @@ async def _binding_snapshot_task(
612696
if isinstance(doc, dict):
613697
doc["meta_"] = {
614698
"op": "u" if count < state.last_count else "c",
615-
"row_id": count
699+
"row_id": count,
616700
}
617701
else:
618702
doc.meta_ = BaseDocument.Meta(
@@ -719,7 +803,10 @@ async def _binding_incremental_task(
719803

720804
if lag < binding.resourceConfig.interval:
721805
sleep_for = binding.resourceConfig.interval - lag
722-
task.log.info("incremental task ran recently, sleeping until `interval` has fully elapsed", {"sleep_for": sleep_for, "interval": binding.resourceConfig.interval})
806+
task.log.info(
807+
"incremental task ran recently, sleeping until `interval` has fully elapsed",
808+
{"sleep_for": sleep_for, "interval": binding.resourceConfig.interval},
809+
)
723810

724811
while True:
725812
try:
@@ -747,9 +834,7 @@ async def _binding_incremental_task(
747834
task.log.info("incremental task triggered backfill")
748835
task.stopping.event.set()
749836
task.checkpoint(
750-
ConnectorState(
751-
backfillRequests={binding.stateKey: True}
752-
)
837+
ConnectorState(backfillRequests={binding.stateKey: True})
753838
)
754839
return
755840
else:
@@ -759,7 +844,12 @@ async def _binding_incremental_task(
759844
is_larger = item > state.cursor
760845
elif isinstance(item, datetime) and isinstance(state.cursor, datetime):
761846
is_larger = item > state.cursor
762-
elif isinstance(item, tuple) and isinstance(state.cursor, tuple) and isinstance(item[0], str) and isinstance(state.cursor[0], str):
847+
elif (
848+
isinstance(item, tuple)
849+
and isinstance(state.cursor, tuple)
850+
and isinstance(item[0], str)
851+
and isinstance(state.cursor[0], str)
852+
):
763853
is_larger = item[0] > state.cursor[0]
764854
else:
765855
raise RuntimeError(
@@ -786,7 +876,7 @@ async def _binding_incremental_task(
786876
sleep_for = binding.resourceConfig.interval
787877

788878
elif isinstance(state.cursor, datetime):
789-
lag = (datetime.now(tz=UTC) - state.cursor)
879+
lag = datetime.now(tz=UTC) - state.cursor
790880

791881
if lag > binding.resourceConfig.interval:
792882
# We're not idle. Attempt to fetch the next changes.
@@ -800,4 +890,6 @@ async def _binding_incremental_task(
800890
sleep_for = timedelta()
801891
continue
802892

803-
task.log.debug("incremental task is idle", {"sleep_for": sleep_for, "cursor": state.cursor})
893+
task.log.debug(
894+
"incremental task is idle", {"sleep_for": sleep_for, "cursor": state.cursor}
895+
)

0 commit comments

Comments
 (0)