Skip to content

poc: migrate to concurrent cursors in declarative package #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,28 +361,15 @@ def _group_streams(
== DatetimeBasedCursorModel.__name__
and hasattr(declarative_stream.retriever, "stream_slicer")
and isinstance(
declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor
declarative_stream.retriever.stream_slicer, ConcurrentPerPartitionCursor
)
):
stream_state = self._connector_state_manager.get_stream_state(
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
)
stream_state = self._migrate_state(declarative_stream, stream_state)

partition_router = declarative_stream.retriever.stream_slicer._partition_router

perpartition_cursor = (
self._constructor.create_concurrent_cursor_from_perpartition_cursor(
state_manager=self._connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=incremental_sync_component_definition,
stream_name=declarative_stream.name,
stream_namespace=declarative_stream.namespace,
config=config or {},
stream_state=stream_state,
partition_router=partition_router,
)
)
perpartition_cursor = declarative_stream.retriever.stream_slicer

retriever = self._get_retriever(declarative_stream, stream_state)

Expand Down Expand Up @@ -464,15 +451,7 @@ def _get_retriever(
if retriever.cursor:
retriever.cursor.set_initial_state(stream_state=stream_state)

# Similar to above, the ClientSideIncrementalRecordFilterDecorator cursor is a separate instance
# from the one initialized on the SimpleRetriever, so it also must also have state initialized
# for semi-incremental streams using is_client_side_incremental to filter properly
if isinstance(retriever.record_selector, RecordSelector) and isinstance(
retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator
):
retriever.record_selector.record_filter._cursor.set_initial_state(
stream_state=stream_state
) # type: ignore # After non-concurrent cursors are deprecated we can remove these cursor workarounds
# FIXME comment: Removing this as the concurrent state should already have the information

# We zero it out here, but since this is a cursor reference, the state is still properly
# instantiated for the other components that reference it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
from datetime import timedelta
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional

from airbyte_cdk.models import (
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStateType,
AirbyteStreamState,
StreamDescriptor,
)
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
Timer,
iterate_with_last_flag_and_state,
)
) # FIXME since it relies on the declarative package, this can generate circular imports errors
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import (
Expand Down Expand Up @@ -150,6 +157,7 @@ def close_partition(self, partition: Partition) -> None:
raise ValueError("stream_slice cannot be None")

partition_key = self._to_partition_key(stream_slice.partition)
logger.warning(f"close_partition... Semaphore for {partition_key}")
with self._lock:
self._semaphore_per_partition[partition_key].acquire()
if not self._use_global_cursor:
Expand Down Expand Up @@ -204,6 +212,7 @@ def _check_and_update_parent_state(self) -> None:
for p_key in list(self._semaphore_per_partition.keys()):
sem = self._semaphore_per_partition[p_key]
if p_key in self._finished_partitions and sem._value == 0:
logger.warning(f"_check_and_update_parent_state delete semaphore for {p_key}")
del self._semaphore_per_partition[p_key]
logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0")
if p_key == earliest_key:
Expand Down Expand Up @@ -261,6 +270,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
slices, self._partition_router.get_stream_state
):
yield from self._generate_slices_from_partition(partition, parent_state)
self._parent_state = self._partition_router.get_stream_state()

def _generate_slices_from_partition(
self, partition: StreamSlice, parent_state: Mapping[str, Any]
Expand Down Expand Up @@ -289,13 +299,15 @@ def _generate_slices_from_partition(
]
!= parent_state
):
print(f"GODO:\n\t{parent_state}") # FIXME parent state needs to be tracked in substream partition router
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Debug print statement needs to be removed before merging

There's a debug print statement with sensitive formatting (GODO) that should be removed before production.

-                print(f"GODO:\n\t{parent_state}")  # FIXME parent state needs to be tracked in substream partition router
+                # parent state needs to be tracked in substream partition router

Would you like to convert this to a proper logging statement instead?

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"GODO:\n\t{parent_state}") # FIXME parent state needs to be tracked in substream partition router
# parent state needs to be tracked in substream partition router
🧰 Tools
🪛 GitHub Actions: Linters

[warning] Would reformat: airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

self._partition_parent_state_map[partition_key] = deepcopy(parent_state)

for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[partition_key].release()
logger.warning(f"Generating... Semaphore for {partition_key} is {self._semaphore_per_partition[partition_key]._value}")
if is_last_slice:
self._finished_partitions.add(partition_key)
yield StreamSlice(
Expand Down Expand Up @@ -418,7 +430,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
self._parent_state = stream_state["parent_state"]

# Set parent state for partition routers based on parent streams
self._partition_router.set_initial_state(stream_state)
self._partition_router.set_initial_state(stream_state) # FIXME can we remove this thing? this would probably be a breaking change though...

def _set_global_state(self, stream_state: Mapping[str, Any]) -> None:
"""
Expand Down Expand Up @@ -489,10 +501,31 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor:
partition_key = self._to_partition_key(record.associated_slice.partition)
if partition_key not in self._cursor_per_partition:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
f"Invalid state as stream slices that are emitted should refer to an existing cursor but {partition_key} is unknown"
)
cursor = self._cursor_per_partition[partition_key]
return cursor

def limit_reached(self) -> bool:
return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT

@staticmethod
def get_parent_state(stream_state: Optional[StreamState], parent_stream_name: str) -> Optional[AirbyteStateMessage]:
return AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(parent_stream_name, None),
stream_state=AirbyteStateBlob(stream_state["parent_state"][parent_stream_name])
)
) if stream_state and "parent_state" in stream_state else None

@staticmethod
def get_global_state(stream_state: Optional[StreamState], parent_stream_name: str) -> Optional[AirbyteStateMessage]:
return AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(parent_stream_name, None),
stream_state=AirbyteStateBlob(stream_state["state"])
)
) if stream_state and "state" in stream_state else None

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
import threading
import time
from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union
Expand All @@ -12,7 +12,7 @@
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState

T = TypeVar("T")

logger = logging.getLogger(__name__)

def iterate_with_last_flag_and_state(
generator: Iterable[T], get_stream_state_func: Callable[[], Optional[Mapping[str, StreamState]]]
Expand Down Expand Up @@ -40,6 +40,7 @@ def iterate_with_last_flag_and_state(
return # Return an empty iterator

for next_item in iterator:
logger.info(f"slice: {current}, state: {state}")
yield current, False, state
current = next_item
state = get_stream_state_func()
Expand Down
Loading