diff --git a/airbyte_cdk/__init__.py b/airbyte_cdk/__init__.py index 262d162cc..df6e4c873 100644 --- a/airbyte_cdk/__init__.py +++ b/airbyte_cdk/__init__.py @@ -176,7 +176,7 @@ InternalConfig, ResourceSchemaLoader, check_config_against_spec_or_exit, - expand_refs, + expand_refs, # noqa: F401 split_config, ) from .sources.utils.transform import TransformConfig, TypeTransformer @@ -187,6 +187,7 @@ from .utils.spec_schema_transformations import resolve_refs from .utils.stream_status_utils import as_airbyte_message + __all__ = [ # Availability strategy "AvailabilityStrategy", @@ -200,7 +201,6 @@ "ConcurrentSourceAdapter", "Cursor", "CursorField", - "DEFAULT_CONCURRENCY", "EpochValueConcurrentStreamStateConverter", "FinalStateCursor", "IsoMillisConcurrentStreamStateConverter", @@ -258,7 +258,6 @@ "RequestOption", "RequestOptionType", "Requester", - "ResponseStatus", "SimpleRetriever", "SinglePartitionRouter", "StopConditionPaginationStrategyDecorator", @@ -276,13 +275,11 @@ "DefaultBackoffException", "default_backoff_handler", "HttpAPIBudget", - "HttpAuthenticator", "HttpRequestMatcher", "HttpStream", "HttpSubStream", "LimiterSession", "MovingWindowCallRatePolicy", - "MultipleTokenAuthenticator", "Oauth2Authenticator", "Rate", "SingleUseRefreshTokenOauth2Authenticator", @@ -317,7 +314,6 @@ # Stream "IncrementalMixin", "Stream", - "StreamData", "package_name_from_class", # Utils "AirbyteTracedException", @@ -354,5 +350,5 @@ third_choice=_dunamai.Version.from_any_vcs, fallback=_dunamai.Version("0.0.0+dev"), ).serialize() -except: +except: # noqa: E722 __version__ = "0.0.0+dev" diff --git a/airbyte_cdk/cli/source_declarative_manifest/__init__.py b/airbyte_cdk/cli/source_declarative_manifest/__init__.py index 0ea86fa7b..ea87aca1d 100644 --- a/airbyte_cdk/cli/source_declarative_manifest/__init__.py +++ b/airbyte_cdk/cli/source_declarative_manifest/__init__.py @@ -1,5 +1,7 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. from airbyte_cdk.cli.source_declarative_manifest._run import run + __all__ = [ "run", ] diff --git a/airbyte_cdk/cli/source_declarative_manifest/_run.py b/airbyte_cdk/cli/source_declarative_manifest/_run.py index 5def00602..44cf69016 100644 --- a/airbyte_cdk/cli/source_declarative_manifest/_run.py +++ b/airbyte_cdk/cli/source_declarative_manifest/_run.py @@ -43,7 +43,7 @@ ConcurrentDeclarativeSource, ) from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource -from airbyte_cdk.sources.source import TState +from airbyte_cdk.sources.source import TState # noqa: TC001 class SourceLocalYaml(YamlDeclarativeSource): @@ -56,7 +56,7 @@ def __init__( catalog: ConfiguredAirbyteCatalog | None, config: Mapping[str, Any] | None, state: TState, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> None: """ HACK! @@ -77,7 +77,7 @@ def __init__( ) -def _is_local_manifest_command(args: list[str]) -> bool: +def _is_local_manifest_command(args: list[str]) -> bool: # noqa: ARG001 # Check for a local manifest.yaml file return Path("/airbyte/integration_code/source_declarative_manifest/manifest.yaml").exists() @@ -111,7 +111,7 @@ def _get_local_yaml_source(args: list[str]) -> SourceLocalYaml: ) ).decode() ) - raise error + raise error # noqa: TRY201 def handle_local_manifest_command(args: list[str]) -> None: @@ -167,12 +167,12 @@ def create_declarative_source( state: list[AirbyteStateMessage] config, catalog, state = _parse_inputs_into_config_catalog_state(args) if config is None or "__injected_declarative_manifest" not in config: - raise ValueError( + raise ValueError( # noqa: TRY301 "Invalid config: `__injected_declarative_manifest` should be provided at the root " f"of the config but config only has keys: {list(config.keys() if config else [])}" ) if not isinstance(config["__injected_declarative_manifest"], dict): - raise ValueError( + raise ValueError( # noqa: TRY004, TRY301 "Invalid config: `__injected_declarative_manifest` should be a dictionary, " f"but got type: {type(config['__injected_declarative_manifest'])}" ) @@ -181,7 +181,7 @@ def create_declarative_source( config=config, catalog=catalog, state=state, - source_config=cast(dict[str, Any], config["__injected_declarative_manifest"]), + source_config=cast(dict[str, Any], config["__injected_declarative_manifest"]), # noqa: TC006 ) except Exception as error: print( @@ -201,7 +201,7 @@ def create_declarative_source( ) ).decode() ) - raise error + raise error # noqa: TRY201 def _parse_inputs_into_config_catalog_state( diff --git a/airbyte_cdk/config_observation.py b/airbyte_cdk/config_observation.py index ae85e8277..1233786a1 100644 --- a/airbyte_cdk/config_observation.py +++ b/airbyte_cdk/config_observation.py @@ -7,8 +7,9 @@ ) import time +from collections.abc import MutableMapping from copy import copy -from typing import Any, List, MutableMapping +from typing import Any import orjson @@ -27,7 +28,7 @@ def __init__( self, non_observed_mapping: MutableMapping[Any, Any], observer: ConfigObserver, - update_on_unchanged_value: bool = True, + update_on_unchanged_value: bool = True, # noqa: FBT001, FBT002 ) -> None: non_observed_mapping = copy(non_observed_mapping) self.observer = observer @@ -38,13 +39,13 @@ def __init__( non_observed_mapping[item] = ObservedDict(value, observer) # Observe nested list of dicts - if isinstance(value, List): + if isinstance(value, list): for i, sub_value in enumerate(value): if isinstance(sub_value, MutableMapping): value[i] = ObservedDict(sub_value, observer) super().__init__(non_observed_mapping) - def __setitem__(self, item: Any, value: Any) -> None: + def __setitem__(self, item: Any, value: Any) -> None: # noqa: ANN401 """Override dict.__setitem__ by: 1. Observing the new value if it is a dict 2. Call observer update if the new value is different from the previous one @@ -52,11 +53,11 @@ def __setitem__(self, item: Any, value: Any) -> None: previous_value = self.get(item) if isinstance(value, MutableMapping): value = ObservedDict(value, self.observer) - if isinstance(value, List): + if isinstance(value, list): for i, sub_value in enumerate(value): if isinstance(sub_value, MutableMapping): value[i] = ObservedDict(sub_value, self.observer) - super(ObservedDict, self).__setitem__(item, value) + super(ObservedDict, self).__setitem__(item, value) # noqa: UP008 if self.update_on_unchanged_value or value != previous_value: self.observer.update() @@ -77,7 +78,7 @@ def observe_connector_config( non_observed_connector_config: MutableMapping[str, Any], ) -> ObservedDict: if isinstance(non_observed_connector_config, ObservedDict): - raise ValueError("This connector configuration is already observed") + raise ValueError("This connector configuration is already observed") # noqa: TRY004 connector_config_observer = ConfigObserver() observed_connector_config = ObservedDict( non_observed_connector_config, connector_config_observer diff --git a/airbyte_cdk/connector.py b/airbyte_cdk/connector.py index 342ecee2d..e534e5d49 100644 --- a/airbyte_cdk/connector.py +++ b/airbyte_cdk/connector.py @@ -8,7 +8,8 @@ import os import pkgutil from abc import ABC, abstractmethod -from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar +from collections.abc import Mapping +from typing import Any, Generic, Protocol, TypeVar import yaml @@ -19,7 +20,7 @@ ) -def load_optional_package_file(package: str, filename: str) -> Optional[bytes]: +def load_optional_package_file(package: str, filename: str) -> bytes | None: """Gets a resource from a package, returning None if it does not exist""" try: return pkgutil.get_data(package, filename) @@ -45,29 +46,28 @@ def read_config(config_path: str) -> Mapping[str, Any]: config = BaseConnector._read_json_file(config_path) if isinstance(config, Mapping): return config - else: - raise ValueError( - f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config." - ) + raise ValueError( + f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config." + ) @staticmethod - def _read_json_file(file_path: str) -> Any: - with open(file_path, "r") as file: + def _read_json_file(file_path: str) -> Any: # noqa: ANN401 + with open(file_path) as file: # noqa: FURB101, PLW1514, PTH123 contents = file.read() try: return json.loads(contents) except json.JSONDecodeError as error: - raise ValueError( + raise ValueError( # noqa: B904 f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON." ) @staticmethod def write_config(config: TConfig, config_path: str) -> None: - with open(config_path, "w") as fh: + with open(config_path, "w") as fh: # noqa: FURB103, PLW1514, PTH123 fh.write(json.dumps(config)) - def spec(self, logger: logging.Logger) -> ConnectorSpecification: + def spec(self, logger: logging.Logger) -> ConnectorSpecification: # noqa: ARG002 """ Returns the spec for this integration. The spec is a JSON-Schema object describing the required configurations (e.g: username and password) required to run this integration. By default, this will be loaded from a "spec.yaml" or a "spec.json" in the package root. @@ -89,7 +89,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: try: spec_obj = json.loads(json_spec) except json.JSONDecodeError as error: - raise ValueError( + raise ValueError( # noqa: B904 f"Could not read json spec file: {error}. Please ensure that it is a valid JSON." ) else: @@ -115,7 +115,7 @@ class DefaultConnectorMixin: def configure( self: _WriteConfigProtocol, config: Mapping[str, Any], temp_dir: str ) -> Mapping[str, Any]: - config_path = os.path.join(temp_dir, "config.json") + config_path = os.path.join(temp_dir, "config.json") # noqa: PTH118 self.write_config(config, config_path) return config diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index b2a728570..028c97a66 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -3,8 +3,9 @@ # import dataclasses +from collections.abc import Mapping from datetime import datetime -from typing import Any, List, Mapping +from typing import Any from airbyte_cdk.connector_builder.message_grouper import MessageGrouper from airbyte_cdk.models import ( @@ -23,6 +24,7 @@ from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException + DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5 DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5 DEFAULT_MAXIMUM_RECORDS = 100 @@ -69,7 +71,7 @@ def read_stream( source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], + state: list[AirbyteStateMessage], limits: TestReadLimits, ) -> AirbyteMessage: try: @@ -90,7 +92,7 @@ def read_stream( error = AirbyteTracedException.from_exception( exc, message=filter_secrets( - f"Error reading stream with config={config} and catalog={configured_catalog}: {str(exc)}" + f"Error reading stream with config={config} and catalog={configured_catalog}: {exc!s}" ), ) return error.as_airbyte_message() @@ -108,7 +110,7 @@ def resolve_manifest(source: ManifestDeclarativeSource) -> AirbyteMessage: ) except Exception as exc: error = AirbyteTracedException.from_exception( - exc, message=f"Error resolving manifest: {str(exc)}" + exc, message=f"Error resolving manifest: {exc!s}" ) return error.as_airbyte_message() diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index e122cee8c..453c6fffb 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -4,7 +4,8 @@ import sys -from typing import Any, List, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any import orjson @@ -30,8 +31,8 @@ def get_config_and_catalog_from_args( - args: List[str], -) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]: + args: list[str], +) -> tuple[str, Mapping[str, Any], ConfiguredAirbyteCatalog | None, Any]: # TODO: Add functionality for the `debug` logger. # Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`. parsed_args = AirbyteEntrypoint.parse_args(args) @@ -70,22 +71,21 @@ def handle_connector_builder_request( source: ManifestDeclarativeSource, command: str, config: Mapping[str, Any], - catalog: Optional[ConfiguredAirbyteCatalog], - state: List[AirbyteStateMessage], + catalog: ConfiguredAirbyteCatalog | None, + state: list[AirbyteStateMessage], limits: TestReadLimits, ) -> AirbyteMessage: if command == "resolve_manifest": return resolve_manifest(source) - elif command == "test_read": - assert ( - catalog is not None - ), "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None." + if command == "test_read": + assert catalog is not None, ( + "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None." + ) return read_stream(source, config, catalog, state, limits) - else: - raise ValueError(f"Unrecognized command {command}.") + raise ValueError(f"Unrecognized command {command}.") -def handle_request(args: List[str]) -> str: +def handle_request(args: list[str]) -> str: command, config, catalog, state = get_config_and_catalog_from_args(args) limits = get_limits(config) source = create_source(config, limits) @@ -101,7 +101,7 @@ def handle_request(args: List[str]) -> str: print(handle_request(sys.argv[1:])) except Exception as exc: error = AirbyteTracedException.from_exception( - exc, message=f"Error handling request: {str(exc)}" + exc, message=f"Error handling request: {exc!s}" ) m = error.as_airbyte_message() print(orjson.dumps(AirbyteMessageSerializer.dump(m)).decode()) diff --git a/airbyte_cdk/connector_builder/message_grouper.py b/airbyte_cdk/connector_builder/message_grouper.py index ce43afab8..e2bf457ae 100644 --- a/airbyte_cdk/connector_builder/message_grouper.py +++ b/airbyte_cdk/connector_builder/message_grouper.py @@ -4,9 +4,10 @@ import json import logging +from collections.abc import Iterable, Iterator, Mapping from copy import deepcopy from json import JSONDecodeError -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union +from typing import Any from airbyte_cdk.connector_builder.models import ( AuxiliaryRequest, @@ -40,14 +41,14 @@ class MessageGrouper: logger = logging.getLogger("airbyte.connector-builder") - def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000): + def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000): # noqa: ANN204 self._max_pages_per_slice = max_pages_per_slice self._max_slices = max_slices self._max_record_limit = max_record_limit def _pk_to_nested_and_composite_field( - self, field: Optional[Union[str, List[str], List[List[str]]]] - ) -> List[List[str]]: + self, field: str | list[str] | list[list[str]] | None + ) -> list[list[str]]: if not field: return [[]] @@ -61,8 +62,8 @@ def _pk_to_nested_and_composite_field( return field # type: ignore # the type of field is expected to be List[List[str]] here def _cursor_field_to_nested_and_composite_field( - self, field: Union[str, List[str]] - ) -> List[List[str]]: + self, field: str | list[str] + ) -> list[list[str]]: if not field: return [[]] @@ -80,8 +81,8 @@ def get_message_groups( source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], - record_limit: Optional[int] = None, + state: list[AirbyteStateMessage], + record_limit: int | None = None, ) -> StreamRead: if record_limit is not None and not (1 <= record_limit <= self._max_record_limit): raise ValueError( @@ -113,20 +114,16 @@ def get_message_groups( ): if isinstance(message_group, AirbyteLogMessage): log_messages.append( - LogMessage( - **{"message": message_group.message, "level": message_group.level.value} - ) + LogMessage(message=message_group.message, level=message_group.level.value) ) elif isinstance(message_group, AirbyteTraceMessage): if message_group.type == TraceType.ERROR: log_messages.append( LogMessage( - **{ - "message": message_group.error.message, - "level": "ERROR", - "internal_message": message_group.error.internal_message, - "stacktrace": message_group.error.stack_trace, - } + message=message_group.error.message, + level="ERROR", + internal_message=message_group.error.internal_message, + stacktrace=message_group.error.stack_trace, ) ) elif isinstance(message_group, AirbyteControlMessage): @@ -140,7 +137,7 @@ def get_message_groups( elif isinstance(message_group, StreamReadSlices): slices.append(message_group) else: - raise ValueError(f"Unknown message group type: {type(message_group)}") + raise ValueError(f"Unknown message group type: {type(message_group)}") # noqa: TRY004 try: # The connector builder currently only supports reading from a single stream at a time @@ -148,7 +145,7 @@ def get_message_groups( schema = schema_inferrer.get_stream_schema(configured_stream.stream.name) except SchemaValidationException as exception: for validation_error in exception.validation_errors: - log_messages.append(LogMessage(validation_error, "ERROR")) + log_messages.append(LogMessage(validation_error, "ERROR")) # noqa: PERF401 schema = exception.schema return StreamRead( @@ -163,20 +160,18 @@ def get_message_groups( inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(), ) - def _get_message_groups( + def _get_message_groups( # noqa: PLR0912, PLR0915 self, messages: Iterator[AirbyteMessage], schema_inferrer: SchemaInferrer, datetime_format_inferrer: DatetimeFormatInferrer, limit: int, ) -> Iterable[ - Union[ - StreamReadPages, - AirbyteControlMessage, - AirbyteLogMessage, - AirbyteTraceMessage, - AuxiliaryRequest, - ] + StreamReadPages + | AirbyteControlMessage + | AirbyteLogMessage + | AirbyteTraceMessage + | AuxiliaryRequest ]: """ Message groups are partitioned according to when request log messages are received. Subsequent response log messages @@ -195,12 +190,12 @@ def _get_message_groups( """ records_count = 0 at_least_one_page_in_group = False - current_page_records: List[Mapping[str, Any]] = [] - current_slice_descriptor: Optional[Dict[str, Any]] = None - current_slice_pages: List[StreamReadPages] = [] - current_page_request: Optional[HttpRequest] = None - current_page_response: Optional[HttpResponse] = None - latest_state_message: Optional[Dict[str, Any]] = None + current_page_records: list[Mapping[str, Any]] = [] + current_slice_descriptor: dict[str, Any] | None = None + current_slice_pages: list[StreamReadPages] = [] + current_page_request: HttpRequest | None = None + current_page_response: HttpResponse | None = None + latest_state_message: dict[str, Any] | None = None while records_count < limit and (message := next(messages, None)): json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None @@ -208,7 +203,7 @@ def _get_message_groups( raise ValueError( f"Expected log message to be a dict, got {json_object} of type {type(json_object)}" ) - json_message: Optional[Dict[str, JsonType]] = json_object + json_message: dict[str, JsonType] | None = json_object if self._need_to_close_page(at_least_one_page_in_group, message, json_message): self._close_page( current_page_request, @@ -285,25 +280,24 @@ def _get_message_groups( yield message.control elif message.type == MessageType.STATE: latest_state_message = message.state # type: ignore[assignment] - else: - if current_page_request or current_page_response or current_page_records: - self._close_page( - current_page_request, - current_page_response, - current_slice_pages, - current_page_records, - ) - yield StreamReadSlices( - pages=current_slice_pages, - slice_descriptor=current_slice_descriptor, - state=[latest_state_message] if latest_state_message else [], - ) + if current_page_request or current_page_response or current_page_records: + self._close_page( + current_page_request, + current_page_response, + current_slice_pages, + current_page_records, + ) + yield StreamReadSlices( + pages=current_slice_pages, + slice_descriptor=current_slice_descriptor, + state=[latest_state_message] if latest_state_message else [], + ) @staticmethod def _need_to_close_page( - at_least_one_page_in_group: bool, + at_least_one_page_in_group: bool, # noqa: FBT001 message: AirbyteMessage, - json_message: Optional[Dict[str, Any]], + json_message: dict[str, Any] | None, ) -> bool: return ( at_least_one_page_in_group @@ -315,20 +309,19 @@ def _need_to_close_page( ) @staticmethod - def _is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool: + def _is_page_http_request(json_message: dict[str, Any] | None) -> bool: if not json_message: return False - else: - return MessageGrouper._is_http_log( - json_message - ) and not MessageGrouper._is_auxiliary_http_request(json_message) + return MessageGrouper._is_http_log( + json_message + ) and not MessageGrouper._is_auxiliary_http_request(json_message) @staticmethod - def _is_http_log(message: Dict[str, JsonType]) -> bool: - return bool(message.get("http", False)) + def _is_http_log(message: dict[str, JsonType]) -> bool: + return bool(message.get("http")) @staticmethod - def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool: + def _is_auxiliary_http_request(message: dict[str, Any] | None) -> bool: """ A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried. A couple of examples are: @@ -343,10 +336,10 @@ def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool: @staticmethod def _close_page( - current_page_request: Optional[HttpRequest], - current_page_response: Optional[HttpResponse], - current_slice_pages: List[StreamReadPages], - current_page_records: List[Mapping[str, Any]], + current_page_request: HttpRequest | None, + current_page_response: HttpResponse | None, + current_slice_pages: list[StreamReadPages], + current_page_records: list[Mapping[str, Any]], ) -> None: """ Close a page when parsing message groups @@ -365,7 +358,7 @@ def _read_stream( source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], + state: list[AirbyteStateMessage], ) -> Iterator[AirbyteMessage]: # the generator can raise an exception # iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage @@ -398,12 +391,12 @@ def _parse_json(log_message: AirbyteLogMessage) -> JsonType: # protocol change is worked on. try: json_object: JsonType = json.loads(log_message.message) - return json_object + return json_object # noqa: TRY300 except JSONDecodeError: return None @staticmethod - def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpRequest: + def _create_request_from_log_message(json_http_message: dict[str, Any]) -> HttpRequest: url = json_http_message.get("url", {}).get("full", "") request = json_http_message.get("http", {}).get("request", {}) return HttpRequest( @@ -414,14 +407,14 @@ def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpR ) @staticmethod - def _create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse: + def _create_response_from_log_message(json_http_message: dict[str, Any]) -> HttpResponse: response = json_http_message.get("http", {}).get("response", {}) body = response.get("body", {}).get("content", "") return HttpResponse( status=response.get("status_code"), body=body, headers=response.get("headers") ) - def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: + def _has_reached_limit(self, slices: list[StreamReadSlices]) -> bool: if len(slices) >= self._max_slices: return True @@ -436,13 +429,13 @@ def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: return True return False - def _parse_slice_description(self, log_message: str) -> Dict[str, Any]: + def _parse_slice_description(self, log_message: str) -> dict[str, Any]: return json.loads(log_message.replace(SliceLogger.SLICE_LOG_PREFIX, "", 1)) # type: ignore @staticmethod - def _clean_config(config: Dict[str, Any]) -> Dict[str, Any]: + def _clean_config(config: dict[str, Any]) -> dict[str, Any]: cleaned_config = deepcopy(config) - for key in config.keys(): + for key in config: if key.startswith("__"): del cleaned_config[key] return cleaned_config diff --git a/airbyte_cdk/connector_builder/models.py b/airbyte_cdk/connector_builder/models.py index 50eb8eb95..f4bf28cdd 100644 --- a/airbyte_cdk/connector_builder/models.py +++ b/airbyte_cdk/connector_builder/models.py @@ -3,44 +3,44 @@ # from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any @dataclass class HttpResponse: status: int - body: Optional[str] = None - headers: Optional[Dict[str, Any]] = None + body: str | None = None + headers: dict[str, Any] | None = None @dataclass class HttpRequest: url: str - headers: Optional[Dict[str, Any]] + headers: dict[str, Any] | None http_method: str - body: Optional[str] = None + body: str | None = None @dataclass class StreamReadPages: - records: List[object] - request: Optional[HttpRequest] = None - response: Optional[HttpResponse] = None + records: list[object] + request: HttpRequest | None = None + response: HttpResponse | None = None @dataclass class StreamReadSlices: - pages: List[StreamReadPages] - slice_descriptor: Optional[Dict[str, Any]] - state: Optional[List[Dict[str, Any]]] = None + pages: list[StreamReadPages] + slice_descriptor: dict[str, Any] | None + state: list[dict[str, Any]] | None = None @dataclass class LogMessage: message: str level: str - internal_message: Optional[str] = None - stacktrace: Optional[str] = None + internal_message: str | None = None + stacktrace: str | None = None @dataclass @@ -52,20 +52,20 @@ class AuxiliaryRequest: @dataclass -class StreamRead(object): - logs: List[LogMessage] - slices: List[StreamReadSlices] +class StreamRead: + logs: list[LogMessage] + slices: list[StreamReadSlices] test_read_limit_reached: bool - auxiliary_requests: List[AuxiliaryRequest] - inferred_schema: Optional[Dict[str, Any]] - inferred_datetime_formats: Optional[Dict[str, str]] - latest_config_update: Optional[Dict[str, Any]] + auxiliary_requests: list[AuxiliaryRequest] + inferred_schema: dict[str, Any] | None + inferred_datetime_formats: dict[str, str] | None + latest_config_update: dict[str, Any] | None @dataclass class StreamReadRequestBody: - manifest: Dict[str, Any] + manifest: dict[str, Any] stream: str - config: Dict[str, Any] - state: Optional[Dict[str, Any]] - record_limit: Optional[int] + config: dict[str, Any] + state: dict[str, Any] | None + record_limit: int | None diff --git a/airbyte_cdk/destinations/__init__.py b/airbyte_cdk/destinations/__init__.py index 3a641025b..6b8bbe8f0 100644 --- a/airbyte_cdk/destinations/__init__.py +++ b/airbyte_cdk/destinations/__init__.py @@ -3,6 +3,7 @@ from .destination import Destination + __all__ = [ "Destination", ] diff --git a/airbyte_cdk/destinations/destination.py b/airbyte_cdk/destinations/destination.py index 547f96684..10113b3a0 100644 --- a/airbyte_cdk/destinations/destination.py +++ b/airbyte_cdk/destinations/destination.py @@ -7,7 +7,8 @@ import logging import sys from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Mapping +from collections.abc import Iterable, Mapping +from typing import Any import orjson @@ -23,11 +24,12 @@ from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit from airbyte_cdk.utils.traced_exception import AirbyteTracedException + logger = logging.getLogger("airbyte") class Destination(Connector, ABC): - VALID_CMDS = {"spec", "check", "write"} + VALID_CMDS = {"spec", "check", "write"} # noqa: RUF012 @abstractmethod def write( @@ -59,7 +61,7 @@ def _run_write( input_stream: io.TextIOWrapper, ) -> Iterable[AirbyteMessage]: catalog = ConfiguredAirbyteCatalogSerializer.load( - orjson.loads(open(configured_catalog_path).read()) + orjson.loads(open(configured_catalog_path).read()) # noqa: PLW1514, PTH123, SIM115 ) input_messages = self._parse_input_stream(input_stream) logger.info("Begin writing to the destination...") @@ -68,7 +70,7 @@ def _run_write( ) logger.info("Writing complete.") - def parse_args(self, args: List[str]) -> argparse.Namespace: + def parse_args(self, args: list[str]) -> argparse.Namespace: """ :param args: commandline arguments :return: @@ -107,18 +109,18 @@ def parse_args(self, args: List[str]) -> argparse.Namespace: parsed_args = main_parser.parse_args(args) cmd = parsed_args.command if not cmd: - raise Exception("No command entered. ") - elif cmd not in ["spec", "check", "write"]: + raise Exception("No command entered. ") # noqa: TRY002 + if cmd not in ["spec", "check", "write"]: # This is technically dead code since parse_args() would fail if this was the case # But it's non-obvious enough to warrant placing it here anyways - raise Exception(f"Unknown command entered: {cmd}") + raise Exception(f"Unknown command entered: {cmd}") # noqa: TRY002 return parsed_args def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: cmd = parsed_args.command if cmd not in self.VALID_CMDS: - raise Exception(f"Unrecognized command: {cmd}") + raise Exception(f"Unrecognized command: {cmd}") # noqa: TRY002 spec = self.spec(logger) if cmd == "spec": @@ -133,7 +135,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: if connection_status and cmd == "check": yield connection_status return - raise traced_exc + raise traced_exc # noqa: TRY201 if cmd == "check": yield self._run_check(config=config) @@ -146,7 +148,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: input_stream=wrapped_stdin, ) - def run(self, args: List[str]) -> None: + def run(self, args: list[str]) -> None: init_uncaught_exception_handler(logger) parsed_args = self.parse_args(args) output_messages = self.run_cmd(parsed_args) diff --git a/airbyte_cdk/destinations/vector_db_based/__init__.py b/airbyte_cdk/destinations/vector_db_based/__init__.py index 86ae207f6..696bd2bf5 100644 --- a/airbyte_cdk/destinations/vector_db_based/__init__.py +++ b/airbyte_cdk/destinations/vector_db_based/__init__.py @@ -16,8 +16,8 @@ from .indexer import Indexer from .writer import Writer + __all__ = [ - "AzureOpenAIEmbedder", "AzureOpenAIEmbeddingConfigModel", "Chunk", "CohereEmbedder", @@ -26,10 +26,8 @@ "Embedder", "FakeEmbedder", "FakeEmbeddingConfigModel", - "FromFieldEmbedder", "FromFieldEmbeddingConfigModel", "Indexer", - "OpenAICompatibleEmbedder", "OpenAICompatibleEmbeddingConfigModel", "OpenAIEmbedder", "OpenAIEmbeddingConfigModel", diff --git a/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte_cdk/destinations/vector_db_based/config.py index c7c40ecaa..e2e412139 100644 --- a/airbyte_cdk/destinations/vector_db_based/config.py +++ b/airbyte_cdk/destinations/vector_db_based/config.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Union import dpath from pydantic.v1 import BaseModel, Field @@ -13,7 +13,7 @@ class SeparatorSplitterConfigModel(BaseModel): mode: Literal["separator"] = Field("separator", const=True) - separators: List[str] = Field( + separators: list[str] = Field( default=['"\\n\\n"', '"\\n"', '" "', '""'], title="Separators", description='List of separator strings to split text fields by. The separator itself needs to be wrapped in double quotes, e.g. to split by the dot character, use ".". To split by a newline, use "\\n".', @@ -77,7 +77,7 @@ class Config(OneOfOptionConfig): discriminator = "mode" -TextSplitterConfigModel = Union[ +TextSplitterConfigModel = Union[ # noqa: UP007 SeparatorSplitterConfigModel, MarkdownHeaderSplitterConfigModel, CodeSplitterConfigModel ] @@ -102,14 +102,14 @@ class ProcessingConfigModel(BaseModel): description="Size of overlap between chunks in tokens to store in vector store to better capture relevant context", default=0, ) - text_fields: Optional[List[str]] = Field( + text_fields: list[str] | None = Field( default=[], title="Text fields to embed", description="List of fields in the record that should be used to calculate the embedding. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered text fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array.", always_show=True, examples=["text", "user.name", "users.*.name"], ) - metadata_fields: Optional[List[str]] = Field( + metadata_fields: list[str] | None = Field( default=[], title="Fields to store as metadata", description="List of fields in the record that should be stored as metadata. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered metadata fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array. When specifying nested paths, all matching values are flattened into an array set to a field named by the path.", @@ -123,14 +123,14 @@ class ProcessingConfigModel(BaseModel): type="object", description="Split text fields into chunks based on the specified method.", ) - field_name_mappings: Optional[List[FieldNameMappingConfigModel]] = Field( + field_name_mappings: list[FieldNameMappingConfigModel] | None = Field( default=[], title="Field name mappings", description="List of fields to rename. Not applicable for nested fields, but can be used to rename fields already flattened via dot notation.", ) class Config: - schema_extra = {"group": "processing"} + schema_extra = {"group": "processing"} # noqa: RUF012 class OpenAIEmbeddingConfigModel(BaseModel): @@ -251,13 +251,13 @@ class VectorDBConfigModel(BaseModel): Processing, embedding and advanced configuration are provided by this base class, while the indexing configuration is provided by the destination connector in the sub class. """ - embedding: Union[ - OpenAIEmbeddingConfigModel, - CohereEmbeddingConfigModel, - FakeEmbeddingConfigModel, - AzureOpenAIEmbeddingConfigModel, - OpenAICompatibleEmbeddingConfigModel, - ] = Field( + embedding: ( + OpenAIEmbeddingConfigModel + | CohereEmbeddingConfigModel + | FakeEmbeddingConfigModel + | AzureOpenAIEmbeddingConfigModel + | OpenAICompatibleEmbeddingConfigModel + ) = Field( ..., title="Embedding", description="Embedding configuration", @@ -275,7 +275,7 @@ class VectorDBConfigModel(BaseModel): class Config: title = "Destination Config" - schema_extra = { + schema_extra = { # noqa: RUF012 "groups": [ {"id": "processing", "title": "Processing"}, {"id": "embedding", "title": "Embedding"}, @@ -285,14 +285,14 @@ class Config: } @staticmethod - def remove_discriminator(schema: Dict[str, Any]) -> None: + def remove_discriminator(schema: dict[str, Any]) -> None: """pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" dpath.delete(schema, "properties/**/discriminator") @classmethod - def schema(cls, by_alias: bool = True, ref_template: str = "") -> Dict[str, Any]: + def schema(cls, by_alias: bool = True, ref_template: str = "") -> dict[str, Any]: # noqa: FBT001, FBT002, ARG003 """we're overriding the schema classmethod to enable some post-processing""" - schema: Dict[str, Any] = super().schema() + schema: dict[str, Any] = super().schema() schema = resolve_refs(schema) cls.remove_discriminator(schema) return schema diff --git a/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte_cdk/destinations/vector_db_based/document_processor.py index c007bf9e2..fdd8a2244 100644 --- a/airbyte_cdk/destinations/vector_db_based/document_processor.py +++ b/airbyte_cdk/destinations/vector_db_based/document_processor.py @@ -4,8 +4,9 @@ import json import logging +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any import dpath from langchain.text_splitter import Language, RecursiveCharacterTextSplitter @@ -26,6 +27,7 @@ ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType + METADATA_STREAM_FIELD = "_ab_stream" METADATA_RECORD_ID_FIELD = "_ab_record_id" @@ -34,10 +36,10 @@ @dataclass class Chunk: - page_content: Optional[str] - metadata: Dict[str, Any] + page_content: str | None + metadata: dict[str, Any] record: AirbyteRecordMessage - embedding: Optional[List[float]] = None + embedding: list[float] | None = None headers_to_split_on = [ @@ -69,7 +71,7 @@ class DocumentProcessor: streams: Mapping[str, ConfiguredAirbyteStream] @staticmethod - def check_config(config: ProcessingConfigModel) -> Optional[str]: + def check_config(config: ProcessingConfigModel) -> str | None: if config.text_splitter is not None and config.text_splitter.mode == "separator": for s in config.text_splitter.separators: try: @@ -80,11 +82,11 @@ def check_config(config: ProcessingConfigModel) -> Optional[str]: return f"Invalid separator: {s}. Separator needs to be a valid JSON string using double quotes." return None - def _get_text_splitter( + def _get_text_splitter( # noqa: RET503 self, chunk_size: int, chunk_overlap: int, - splitter_config: Optional[TextSplitterConfigModel], + splitter_config: TextSplitterConfigModel | None, ) -> RecursiveCharacterTextSplitter: if splitter_config is None: splitter_config = SeparatorSplitterConfigModel(mode="separator") @@ -105,7 +107,7 @@ def _get_text_splitter( keep_separator=True, disallowed_special=(), ) - if splitter_config.mode == "code": + if splitter_config.mode == "code": # noqa: RET503, RUF100 return RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=chunk_size, chunk_overlap=chunk_overlap, @@ -115,7 +117,7 @@ def _get_text_splitter( disallowed_special=(), ) - def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog): + def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog): # noqa: ANN204 self.streams = { create_stream_identifier(stream.stream): stream for stream in catalog.streams } @@ -128,13 +130,13 @@ def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCata self.field_name_mappings = config.field_name_mappings self.logger = logging.getLogger("airbyte.document_processor") - def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[str]]: + def process(self, record: AirbyteRecordMessage) -> tuple[list[Chunk], str | None]: """ Generate documents from records. :param records: List of AirbyteRecordMessages :return: Tuple of (List of document chunks, record id to delete if a stream is in dedup mode to avoid stale documents in the vector store) """ - if CDC_DELETED_FIELD in record.data and record.data[CDC_DELETED_FIELD]: + if record.data.get(CDC_DELETED_FIELD): return [], self._extract_primary_key(record) doc = self._generate_document(record) if doc is None: @@ -153,13 +155,13 @@ def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[s for chunk_document in self._split_document(doc) ] id_to_delete = ( - doc.metadata[METADATA_RECORD_ID_FIELD] + doc.metadata[METADATA_RECORD_ID_FIELD] # noqa: SIM401 if METADATA_RECORD_ID_FIELD in doc.metadata else None ) return chunks, id_to_delete - def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]: + def _generate_document(self, record: AirbyteRecordMessage) -> Document | None: relevant_fields = self._extract_relevant_fields(record, self.text_fields) if len(relevant_fields) == 0: return None @@ -168,8 +170,8 @@ def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document] return Document(page_content=text, metadata=metadata) def _extract_relevant_fields( - self, record: AirbyteRecordMessage, fields: Optional[List[str]] - ) -> Dict[str, Any]: + self, record: AirbyteRecordMessage, fields: list[str] | None + ) -> dict[str, Any]: relevant_fields = {} if fields and len(fields) > 0: for field in fields: @@ -180,7 +182,7 @@ def _extract_relevant_fields( relevant_fields = record.data return self._remap_field_names(relevant_fields) - def _extract_metadata(self, record: AirbyteRecordMessage) -> Dict[str, Any]: + def _extract_metadata(self, record: AirbyteRecordMessage) -> dict[str, Any]: metadata = self._extract_relevant_fields(record, self.metadata_fields) metadata[METADATA_STREAM_FIELD] = create_stream_identifier(record) primary_key = self._extract_primary_key(record) @@ -188,7 +190,7 @@ def _extract_metadata(self, record: AirbyteRecordMessage) -> Dict[str, Any]: metadata[METADATA_RECORD_ID_FIELD] = primary_key return metadata - def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]: + def _extract_primary_key(self, record: AirbyteRecordMessage) -> str | None: stream_identifier = create_stream_identifier(record) current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier] # if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones @@ -207,11 +209,11 @@ def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]: stringified_primary_key = "_".join(primary_key) return f"{stream_identifier}_{stringified_primary_key}" - def _split_document(self, doc: Document) -> List[Document]: - chunks: List[Document] = self.splitter.split_documents([doc]) + def _split_document(self, doc: Document) -> list[Document]: + chunks: list[Document] = self.splitter.split_documents([doc]) return chunks - def _remap_field_names(self, fields: Dict[str, Any]) -> Dict[str, Any]: + def _remap_field_names(self, fields: dict[str, Any]) -> dict[str, Any]: if not self.field_name_mappings: return fields diff --git a/airbyte_cdk/destinations/vector_db_based/embedder.py b/airbyte_cdk/destinations/vector_db_based/embedder.py index 6889c8e16..5d491d7d5 100644 --- a/airbyte_cdk/destinations/vector_db_based/embedder.py +++ b/airbyte_cdk/destinations/vector_db_based/embedder.py @@ -5,7 +5,7 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import cast from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.fake import FakeEmbeddings @@ -41,15 +41,15 @@ class Embedder(ABC): The CDK defines basic embedders that should be supported in each destination. It is possible to implement custom embedders for special destinations if needed. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: B027 pass @abstractmethod - def check(self) -> Optional[str]: + def check(self) -> str | None: pass @abstractmethod - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: """ Embed the text of each chunk and return the resulting embedding vectors. If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk. @@ -68,19 +68,19 @@ def embedding_dimensions(self) -> int: class BaseOpenAIEmbedder(Embedder): - def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int): + def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int): # noqa: ANN204 super().__init__() self.embeddings = embeddings self.chunk_size = chunk_size - def check(self) -> Optional[str]: + def check(self) -> str | None: try: self.embeddings.embed_query("test") except Exception as e: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: """ Embed the text of each chunk and return the resulting embedding vectors. @@ -91,7 +91,7 @@ def embed_documents(self, documents: List[Document]) -> List[Optional[List[float # Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size batches = create_chunks(documents, batch_size=embedding_batch_size) - embeddings: List[Optional[List[float]]] = [] + embeddings: list[list[float] | None] = [] for batch in batches: embeddings.extend( self.embeddings.embed_documents([chunk.page_content for chunk in batch]) @@ -105,7 +105,7 @@ def embedding_dimensions(self) -> int: class OpenAIEmbedder(BaseOpenAIEmbedder): - def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int): + def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int): # noqa: ANN204 super().__init__( OpenAIEmbeddings( # type: ignore [call-arg] openai_api_key=config.openai_key, max_retries=15, disallowed_special=() @@ -115,7 +115,7 @@ def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int): class AzureOpenAIEmbedder(BaseOpenAIEmbedder): - def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int): + def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int): # noqa: ANN204 # Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request super().__init__( OpenAIEmbeddings( # type: ignore [call-arg] @@ -136,23 +136,23 @@ def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int): class CohereEmbedder(Embedder): - def __init__(self, config: CohereEmbeddingConfigModel): + def __init__(self, config: CohereEmbeddingConfigModel): # noqa: ANN204 super().__init__() # Client is set internally self.embeddings = CohereEmbeddings( cohere_api_key=config.cohere_key, model="embed-english-light-v2.0" ) # type: ignore - def check(self) -> Optional[str]: + def check(self) -> str | None: try: self.embeddings.embed_query("test") except Exception as e: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: return cast( - List[Optional[List[float]]], + list[list[float] | None], # noqa: TC006 self.embeddings.embed_documents([document.page_content for document in documents]), ) @@ -163,20 +163,20 @@ def embedding_dimensions(self) -> int: class FakeEmbedder(Embedder): - def __init__(self, config: FakeEmbeddingConfigModel): + def __init__(self, config: FakeEmbeddingConfigModel): # noqa: ANN204, ARG002 super().__init__() self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE) - def check(self) -> Optional[str]: + def check(self) -> str | None: try: self.embeddings.embed_query("test") except Exception as e: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: return cast( - List[Optional[List[float]]], + list[list[float] | None], # noqa: TC006 self.embeddings.embed_documents([document.page_content for document in documents]), ) @@ -190,7 +190,7 @@ def embedding_dimensions(self) -> int: class OpenAICompatibleEmbedder(Embedder): - def __init__(self, config: OpenAICompatibleEmbeddingConfigModel): + def __init__(self, config: OpenAICompatibleEmbeddingConfigModel): # noqa: ANN204 super().__init__() self.config = config # Client is set internally @@ -203,7 +203,7 @@ def __init__(self, config: OpenAICompatibleEmbeddingConfigModel): disallowed_special=(), ) # type: ignore - def check(self) -> Optional[str]: + def check(self) -> str | None: deployment_mode = os.environ.get("DEPLOYMENT_MODE", "") if ( deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE @@ -217,9 +217,9 @@ def check(self) -> Optional[str]: return format_exception(e) return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: return cast( - List[Optional[List[float]]], + list[list[float] | None], # noqa: TC006 self.embeddings.embed_documents([document.page_content for document in documents]), ) @@ -230,19 +230,19 @@ def embedding_dimensions(self) -> int: class FromFieldEmbedder(Embedder): - def __init__(self, config: FromFieldEmbeddingConfigModel): + def __init__(self, config: FromFieldEmbeddingConfigModel): # noqa: ANN204 super().__init__() self.config = config - def check(self) -> Optional[str]: + def check(self) -> str | None: return None - def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: list[Document]) -> list[list[float] | None]: """ From each chunk, pull the embedding from the field specified in the config. Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem. """ - embeddings: List[Optional[List[float]]] = [] + embeddings: list[list[float] | None] = [] for document in documents: data = document.record.data if self.config.field_name not in data: @@ -252,7 +252,7 @@ def embed_documents(self, documents: List[Document]) -> List[Optional[List[float message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.", ) field = data[self.config.field_name] - if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field): + if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field): # noqa: UP038 raise AirbyteTracedException( internal_message="Embedding vector field not a list of numbers", failure_type=FailureType.config_error, @@ -284,20 +284,17 @@ def embedding_dimensions(self) -> int: def create_from_config( - embedding_config: Union[ - AzureOpenAIEmbeddingConfigModel, - CohereEmbeddingConfigModel, - FakeEmbeddingConfigModel, - FromFieldEmbeddingConfigModel, - OpenAIEmbeddingConfigModel, - OpenAICompatibleEmbeddingConfigModel, - ], + embedding_config: AzureOpenAIEmbeddingConfigModel + | CohereEmbeddingConfigModel + | FakeEmbeddingConfigModel + | FromFieldEmbeddingConfigModel + | OpenAIEmbeddingConfigModel + | OpenAICompatibleEmbeddingConfigModel, processing_config: ProcessingConfigModel, ) -> Embedder: - if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai": + if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai": # noqa: PLR1714 return cast( - Embedder, + Embedder, # noqa: TC006 embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size), ) - else: - return cast(Embedder, embedder_map[embedding_config.mode](embedding_config)) + return cast(Embedder, embedder_map[embedding_config.mode](embedding_config)) # noqa: TC006 diff --git a/airbyte_cdk/destinations/vector_db_based/indexer.py b/airbyte_cdk/destinations/vector_db_based/indexer.py index c49f576a6..74804b75c 100644 --- a/airbyte_cdk/destinations/vector_db_based/indexer.py +++ b/airbyte_cdk/destinations/vector_db_based/indexer.py @@ -4,7 +4,8 @@ import itertools from abc import ABC, abstractmethod -from typing import Any, Generator, Iterable, List, Optional, Tuple, TypeVar +from collections.abc import Generator, Iterable +from typing import Any, TypeVar from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog @@ -18,11 +19,11 @@ class Indexer(ABC): In a destination connector, implement a custom indexer by extending this class and implementing the abstract methods. """ - def __init__(self, config: Any): + def __init__(self, config: Any): # noqa: ANN204, ANN401 self.config = config pass - def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None: + def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None: # noqa: B027 """ Run before the sync starts. This method should be used to make sure all records in the destination that belong to streams with a destination mode of overwrite are deleted. @@ -31,14 +32,14 @@ def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None: """ pass - def post_sync(self) -> List[AirbyteMessage]: + def post_sync(self) -> list[AirbyteMessage]: """ Run after the sync finishes. This method should be used to perform any cleanup operations and can return a list of AirbyteMessages to be logged. """ return [] @abstractmethod - def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> None: + def index(self, document_chunks: list[Chunk], namespace: str, stream: str) -> None: """ Index a list of document chunks. @@ -48,7 +49,7 @@ def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> No pass @abstractmethod - def delete(self, delete_ids: List[str], namespace: str, stream: str) -> None: + def delete(self, delete_ids: list[str], namespace: str, stream: str) -> None: """ Delete document chunks belonging to certain record ids. @@ -59,7 +60,7 @@ def delete(self, delete_ids: List[str], namespace: str, stream: str) -> None: pass @abstractmethod - def check(self) -> Optional[str]: + def check(self) -> str | None: """ Check if the indexer is configured correctly. This method should be used to check if the indexer is configured correctly and return an error message if it is not. """ @@ -69,7 +70,7 @@ def check(self) -> Optional[str]: T = TypeVar("T") -def chunks(iterable: Iterable[T], batch_size: int) -> Generator[Tuple[T, ...], None, None]: +def chunks(iterable: Iterable[T], batch_size: int) -> Generator[tuple[T, ...], None, None]: """A helper function to break an iterable into chunks of size batch_size.""" it = iter(iterable) chunk = tuple(itertools.islice(it, batch_size)) diff --git a/airbyte_cdk/destinations/vector_db_based/test_utils.py b/airbyte_cdk/destinations/vector_db_based/test_utils.py index a2f3d3d83..8f29b386a 100644 --- a/airbyte_cdk/destinations/vector_db_based/test_utils.py +++ b/airbyte_cdk/destinations/vector_db_based/test_utils.py @@ -4,7 +4,7 @@ import json import unittest -from typing import Any, Dict +from typing import Any from airbyte_cdk.models import ( AirbyteMessage, @@ -47,7 +47,7 @@ def _get_configured_catalog( return ConfiguredAirbyteCatalog(streams=[overwrite_stream]) - def _state(self, data: Dict[str, Any]) -> AirbyteMessage: + def _state(self, data: dict[str, Any]) -> AirbyteMessage: return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=data)) def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage: @@ -59,5 +59,5 @@ def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage ) def setUp(self) -> None: - with open("secrets/config.json", "r") as f: + with open("secrets/config.json") as f: # noqa: FURB101, PLW1514, PTH123 self.config = json.loads(f.read()) diff --git a/airbyte_cdk/destinations/vector_db_based/utils.py b/airbyte_cdk/destinations/vector_db_based/utils.py index dbb1f4714..0c77dc2df 100644 --- a/airbyte_cdk/destinations/vector_db_based/utils.py +++ b/airbyte_cdk/destinations/vector_db_based/utils.py @@ -4,7 +4,8 @@ import itertools import traceback -from typing import Any, Iterable, Iterator, Tuple, Union +from collections.abc import Iterable, Iterator +from typing import Any from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStream @@ -17,7 +18,7 @@ def format_exception(exception: Exception) -> str: ) -def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[Any, ...]]: +def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[tuple[Any, ...]]: """A helper function to break an iterable into chunks of size batch_size.""" it = iter(iterable) chunk = tuple(itertools.islice(it, batch_size)) @@ -26,10 +27,7 @@ def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[An chunk = tuple(itertools.islice(it, batch_size)) -def create_stream_identifier(stream: Union[AirbyteStream, AirbyteRecordMessage]) -> str: +def create_stream_identifier(stream: AirbyteStream | AirbyteRecordMessage) -> str: if isinstance(stream, AirbyteStream): return str(stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}") - else: - return str( - stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}" - ) + return str(stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}") diff --git a/airbyte_cdk/destinations/vector_db_based/writer.py b/airbyte_cdk/destinations/vector_db_based/writer.py index 45c7c7326..14f13cbdd 100644 --- a/airbyte_cdk/destinations/vector_db_based/writer.py +++ b/airbyte_cdk/destinations/vector_db_based/writer.py @@ -4,7 +4,7 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Tuple +from collections.abc import Iterable from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor @@ -32,7 +32,7 @@ def __init__( indexer: Indexer, embedder: Embedder, batch_size: int, - omit_raw_text: bool, + omit_raw_text: bool, # noqa: FBT001 ) -> None: self.processing_config = processing_config self.indexer = indexer @@ -42,8 +42,8 @@ def __init__( self._init_batch() def _init_batch(self) -> None: - self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list) - self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list) + self.chunks: dict[tuple[str, str], list[Chunk]] = defaultdict(list) + self.ids_to_delete: dict[tuple[str, str], list[str]] = defaultdict(list) self.number_of_chunks = 0 def _convert_to_document(self, chunk: Chunk) -> Document: @@ -59,9 +59,9 @@ def _process_batch(self) -> None: self.indexer.delete(ids, namespace, stream) for (namespace, stream), chunks in self.chunks.items(): - embeddings = self.embedder.embed_documents( - [self._convert_to_document(chunk) for chunk in chunks] - ) + embeddings = self.embedder.embed_documents([ + self._convert_to_document(chunk) for chunk in chunks + ]) for i, document in enumerate(chunks): document.embedding = embeddings[i] if self.omit_raw_text: @@ -84,14 +84,14 @@ def write( elif message.type == Type.RECORD: record_chunks, record_id_to_delete = self.processor.process(message.record) self.chunks[ - ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]" + ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]" # noqa: RUF031 message.record.namespace, # type: ignore [union-attr] # record not None message.record.stream, # type: ignore [union-attr] # record not None ) ].extend(record_chunks) if record_id_to_delete is not None: self.ids_to_delete[ - ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]" + ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]" # noqa: RUF031 message.record.namespace, # type: ignore [union-attr] # record not None message.record.stream, # type: ignore [union-attr] # record not None ) diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py index a5052a575..092025f7d 100644 --- a/airbyte_cdk/entrypoint.py +++ b/airbyte_cdk/entrypoint.py @@ -12,8 +12,9 @@ import sys import tempfile from collections import defaultdict +from collections.abc import Iterable, Mapping from functools import wraps -from typing import Any, DefaultDict, Iterable, List, Mapping, Optional +from typing import Any from urllib.parse import urlparse import orjson @@ -43,6 +44,7 @@ from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH from airbyte_cdk.utils.traced_exception import AirbyteTracedException + logger = init_logger("airbyte") VALID_URL_SCHEMES = ["https"] @@ -50,8 +52,8 @@ _HAS_LOGGED_FOR_SERIALIZATION_ERROR = False -class AirbyteEntrypoint(object): - def __init__(self, source: Source): +class AirbyteEntrypoint: + def __init__(self, source: Source): # noqa: ANN204 init_uncaught_exception_handler(logger) # Deployment mode is read when instantiating the entrypoint because it is the common path shared by syncs and connector builder test requests @@ -62,7 +64,7 @@ def __init__(self, source: Source): self.logger = logging.getLogger(f"airbyte.{getattr(source, 'name', '')}") @staticmethod - def parse_args(args: List[str]) -> argparse.Namespace: + def parse_args(args: list[str]) -> argparse.Namespace: # set up parent parsers parent_parser = argparse.ArgumentParser(add_help=False) parent_parser.add_argument( @@ -120,7 +122,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: def run(self, parsed_args: argparse.Namespace) -> Iterable[str]: cmd = parsed_args.command if not cmd: - raise Exception("No command passed") + raise Exception("No command passed") # noqa: TRY002 if hasattr(parsed_args, "debug") and parsed_args.debug: self.logger.setLevel(logging.DEBUG) @@ -173,7 +175,7 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]: self.read(source_spec, config, config_catalog, state), ) else: - raise Exception("Unexpected command " + cmd) + raise Exception("Unexpected command " + cmd) # noqa: TRY002 finally: yield from [ self.airbyte_message_to_string(queued_message) @@ -191,7 +193,7 @@ def check( # The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error # If the failure is not exceptional, we'll emit a failed connection status message and return if traced_exc.failure_type != FailureType.config_error: - raise traced_exc + raise traced_exc # noqa: TRY201 if connection_status: yield from self._emit_queued_messages(self.source) yield connection_status @@ -204,7 +206,7 @@ def check( # The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error # If the failure is not exceptional, we'll emit a failed connection status message and return if traced_exc.failure_type != FailureType.config_error: - raise traced_exc + raise traced_exc # noqa: TRY201 else: yield AirbyteMessage( type=Type.CONNECTION_STATUS, @@ -233,14 +235,18 @@ def discover( yield AirbyteMessage(type=Type.CATALOG, catalog=catalog) def read( - self, source_spec: ConnectorSpecification, config: TConfig, catalog: Any, state: list[Any] + self, + source_spec: ConnectorSpecification, + config: TConfig, + catalog: Any, # noqa: ANN401 + state: list[Any], ) -> Iterable[AirbyteMessage]: self.set_up_secret_filter(config, source_spec.connectionSpecification) if self.source.check_config_against_spec: self.validate_connection(source_spec, config) # The Airbyte protocol dictates that counts be expressed as float/double to better protect against integer overflows - stream_message_counter: DefaultDict[HashableStreamDescriptor, float] = defaultdict(float) + stream_message_counter: defaultdict[HashableStreamDescriptor, float] = defaultdict(float) for message in self.source.read(self.logger, config, catalog, state): yield self.handle_record_counts(message, stream_message_counter) for message in self._emit_queued_messages(self.source): @@ -248,7 +254,7 @@ def read( @staticmethod def handle_record_counts( - message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float] + message: AirbyteMessage, stream_message_count: defaultdict[HashableStreamDescriptor, float] ) -> AirbyteMessage: match message.type: case Type.RECORD: @@ -306,21 +312,21 @@ def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str: return json.dumps(serialized_message) @classmethod - def extract_state(cls, args: List[str]) -> Optional[Any]: + def extract_state(cls, args: list[str]) -> Any | None: # noqa: ANN401 parsed_args = cls.parse_args(args) if hasattr(parsed_args, "state"): return parsed_args.state return None @classmethod - def extract_catalog(cls, args: List[str]) -> Optional[Any]: + def extract_catalog(cls, args: list[str]) -> Any | None: # noqa: ANN401 parsed_args = cls.parse_args(args) if hasattr(parsed_args, "catalog"): return parsed_args.catalog return None @classmethod - def extract_config(cls, args: List[str]) -> Optional[Any]: + def extract_config(cls, args: list[str]) -> Any | None: # noqa: ANN401 parsed_args = cls.parse_args(args) if hasattr(parsed_args, "config"): return parsed_args.config @@ -332,7 +338,7 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]: return -def launch(source: Source, args: List[str]) -> None: +def launch(source: Source, args: list[str]) -> None: source_entrypoint = AirbyteEntrypoint(source) parsed_args = source_entrypoint.parse_args(args) # temporarily removes the PrintBuffer because we're seeing weird print behavior for concurrent syncs @@ -351,12 +357,12 @@ def _init_internal_request_filter() -> None: wrapped_fn = Session.send @wraps(wrapped_fn) - def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Response: + def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Response: # noqa: ANN401 parsed_url = urlparse(request.url) if parsed_url.scheme not in VALID_URL_SCHEMES: raise requests.exceptions.InvalidSchema( - "Invalid Protocol Scheme: The endpoint that data is being requested from is using an invalid or insecure " + "Invalid Protocol Scheme: The endpoint that data is being requested from is using an invalid or insecure " # noqa: ISC003 + f"protocol {parsed_url.scheme!r}. Valid protocol schemes: {','.join(VALID_URL_SCHEMES)}" ) @@ -377,7 +383,7 @@ def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Respons # This is a special case where the developer specifies an IP address string that is not formatted correctly like trailing # whitespace which will fail the socket IP lookup. This only happens when using IP addresses and not text hostnames. # Knowing that this is a request using the requests library, we will mock the exception without calling the lib - raise requests.exceptions.InvalidURL(f"Invalid URL {parsed_url}: {exception}") + raise requests.exceptions.InvalidURL(f"Invalid URL {parsed_url}: {exception}") # noqa: B904 return wrapped_fn(self, request, **kwargs) @@ -409,6 +415,6 @@ def main() -> None: source = impl() if not isinstance(source, Source): - raise Exception("Source implementation provided does not implement Source class!") + raise Exception("Source implementation provided does not implement Source class!") # noqa: TRY002, TRY004 launch(source, sys.argv[1:]) diff --git a/airbyte_cdk/exception_handler.py b/airbyte_cdk/exception_handler.py index 84aa39ba1..78170375b 100644 --- a/airbyte_cdk/exception_handler.py +++ b/airbyte_cdk/exception_handler.py @@ -4,8 +4,9 @@ import logging import sys +from collections.abc import Mapping from types import TracebackType -from typing import Any, List, Mapping, Optional +from typing import Any from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -28,8 +29,8 @@ def init_uncaught_exception_handler(logger: logging.Logger) -> None: def hook_fn( exception_type: type[BaseException], exception_value: BaseException, - traceback_: Optional[TracebackType], - ) -> Any: + traceback_: TracebackType | None, + ) -> Any: # noqa: ANN401 # For developer ergonomics, we want to see the stack trace in the logs when we do a ctrl-c if issubclass(exception_type, KeyboardInterrupt): sys.__excepthook__(exception_type, exception_value, traceback_) @@ -45,12 +46,10 @@ def hook_fn( sys.excepthook = hook_fn -def generate_failed_streams_error_message(stream_failures: Mapping[str, List[Exception]]) -> str: - failures = "\n".join( - [ - f"{stream}: {filter_secrets(exception.__repr__())}" - for stream, exceptions in stream_failures.items() - for exception in exceptions - ] - ) +def generate_failed_streams_error_message(stream_failures: Mapping[str, list[Exception]]) -> str: + failures = "\n".join([ + f"{stream}: {filter_secrets(exception.__repr__())}" # noqa: PLC2801 + for stream, exceptions in stream_failures.items() + for exception in exceptions + ]) return f"During the sync, the following streams did not sync successfully: {failures}" diff --git a/airbyte_cdk/logger.py b/airbyte_cdk/logger.py index 78061b605..75f29c7ae 100644 --- a/airbyte_cdk/logger.py +++ b/airbyte_cdk/logger.py @@ -5,7 +5,8 @@ import json import logging import logging.config -from typing import Any, Callable, Mapping, Optional, Tuple +from collections.abc import Callable, Mapping +from typing import Any import orjson @@ -18,6 +19,7 @@ ) from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets + LOGGING_CONFIG = { "version": 1, "disable_existing_loggers": False, @@ -37,7 +39,7 @@ } -def init_logger(name: Optional[str] = None) -> logging.Logger: +def init_logger(name: str | None = None) -> logging.Logger: """Initial set up of logger""" logger = logging.getLogger(name) logger.setLevel(logging.INFO) @@ -57,7 +59,7 @@ class AirbyteLogFormatter(logging.Formatter): """Output log records using AirbyteMessage""" # Transforming Python log levels to Airbyte protocol log levels - level_mapping = { + level_mapping = { # noqa: RUF012 logging.FATAL: Level.FATAL, logging.ERROR: Level.ERROR, logging.WARNING: Level.WARN, @@ -72,13 +74,12 @@ def format(self, record: logging.LogRecord) -> str: extras = self.extract_extra_args_from_record(record) debug_dict = {"type": "DEBUG", "message": record.getMessage(), "data": extras} return filter_secrets(json.dumps(debug_dict)) - else: - message = super().format(record) - message = filter_secrets(message) - log_message = AirbyteMessage( - type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message) - ) - return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() + message = super().format(record) + message = filter_secrets(message) + log_message = AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message) + ) + return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() @staticmethod def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]: @@ -91,7 +92,7 @@ def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, An return {k: str(getattr(record, k)) for k in extra_keys if hasattr(record, k)} -def log_by_prefix(msg: str, default_level: str) -> Tuple[int, str]: +def log_by_prefix(msg: str, default_level: str) -> tuple[int, str]: """Custom method, which takes log level from first word of message""" valid_log_types = ["FATAL", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"] split_line = msg.split() diff --git a/airbyte_cdk/models/__init__.py b/airbyte_cdk/models/__init__.py index 3fa24be49..1d4fc1f38 100644 --- a/airbyte_cdk/models/__init__.py +++ b/airbyte_cdk/models/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +# # The earlier versions of airbyte-cdk (0.28.0<=) had the airbyte_protocol python classes # declared inline in the airbyte-cdk code. However, somewhere around Feb 2023 the # Airbyte Protocol moved to its own repo/PyPi package, called airbyte-protocol-models. @@ -6,66 +8,66 @@ # to make the airbyte_protocol python classes available to the airbyte-cdk consumer as part # of airbyte-cdk rather than a standalone package. from .airbyte_protocol import ( - AdvancedAuth, - AirbyteAnalyticsTraceMessage, - AirbyteCatalog, - AirbyteConnectionStatus, - AirbyteControlConnectorConfigMessage, - AirbyteControlMessage, - AirbyteErrorTraceMessage, - AirbyteEstimateTraceMessage, - AirbyteGlobalState, - AirbyteLogMessage, - AirbyteMessage, - AirbyteProtocol, - AirbyteRecordMessage, - AirbyteStateBlob, - AirbyteStateMessage, - AirbyteStateStats, - AirbyteStateType, - AirbyteStream, - AirbyteStreamState, - AirbyteStreamStatus, - AirbyteStreamStatusReason, - AirbyteStreamStatusReasonType, - AirbyteStreamStatusTraceMessage, - AirbyteTraceMessage, - AuthFlowType, - ConfiguredAirbyteCatalog, - ConfiguredAirbyteStream, - ConnectorSpecification, - DestinationSyncMode, - EstimateType, - FailureType, - Level, - OAuthConfigSpecification, - OauthConnectorInputSpecification, - OrchestratorType, - State, - Status, - StreamDescriptor, - SyncMode, - TraceType, - Type, + AdvancedAuth, # noqa: F401 + AirbyteAnalyticsTraceMessage, # noqa: F401 + AirbyteCatalog, # noqa: F401 + AirbyteConnectionStatus, # noqa: F401 + AirbyteControlConnectorConfigMessage, # noqa: F401 + AirbyteControlMessage, # noqa: F401 + AirbyteErrorTraceMessage, # noqa: F401 + AirbyteEstimateTraceMessage, # noqa: F401 + AirbyteGlobalState, # noqa: F401 + AirbyteLogMessage, # noqa: F401 + AirbyteMessage, # noqa: F401 + AirbyteProtocol, # noqa: F401 + AirbyteRecordMessage, # noqa: F401 + AirbyteStateBlob, # noqa: F401 + AirbyteStateMessage, # noqa: F401 + AirbyteStateStats, # noqa: F401 + AirbyteStateType, # noqa: F401 + AirbyteStream, # noqa: F401 + AirbyteStreamState, # noqa: F401 + AirbyteStreamStatus, # noqa: F401 + AirbyteStreamStatusReason, # noqa: F401 + AirbyteStreamStatusReasonType, # noqa: F401 + AirbyteStreamStatusTraceMessage, # noqa: F401 + AirbyteTraceMessage, # noqa: F401 + AuthFlowType, # noqa: F401 + ConfiguredAirbyteCatalog, # noqa: F401 + ConfiguredAirbyteStream, # noqa: F401 + ConnectorSpecification, # noqa: F401 + DestinationSyncMode, # noqa: F401 + EstimateType, # noqa: F401 + FailureType, # noqa: F401 + Level, # noqa: F401 + OAuthConfigSpecification, # noqa: F401 + OauthConnectorInputSpecification, # noqa: F401 + OrchestratorType, # noqa: F401 + State, # noqa: F401 + Status, # noqa: F401 + StreamDescriptor, # noqa: F401 + SyncMode, # noqa: F401 + TraceType, # noqa: F401 + Type, # noqa: F401 ) from .airbyte_protocol_serializers import ( - AirbyteMessageSerializer, - AirbyteStateMessageSerializer, - AirbyteStreamStateSerializer, - ConfiguredAirbyteCatalogSerializer, - ConfiguredAirbyteStreamSerializer, - ConnectorSpecificationSerializer, + AirbyteMessageSerializer, # noqa: F401 + AirbyteStateMessageSerializer, # noqa: F401 + AirbyteStreamStateSerializer, # noqa: F401 + ConfiguredAirbyteCatalogSerializer, # noqa: F401 + ConfiguredAirbyteStreamSerializer, # noqa: F401 + ConnectorSpecificationSerializer, # noqa: F401 ) from .well_known_types import ( - BinaryData, - Boolean, - Date, - Integer, - Model, - Number, - String, - TimestampWithoutTimezone, - TimestampWithTimezone, - TimeWithoutTimezone, - TimeWithTimezone, + BinaryData, # noqa: F401 + Boolean, # noqa: F401 + Date, # noqa: F401 + Integer, # noqa: F401 + Model, # noqa: F401 + Number, # noqa: F401 + String, # noqa: F401 + TimestampWithoutTimezone, # noqa: F401 + TimestampWithTimezone, # noqa: F401 + TimeWithoutTimezone, # noqa: F401 + TimeWithTimezone, # noqa: F401 ) diff --git a/airbyte_cdk/models/airbyte_protocol.py b/airbyte_cdk/models/airbyte_protocol.py index 2528f7d7e..592cd138b 100644 --- a/airbyte_cdk/models/airbyte_protocol.py +++ b/airbyte_cdk/models/airbyte_protocol.py @@ -2,19 +2,22 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Annotated, Any, Dict, List, Mapping, Optional, Union +from typing import Annotated, Any -from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*' from serpyco_rs.metadata import Alias +from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*' + from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage + # ruff: noqa: F405 # ignore fuzzy import issues with 'import *' @dataclass -class AirbyteStateBlob: +class AirbyteStateBlob: # noqa: PLW1641 """ A dataclass that dynamically sets attributes based on provided keyword arguments and positional arguments. Used to "mimic" pydantic Basemodel with ConfigDict(extra='allow') option. @@ -37,7 +40,7 @@ class AirbyteStateBlob: kwargs: InitVar[Mapping[str, Any]] - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 # Set any attribute passed in through kwargs for arg in args: self.__dict__.update(arg) @@ -56,35 +59,35 @@ def __eq__(self, other: object) -> bool: @dataclass class AirbyteStreamState: stream_descriptor: StreamDescriptor # type: ignore [name-defined] - stream_state: Optional[AirbyteStateBlob] = None + stream_state: AirbyteStateBlob | None = None @dataclass class AirbyteGlobalState: - stream_states: List[AirbyteStreamState] - shared_state: Optional[AirbyteStateBlob] = None + stream_states: list[AirbyteStreamState] + shared_state: AirbyteStateBlob | None = None @dataclass class AirbyteStateMessage: - type: Optional[AirbyteStateType] = None # type: ignore [name-defined] - stream: Optional[AirbyteStreamState] = None + type: AirbyteStateType | None = None # type: ignore [name-defined] + stream: AirbyteStreamState | None = None global_: Annotated[AirbyteGlobalState | None, Alias("global")] = ( None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization ) - data: Optional[Dict[str, Any]] = None - sourceStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined] - destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined] + data: dict[str, Any] | None = None + sourceStats: AirbyteStateStats | None = None # type: ignore [name-defined] # noqa: N815 + destinationStats: AirbyteStateStats | None = None # type: ignore [name-defined] # noqa: N815 @dataclass class AirbyteMessage: type: Type # type: ignore [name-defined] - log: Optional[AirbyteLogMessage] = None # type: ignore [name-defined] - spec: Optional[ConnectorSpecification] = None # type: ignore [name-defined] - connectionStatus: Optional[AirbyteConnectionStatus] = None # type: ignore [name-defined] - catalog: Optional[AirbyteCatalog] = None # type: ignore [name-defined] - record: Optional[Union[AirbyteFileTransferRecordMessage, AirbyteRecordMessage]] = None # type: ignore [name-defined] - state: Optional[AirbyteStateMessage] = None - trace: Optional[AirbyteTraceMessage] = None # type: ignore [name-defined] - control: Optional[AirbyteControlMessage] = None # type: ignore [name-defined] + log: AirbyteLogMessage | None = None # type: ignore [name-defined] + spec: ConnectorSpecification | None = None # type: ignore [name-defined] + connectionStatus: AirbyteConnectionStatus | None = None # type: ignore [name-defined] # noqa: N815 + catalog: AirbyteCatalog | None = None # type: ignore [name-defined] + record: AirbyteFileTransferRecordMessage | AirbyteRecordMessage | None = None # type: ignore [name-defined] + state: AirbyteStateMessage | None = None + trace: AirbyteTraceMessage | None = None # type: ignore [name-defined] + control: AirbyteControlMessage | None = None # type: ignore [name-defined] diff --git a/airbyte_cdk/models/airbyte_protocol_serializers.py b/airbyte_cdk/models/airbyte_protocol_serializers.py index 129556acc..6ce15d130 100644 --- a/airbyte_cdk/models/airbyte_protocol_serializers.py +++ b/airbyte_cdk/models/airbyte_protocol_serializers.py @@ -1,5 +1,5 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, Dict +from typing import Any from serpyco_rs import CustomType, Serializer @@ -14,19 +14,19 @@ ) -class AirbyteStateBlobType(CustomType[AirbyteStateBlob, Dict[str, Any]]): - def serialize(self, value: AirbyteStateBlob) -> Dict[str, Any]: +class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]): + def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]: # cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete" return {k: v for k, v in value.__dict__.items()} - def deserialize(self, value: Dict[str, Any]) -> AirbyteStateBlob: + def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob: return AirbyteStateBlob(value) - def get_json_schema(self) -> Dict[str, Any]: + def get_json_schema(self) -> dict[str, Any]: return {"type": "object"} -def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, Dict[str, Any]] | None: +def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None: return AirbyteStateBlobType() if t is AirbyteStateBlob else None diff --git a/airbyte_cdk/models/file_transfer_record_message.py b/airbyte_cdk/models/file_transfer_record_message.py index dcc1b7a92..8bde8b408 100644 --- a/airbyte_cdk/models/file_transfer_record_message.py +++ b/airbyte_cdk/models/file_transfer_record_message.py @@ -1,13 +1,13 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any @dataclass class AirbyteFileTransferRecordMessage: stream: str - file: Dict[str, Any] + file: dict[str, Any] emitted_at: int - namespace: Optional[str] = None - data: Optional[Dict[str, Any]] = None + namespace: str | None = None + data: dict[str, Any] | None = None diff --git a/airbyte_cdk/sources/__init__.py b/airbyte_cdk/sources/__init__.py index a6560a503..162ac01f2 100644 --- a/airbyte_cdk/sources/__init__.py +++ b/airbyte_cdk/sources/__init__.py @@ -8,6 +8,7 @@ from .config import BaseConfig from .source import Source + # As part of the CDK sources, we do not control what the APIs return and it is possible that a key is empty. # Reasons why we are doing this at the airbyte_cdk level: # * As of today, all the use cases should allow for empty keys diff --git a/airbyte_cdk/sources/abstract_source.py b/airbyte_cdk/sources/abstract_source.py index ab9ee48b8..5eef75933 100644 --- a/airbyte_cdk/sources/abstract_source.py +++ b/airbyte_cdk/sources/abstract_source.py @@ -4,17 +4,9 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator, Mapping, MutableMapping from typing import ( Any, - Dict, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - Optional, - Tuple, - Union, ) from airbyte_cdk.exception_handler import generate_failed_streams_error_message @@ -46,6 +38,7 @@ ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException + _default_message_repository = InMemoryMessageRepository() @@ -58,7 +51,7 @@ class AbstractSource(Source, ABC): @abstractmethod def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: """ :param logger: source logger :param config: The user-provided configuration as specified by the source's spec. @@ -71,7 +64,7 @@ def check_connection( """ @abstractmethod - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: """ :param config: The user-provided configuration as specified by the source's spec. Any stream construction related operation should happen here. @@ -79,10 +72,10 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: """ # Stream name to instance map for applying output object transformation - _stream_to_instance_map: Dict[str, Stream] = {} + _stream_to_instance_map: dict[str, Stream] = {} # noqa: RUF012 _slice_logger: SliceLogger = DebugSliceLogger() - def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: + def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: # noqa: ARG002 """Implements the Discover operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/#discover. """ @@ -98,17 +91,17 @@ def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCon return AirbyteConnectionStatus(status=Status.FAILED, message=repr(error)) return AirbyteConnectionStatus(status=Status.SUCCEEDED) - def read( + def read( # noqa: PLR0915 self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: """Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/.""" logger.info(f"Starting syncing {self.name}") config, internal_config = split_config(config) - # TODO assert all streams exist in the connector + # TODO assert all streams exist in the connector # noqa: TD004 # get the streams once in case the connector needs to make any queries to generate them stream_instances = {s.name: s for s in self.streams(config)} state_manager = ConnectorStateManager(state=state) @@ -137,7 +130,7 @@ def read( # Use configured_stream as stream_instance to support references in error handling. stream_instance = configured_stream.stream - raise AirbyteTracedException( + raise AirbyteTracedException( # noqa: TRY301 message="A stream listed in your configuration was not found in the source. Please check the logs for more " "details.", internal_message=error_message, @@ -198,9 +191,9 @@ def read( logger.info(timer.report()) if len(stream_name_to_exception) > 0: - error_message = generate_failed_streams_error_message( - {key: [value] for key, value in stream_name_to_exception.items()} - ) + error_message = generate_failed_streams_error_message({ + key: [value] for key, value in stream_name_to_exception.items() + }) logger.info(error_message) # We still raise at least one exception when a stream raises an exception because the platform currently relies # on a non-zero exit code to determine if a sync attempt has failed. We also raise the exception as a config_error @@ -212,7 +205,7 @@ def read( @staticmethod def _serialize_exception( - stream_descriptor: StreamDescriptor, e: Exception, stream_instance: Optional[Stream] = None + stream_descriptor: StreamDescriptor, e: Exception, stream_instance: Stream | None = None ) -> AirbyteTracedException: display_message = stream_instance.get_error_display_message(e) if stream_instance else None if display_message: @@ -294,7 +287,7 @@ def _emit_queued_messages(self) -> Iterable[AirbyteMessage]: return def _get_message( - self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream + self, record_data_or_message: StreamData | AirbyteMessage, stream: Stream ) -> AirbyteMessage: """ Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage @@ -311,7 +304,7 @@ def _get_message( ) @property - def message_repository(self) -> Union[None, MessageRepository]: + def message_repository(self) -> None | MessageRepository: # noqa: RUF036 return _default_message_repository @property diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index f57db7e14..688c6beda 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import logging -from typing import Dict, Iterable, List, Optional, Set +from collections.abc import Iterable from airbyte_cdk.exception_handler import generate_failed_streams_error_message from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus, FailureType, StreamDescriptor @@ -28,9 +28,9 @@ class ConcurrentReadProcessor: - def __init__( + def __init__( # noqa: ANN204 self, - stream_instances_to_read_from: List[AbstractStream], + stream_instances_to_read_from: list[AbstractStream], partition_enqueuer: PartitionEnqueuer, thread_pool_manager: ThreadPoolManager, logger: logging.Logger, @@ -50,20 +50,20 @@ def __init__( """ self._stream_name_to_instance = {s.name: s for s in stream_instances_to_read_from} self._record_counter = {} - self._streams_to_running_partitions: Dict[str, Set[Partition]] = {} + self._streams_to_running_partitions: dict[str, set[Partition]] = {} for stream in stream_instances_to_read_from: self._streams_to_running_partitions[stream.name] = set() self._record_counter[stream.name] = 0 self._thread_pool_manager = thread_pool_manager self._partition_enqueuer = partition_enqueuer self._stream_instances_to_start_partition_generation = stream_instances_to_read_from - self._streams_currently_generating_partitions: List[str] = [] + self._streams_currently_generating_partitions: list[str] = [] self._logger = logger self._slice_logger = slice_logger self._message_repository = message_repository self._partition_reader = partition_reader - self._streams_done: Set[str] = set() - self._exceptions_per_stream_name: dict[str, List[Exception]] = {} + self._streams_done: set[str] = set() + self._exceptions_per_stream_name: dict[str, list[Exception]] = {} def on_partition_generation_completed( self, sentinel: PartitionGenerationCompletedSentinel @@ -186,7 +186,7 @@ def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMess def _flag_exception(self, stream_name: str, exception: Exception) -> None: self._exceptions_per_stream_name.setdefault(stream_name, []).append(exception) - def start_next_partition_generator(self) -> Optional[AirbyteMessage]: + def start_next_partition_generator(self) -> AirbyteMessage | None: """ Start the next partition generator. 1. Pop the next stream to read from @@ -204,8 +204,7 @@ def start_next_partition_generator(self) -> Optional[AirbyteMessage]: stream.as_airbyte_stream(), AirbyteStreamStatus.STARTED, ) - else: - return None + return None def is_done(self) -> bool: """ @@ -215,12 +214,10 @@ def is_done(self) -> bool: 2. There are no more streams to read from 3. All partitions for all streams are closed """ - is_done = all( - [ - self._is_stream_done(stream_name) - for stream_name in self._stream_name_to_instance.keys() - ] - ) + is_done = all([ # noqa: C419 + self._is_stream_done(stream_name) + for stream_name in self._stream_name_to_instance.keys() # noqa: SIM118 + ]) if is_done and self._exceptions_per_stream_name: error_message = generate_failed_streams_error_message(self._exceptions_per_stream_name) self._logger.info(error_message) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index bc7d97cdd..59f847cb9 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -3,8 +3,8 @@ # import concurrent import logging +from collections.abc import Iterable, Iterator from queue import Queue -from typing import Iterable, Iterator, List from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor @@ -49,9 +49,9 @@ def create( too_many_generator = ( not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers ) - assert ( - not too_many_generator - ), "It is required to have more workers than threads generating partitions" + assert not too_many_generator, ( + "It is required to have more workers than threads generating partitions" + ) threadpool = ThreadPoolManager( concurrent.futures.ThreadPoolExecutor( max_workers=num_workers, thread_name_prefix="workerpool" @@ -71,8 +71,8 @@ def __init__( self, threadpool: ThreadPoolManager, logger: logging.Logger, - slice_logger: SliceLogger = DebugSliceLogger(), - message_repository: MessageRepository = InMemoryMessageRepository(), + slice_logger: SliceLogger = DebugSliceLogger(), # noqa: B008 + message_repository: MessageRepository = InMemoryMessageRepository(), # noqa: B008 initial_number_partitions_to_generate: int = 1, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> None: @@ -93,7 +93,7 @@ def __init__( def read( self, - streams: List[AbstractStream], + streams: list[AbstractStream], ) -> Iterator[AirbyteMessage]: self._logger.info("Starting syncing") @@ -162,4 +162,4 @@ def _handle_item( elif isinstance(queue_item, Record): yield from concurrent_stream_processor.on_record(queue_item) else: - raise ValueError(f"Unknown queue item type: {type(queue_item)}") + raise ValueError(f"Unknown queue item type: {type(queue_item)}") # noqa: TRY004 diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py index c150dc956..91b7eb58d 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py @@ -4,8 +4,9 @@ import logging from abc import ABC +from collections.abc import Callable, Iterator, Mapping, MutableMapping from datetime import timedelta -from typing import Any, Callable, Iterator, List, Mapping, MutableMapping, Optional, Tuple +from typing import Any from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog from airbyte_cdk.sources import AbstractSource @@ -27,11 +28,12 @@ AbstractStreamStateConverter, ) + DEFAULT_LOOKBACK_SECONDS = 0 class ConcurrentSourceAdapter(AbstractSource, ABC): - def __init__(self, concurrent_source: ConcurrentSource, **kwargs: Any) -> None: + def __init__(self, concurrent_source: ConcurrentSource, **kwargs: Any) -> None: # noqa: ANN401 """ ConcurrentSourceAdapter is a Source that wraps a concurrent source and exposes it as a regular source. @@ -47,7 +49,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: abstract_streams = self._select_abstract_streams(config, catalog) concurrent_stream_names = {stream.name for stream in abstract_streams} @@ -65,13 +67,13 @@ def read( def _select_abstract_streams( self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog - ) -> List[AbstractStream]: + ) -> list[AbstractStream]: """ Selects streams that can be processed concurrently and returns their abstract representations. """ all_streams = self.streams(config) stream_name_to_instance: Mapping[str, Stream] = {s.name: s for s in all_streams} - abstract_streams: List[AbstractStream] = [] + abstract_streams: list[AbstractStream] = [] for configured_stream in configured_catalog.streams: stream_instance = stream_name_to_instance.get(configured_stream.stream.name) if not stream_instance: @@ -86,7 +88,7 @@ def convert_to_concurrent_stream( logger: logging.Logger, stream: Stream, state_manager: ConnectorStateManager, - cursor: Optional[Cursor] = None, + cursor: Cursor | None = None, ) -> Stream: """ Prepares a stream for concurrent processing by initializing or assigning a cursor, @@ -113,12 +115,12 @@ def initialize_cursor( stream: Stream, state_manager: ConnectorStateManager, converter: AbstractStreamStateConverter, - slice_boundary_fields: Optional[Tuple[str, str]], - start: Optional[CursorValueType], + slice_boundary_fields: tuple[str, str] | None, + start: CursorValueType | None, end_provider: Callable[[], CursorValueType], - lookback_window: Optional[GapType] = None, - slice_range: Optional[GapType] = None, - ) -> Optional[ConcurrentCursor]: + lookback_window: GapType | None = None, + slice_range: GapType | None = None, + ) -> ConcurrentCursor | None: lookback_window = lookback_window or timedelta(seconds=DEFAULT_LOOKBACK_SECONDS) cursor_field_name = stream.cursor_field diff --git a/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py b/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py index b6643042b..4cdbd221f 100644 --- a/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py +++ b/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py @@ -1,24 +1,23 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream -class PartitionGenerationCompletedSentinel: +class PartitionGenerationCompletedSentinel: # noqa: PLW1641 """ A sentinel object indicating all partitions for a stream were produced. Includes a pointer to the stream that was processed. """ - def __init__(self, stream: AbstractStream): + def __init__(self, stream: AbstractStream): # noqa: ANN204 """ :param stream: The stream that was processed """ self.stream = stream - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, PartitionGenerationCompletedSentinel): return self.stream == other.stream return False diff --git a/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py b/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py index c865bef59..da034575d 100644 --- a/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py +++ b/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py @@ -1,10 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any - -class StreamThreadException(Exception): - def __init__(self, exception: Exception, stream_name: str): +class StreamThreadException(Exception): # noqa: PLW1641 + def __init__(self, exception: Exception, stream_name: str): # noqa: ANN204 self._exception = exception self._stream_name = stream_name @@ -19,7 +17,7 @@ def exception(self) -> Exception: def __str__(self) -> str: return f"Exception while syncing stream {self._stream_name}: {self._exception}" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, StreamThreadException): return self._exception == other._exception and self._stream_name == other._stream_name return False diff --git a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py index 59f8a1f0b..556d0117a 100644 --- a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py +++ b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py @@ -3,8 +3,9 @@ # import logging import threading +from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, List, Optional +from typing import Any class ThreadPoolManager: @@ -14,7 +15,7 @@ class ThreadPoolManager: DEFAULT_MAX_QUEUE_SIZE = 10_000 - def __init__( + def __init__( # noqa: ANN204 self, threadpool: ThreadPoolExecutor, logger: logging.Logger, @@ -28,9 +29,9 @@ def __init__( self._threadpool = threadpool self._logger = logger self._max_concurrent_tasks = max_concurrent_tasks - self._futures: List[Future[Any]] = [] + self._futures: list[Future[Any]] = [] self._lock = threading.Lock() - self._most_recently_seen_exception: Optional[Exception] = None + self._most_recently_seen_exception: Exception | None = None self._logging_threshold = max_concurrent_tasks * 2 @@ -42,10 +43,10 @@ def prune_to_validate_has_reached_futures_limit(self) -> bool: ) return len(self._futures) >= self._max_concurrent_tasks - def submit(self, function: Callable[..., Any], *args: Any) -> None: + def submit(self, function: Callable[..., Any], *args: Any) -> None: # noqa: ANN401 self._futures.append(self._threadpool.submit(function, *args)) - def _prune_futures(self, futures: List[Future[Any]]) -> None: + def _prune_futures(self, futures: list[Future[Any]]) -> None: """ Take a list in input and remove the futures that are completed. If a future has an exception, it'll raise and kill the stream operation. @@ -79,7 +80,7 @@ def _shutdown(self) -> None: self._threadpool.shutdown(wait=False, cancel_futures=True) def is_done(self) -> bool: - return all([f.done() for f in self._futures]) + return all([f.done() for f in self._futures]) # noqa: C419 def check_for_errors_and_shutdown(self) -> None: """ diff --git a/airbyte_cdk/sources/config.py b/airbyte_cdk/sources/config.py index ea91b17f3..d0ce9c896 100644 --- a/airbyte_cdk/sources/config.py +++ b/airbyte_cdk/sources/config.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Dict +from typing import Any from pydantic.v1 import BaseModel @@ -18,7 +18,7 @@ class BaseConfig(BaseModel): """ @classmethod - def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401 """We're overriding the schema classmethod to enable some post-processing""" schema = super().schema(*args, **kwargs) rename_key(schema, old_key="anyOf", new_key="oneOf") # UI supports only oneOf diff --git a/airbyte_cdk/sources/connector_state_manager.py b/airbyte_cdk/sources/connector_state_manager.py index 914374a55..50fb16542 100644 --- a/airbyte_cdk/sources/connector_state_manager.py +++ b/airbyte_cdk/sources/connector_state_manager.py @@ -3,8 +3,9 @@ # import copy +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass -from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union, cast +from typing import Any, cast from airbyte_cdk.models import ( AirbyteMessage, @@ -15,7 +16,7 @@ StreamDescriptor, ) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.models.airbyte_protocol import AirbyteGlobalState, AirbyteStateBlob +from airbyte_cdk.models.airbyte_protocol import AirbyteGlobalState, AirbyteStateBlob # noqa: F811 @dataclass(frozen=True) @@ -26,7 +27,7 @@ class HashableStreamDescriptor: """ name: str - namespace: Optional[str] = None + namespace: str | None = None class ConnectorStateManager: @@ -35,7 +36,7 @@ class ConnectorStateManager: interface. It also provides methods to extract and update state """ - def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): + def __init__(self, state: list[AirbyteStateMessage] | None = None): # noqa: ANN204 shared_state, per_stream_states = self._extract_from_state_message(state) # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are @@ -50,9 +51,7 @@ def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): ) self.per_stream_states = per_stream_states - def get_stream_state( - self, stream_name: str, namespace: Optional[str] - ) -> MutableMapping[str, Any]: + def get_stream_state(self, stream_name: str, namespace: str | None) -> MutableMapping[str, Any]: """ Retrieves the state of a given stream based on its descriptor (name + namespace). :param stream_name: Name of the stream being fetched @@ -67,7 +66,7 @@ def get_stream_state( return {} def update_state_for_stream( - self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any] + self, stream_name: str, namespace: str | None, value: Mapping[str, Any] ) -> None: """ Overwrites the state blob of a specific stream based on the provided stream name and optional namespace @@ -78,7 +77,7 @@ def update_state_for_stream( stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value) - def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage: + def create_state_message(self, stream_name: str, namespace: str | None) -> AirbyteMessage: """ Generates an AirbyteMessage using the current per-stream state of a specified stream :param stream_name: The name of the stream for the message that is being created @@ -102,10 +101,10 @@ def create_state_message(self, stream_name: str, namespace: Optional[str]) -> Ai @classmethod def _extract_from_state_message( cls, - state: Optional[List[AirbyteStateMessage]], - ) -> Tuple[ - Optional[AirbyteStateBlob], - MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]], + state: list[AirbyteStateMessage] | None, + ) -> tuple[ + AirbyteStateBlob | None, + MutableMapping[HashableStreamDescriptor, AirbyteStateBlob | None], ]: """ Takes an incoming list of state messages or a global state message and extracts state attributes according to @@ -120,10 +119,11 @@ def _extract_from_state_message( if is_global: # We already validate that this is a global state message, not None: - global_state = cast(AirbyteGlobalState, state[0].global_) + global_state = cast(AirbyteGlobalState, state[0].global_) # noqa: TC006 # global_state has shared_state, also not None: shared_state: AirbyteStateBlob = cast( - AirbyteStateBlob, copy.deepcopy(global_state.shared_state, {}) + "AirbyteStateBlob", + copy.deepcopy(global_state.shared_state, {}), ) streams = { HashableStreamDescriptor( @@ -133,22 +133,21 @@ def _extract_from_state_message( for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state } return shared_state, streams - else: - streams = { - HashableStreamDescriptor( - name=per_stream_state.stream.stream_descriptor.name, # type: ignore[union-attr] # stream has stream_descriptor - namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor - ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state - for per_stream_state in state - if per_stream_state.type == AirbyteStateType.STREAM - and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True - } - return None, streams + streams = { + HashableStreamDescriptor( + name=per_stream_state.stream.stream_descriptor.name, # type: ignore[union-attr] # stream has stream_descriptor + namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor + ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state + for per_stream_state in state + if per_stream_state.type == AirbyteStateType.STREAM + and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True + } + return None, streams @staticmethod - def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool: + def _is_global_state(state: list[AirbyteStateMessage] | MutableMapping[str, Any]) -> bool: return ( - isinstance(state, List) + isinstance(state, list) and len(state) == 1 and isinstance(state[0], AirbyteStateMessage) and state[0].type == AirbyteStateType.GLOBAL @@ -156,6 +155,6 @@ def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, @staticmethod def _is_per_stream_state( - state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], + state: list[AirbyteStateMessage] | MutableMapping[str, Any], ) -> bool: - return isinstance(state, List) + return isinstance(state, list) diff --git a/airbyte_cdk/sources/declarative/async_job/job.py b/airbyte_cdk/sources/declarative/async_job/job.py index b075b61e2..7283a34e4 100644 --- a/airbyte_cdk/sources/declarative/async_job/job.py +++ b/airbyte_cdk/sources/declarative/async_job/job.py @@ -2,13 +2,11 @@ from datetime import timedelta -from typing import Optional +from .status import AsyncJobStatus from airbyte_cdk.sources.declarative.async_job.timer import Timer from airbyte_cdk.sources.types import StreamSlice -from .status import AsyncJobStatus - class AsyncJob: """ @@ -19,13 +17,13 @@ class AsyncJob: """ def __init__( - self, api_job_id: str, job_parameters: StreamSlice, timeout: Optional[timedelta] = None + self, api_job_id: str, job_parameters: StreamSlice, timeout: timedelta | None = None ) -> None: self._api_job_id = api_job_id self._job_parameters = job_parameters self._status = AsyncJobStatus.RUNNING - timeout = timeout if timeout else timedelta(minutes=60) + timeout = timeout if timeout else timedelta(minutes=60) # noqa: FURB110 self._timer = Timer(timeout) self._timer.start() diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index 3938b8c07..bb9232eff 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -5,18 +5,11 @@ import time import traceback import uuid +from collections.abc import Generator, Iterable, Mapping from datetime import timedelta from typing import ( Any, - Generator, Generic, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - Type, TypeVar, ) @@ -34,6 +27,7 @@ from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException + LOGGER = logging.getLogger("airbyte") _NO_TIMEOUT = timedelta.max _API_SIDE_RUNNING_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} @@ -46,30 +40,30 @@ class AsyncPartition: _MAX_NUMBER_OF_ATTEMPTS = 3 - def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None: - self._attempts_per_job = {job: 1 for job in jobs} + def __init__(self, jobs: list[AsyncJob], stream_slice: StreamSlice) -> None: + self._attempts_per_job = {job: 1 for job in jobs} # noqa: C420 self._stream_slice = stream_slice def has_reached_max_attempt(self) -> bool: return any( - map( + map( # noqa: C417 lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS, self._attempts_per_job.values(), ) ) - def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> None: + def replace_job(self, job_to_replace: AsyncJob, new_jobs: list[AsyncJob]) -> None: current_attempt_count = self._attempts_per_job.pop(job_to_replace, None) if current_attempt_count is None: raise ValueError("Could not find job to replace") - elif current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS: + if current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS: raise ValueError(f"Max attempt reached for job in partition {self._stream_slice}") new_attempt_count = current_attempt_count + 1 for job in new_jobs: self._attempts_per_job[job] = new_attempt_count - def should_split(self, job: AsyncJob) -> bool: + def should_split(self, job: AsyncJob) -> bool: # noqa: ARG002 """ Not used right now but once we support job split, we should split based on the number of attempts """ @@ -88,20 +82,19 @@ def status(self) -> AsyncJobStatus: """ Given different job statuses, the priority is: FAILED, TIMED_OUT, RUNNING. Else, it means everything is completed. """ - statuses = set(map(lambda job: job.status(), self.jobs)) + statuses = set(map(lambda job: job.status(), self.jobs)) # noqa: C417 if statuses == {AsyncJobStatus.COMPLETED}: return AsyncJobStatus.COMPLETED - elif AsyncJobStatus.FAILED in statuses: + if AsyncJobStatus.FAILED in statuses: return AsyncJobStatus.FAILED - elif AsyncJobStatus.TIMED_OUT in statuses: + if AsyncJobStatus.TIMED_OUT in statuses: return AsyncJobStatus.TIMED_OUT - else: - return AsyncJobStatus.RUNNING + return AsyncJobStatus.RUNNING def __repr__(self) -> str: return f"AsyncPartition(stream_slice={self._stream_slice}, attempt_per_job={self._attempts_per_job})" - def __json_serializable__(self) -> Any: + def __json_serializable__(self) -> Any: # noqa: ANN401, PLW3201 return self._stream_slice @@ -111,7 +104,7 @@ def __json_serializable__(self) -> Any: class LookaheadIterator(Generic[T]): def __init__(self, iterable: Iterable[T]) -> None: self._iterator = iter(iterable) - self._buffer: List[T] = [] + self._buffer: list[T] = [] def __iter__(self) -> "LookaheadIterator[T]": return self @@ -119,8 +112,7 @@ def __iter__(self) -> "LookaheadIterator[T]": def __next__(self) -> T: if self._buffer: return self._buffer.pop() - else: - return next(self._iterator) + return next(self._iterator) def has_next(self) -> bool: if self._buffer: @@ -134,18 +126,18 @@ def has_next(self) -> bool: return True def add_at_the_beginning(self, item: T) -> None: - self._buffer = [item] + self._buffer + self._buffer = [item] + self._buffer # noqa: RUF005 class AsyncJobOrchestrator: _WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS = 5 - _KNOWN_JOB_STATUSES = { + _KNOWN_JOB_STATUSES = { # noqa: RUF012 AsyncJobStatus.COMPLETED, AsyncJobStatus.FAILED, AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT, } - _RUNNING_ON_API_SIDE_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} + _RUNNING_ON_API_SIDE_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} # noqa: RUF012 def __init__( self, @@ -153,8 +145,8 @@ def __init__( slices: Iterable[StreamSlice], job_tracker: JobTracker, message_repository: MessageRepository, - exceptions_to_break_on: Iterable[Type[Exception]] = tuple(), - has_bulk_parent: bool = False, + exceptions_to_break_on: Iterable[type[Exception]] = tuple(), # noqa: C408 + has_bulk_parent: bool = False, # noqa: FBT001, FBT002 ) -> None: """ If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent` @@ -170,13 +162,13 @@ def __init__( self._job_repository: AsyncJobRepository = job_repository self._slice_iterator = LookaheadIterator(slices) - self._running_partitions: List[AsyncPartition] = [] + self._running_partitions: list[AsyncPartition] = [] self._job_tracker = job_tracker self._message_repository = message_repository - self._exceptions_to_break_on: Tuple[Type[Exception], ...] = tuple(exceptions_to_break_on) + self._exceptions_to_break_on: tuple[type[Exception], ...] = tuple(exceptions_to_break_on) self._has_bulk_parent = has_bulk_parent - self._non_breaking_exceptions: List[Exception] = [] + self._non_breaking_exceptions: list[Exception] = [] def _replace_failed_jobs(self, partition: AsyncPartition) -> None: failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT) @@ -225,7 +217,7 @@ def _start_jobs(self) -> None: "Waiting before creating more jobs as the limit of concurrent jobs has been reached. Will try again later..." ) - def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) -> AsyncJob: + def _start_job(self, _slice: StreamSlice, previous_job_id: str | None = None) -> AsyncJob: if previous_job_id: id_to_replace = previous_job_id lazy_log(LOGGER, logging.DEBUG, lambda: f"Attempting to replace job {id_to_replace}...") @@ -235,16 +227,19 @@ def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) try: job = self._job_repository.start(_slice) self._job_tracker.add_job(id_to_replace, job.api_job_id()) - return job + return job # noqa: TRY300 except Exception as exception: LOGGER.warning(f"Exception has occurred during job creation: {exception}") if self._is_breaking_exception(exception): self._job_tracker.remove_job(id_to_replace) - raise exception + raise exception # noqa: TRY201 return self._keep_api_budget_with_failed_job(_slice, exception, id_to_replace) def _keep_api_budget_with_failed_job( - self, _slice: StreamSlice, exception: Exception, intent: str + self, + slice_: StreamSlice, + exception: Exception, + intent: str, ) -> AsyncJob: """ We have a mechanism to retry job. It is used when a job status is FAILED or TIMED_OUT. The easiest way to retry is to have this job @@ -252,7 +247,7 @@ def _keep_api_budget_with_failed_job( retrying jobs that couldn't be started. """ LOGGER.warning( - f"Could not start job for slice {_slice}. Job will be flagged as failed and retried if max number of attempts not reached: {exception}" + f"Could not start job for slice {slice_}. Job will be flagged as failed and retried if max number of attempts not reached: {exception}" ) traced_exception = ( exception @@ -262,7 +257,7 @@ def _keep_api_budget_with_failed_job( # Even though we're not sure this will break the stream, we will emit here for simplicity's sake. If we wanted to be more accurate, # we would keep the exceptions in-memory until we know that we have reached the max attempt. self._message_repository.emit_message(traced_exception.as_airbyte_message()) - job = self._create_failed_job(_slice) + job = self._create_failed_job(slice_) self._job_tracker.add_job(intent, job.api_job_id()) return job @@ -271,7 +266,7 @@ def _create_failed_job(self, stream_slice: StreamSlice) -> AsyncJob: job.update_status(AsyncJobStatus.FAILED) return job - def _get_running_jobs(self) -> Set[AsyncJob]: + def _get_running_jobs(self) -> set[AsyncJob]: """ Returns a set of running AsyncJob objects. @@ -324,7 +319,7 @@ def _process_completed_partition(self, partition: AsyncPartition) -> None: Args: partition (AsyncPartition): The completed partition to process. """ - job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs})) + job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs})) # noqa: C417 LOGGER.info( f"The following jobs for stream slice {partition.stream_slice} have been completed: {job_ids}." ) @@ -346,7 +341,7 @@ def _process_running_partitions_and_yield_completed_ones( Raises: Any: Any exception raised during processing. """ - current_running_partitions: List[AsyncPartition] = [] + current_running_partitions: list[AsyncPartition] = [] for partition in self._running_partitions: match partition.status: case AsyncJobStatus.COMPLETED: @@ -384,7 +379,7 @@ def _stop_timed_out_jobs(self, partition: AsyncPartition) -> None: # we don't free allocation here because it is expected to retry the job self._abort_job(job, free_job_allocation=False) - def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None: + def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None: # noqa: FBT001, FBT002 try: self._job_repository.abort(job) if free_job_allocation: @@ -442,7 +437,7 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: f"Caught exception that stops the processing of the jobs: {exception}" ) self._abort_all_running_jobs() - raise exception + raise exception # noqa: TRY201 self._non_breaking_exceptions.append(exception) @@ -454,12 +449,10 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]: # call of `create_and_get_completed_partitions` knows that there was an issue with some partitions and the sync is incomplete. raise AirbyteTracedException( message="", - internal_message="\n".join( - [ - filter_secrets(exception.__repr__()) - for exception in self._non_breaking_exceptions - ] - ), + internal_message="\n".join([ + filter_secrets(exception.__repr__()) # noqa: PLC2801 + for exception in self._non_breaking_exceptions + ]), failure_type=FailureType.config_error, ) diff --git a/airbyte_cdk/sources/declarative/async_job/job_tracker.py b/airbyte_cdk/sources/declarative/async_job/job_tracker.py index b47fc4cad..25c3c570e 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_tracker.py +++ b/airbyte_cdk/sources/declarative/async_job/job_tracker.py @@ -3,10 +3,10 @@ import logging import threading import uuid -from typing import Set from airbyte_cdk.logger import lazy_log + LOGGER = logging.getLogger("airbyte") @@ -15,8 +15,8 @@ class ConcurrentJobLimitReached(Exception): class JobTracker: - def __init__(self, limit: int): - self._jobs: Set[str] = set() + def __init__(self, limit: int): # noqa: ANN204 + self._jobs: set[str] = set() self._limit = limit self._lock = threading.Lock() @@ -31,7 +31,7 @@ def try_to_get_intent(self) -> str: raise ConcurrentJobLimitReached( "Can't allocate more jobs right now: limit already reached" ) - intent = f"intent_{str(uuid.uuid4())}" + intent = f"intent_{uuid.uuid4()!s}" lazy_log( LOGGER, logging.DEBUG, diff --git a/airbyte_cdk/sources/declarative/async_job/repository.py b/airbyte_cdk/sources/declarative/async_job/repository.py index 21581ec4f..df46748b4 100644 --- a/airbyte_cdk/sources/declarative/async_job/repository.py +++ b/airbyte_cdk/sources/declarative/async_job/repository.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from abc import abstractmethod -from typing import Any, Iterable, Mapping, Set +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.async_job.job import AsyncJob from airbyte_cdk.sources.types import StreamSlice @@ -13,7 +14,7 @@ def start(self, stream_slice: StreamSlice) -> AsyncJob: pass @abstractmethod - def update_jobs_status(self, jobs: Set[AsyncJob]) -> None: + def update_jobs_status(self, jobs: set[AsyncJob]) -> None: pass @abstractmethod diff --git a/airbyte_cdk/sources/declarative/async_job/status.py b/airbyte_cdk/sources/declarative/async_job/status.py index 586e79889..abd7bb4b9 100644 --- a/airbyte_cdk/sources/declarative/async_job/status.py +++ b/airbyte_cdk/sources/declarative/async_job/status.py @@ -3,6 +3,7 @@ from enum import Enum + _TERMINAL = True @@ -12,7 +13,7 @@ class AsyncJobStatus(Enum): FAILED = ("FAILED", _TERMINAL) TIMED_OUT = ("TIMED_OUT", _TERMINAL) - def __init__(self, value: str, is_terminal: bool) -> None: + def __init__(self, value: str, is_terminal: bool) -> None: # noqa: FBT001 self._value = value self._is_terminal = is_terminal diff --git a/airbyte_cdk/sources/declarative/async_job/timer.py b/airbyte_cdk/sources/declarative/async_job/timer.py index c4e5a9a1d..a1686bc04 100644 --- a/airbyte_cdk/sources/declarative/async_job/timer.py +++ b/airbyte_cdk/sources/declarative/async_job/timer.py @@ -1,12 +1,11 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from datetime import datetime, timedelta, timezone -from typing import Optional class Timer: def __init__(self, timeout: timedelta) -> None: - self._start_datetime: Optional[datetime] = None - self._end_datetime: Optional[datetime] = None + self._start_datetime: datetime | None = None + self._end_datetime: datetime | None = None self._timeout = timeout def start(self) -> None: @@ -21,7 +20,7 @@ def is_started(self) -> bool: return self._start_datetime is not None @property - def elapsed_time(self) -> Optional[timedelta]: + def elapsed_time(self) -> timedelta | None: if not self._start_datetime: return None diff --git a/airbyte_cdk/sources/declarative/auth/__init__.py b/airbyte_cdk/sources/declarative/auth/__init__.py index 810437810..a0b61072d 100644 --- a/airbyte_cdk/sources/declarative/auth/__init__.py +++ b/airbyte_cdk/sources/declarative/auth/__init__.py @@ -5,4 +5,5 @@ from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator + __all__ = ["DeclarativeOauth2Authenticator", "JwtAuthenticator"] diff --git a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py index b749718fa..5a7949eb2 100644 --- a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py +++ b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Union +from typing import Any from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import ( AbstractHeaderAuthenticator, @@ -20,7 +21,7 @@ def get_request_params(self) -> Mapping[str, Any]: """HTTP request parameter to add to the requests""" return {} - def get_request_body_data(self) -> Union[Mapping[str, Any], str]: + def get_request_body_data(self) -> Mapping[str, Any] | str: """Form-encoded body data to set on the requests""" return {} diff --git a/airbyte_cdk/sources/declarative/auth/jwt.py b/airbyte_cdk/sources/declarative/auth/jwt.py index d7dd59282..005c336a5 100644 --- a/airbyte_cdk/sources/declarative/auth/jwt.py +++ b/airbyte_cdk/sources/declarative/auth/jwt.py @@ -3,9 +3,10 @@ # import base64 +from collections.abc import Mapping from dataclasses import InitVar, dataclass from datetime import datetime -from typing import Any, Mapping, Optional, Union +from typing import Any import jwt @@ -15,7 +16,7 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString -class JwtAlgorithm(str): +class JwtAlgorithm(str): # noqa: SLOT000 """ Enum for supported JWT algorithms """ @@ -60,19 +61,19 @@ class JwtAuthenticator(DeclarativeAuthenticator): config: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] - secret_key: Union[InterpolatedString, str] - algorithm: Union[str, JwtAlgorithm] - token_duration: Optional[int] - base64_encode_secret_key: Optional[Union[InterpolatedBoolean, str, bool]] = False - header_prefix: Optional[Union[InterpolatedString, str]] = None - kid: Optional[Union[InterpolatedString, str]] = None - typ: Optional[Union[InterpolatedString, str]] = None - cty: Optional[Union[InterpolatedString, str]] = None - iss: Optional[Union[InterpolatedString, str]] = None - sub: Optional[Union[InterpolatedString, str]] = None - aud: Optional[Union[InterpolatedString, str]] = None - additional_jwt_headers: Optional[Mapping[str, Any]] = None - additional_jwt_payload: Optional[Mapping[str, Any]] = None + secret_key: InterpolatedString | str + algorithm: str | JwtAlgorithm + token_duration: int | None + base64_encode_secret_key: InterpolatedBoolean | str | bool | None = False + header_prefix: InterpolatedString | str | None = None + kid: InterpolatedString | str | None = None + typ: InterpolatedString | str | None = None + cty: InterpolatedString | str | None = None + iss: InterpolatedString | str | None = None + sub: InterpolatedString | str | None = None + aud: InterpolatedString | str | None = None + additional_jwt_headers: Mapping[str, Any] | None = None + additional_jwt_payload: Mapping[str, Any] | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters) @@ -158,7 +159,7 @@ def _get_secret_key(self) -> str: else secret_key ) - def _get_signed_token(self) -> Union[str, Any]: + def _get_signed_token(self) -> str | Any: # noqa: ANN401 """ Signed the JWT using the provided secret key and algorithm and the generated headers and payload. For additional information on PyJWT see: https://pyjwt.readthedocs.io/en/stable/ """ @@ -170,9 +171,9 @@ def _get_signed_token(self) -> Union[str, Any]: headers=self._get_jwt_headers(), ) except Exception as e: - raise ValueError(f"Failed to sign token: {e}") + raise ValueError(f"Failed to sign token: {e}") # noqa: B904 - def _get_header_prefix(self) -> Union[str, None]: + def _get_header_prefix(self) -> str | None: """ Returns the header prefix to be used when attaching the token to the request. """ diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 36508fd7e..a7dda3921 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, Optional, Union +from typing import Any import pendulum @@ -44,33 +45,33 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut message_repository (MessageRepository): the message repository used to emit logs on HTTP requests """ - client_id: Union[InterpolatedString, str] - client_secret: Union[InterpolatedString, str] + client_id: InterpolatedString | str + client_secret: InterpolatedString | str config: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] - token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None - refresh_token: Optional[Union[InterpolatedString, str]] = None - scopes: Optional[List[str]] = None - token_expiry_date: Optional[Union[InterpolatedString, str]] = None - _token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None) - token_expiry_date_format: Optional[str] = None + token_refresh_endpoint: InterpolatedString | str | None = None + refresh_token: InterpolatedString | str | None = None + scopes: list[str] | None = None + token_expiry_date: InterpolatedString | str | None = None + _token_expiry_date: pendulum.DateTime | None = field(init=False, repr=False, default=None) + token_expiry_date_format: str | None = None token_expiry_is_time_of_expiration: bool = False - access_token_name: Union[InterpolatedString, str] = "access_token" - access_token_value: Optional[Union[InterpolatedString, str]] = None - client_id_name: Union[InterpolatedString, str] = "client_id" - client_secret_name: Union[InterpolatedString, str] = "client_secret" - expires_in_name: Union[InterpolatedString, str] = "expires_in" - refresh_token_name: Union[InterpolatedString, str] = "refresh_token" - refresh_request_body: Optional[Mapping[str, Any]] = None - refresh_request_headers: Optional[Mapping[str, Any]] = None - grant_type_name: Union[InterpolatedString, str] = "grant_type" - grant_type: Union[InterpolatedString, str] = "refresh_token" - message_repository: MessageRepository = NoopMessageRepository() + access_token_name: InterpolatedString | str = "access_token" + access_token_value: InterpolatedString | str | None = None + client_id_name: InterpolatedString | str = "client_id" + client_secret_name: InterpolatedString | str = "client_secret" + expires_in_name: InterpolatedString | str = "expires_in" + refresh_token_name: InterpolatedString | str = "refresh_token" + refresh_request_body: Mapping[str, Any] | None = None + refresh_request_headers: Mapping[str, Any] | None = None + grant_type_name: InterpolatedString | str = "grant_type" + grant_type: InterpolatedString | str = "refresh_token" + message_repository: MessageRepository = NoopMessageRepository() # noqa: RUF009 def __post_init__(self, parameters: Mapping[str, Any]) -> None: super().__init__() if self.token_refresh_endpoint is not None: - self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create( + self._token_refresh_endpoint: InterpolatedString | None = InterpolatedString.create( self.token_refresh_endpoint, parameters=parameters ) else: @@ -85,7 +86,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.refresh_token_name, parameters=parameters ) if self.refresh_token is not None: - self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create( + self._refresh_token: InterpolatedString | None = InterpolatedString.create( self.refresh_token, parameters=parameters ) else: @@ -122,7 +123,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else: self._access_token_value = None - self._access_token: Optional[str] = ( + self._access_token: str | None = ( self._access_token_value if self.access_token_value else None ) @@ -131,7 +132,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: "OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`" ) - def get_token_refresh_endpoint(self) -> Optional[str]: + def get_token_refresh_endpoint(self) -> str | None: if self._token_refresh_endpoint is not None: refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config) if not refresh_token_endpoint: @@ -162,10 +163,10 @@ def get_client_secret(self) -> str: def get_refresh_token_name(self) -> str: return self._refresh_token_name.eval(self.config) # type: ignore # eval returns a string in this context - def get_refresh_token(self) -> Optional[str]: + def get_refresh_token(self) -> str | None: return None if self._refresh_token is None else str(self._refresh_token.eval(self.config)) - def get_scopes(self) -> List[str]: + def get_scopes(self) -> list[str]: return self.scopes or [] def get_access_token_name(self) -> str: @@ -189,7 +190,7 @@ def get_refresh_request_headers(self) -> Mapping[str, Any]: def get_token_expiry_date(self) -> pendulum.DateTime: return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: str | int) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) @property @@ -218,5 +219,5 @@ class DeclarativeSingleUseRefreshTokenOauth2Authenticator( Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 super().__init__(*args, **kwargs) diff --git a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py index 3a84150bf..fdd0b4deb 100644 --- a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py +++ b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, List, Mapping +from typing import Any import dpath @@ -16,16 +17,16 @@ class SelectiveAuthenticator(DeclarativeAuthenticator): config: Mapping[str, Any] authenticators: Mapping[str, DeclarativeAuthenticator] - authenticator_selection_path: List[str] + authenticator_selection_path: list[str] # returns "DeclarativeAuthenticator", but must return a subtype of "SelectiveAuthenticator" def __new__( # type: ignore[misc] cls, config: Mapping[str, Any], authenticators: Mapping[str, DeclarativeAuthenticator], - authenticator_selection_path: List[str], - *arg: Any, - **kwargs: Any, + authenticator_selection_path: list[str], + *arg: Any, # noqa: ANN401, ARG003 + **kwargs: Any, # noqa: ANN401, ARG003 ) -> DeclarativeAuthenticator: try: selected_key = str( diff --git a/airbyte_cdk/sources/declarative/auth/token.py b/airbyte_cdk/sources/declarative/auth/token.py index 12fb899b9..0dbf0cf2a 100644 --- a/airbyte_cdk/sources/declarative/auth/token.py +++ b/airbyte_cdk/sources/declarative/auth/token.py @@ -1,11 +1,12 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import base64 import logging +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Union +from typing import Any import requests from cachetools import TTLCache, cached @@ -68,7 +69,7 @@ def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, A def get_request_params(self) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter) - def get_request_body_data(self) -> Union[Mapping[str, Any], str]: + def get_request_body_data(self) -> Mapping[str, Any] | str: return self._get_request_options(RequestOptionType.body_data) def get_request_body_json(self) -> Mapping[str, Any]: @@ -118,10 +119,10 @@ class BasicHttpAuthenticator(DeclarativeAuthenticator): parameters (Mapping[str, Any]): Additional runtime parameters to be used for string interpolation """ - username: Union[InterpolatedString, str] + username: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - password: Union[InterpolatedString, str] = "" + password: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._username = InterpolatedString.create(self.username, parameters=parameters) @@ -134,7 +135,7 @@ def auth_header(self) -> str: @property def token(self) -> str: auth_string = ( - f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode("utf8") + f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode() ) b64_encoded = base64.b64encode(auth_string).decode("utf8") return f"Basic {b64_encoded}" @@ -148,7 +149,7 @@ def token(self) -> str: i.e. by adding another item the cache would exceed its maximum size, the cache must choose which item(s) to discard ttl=86400 means that cached token will live for 86400 seconds (one day) """ -cacheSessionTokenAuthenticator: TTLCache[str, str] = TTLCache(maxsize=1000, ttl=86400) +cacheSessionTokenAuthenticator: TTLCache[str, str] = TTLCache(maxsize=1000, ttl=86400) # noqa: N816 @cached(cacheSessionTokenAuthenticator) @@ -201,16 +202,16 @@ class LegacySessionTokenAuthenticator(DeclarativeAuthenticator): validate_session_url (Union[InterpolatedString, str]): Url to validate passed session token """ - api_url: Union[InterpolatedString, str] - header: Union[InterpolatedString, str] - session_token: Union[InterpolatedString, str] - session_token_response_key: Union[InterpolatedString, str] - username: Union[InterpolatedString, str] + api_url: InterpolatedString | str + header: InterpolatedString | str + session_token: InterpolatedString | str + session_token_response_key: InterpolatedString | str + username: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - login_url: Union[InterpolatedString, str] - validate_session_url: Union[InterpolatedString, str] - password: Union[InterpolatedString, str] = "" + login_url: InterpolatedString | str + validate_session_url: InterpolatedString | str + password: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._username = InterpolatedString.create(self.username, parameters=parameters) @@ -234,7 +235,7 @@ def auth_header(self) -> str: @property def token(self) -> str: - if self._session_token.eval(self.config): + if self._session_token.eval(self.config): # noqa: SIM102 if self.is_valid_session_token(): return str(self._session_token.eval(self.config)) if self._password.eval(self.config) and self._username.eval(self.config): @@ -259,14 +260,12 @@ def is_valid_session_token(self) -> bool: response.raise_for_status() except requests.exceptions.HTTPError as e: if e.response.status_code == requests.codes["unauthorized"]: - self.logger.info(f"Unable to connect by session token from config due to {str(e)}") + self.logger.info(f"Unable to connect by session token from config due to {e!s}") return False - else: - raise ConnectionError(f"Error while validating session token: {e}") + raise ConnectionError(f"Error while validating session token: {e}") # noqa: B904 if response.ok: self.logger.info("Connection check for source is successful.") return True - else: - raise ConnectionError( - f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}" - ) + raise ConnectionError( + f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}" + ) diff --git a/airbyte_cdk/sources/declarative/auth/token_provider.py b/airbyte_cdk/sources/declarative/auth/token_provider.py index ed933bc59..1b1a9eabc 100644 --- a/airbyte_cdk/sources/declarative/auth/token_provider.py +++ b/airbyte_cdk/sources/declarative/auth/token_provider.py @@ -5,8 +5,9 @@ import datetime from abc import abstractmethod +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, Optional, Union +from typing import Any import dpath import pendulum @@ -32,14 +33,14 @@ def get_token(self) -> str: @dataclass class SessionTokenProvider(TokenProvider): login_requester: Requester - session_token_path: List[str] - expiration_duration: Optional[Union[datetime.timedelta, Duration]] + session_token_path: list[str] + expiration_duration: datetime.timedelta | Duration | None parameters: InitVar[Mapping[str, Any]] - message_repository: MessageRepository = NoopMessageRepository() + message_repository: MessageRepository = NoopMessageRepository() # noqa: RUF009 decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={})) - _next_expiration_time: Optional[DateTime] = None - _token: Optional[str] = None + _next_expiration_time: DateTime | None = None + _token: str | None = None def get_token(self) -> str: self._refresh_if_necessary() @@ -72,7 +73,7 @@ def _refresh(self) -> None: @dataclass class InterpolatedStringTokenProvider(TokenProvider): config: Config - api_token: Union[InterpolatedString, str] + api_token: InterpolatedString | str parameters: Mapping[str, Any] def __post_init__(self) -> None: diff --git a/airbyte_cdk/sources/declarative/checks/__init__.py b/airbyte_cdk/sources/declarative/checks/__init__.py index 6362e0613..41f6d832a 100644 --- a/airbyte_cdk/sources/declarative/checks/__init__.py +++ b/airbyte_cdk/sources/declarative/checks/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) 2025 Airbyte, Inc., all rights reserved. # -from typing import Mapping +from collections.abc import Mapping from pydantic.v1 import BaseModel @@ -16,6 +16,7 @@ CheckStream as CheckStreamModel, ) + COMPONENTS_CHECKER_TYPE_MAPPING: Mapping[str, type[BaseModel]] = { "CheckStream": CheckStreamModel, "CheckDynamicStream": CheckDynamicStreamModel, diff --git a/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py b/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py index 75807c400..0abaa3b15 100644 --- a/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py +++ b/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py @@ -4,8 +4,9 @@ import logging import traceback +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, List, Mapping, Tuple +from typing import Any from airbyte_cdk import AbstractSource from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker @@ -29,7 +30,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def check_connection( self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: streams = source.streams(config=config) if len(streams) == 0: return False, f"No streams to connect to from source {source}" diff --git a/airbyte_cdk/sources/declarative/checks/check_stream.py b/airbyte_cdk/sources/declarative/checks/check_stream.py index c45159ec9..2f65675f8 100644 --- a/airbyte_cdk/sources/declarative/checks/check_stream.py +++ b/airbyte_cdk/sources/declarative/checks/check_stream.py @@ -4,8 +4,9 @@ import logging import traceback +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, List, Mapping, Tuple +from typing import Any from airbyte_cdk import AbstractSource from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker @@ -21,7 +22,7 @@ class CheckStream(ConnectionChecker): stream_name (List[str]): names of streams to check """ - stream_names: List[str] + stream_names: list[str] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -29,13 +30,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def check_connection( self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: streams = source.streams(config=config) stream_name_to_stream = {s.name: s for s in streams} if len(streams) == 0: return False, f"No streams to connect to from source {source}" for stream_name in self.stream_names: - if stream_name not in stream_name_to_stream.keys(): + if stream_name not in stream_name_to_stream: raise ValueError( f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}." ) diff --git a/airbyte_cdk/sources/declarative/checks/connection_checker.py b/airbyte_cdk/sources/declarative/checks/connection_checker.py index fd1d1bba2..ee1b10783 100644 --- a/airbyte_cdk/sources/declarative/checks/connection_checker.py +++ b/airbyte_cdk/sources/declarative/checks/connection_checker.py @@ -4,7 +4,8 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Mapping, Tuple +from collections.abc import Mapping +from typing import Any from airbyte_cdk import AbstractSource @@ -17,7 +18,7 @@ class ConnectionChecker(ABC): @abstractmethod def check_connection( self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: """ Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect to the Stripe API. diff --git a/airbyte_cdk/sources/declarative/concurrency_level/__init__.py b/airbyte_cdk/sources/declarative/concurrency_level/__init__.py index 6c55c15c9..669500b8c 100644 --- a/airbyte_cdk/sources/declarative/concurrency_level/__init__.py +++ b/airbyte_cdk/sources/declarative/concurrency_level/__init__.py @@ -4,4 +4,5 @@ from airbyte_cdk.sources.declarative.concurrency_level.concurrency_level import ConcurrencyLevel + __all__ = ["ConcurrencyLevel"] diff --git a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py index f5cd24f00..0f6f36af2 100644 --- a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py +++ b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py @@ -2,8 +2,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation import InterpolatedString from airbyte_cdk.sources.types import Config @@ -19,14 +20,14 @@ class ConcurrencyLevel: max_concurrency (Optional[int]): The maximum number of worker threads to use when the default_concurrency is exceeded """ - default_concurrency: Union[int, str] - max_concurrency: Optional[int] + default_concurrency: int | str + max_concurrency: int | None config: Config parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.default_concurrency, int): - self._default_concurrency: Union[int, InterpolatedString] = self.default_concurrency + self._default_concurrency: int | InterpolatedString = self.default_concurrency elif "config" in self.default_concurrency and not self.max_concurrency: raise ValueError( "ConcurrencyLevel requires that max_concurrency be defined if the default_concurrency can be used-specified" @@ -40,11 +41,10 @@ def get_concurrency_level(self) -> int: if isinstance(self._default_concurrency, InterpolatedString): evaluated_default_concurrency = self._default_concurrency.eval(config=self.config) if not isinstance(evaluated_default_concurrency, int): - raise ValueError("default_concurrency did not evaluate to an integer") + raise ValueError("default_concurrency did not evaluate to an integer") # noqa: TRY004 return ( min(evaluated_default_concurrency, self.max_concurrency) if self.max_concurrency else evaluated_default_concurrency ) - else: - return self._default_concurrency + return self._default_concurrency diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 5db0b0909..ab1bead26 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,8 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple +from collections.abc import Iterator, Mapping +from typing import Any, Generic from airbyte_cdk.models import ( AirbyteCatalog, @@ -57,14 +58,14 @@ class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]): def __init__( self, - catalog: Optional[ConfiguredAirbyteCatalog], - config: Optional[Mapping[str, Any]], + catalog: ConfiguredAirbyteCatalog | None, # noqa: ARG002 + config: Mapping[str, Any] | None, state: TState, source_config: ConnectionDefinition, - debug: bool = False, - emit_connector_builder_messages: bool = False, - component_factory: Optional[ModelToComponentFactory] = None, - **kwargs: Any, + debug: bool = False, # noqa: FBT001, FBT002 + emit_connector_builder_messages: bool = False, # noqa: FBT001, FBT002 + component_factory: ModelToComponentFactory | None = None, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> None: # To reduce the complexity of the concurrent framework, we are not enabling RFR with synthetic # cursors. We do this by no longer automatically instantiating RFR cursors when converting @@ -83,7 +84,7 @@ def __init__( component_factory=component_factory, ) - # todo: We could remove state from initialization. Now that streams are grouped during the read(), a source + # TODO: We could remove state from initialization. Now that streams are grouped during the read(), a source # no longer needs to store the original incoming state. But maybe there's an edge case? self._state = state @@ -120,7 +121,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: concurrent_streams, _ = self._group_streams(config=config) @@ -128,7 +129,7 @@ def read( # the concurrent streams must be saved so that they can be removed from the catalog before starting # synchronous streams if len(concurrent_streams) > 0: - concurrent_stream_names = set( + concurrent_stream_names = set( # noqa: C403 [concurrent_stream.name for concurrent_stream in concurrent_streams] ) @@ -151,7 +152,7 @@ def read( yield from super().read(logger, config, filtered_catalog, state) - def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: + def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: # noqa: ARG002 concurrent_streams, synchronous_streams = self._group_streams(config=config) return AirbyteCatalog( streams=[ @@ -159,7 +160,7 @@ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> Airbyte ] ) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: """ The `streams` method is used as part of the AbstractSource in the following cases: * ConcurrentDeclarativeSource.check -> ManifestDeclarativeSource.check -> AbstractSource.check -> DeclarativeSource.check_connection -> CheckStream.check_connection -> streams @@ -172,9 +173,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: def _group_streams( self, config: Mapping[str, Any] - ) -> Tuple[List[AbstractStream], List[Stream]]: - concurrent_streams: List[AbstractStream] = [] - synchronous_streams: List[Stream] = [] + ) -> tuple[list[AbstractStream], list[Stream]]: + concurrent_streams: list[AbstractStream] = [] + synchronous_streams: list[Stream] = [] state_manager = ConnectorStateManager(state=self._state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later @@ -268,7 +269,7 @@ def _group_streams( if hasattr(cursor, "cursor_field") and hasattr( cursor.cursor_field, "cursor_field_key" - ) # FIXME this will need to be updated once we do the per partition + ) # FIXME this will need to be updated once we do the per partition # noqa: FIX001, TD001, TD004 else None, logger=self.logger, cursor=cursor, @@ -347,13 +348,13 @@ def _stream_supports_concurrent_partition_processing( declarative_stream.retriever.requester, HttpRequester ): http_requester = declarative_stream.retriever.requester - if "stream_state" in http_requester._path.string: + if "stream_state" in http_requester._path.string: # noqa: SLF001 self.logger.warning( f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the HttpRequester which is not thread-safe. Defaulting to synchronous processing" ) return False - request_options_provider = http_requester._request_options_provider + request_options_provider = http_requester._request_options_provider # noqa: SLF001 if request_options_provider.request_options_contain_stream_state(): self.logger.warning( f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the HttpRequester which is not thread-safe. Defaulting to synchronous processing" @@ -397,10 +398,10 @@ def _stream_supports_concurrent_partition_processing( @staticmethod def _select_streams( - streams: List[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog - ) -> List[AbstractStream]: + streams: list[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog + ) -> list[AbstractStream]: stream_name_to_instance: Mapping[str, AbstractStream] = {s.name: s for s in streams} - abstract_streams: List[AbstractStream] = [] + abstract_streams: list[AbstractStream] = [] for configured_stream in configured_catalog.streams: stream_instance = stream_name_to_instance.get(configured_stream.stream.name) if stream_instance: diff --git a/airbyte_cdk/sources/declarative/datetime/__init__.py b/airbyte_cdk/sources/declarative/datetime/__init__.py index bf1f13e1e..dbf4fb5b8 100644 --- a/airbyte_cdk/sources/declarative/datetime/__init__.py +++ b/airbyte_cdk/sources/declarative/datetime/__init__.py @@ -4,4 +4,5 @@ from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime + __all__ = ["MinMaxDatetime"] diff --git a/airbyte_cdk/sources/declarative/datetime/datetime_parser.py b/airbyte_cdk/sources/declarative/datetime/datetime_parser.py index 93122e29c..bafd3d8bb 100644 --- a/airbyte_cdk/sources/declarative/datetime/datetime_parser.py +++ b/airbyte_cdk/sources/declarative/datetime/datetime_parser.py @@ -3,7 +3,6 @@ # import datetime -from typing import Union class DatetimeParser: @@ -18,7 +17,7 @@ class DatetimeParser: _UNIX_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) - def parse(self, date: Union[str, int], format: str) -> datetime.datetime: + def parse(self, date: str | int, format: str) -> datetime.datetime: # noqa: A002 # "%s" is a valid (but unreliable) directive for formatting, but not for parsing # It is defined as # The number of seconds since the Epoch, 1970-01-01 00:00:00+0000 (UTC). https://man7.org/linux/man-pages/man3/strptime.3.html @@ -27,9 +26,9 @@ def parse(self, date: Union[str, int], format: str) -> datetime.datetime: # See https://stackoverflow.com/a/4974930 if format == "%s": return datetime.datetime.fromtimestamp(int(date), tz=datetime.timezone.utc) - elif format == "%s_as_float": + if format == "%s_as_float": return datetime.datetime.fromtimestamp(float(date), tz=datetime.timezone.utc) - elif format == "%ms": + if format == "%ms": return self._UNIX_EPOCH + datetime.timedelta(milliseconds=int(date)) parsed_datetime = datetime.datetime.strptime(str(date), format) @@ -37,7 +36,7 @@ def parse(self, date: Union[str, int], format: str) -> datetime.datetime: return parsed_datetime.replace(tzinfo=datetime.timezone.utc) return parsed_datetime - def format(self, dt: datetime.datetime, format: str) -> str: + def format(self, dt: datetime.datetime, format: str) -> str: # noqa: A002 # strftime("%s") is unreliable because it ignores the time zone information and assumes the time zone of the system it's running on # It's safer to use the timestamp() method than the %s directive # See https://stackoverflow.com/a/4974930 @@ -48,8 +47,7 @@ def format(self, dt: datetime.datetime, format: str) -> str: if format == "%ms": # timstamp() returns a float representing the number of seconds since the unix epoch return str(int(dt.timestamp() * 1000)) - else: - return dt.strftime(format) + return dt.strftime(format) def _is_naive(self, dt: datetime.datetime) -> bool: return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None diff --git a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py index eb407db44..cb1c65140 100644 --- a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py +++ b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py @@ -3,8 +3,9 @@ # import datetime as dt +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Union +from typing import Any, Union from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -28,14 +29,14 @@ class MinMaxDatetime: max_datetime (Union[InterpolatedString, str]): Represents the maximum allowed datetime value. """ - datetime: Union[InterpolatedString, str] + datetime: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] # datetime_format is a unique case where we inherit it from the parent if it is not specified before using the default value # which is why we need dedicated getter/setter methods and private dataclass field datetime_format: str _datetime_format: str = field(init=False, repr=False, default="") - min_datetime: Union[InterpolatedString, str] = "" - max_datetime: Union[InterpolatedString, str] = "" + min_datetime: InterpolatedString | str = "" + max_datetime: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {}) @@ -104,15 +105,14 @@ def datetime_format(self, value: str) -> None: def create( cls, interpolated_string_or_min_max_datetime: Union[InterpolatedString, str, "MinMaxDatetime"], - parameters: Optional[Mapping[str, Any]] = None, + parameters: Mapping[str, Any] | None = None, ) -> "MinMaxDatetime": if parameters is None: parameters = {} - if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance( + if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance( # noqa: SIM101 interpolated_string_or_min_max_datetime, str ): return MinMaxDatetime( # type: ignore [call-arg] datetime=interpolated_string_or_min_max_datetime, parameters=parameters ) - else: - return interpolated_string_or_min_max_datetime + return interpolated_string_or_min_max_datetime diff --git a/airbyte_cdk/sources/declarative/declarative_source.py b/airbyte_cdk/sources/declarative/declarative_source.py index 77bf427a1..769407d3a 100644 --- a/airbyte_cdk/sources/declarative/declarative_source.py +++ b/airbyte_cdk/sources/declarative/declarative_source.py @@ -4,7 +4,8 @@ import logging from abc import abstractmethod -from typing import Any, Mapping, Tuple +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker @@ -22,7 +23,7 @@ def connection_checker(self) -> ConnectionChecker: def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: """ :param logger: The source logger :param config: The user-provided configuration as specified by the source's spec. diff --git a/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte_cdk/sources/declarative/declarative_stream.py index 12cdd3337..ac11ea870 100644 --- a/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte_cdk/sources/declarative/declarative_stream.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import logging +from collections.abc import Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.declarative.incremental import ( @@ -46,12 +47,12 @@ class DeclarativeStream(Stream): config: Config parameters: InitVar[Mapping[str, Any]] name: str - primary_key: Optional[Union[str, List[str], List[List[str]]]] - state_migrations: List[StateMigration] = field(repr=True, default_factory=list) - schema_loader: Optional[SchemaLoader] = None + primary_key: str | list[str] | list[list[str]] | None + state_migrations: list[StateMigration] = field(repr=True, default_factory=list) + schema_loader: SchemaLoader | None = None _name: str = field(init=False, repr=False, default="") _primary_key: str = field(init=False, repr=False, default="") - stream_cursor_field: Optional[Union[InterpolatedString, str]] = None + stream_cursor_field: InterpolatedString | str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._stream_cursor_field = ( @@ -60,13 +61,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else self.stream_cursor_field ) self._schema_loader = ( - self.schema_loader + self.schema_loader # noqa: FURB110 if self.schema_loader else DefaultSchemaLoader(config=self.config, parameters=parameters) ) @property # type: ignore - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: return self._primary_key @primary_key.setter @@ -109,18 +110,20 @@ def state(self, value: MutableMapping[str, Any]) -> None: self.retriever.state = state def get_updated_state( - self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any] + self, + current_stream_state: MutableMapping[str, Any], # noqa: ARG002 + latest_record: Mapping[str, Any], # noqa: ARG002 ) -> MutableMapping[str, Any]: return self.state @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: """ Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. """ cursor = self._stream_cursor_field.eval(self.config) # type: ignore # _stream_cursor_field is always cast to interpolated string - return cursor if cursor else [] + return cursor if cursor else [] # noqa: FURB110 @property def is_resumable(self) -> bool: @@ -130,10 +133,10 @@ def is_resumable(self) -> bool: def read_records( self, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Iterable[Mapping[str, Any]]: """ :param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state. @@ -146,7 +149,7 @@ def read_records( # empty slice which seems to make sense. stream_slice = StreamSlice(partition={}, cursor_slice={}) if not isinstance(stream_slice, StreamSlice): - raise ValueError( + raise ValueError( # noqa: TRY004 f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}" ) yield from self.retriever.read_records(self.get_json_schema(), stream_slice) # type: ignore # records are of the correct type @@ -163,10 +166,10 @@ def get_json_schema(self) -> Mapping[str, Any]: # type: ignore def stream_slices( self, *, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[StreamSlice]]: + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Iterable[StreamSlice | None]: """ Override to define the slices for this stream. See the stream slicing section of the docs for more information. @@ -178,7 +181,7 @@ def stream_slices( return self.retriever.stream_slices() @property - def state_checkpoint_interval(self) -> Optional[int]: + def state_checkpoint_interval(self) -> int | None: """ We explicitly disable checkpointing here. There are a couple reasons for that and not all are documented here but: * In the case where records are not ordered, the granularity of what is ordered is the slice. Therefore, we will only update the @@ -188,7 +191,7 @@ def state_checkpoint_interval(self) -> Optional[int]: """ return None - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: if self.retriever and isinstance(self.retriever, SimpleRetriever): return self.retriever.cursor return None @@ -196,7 +199,7 @@ def get_cursor(self) -> Optional[Cursor]: def _get_checkpoint_reader( self, logger: logging.Logger, - cursor_field: Optional[List[str]], + cursor_field: list[str] | None, sync_mode: SyncMode, stream_state: MutableMapping[str, Any], ) -> CheckpointReader: @@ -212,14 +215,14 @@ def _get_checkpoint_reader( """ mappings_or_slices = self.stream_slices( cursor_field=cursor_field, - sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior + sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior stream_state=stream_state, ) cursor = self.get_cursor() checkpoint_mode = self._checkpoint_mode - if isinstance( + if isinstance( # noqa: UP038 cursor, (GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor) ): self.has_multiple_slices = True diff --git a/airbyte_cdk/sources/declarative/decoders/__init__.py b/airbyte_cdk/sources/declarative/decoders/__init__.py index 45eaf5599..e81195349 100644 --- a/airbyte_cdk/sources/declarative/decoders/__init__.py +++ b/airbyte_cdk/sources/declarative/decoders/__init__.py @@ -4,9 +4,9 @@ from airbyte_cdk.sources.declarative.decoders.composite_raw_decoder import ( CompositeRawDecoder, - GzipParser, + GzipParser, # noqa: F401 JsonParser, - Parser, + Parser, # noqa: F401 ) from airbyte_cdk.sources.declarative.decoders.decoder import Decoder from airbyte_cdk.sources.declarative.decoders.json_decoder import ( @@ -22,6 +22,7 @@ from airbyte_cdk.sources.declarative.decoders.xml_decoder import XmlDecoder from airbyte_cdk.sources.declarative.decoders.zipfile_decoder import ZipfileDecoder + __all__ = [ "Decoder", "CompositeRawDecoder", diff --git a/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py b/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py index 4d670db11..340b46848 100644 --- a/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py @@ -1,11 +1,13 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. import csv import gzip import json import logging from abc import ABC, abstractmethod +from collections.abc import Generator, MutableMapping from dataclasses import dataclass from io import BufferedIOBase, TextIOWrapper -from typing import Any, Generator, MutableMapping, Optional +from typing import Any import orjson import requests @@ -14,6 +16,7 @@ from airbyte_cdk.sources.declarative.decoders.decoder import Decoder from airbyte_cdk.utils import AirbyteTracedException + logger = logging.getLogger("airbyte") @@ -59,7 +62,7 @@ def parse(self, data: BufferedIOBase) -> Generator[MutableMapping[str, Any], Non if body_json is None: raise AirbyteTracedException( message="Response JSON data failed to be parsed. See logs for more information.", - internal_message=f"Response JSON data failed to be parsed.", + internal_message="Response JSON data failed to be parsed.", failure_type=FailureType.system_error, ) @@ -68,7 +71,7 @@ def parse(self, data: BufferedIOBase) -> Generator[MutableMapping[str, Any], Non else: yield from [body_json] - def _parse_orjson(self, raw_data: bytes) -> Optional[Any]: + def _parse_orjson(self, raw_data: bytes) -> Any | None: # noqa: ANN401 try: return orjson.loads(raw_data.decode(self.encoding)) except Exception as exc: @@ -77,7 +80,7 @@ def _parse_orjson(self, raw_data: bytes) -> Optional[Any]: ) return None - def _parse_json(self, raw_data: bytes) -> Optional[Any]: + def _parse_json(self, raw_data: bytes) -> Any | None: # noqa: ANN401 try: return json.loads(raw_data.decode(self.encoding)) except Exception as exc: @@ -87,7 +90,7 @@ def _parse_json(self, raw_data: bytes) -> Optional[Any]: @dataclass class JsonLineParser(Parser): - encoding: Optional[str] = "utf-8" + encoding: str | None = "utf-8" def parse( self, @@ -103,8 +106,8 @@ def parse( @dataclass class CsvParser(Parser): # TODO: migrate implementation to re-use file-base classes - encoding: Optional[str] = "utf-8" - delimiter: Optional[str] = "," + encoding: str | None = "utf-8" + delimiter: str | None = "," def parse( self, diff --git a/airbyte_cdk/sources/declarative/decoders/decoder.py b/airbyte_cdk/sources/declarative/decoders/decoder.py index 5fa9dc8f6..e6b6aa9bd 100644 --- a/airbyte_cdk/sources/declarative/decoders/decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/decoder.py @@ -3,8 +3,9 @@ # from abc import abstractmethod +from collections.abc import Generator, MutableMapping from dataclasses import dataclass -from typing import Any, Generator, MutableMapping +from typing import Any import requests diff --git a/airbyte_cdk/sources/declarative/decoders/json_decoder.py b/airbyte_cdk/sources/declarative/decoders/json_decoder.py index cab572ef4..e47400758 100644 --- a/airbyte_cdk/sources/declarative/decoders/json_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/json_decoder.py @@ -3,15 +3,17 @@ # import codecs import logging +from collections.abc import Generator, Mapping, MutableMapping from dataclasses import InitVar, dataclass from gzip import decompress -from typing import Any, Generator, List, Mapping, MutableMapping, Optional +from typing import Any import orjson import requests from airbyte_cdk.sources.declarative.decoders.decoder import Decoder + logger = logging.getLogger("airbyte") @@ -43,7 +45,7 @@ def decode( @staticmethod def parse_body_json( - body_json: MutableMapping[str, Any] | List[MutableMapping[str, Any]], + body_json: MutableMapping[str, Any] | list[MutableMapping[str, Any]], ) -> Generator[MutableMapping[str, Any], None, None]: if not isinstance(body_json, list): body_json = [body_json] @@ -85,7 +87,7 @@ def is_stream_response(self) -> bool: def decode( self, response: requests.Response ) -> Generator[MutableMapping[str, Any], None, None]: - # TODO???: set delimiter? usually it is `\n` but maybe it would be useful to set optional? + # TODO???: set delimiter? usually it is `\n` but maybe it would be useful to set optional? # noqa: TD004 # https://github.com/airbytehq/airbyte-internal-issues/issues/8436 for record in response.iter_lines(): yield orjson.loads(record) @@ -93,14 +95,14 @@ def decode( @dataclass class GzipJsonDecoder(JsonDecoder): - encoding: Optional[str] + encoding: str | None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.encoding: try: codecs.lookup(self.encoding) except LookupError: - raise ValueError( + raise ValueError( # noqa: B904 f"Invalid encoding '{self.encoding}'. Please check provided encoding" ) diff --git a/airbyte_cdk/sources/declarative/decoders/noop_decoder.py b/airbyte_cdk/sources/declarative/decoders/noop_decoder.py index cf0bc56eb..145303539 100644 --- a/airbyte_cdk/sources/declarative/decoders/noop_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/noop_decoder.py @@ -1,12 +1,14 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. import logging -from typing import Any, Generator, Mapping +from collections.abc import Generator, Mapping +from typing import Any import requests from airbyte_cdk.sources.declarative.decoders.decoder import Decoder + logger = logging.getLogger("airbyte") @@ -16,6 +18,6 @@ def is_stream_response(self) -> bool: def decode( # type: ignore[override] # Signature doesn't match base class self, - response: requests.Response, + response: requests.Response, # noqa: ARG002 ) -> Generator[Mapping[str, Any], None, None]: yield from [{}] diff --git a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py index e5a152711..6928e0d53 100644 --- a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py +++ b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py @@ -3,13 +3,15 @@ # import logging +from collections.abc import Generator, MutableMapping from dataclasses import dataclass -from typing import Any, Generator, MutableMapping +from typing import Any import requests from airbyte_cdk.sources.declarative.decoders import Decoder + logger = logging.getLogger("airbyte") @@ -19,7 +21,7 @@ class PaginationDecoderDecorator(Decoder): Decoder to wrap other decoders when instantiating a DefaultPaginator in order to bypass decoding if the response is streamed. """ - def __init__(self, decoder: Decoder): + def __init__(self, decoder: Decoder): # noqa: ANN204 self._decoder = decoder @property diff --git a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py index 0786c3202..471ac0abf 100644 --- a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py @@ -3,8 +3,9 @@ # import logging +from collections.abc import Generator, Mapping, MutableMapping from dataclasses import InitVar, dataclass -from typing import Any, Generator, Mapping, MutableMapping +from typing import Any from xml.parsers.expat import ExpatError import requests @@ -12,6 +13,7 @@ from airbyte_cdk.sources.declarative.decoders.decoder import Decoder + logger = logging.getLogger("airbyte") diff --git a/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py b/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py index a937a1e4d..9510134cb 100644 --- a/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py +++ b/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py @@ -4,11 +4,11 @@ import logging import zipfile +from collections.abc import Generator, MutableMapping from dataclasses import dataclass from io import BytesIO -from typing import Any, Generator, MutableMapping +from typing import Any -import orjson import requests from airbyte_cdk.models import FailureType @@ -18,6 +18,7 @@ ) from airbyte_cdk.utils import AirbyteTracedException + logger = logging.getLogger("airbyte") diff --git a/airbyte_cdk/sources/declarative/extractors/__init__.py b/airbyte_cdk/sources/declarative/extractors/__init__.py index 8f1d18d12..9cef0d121 100644 --- a/airbyte_cdk/sources/declarative/extractors/__init__.py +++ b/airbyte_cdk/sources/declarative/extractors/__init__.py @@ -11,6 +11,7 @@ ) from airbyte_cdk.sources.declarative.extractors.type_transformer import TypeTransformer + __all__ = [ "TypeTransformer", "HttpSelector", diff --git a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py index 9c97773e3..d1575e4b6 100644 --- a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, MutableMapping, Union +from typing import Any import dpath import requests @@ -53,7 +54,7 @@ class DpathExtractor(RecordExtractor): decoder (Decoder): The decoder responsible to transfom the response in a Mapping """ - field_path: List[Union[InterpolatedString, str]] + field_path: list[InterpolatedString | str] config: Config parameters: InitVar[Mapping[str, Any]] decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={})) diff --git a/airbyte_cdk/sources/declarative/extractors/http_selector.py b/airbyte_cdk/sources/declarative/extractors/http_selector.py index 846071125..8db96c796 100644 --- a/airbyte_cdk/sources/declarative/extractors/http_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/http_selector.py @@ -3,7 +3,8 @@ # from abc import abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any import requests @@ -22,8 +23,8 @@ def select_records( response: requests.Response, stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: """ Selects records from the response diff --git a/airbyte_cdk/sources/declarative/extractors/record_extractor.py b/airbyte_cdk/sources/declarative/extractors/record_extractor.py index 5de6a84a7..6890950af 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/record_extractor.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # from abc import abstractmethod +from collections.abc import Iterable, Mapping from dataclasses import dataclass -from typing import Any, Iterable, Mapping +from typing import Any import requests diff --git a/airbyte_cdk/sources/declarative/extractors/record_filter.py b/airbyte_cdk/sources/declarative/extractors/record_filter.py index b744c9796..5cdca8618 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_filter.py +++ b/airbyte_cdk/sources/declarative/extractors/record_filter.py @@ -1,8 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.incremental import ( DatetimeBasedCursor, @@ -35,8 +36,8 @@ def filter_records( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: kwargs = { "stream_state": stream_state, @@ -57,11 +58,11 @@ class ClientSideIncrementalRecordFilterDecorator(RecordFilter): :param PerPartitionCursor per_partition_cursor: Optional Cursor used for mapping cursor value in nested stream_state """ - def __init__( + def __init__( # noqa: ANN204 self, date_time_based_cursor: DatetimeBasedCursor, - substream_cursor: Optional[Union[PerPartitionWithGlobalCursor, GlobalSubstreamCursor]], - **kwargs: Any, + substream_cursor: PerPartitionWithGlobalCursor | GlobalSubstreamCursor | None, + **kwargs: Any, # noqa: ANN401 ): super().__init__(**kwargs) self._date_time_based_cursor = date_time_based_cursor @@ -71,8 +72,8 @@ def filter_records( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: records = ( record diff --git a/airbyte_cdk/sources/declarative/extractors/record_selector.py b/airbyte_cdk/sources/declarative/extractors/record_selector.py index f29b8a75a..92905fa6c 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/record_selector.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import Any import requests @@ -14,7 +15,6 @@ TypeTransformer as DeclarativeTypeTransformer, ) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString -from airbyte_cdk.sources.declarative.models import SchemaNormalization from airbyte_cdk.sources.declarative.transformations import RecordTransformation from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.sources.utils.transform import TypeTransformer @@ -36,11 +36,11 @@ class RecordSelector(HttpSelector): extractor: RecordExtractor config: Config parameters: InitVar[Mapping[str, Any]] - schema_normalization: Union[TypeTransformer, DeclarativeTypeTransformer] + schema_normalization: TypeTransformer | DeclarativeTypeTransformer name: str - _name: Union[InterpolatedString, str] = field(init=False, repr=False, default="") - record_filter: Optional[RecordFilter] = None - transformations: List[RecordTransformation] = field(default_factory=lambda: []) + _name: InterpolatedString | str = field(init=False, repr=False, default="") + record_filter: RecordFilter | None = None + transformations: list[RecordTransformation] = field(default_factory=list) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._parameters = parameters @@ -71,8 +71,8 @@ def select_records( response: requests.Response, stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: """ Selects records from the response @@ -93,8 +93,8 @@ def filter_and_transform( all_data: Iterable[Mapping[str, Any]], stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: """ There is an issue with the selector as of 2024-08-30: it does technology-agnostic processing like filtering, transformation and @@ -111,7 +111,7 @@ def filter_and_transform( yield Record(data=data, stream_name=self.name, associated_slice=stream_slice) def _normalize_by_schema( - self, records: Iterable[Mapping[str, Any]], schema: Optional[Mapping[str, Any]] + self, records: Iterable[Mapping[str, Any]], schema: Mapping[str, Any] | None ) -> Iterable[Mapping[str, Any]]: if schema: # record has type Mapping[str, Any], but dict[str, Any] expected @@ -126,8 +126,8 @@ def _filter( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, ) -> Iterable[Mapping[str, Any]]: if self.record_filter: yield from self.record_filter.filter_records( @@ -143,7 +143,7 @@ def _transform( self, records: Iterable[Mapping[str, Any]], stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[Mapping[str, Any]]: for record in records: for transformation in self.transformations: diff --git a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py index 0215ddb45..0ee34d540 100644 --- a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py +++ b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py @@ -5,9 +5,10 @@ import os import uuid import zlib +from collections.abc import Iterable, Mapping from contextlib import closing from dataclasses import InitVar, dataclass -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple +from typing import Any import pandas as pd import requests @@ -15,6 +16,7 @@ from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor + EMPTY_STR: str = "" DEFAULT_ENCODING: str = "utf-8" DOWNLOAD_CHUNK_SIZE: int = 1024 * 10 @@ -35,7 +37,7 @@ class ResponseToFileExtractor(RecordExtractor): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.logger = logging.getLogger("airbyte") - def _get_response_encoding(self, headers: Dict[str, Any]) -> str: + def _get_response_encoding(self, headers: dict[str, Any]) -> str: """ Get the encoding of the response based on the provided headers. This method is heavily inspired by the requests library implementation. @@ -78,7 +80,7 @@ def _filter_null_bytes(self, b: bytes) -> bytes: ) return res - def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: + def _save_to_file(self, response: requests.Response) -> tuple[str, str]: """ Saves the binary data from the given response to a temporary file and returns the filepath and response encoding. @@ -96,7 +98,7 @@ def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: needs_decompression = True # we will assume at first that the response is compressed and change the flag if not tmp_file = str(uuid.uuid4()) - with closing(response) as response, open(tmp_file, "wb") as data_file: + with closing(response) as response, open(tmp_file, "wb") as data_file: # noqa: PTH123, PLR1704 response_encoding = self._get_response_encoding(dict(response.headers or {})) for chunk in response.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): try: @@ -110,12 +112,11 @@ def _save_to_file(self, response: requests.Response) -> Tuple[str, str]: needs_decompression = False # check the file exists - if os.path.isfile(tmp_file): + if os.path.isfile(tmp_file): # noqa: PTH113 return tmp_file, response_encoding - else: - raise ValueError( - f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist." - ) + raise ValueError( + f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist." + ) def _read_with_chunks( self, path: str, file_encoding: str, chunk_size: int = 100 @@ -136,25 +137,25 @@ def _read_with_chunks( """ try: - with open(path, "r", encoding=file_encoding) as data: + with open(path, encoding=file_encoding) as data: # noqa: PTH123 chunks = pd.read_csv( data, chunksize=chunk_size, iterator=True, dialect="unix", dtype=object ) for chunk in chunks: - chunk = chunk.replace({nan: None}).to_dict(orient="records") - for row in chunk: + chunk = chunk.replace({nan: None}).to_dict(orient="records") # noqa: PLW2901 + for row in chunk: # noqa: UP028 yield row except pd.errors.EmptyDataError as e: self.logger.info(f"Empty data received. {e}") yield from [] - except IOError as ioe: - raise ValueError(f"The IO/Error occured while reading tmp data. Called: {path}", ioe) + except OSError as ioe: + raise ValueError(f"The IO/Error occured while reading tmp data. Called: {path}", ioe) # noqa: B904 finally: # remove binary tmp file, after data is read - os.remove(path) + os.remove(path) # noqa: PTH107 def extract_records( - self, response: Optional[requests.Response] = None + self, response: requests.Response | None = None ) -> Iterable[Mapping[str, Any]]: """ Extracts records from the given response by: diff --git a/airbyte_cdk/sources/declarative/extractors/type_transformer.py b/airbyte_cdk/sources/declarative/extractors/type_transformer.py index fe307684f..9060bed35 100644 --- a/airbyte_cdk/sources/declarative/extractors/type_transformer.py +++ b/airbyte_cdk/sources/declarative/extractors/type_transformer.py @@ -3,8 +3,9 @@ # from abc import ABC, abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict, Mapping +from typing import Any @dataclass @@ -35,7 +36,7 @@ class TypeTransformer(ABC): @abstractmethod def transform( self, - record: Dict[str, Any], + record: dict[str, Any], schema: Mapping[str, Any], ) -> None: """ diff --git a/airbyte_cdk/sources/declarative/incremental/__init__.py b/airbyte_cdk/sources/declarative/incremental/__init__.py index 7ce54a07a..66f48438a 100644 --- a/airbyte_cdk/sources/declarative/incremental/__init__.py +++ b/airbyte_cdk/sources/declarative/incremental/__init__.py @@ -19,6 +19,7 @@ ResumableFullRefreshCursor, ) + __all__ = [ "CursorFactory", "DatetimeBasedCursor", diff --git a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py index d6d329aec..a38be1613 100644 --- a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py @@ -3,9 +3,10 @@ # import datetime +from collections.abc import Callable, Iterable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field from datetime import timedelta -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any from isodate import Duration, duration_isoformat, parse_duration @@ -52,28 +53,28 @@ class DatetimeBasedCursor(DeclarativeCursor): lookback_window (Optional[InterpolatedString]): how many days before start_datetime to read data for (ISO8601 duration) """ - start_datetime: Union[MinMaxDatetime, str] - cursor_field: Union[InterpolatedString, str] + start_datetime: MinMaxDatetime | str + cursor_field: InterpolatedString | str datetime_format: str config: Config parameters: InitVar[Mapping[str, Any]] - _highest_observed_cursor_field_value: Optional[str] = field( + _highest_observed_cursor_field_value: str | None = field( repr=False, default=None ) # tracks the latest observed datetime, which may not be safe to emit in the case of out-of-order records - _cursor: Optional[str] = field( + _cursor: str | None = field( repr=False, default=None ) # tracks the latest observed datetime that is appropriate to emit as stream state - end_datetime: Optional[Union[MinMaxDatetime, str]] = None - step: Optional[Union[InterpolatedString, str]] = None - cursor_granularity: Optional[str] = None - start_time_option: Optional[RequestOption] = None - end_time_option: Optional[RequestOption] = None - partition_field_start: Optional[str] = None - partition_field_end: Optional[str] = None - lookback_window: Optional[Union[InterpolatedString, str]] = None - message_repository: Optional[MessageRepository] = None - is_compare_strictly: Optional[bool] = False - cursor_datetime_formats: List[str] = field(default_factory=lambda: []) + end_datetime: MinMaxDatetime | str | None = None + step: InterpolatedString | str | None = None + cursor_granularity: str | None = None + start_time_option: RequestOption | None = None + end_time_option: RequestOption | None = None + partition_field_start: str | None = None + partition_field_end: str | None = None + lookback_window: InterpolatedString | str | None = None + message_repository: MessageRepository | None = None + is_compare_strictly: bool | None = False + cursor_datetime_formats: list[str] = field(default_factory=list) def __post_init__(self, parameters: Mapping[str, Any]) -> None: if (self.step and not self.cursor_granularity) or ( @@ -166,12 +167,12 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: ): self._highest_observed_cursor_field_value = record_cursor_value - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002 if stream_slice.partition: raise ValueError( f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}." ) - cursor_value_str_by_cursor_value_datetime = dict( + cursor_value_str_by_cursor_value_datetime = dict( # noqa: C417 map( # we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like # 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z' @@ -202,7 +203,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: start_datetime = self._calculate_earliest_possible_value(self.select_best_end_datetime()) return self._partition_daterange(start_datetime, end_datetime, self._step) - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002 # Datetime based cursors operate over slices made up of datetime ranges. Stream state is based on the progress # through each slice and does not belong to a specific slice. We just return stream state as it is. return self.get_stream_state() @@ -254,8 +255,8 @@ def _partition_daterange( self, start: datetime.datetime, end: datetime.datetime, - step: Union[datetime.timedelta, Duration], - ) -> List[StreamSlice]: + step: datetime.timedelta | Duration, + ) -> list[StreamSlice]: start_field = self._partition_field_start.eval(self.config) end_field = self._partition_field_end.eval(self.config) dates = [] @@ -303,7 +304,7 @@ def _get_date( return comparator(cursor_date, default_date) def parse_date(self, date: str) -> datetime.datetime: - for datetime_format in self.cursor_datetime_formats + [self.datetime_format]: + for datetime_format in self.cursor_datetime_formats + [self.datetime_format]: # noqa: RUF005 try: return self._parser.parse(date, datetime_format) except ValueError: @@ -311,7 +312,7 @@ def parse_date(self, date: str) -> datetime.datetime: raise ValueError(f"No format in {self.cursor_datetime_formats} matching {date}") @classmethod - def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, Duration]: + def _parse_timedelta(cls, time_str: str | None) -> datetime.timedelta | Duration: """ :return Parses an ISO 8601 durations into datetime.timedelta or Duration objects. """ @@ -322,36 +323,36 @@ def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.header, stream_slice) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json, stream_slice) @@ -360,7 +361,7 @@ def request_kwargs(self) -> Mapping[str, Any]: return {} def _get_request_options( - self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: options: MutableMapping[str, Any] = {} if not stream_slice: @@ -395,8 +396,8 @@ def should_be_synced(self, record: Record) -> bool: def _is_within_daterange_boundaries( self, record: Record, - start_datetime_boundary: Union[datetime.datetime, str], - end_datetime_boundary: Union[datetime.datetime, str], + start_datetime_boundary: datetime.datetime | str, + end_datetime_boundary: datetime.datetime | str, ) -> bool: cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__ record_cursor_value = record.get(cursor_field) @@ -429,10 +430,9 @@ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: second_cursor_value = second.get(cursor_field) if first_cursor_value and second_cursor_value: return self.parse_date(first_cursor_value) >= self.parse_date(second_cursor_value) - elif first_cursor_value: + if first_cursor_value: # noqa: SIM103 return True - else: - return False + return False def set_runtime_lookback_window(self, lookback_window_in_seconds: int) -> None: """ diff --git a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py index 3b3636236..a27436d0c 100644 --- a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py @@ -4,18 +4,20 @@ import threading import time -from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union +from collections.abc import Callable, Iterable, Mapping +from typing import Any, TypeVar from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import Record, StreamSlice, StreamState + T = TypeVar("T") def iterate_with_last_flag_and_state( - generator: Iterable[T], get_stream_state_func: Callable[[], Optional[Mapping[str, StreamState]]] + generator: Iterable[T], get_stream_state_func: Callable[[], Mapping[str, StreamState] | None] ) -> Iterable[tuple[T, bool, Any]]: """ Iterates over the given generator, yielding tuples containing the element, a flag @@ -53,16 +55,15 @@ class Timer: """ def __init__(self) -> None: - self._start: Optional[int] = None + self._start: int | None = None def start(self) -> None: self._start = time.perf_counter_ns() def finish(self) -> int: if self._start: - return ((time.perf_counter_ns() - self._start) / 1e9).__ceil__() - else: - raise RuntimeError("Global substream cursor timer not started") + return ((time.perf_counter_ns() - self._start) / 1e9).__ceil__() # noqa: PLC2801 + raise RuntimeError("Global substream cursor timer not started") class GlobalSubstreamCursor(DeclarativeCursor): @@ -79,7 +80,7 @@ class GlobalSubstreamCursor(DeclarativeCursor): - When using the `incremental_dependency` option, the sync will progress through parent records, preventing the sync from getting infinitely stuck. However, it is crucial to understand the requirements for both the `global_substream_cursor` and `incremental_dependency` options to avoid data loss. """ - def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: PartitionRouter): + def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: PartitionRouter): # noqa: ANN204 self._stream_cursor = stream_cursor self._partition_router = partition_router self._timer = Timer() @@ -88,10 +89,10 @@ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: Partiti 0 ) # Start with 0, indicating no slices being tracked self._all_slices_yielded = False - self._lookback_window: Optional[int] = None - self._current_partition: Optional[Mapping[str, Any]] = None + self._lookback_window: int | None = None + self._current_partition: Mapping[str, Any] | None = None self._last_slice: bool = False - self._parent_state: Optional[Mapping[str, Any]] = None + self._parent_state: Mapping[str, Any] | None = None def start_slices_generation(self) -> None: self._timer.start() @@ -118,7 +119,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: ) self.start_slices_generation() - for slice, last, state in iterate_with_last_flag_and_state( + for slice, last, state in iterate_with_last_flag_and_state( # noqa: A001 slice_generator, self._partition_router.get_stream_state ): self._parent_state = state @@ -134,7 +135,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str yield from slice_generator - def register_slice(self, last: bool) -> None: + def register_slice(self, last: bool) -> None: # noqa: FBT001 """ Tracks the processing of a stream slice. @@ -213,7 +214,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record ) - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401 """ Close the current stream slice. @@ -227,7 +228,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: """ with self._lock: self._slice_semaphore.acquire() - if self._all_slices_yielded and self._slice_semaphore._value == 0: + if self._all_slices_yielded and self._slice_semaphore._value == 0: # noqa: SLF001 self._lookback_window = self._timer.finish() self._stream_cursor.close_slice( StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args @@ -244,16 +245,16 @@ def get_stream_state(self) -> StreamState: return state - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002 # stream_slice is ignored as cursor is global return self._stream_cursor.get_stream_state() def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_params( # type: ignore # this always returns a mapping @@ -265,15 +266,14 @@ def get_request_params( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request params") + raise ValueError("A partition needs to be provided in order to get request params") def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping @@ -285,16 +285,15 @@ def get_request_headers( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request headers") + raise ValueError("A partition needs to be provided in order to get request headers") def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: if stream_slice: return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping stream_state=stream_state, @@ -305,15 +304,14 @@ def get_request_body_data( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body data") + raise ValueError("A partition needs to be provided in order to get request body data") def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping @@ -325,8 +323,7 @@ def get_request_body_json( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body json") + raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: return self._stream_cursor.should_be_synced(self._convert_record_to_cursor_record(record)) diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index 8241f7761..9d65ba5fb 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -4,7 +4,8 @@ import logging from collections import OrderedDict -from typing import Any, Callable, Iterable, Mapping, Optional, Union +from collections.abc import Callable, Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter @@ -13,11 +14,12 @@ ) from airbyte_cdk.sources.types import Record, StreamSlice, StreamState + logger = logging.getLogger("airbyte") class CursorFactory: - def __init__(self, create_function: Callable[[], DeclarativeCursor]): + def __init__(self, create_function: Callable[[], DeclarativeCursor]): # noqa: ANN204 self._create_function = create_function def create(self) -> DeclarativeCursor: @@ -49,7 +51,7 @@ class PerPartitionCursor(DeclarativeCursor): _VALUE = 1 _state_to_migrate_from: Mapping[str, Any] = {} - def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter): + def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter): # noqa: ANN204 self._cursor_factory = cursor_factory self._partition_router = partition_router # The dict is ordered to ensure that once the maximum number of partitions is reached, @@ -70,7 +72,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition)) if not cursor: partition_state = ( - self._state_to_migrate_from + self._state_to_migrate_from # noqa: FURB110 if self._state_to_migrate_from else self._NO_CURSOR_STATE ) @@ -154,14 +156,14 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record ) - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401 try: self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].close_slice( StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args ) except KeyError as exception: - raise ValueError( - f"Partition {str(exception)} could not be found in current state based on the record. This is unexpected because " + raise ValueError( # noqa: B904 + f"Partition {exception!s} could not be found in current state based on the record. This is unexpected because " f"we should only update state for partitions that were emitted during `stream_slices`" ) @@ -170,12 +172,10 @@ def get_stream_state(self) -> StreamState: for partition_tuple, cursor in self._cursor_per_partition.items(): cursor_state = cursor.get_stream_state() if cursor_state: - states.append( - { - "partition": self._to_dict(partition_tuple), - "cursor": cursor_state, - } - ) + states.append({ + "partition": self._to_dict(partition_tuple), + "cursor": cursor_state, + }) state: dict[str, Any] = {"states": states} parent_state = self._partition_router.get_stream_state() @@ -183,7 +183,7 @@ def get_stream_state(self) -> StreamState: state["parent_state"] = parent_state return state - def _get_state_for_partition(self, partition: Mapping[str, Any]) -> Optional[StreamState]: + def _get_state_for_partition(self, partition: Mapping[str, Any]) -> StreamState | None: cursor = self._cursor_per_partition.get(self._to_partition_key(partition)) if cursor: return cursor.get_stream_state() @@ -200,7 +200,7 @@ def _to_partition_key(self, partition: Mapping[str, Any]) -> str: def _to_dict(self, partition_key: str) -> Mapping[str, Any]: return self._partition_serializer.to_partition(partition_key) - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: if not stream_slice: raise ValueError("A partition needs to be provided in order to extract a state") @@ -209,7 +209,7 @@ def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[S return self._get_state_for_partition(stream_slice.partition) - def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor: + def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor: # noqa: ANN401 cursor = self._cursor_factory.create() cursor.set_initial_state(cursor_state) return cursor @@ -217,9 +217,9 @@ def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_params( # type: ignore # this always returns a mapping @@ -233,15 +233,14 @@ def get_request_params( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request params") + raise ValueError("A partition needs to be provided in order to get request params") def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping @@ -255,16 +254,15 @@ def get_request_headers( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request headers") + raise ValueError("A partition needs to be provided in order to get request headers") def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: if stream_slice: return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping stream_state=stream_state, @@ -277,15 +275,14 @@ def get_request_body_data( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body data") + raise ValueError("A partition needs to be provided in order to get request body data") def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: if stream_slice: return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping @@ -299,8 +296,7 @@ def get_request_body_json( stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), next_page_token=next_page_token, ) - else: - raise ValueError("A partition needs to be provided in order to get request body json") + raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: return self._get_cursor(record).should_be_synced( diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py index defa2d897..7a86a7b9b 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py @@ -1,7 +1,8 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Iterable, Mapping, MutableMapping, Optional, Union +from collections.abc import Iterable, Mapping, MutableMapping +from typing import Any from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor @@ -66,7 +67,7 @@ class PerPartitionWithGlobalCursor(DeclarativeCursor): Suitable for streams where the number of partitions may vary significantly, requiring dynamic switching between per-partition and global state management to ensure data consistency and efficient synchronization. """ - def __init__( + def __init__( # noqa: ANN204 self, cursor_factory: CursorFactory, partition_router: PartitionRouter, @@ -76,11 +77,11 @@ def __init__( self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router) self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router) self._use_global_cursor = False - self._current_partition: Optional[Mapping[str, Any]] = None + self._current_partition: Mapping[str, Any] | None = None self._last_slice: bool = False - self._parent_state: Optional[Mapping[str, Any]] = None + self._parent_state: Mapping[str, Any] | None = None - def _get_active_cursor(self) -> Union[PerPartitionCursor, GlobalSubstreamCursor]: + def _get_active_cursor(self) -> PerPartitionCursor | GlobalSubstreamCursor: return self._global_cursor if self._use_global_cursor else self._per_partition_cursor def stream_slices(self) -> Iterable[StreamSlice]: @@ -92,7 +93,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: ): # Generate slices for the current cursor and handle the last slice using the flag self._parent_state = parent_state - for slice, is_last_slice, _ in iterate_with_last_flag_and_state( + for slice, is_last_slice, _ in iterate_with_last_flag_and_state( # noqa: A001 self._get_active_cursor().generate_slices_from_partition(partition=partition), lambda: None, ): @@ -120,7 +121,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: self._per_partition_cursor.observe(stream_slice, record) self._global_cursor.observe(stream_slice, record) - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401 if not self._use_global_cursor: self._per_partition_cursor.close_slice(stream_slice, *args) self._global_cursor.close_slice(stream_slice, *args) @@ -138,15 +139,15 @@ def get_stream_state(self) -> StreamState: return final_state - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: return self._get_active_cursor().select_state(stream_slice) def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_active_cursor().get_request_params( stream_state=stream_state, @@ -157,9 +158,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_active_cursor().get_request_headers( stream_state=stream_state, @@ -170,10 +171,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._get_active_cursor().get_request_body_data( stream_state=stream_state, stream_slice=stream_slice, @@ -183,9 +184,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_active_cursor().get_request_body_json( stream_state=stream_state, diff --git a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py index a0b4665f1..82d1d70d4 100644 --- a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.incremental import DeclarativeCursor from airbyte_cdk.sources.declarative.types import Record, StreamSlice, StreamState @@ -27,7 +28,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: """ pass - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002 # The ResumableFullRefreshCursor doesn't support nested streams yet so receiving a partition is unexpected if stream_slice.partition: raise ValueError( @@ -35,20 +36,20 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: ) self._cursor = stream_slice.cursor_slice - def should_be_synced(self, record: Record) -> bool: + def should_be_synced(self, record: Record) -> bool: # noqa: ARG002 """ Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages that don't have filterable bounds. We should always return them. """ return True - def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: + def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: # noqa: ARG002 """ RFR record don't have ordering to be compared between one another. """ return False - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002 # A top-level RFR cursor only manages the state of a single partition return self._cursor @@ -65,36 +66,36 @@ def stream_slices(self) -> Iterable[StreamSlice]: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} @@ -109,7 +110,7 @@ class ChildPartitionResumableFullRefreshCursor(ResumableFullRefreshCursor): Check the `close_slice` method overide for more info about the actual behaviour of this cursor. """ - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002 """ Once the current slice has finished syncing: - paginator returns None diff --git a/airbyte_cdk/sources/declarative/interpolation/__init__.py b/airbyte_cdk/sources/declarative/interpolation/__init__.py index d721b99f1..082e7ffb2 100644 --- a/airbyte_cdk/sources/declarative/interpolation/__init__.py +++ b/airbyte_cdk/sources/declarative/interpolation/__init__.py @@ -6,4 +6,5 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString + __all__ = ["InterpolatedBoolean", "InterpolatedMapping", "InterpolatedString"] diff --git a/airbyte_cdk/sources/declarative/interpolation/filters.py b/airbyte_cdk/sources/declarative/interpolation/filters.py index 52d76cab6..aaf7cc4a6 100644 --- a/airbyte_cdk/sources/declarative/interpolation/filters.py +++ b/airbyte_cdk/sources/declarative/interpolation/filters.py @@ -5,10 +5,10 @@ import hashlib import json import re -from typing import Any, Optional +from typing import Any -def hash(value: Any, hash_type: str = "md5", salt: Optional[str] = None) -> str: +def hash(value: Any, hash_type: str = "md5", salt: str | None = None) -> str: # noqa: ANN401, A001 """ Implementation of a custom Jinja2 hash filter Hash type defaults to 'md5' if one is not specified. @@ -49,7 +49,7 @@ def hash(value: Any, hash_type: str = "md5", salt: Optional[str] = None) -> str: hash_obj.update(str(salt).encode("utf-8")) computed_hash: str = hash_obj.hexdigest() else: - raise AttributeError("No hashing function named {hname}".format(hname=hash_type)) + raise AttributeError(f"No hashing function named {hash_type}") return computed_hash @@ -92,7 +92,7 @@ def base64decode(value: str) -> str: return base64.b64decode(value.encode("utf-8")).decode() -def string(value: Any) -> str: +def string(value: Any) -> str: # noqa: ANN401 """ Converts the input value to a string. If the value is already a string, it is returned as is. diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py index 78569b350..8ebb12d3b 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py @@ -2,13 +2,15 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Final, List, Mapping +from typing import Any, Final from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config -FALSE_VALUES: Final[List[Any]] = [ + +FALSE_VALUES: Final[list[Any]] = [ "False", "false", "{}", @@ -33,7 +35,7 @@ class InterpolatedBoolean: Attributes: condition (str): The string representing the condition to evaluate to a boolean - """ + """ # noqa: B021 condition: str parameters: InitVar[Mapping[str, Any]] @@ -42,7 +44,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._interpolation = JinjaInterpolation() self._parameters = parameters - def eval(self, config: Config, **additional_parameters: Any) -> bool: + def eval(self, config: Config, **additional_parameters: Any) -> bool: # noqa: ANN401 """ Interpolates the predicate condition string using the config and other optional arguments passed as parameter. @@ -52,15 +54,14 @@ def eval(self, config: Config, **additional_parameters: Any) -> bool: """ if isinstance(self.condition, bool): return self.condition - else: - evaluated = self._interpolation.eval( - self.condition, - config, - self._default, - parameters=self._parameters, - **additional_parameters, - ) - if evaluated in FALSE_VALUES: - return False - # The presence of a value is generally regarded as truthy, so we treat it as such - return True + evaluated = self._interpolation.eval( + self.condition, + config, + self._default, + parameters=self._parameters, + **additional_parameters, + ) + if evaluated in FALSE_VALUES: # noqa: SIM103 + return False + # The presence of a value is generally regarded as truthy, so we treat it as such + return True diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py index 11b2dac97..fa8919ec5 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py @@ -3,8 +3,9 @@ # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config @@ -22,11 +23,11 @@ class InterpolatedMapping: mapping: Mapping[str, str] parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None: + def __post_init__(self, parameters: Mapping[str, Any] | None) -> None: self._interpolation = JinjaInterpolation() self._parameters = parameters - def eval(self, config: Config, **additional_parameters: Any) -> Dict[str, Any]: + def eval(self, config: Config, **additional_parameters: Any) -> dict[str, Any]: # noqa: ANN401 """ Wrapper around a Mapping[str, str] that allows for both keys and values to be interpolated. @@ -47,10 +48,9 @@ def eval(self, config: Config, **additional_parameters: Any) -> Dict[str, Any]: for name, value in self.mapping.items() } - def _eval(self, value: str, config: Config, **kwargs: Any) -> Any: + def _eval(self, value: str, config: Config, **kwargs: Any) -> Any: # noqa: ANN401 # The values in self._mapping can be of Any type # We only want to interpolate them if they are strings if isinstance(value, str): return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs) - else: - return value + return value diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py index 82454919e..3c656d501 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py @@ -3,16 +3,18 @@ # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any, Union from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config -NestedMappingEntry = Union[ + +NestedMappingEntry = Union[ # noqa: UP007 dict[str, "NestedMapping"], list["NestedMapping"], str, int, float, bool, None ] -NestedMapping = Union[dict[str, NestedMappingEntry], str] +NestedMapping = Union[dict[str, NestedMappingEntry], str] # noqa: UP007 @dataclass @@ -27,26 +29,28 @@ class InterpolatedNestedMapping: mapping: NestedMapping parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None: + def __post_init__(self, parameters: Mapping[str, Any] | None) -> None: self._interpolation = JinjaInterpolation() self._parameters = parameters - def eval(self, config: Config, **additional_parameters: Any) -> Any: + def eval(self, config: Config, **additional_parameters: Any) -> Any: # noqa: ANN401 return self._eval(self.mapping, config, **additional_parameters) def _eval( - self, value: Union[NestedMapping, NestedMappingEntry], config: Config, **kwargs: Any - ) -> Any: + self, + value: NestedMapping | NestedMappingEntry, + config: Config, + **kwargs: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 # Recursively interpolate dictionaries and lists if isinstance(value, str): return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs) - elif isinstance(value, dict): + if isinstance(value, dict): interpolated_dict = { self._eval(k, config, **kwargs): self._eval(v, config, **kwargs) for k, v in value.items() } return {k: v for k, v in interpolated_dict.items() if v is not None} - elif isinstance(value, list): + if isinstance(value, list): return [self._eval(v, config, **kwargs) for v in value] - else: - return value + return value diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py index 542fa8068..55eb72fdf 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py @@ -2,15 +2,16 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any, Union from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation from airbyte_cdk.sources.types import Config @dataclass -class InterpolatedString: +class InterpolatedString: # noqa: PLW1641 """ Wrapper around a raw string to be interpolated with the Jinja2 templating engine @@ -22,7 +23,7 @@ class InterpolatedString: string: str parameters: InitVar[Mapping[str, Any]] - default: Optional[str] = None + default: str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.default = self.default or self.string @@ -32,7 +33,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: # This allows for optimization, but we do not know it yet at this stage self._is_plain_string = None - def eval(self, config: Config, **kwargs: Any) -> Any: + def eval(self, config: Config, **kwargs: Any) -> Any: # noqa: ANN401 """ Interpolates the input string using the config and other optional arguments passed as parameter. @@ -54,7 +55,7 @@ def eval(self, config: Config, **kwargs: Any) -> Any: self.string, config, self.default, parameters=self._parameters, **kwargs ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, InterpolatedString): return False return self.string == other.string and self.default == other.default @@ -75,5 +76,4 @@ def create( """ if isinstance(string_or_interpolated, str): return InterpolatedString(string=string_or_interpolated, parameters=parameters) - else: - return string_or_interpolated + return string_or_interpolated diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolation.py b/airbyte_cdk/sources/declarative/interpolation/interpolation.py index 5af61905e..2158996f9 100644 --- a/airbyte_cdk/sources/declarative/interpolation/interpolation.py +++ b/airbyte_cdk/sources/declarative/interpolation/interpolation.py @@ -3,7 +3,7 @@ # from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from airbyte_cdk.sources.types import Config @@ -18,9 +18,9 @@ def eval( self, input_str: str, config: Config, - default: Optional[str] = None, - **additional_options: Any, - ) -> Any: + default: str | None = None, + **additional_options: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 """ Interpolates the input string using the config, and additional options passed as parameter. diff --git a/airbyte_cdk/sources/declarative/interpolation/jinja.py b/airbyte_cdk/sources/declarative/interpolation/jinja.py index 3bb9b0c20..96aeb6005 100644 --- a/airbyte_cdk/sources/declarative/interpolation/jinja.py +++ b/airbyte_cdk/sources/declarative/interpolation/jinja.py @@ -3,8 +3,9 @@ # import ast +from collections.abc import Mapping from functools import cache -from typing import Any, Mapping, Optional, Set, Tuple, Type +from typing import Any from jinja2 import meta from jinja2.environment import Template @@ -24,8 +25,8 @@ class StreamPartitionAccessEnvironment(SandboxedEnvironment): parameter """ - def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool: - if attr in ["_partition"]: + def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool: # noqa: ANN401 + if attr in ["_partition"]: # noqa: FURB171 return True return super().is_safe_attribute(obj, attr, value) # type: ignore # for some reason, mypy says 'Returning Any from function declared to return "bool"' @@ -80,10 +81,10 @@ def eval( self, input_str: str, config: Config, - default: Optional[str] = None, - valid_types: Optional[Tuple[Type[Any]]] = None, - **additional_parameters: Any, - ) -> Any: + default: str | None = None, + valid_types: tuple[type[Any]] | None = None, + **additional_parameters: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 context = {"config": config, **additional_parameters} for alias, equivalent in _ALIASES.items(): @@ -92,7 +93,7 @@ def eval( raise ValueError( f"Found reserved keyword {alias} in interpolation context. This is unexpected and indicative of a bug in the CDK." ) - elif equivalent in context: + if equivalent in context: context[alias] = context[equivalent] try: @@ -102,14 +103,14 @@ def eval( return self._literal_eval(result, valid_types) else: # If input is not a string, return it as is - raise Exception(f"Expected a string, got {input_str}") + raise Exception(f"Expected a string, got {input_str}") # noqa: TRY002, TRY004 except UndefinedError: pass # If result is empty or resulted in an undefined error, evaluate and return the default string return self._literal_eval(self._eval(default, context), valid_types) - def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[Any]]]) -> Any: + def _literal_eval(self, result: str | None, valid_types: tuple[type[Any]] | None) -> Any: # noqa: ANN401 try: evaluated = ast.literal_eval(result) # type: ignore # literal_eval is able to handle None except (ValueError, SyntaxError): @@ -118,7 +119,7 @@ def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[ return evaluated return result - def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: + def _eval(self, s: str | None, context: Mapping[str, Any]) -> str | None: try: undeclared = self._find_undeclared_variables(s) undeclared_not_in_context = {var for var in undeclared if var not in context} @@ -132,15 +133,15 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: # It can be returned as is return s - @cache - def _find_undeclared_variables(self, s: Optional[str]) -> Set[str]: + @cache # noqa: B019 + def _find_undeclared_variables(self, s: str | None) -> set[str]: """ Find undeclared variables and cache them """ ast = _ENVIRONMENT.parse(s) # type: ignore # parse is able to handle None return meta.find_undeclared_variables(ast) - @cache + @cache # noqa: B019 def _compile(self, s: str) -> Template: """ We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader diff --git a/airbyte_cdk/sources/declarative/interpolation/macros.py b/airbyte_cdk/sources/declarative/interpolation/macros.py index 1ca5b31f0..a208a1438 100644 --- a/airbyte_cdk/sources/declarative/interpolation/macros.py +++ b/airbyte_cdk/sources/declarative/interpolation/macros.py @@ -5,13 +5,13 @@ import builtins import datetime import typing -from typing import Optional, Union import isodate import pytz from dateutil import parser from isodate import parse_duration + """ This file contains macros that can be evaluated by a `JinjaInterpolation` object """ @@ -47,7 +47,7 @@ def today_with_timezone(timezone: str) -> datetime.date: return datetime.datetime.now(tz=pytz.timezone(timezone)).date() -def timestamp(dt: Union[float, str]) -> Union[int, float]: +def timestamp(dt: float | str) -> int | float: """ Converts a number or a string to a timestamp @@ -60,10 +60,9 @@ def timestamp(dt: Union[float, str]) -> Union[int, float]: :param dt: datetime to convert to timestamp :return: unix timestamp """ - if isinstance(dt, (int, float)): + if isinstance(dt, (int, float)): # noqa: UP038 return int(dt) - else: - return _str_to_datetime(dt).astimezone(pytz.utc).timestamp() + return _str_to_datetime(dt).astimezone(pytz.utc).timestamp() def _str_to_datetime(s: str) -> datetime.datetime: @@ -74,7 +73,7 @@ def _str_to_datetime(s: str) -> datetime.datetime: return parsed_date.astimezone(pytz.utc) -def max(*args: typing.Any) -> typing.Any: +def max(*args: typing.Any) -> typing.Any: # noqa: ANN401, A001 """ Returns biggest object of an iterable, or two or more arguments. @@ -94,7 +93,7 @@ def max(*args: typing.Any) -> typing.Any: return builtins.max(*args) -def min(*args: typing.Any) -> typing.Any: +def min(*args: typing.Any) -> typing.Any: # noqa: ANN401, A001 """ Returns smallest object of an iterable, or two or more arguments. @@ -114,7 +113,7 @@ def min(*args: typing.Any) -> typing.Any: return builtins.min(*args) -def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: +def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: # noqa: A002 """ Returns datetime of now() + num_days @@ -129,7 +128,7 @@ def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: ).strftime(format) -def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]: +def duration(datestring: str) -> datetime.timedelta | isodate.Duration: """ Converts ISO8601 duration to datetime.timedelta @@ -140,7 +139,9 @@ def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]: def format_datetime( - dt: Union[str, datetime.datetime], format: str, input_format: Optional[str] = None + dt: str | datetime.datetime, + format: str, # noqa: A002 + input_format: str | None = None, ) -> str: """ Converts datetime to another format diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index deef5a3be..a58ffed72 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -5,10 +5,11 @@ import json import logging import pkgutil +from collections.abc import Iterator, Mapping from copy import deepcopy from importlib import metadata from types import ModuleType -from typing import Any, Dict, Iterator, List, Mapping, Optional, Set +from typing import Any import yaml from jsonschema.exceptions import ValidationError @@ -26,9 +27,6 @@ from airbyte_cdk.sources.declarative.checks import COMPONENTS_CHECKER_TYPE_MAPPING from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource -from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( - CheckStream as CheckStreamModel, -) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( DeclarativeStream as DeclarativeStreamModel, ) @@ -60,14 +58,14 @@ class ManifestDeclarativeSource(DeclarativeSource): """Declarative source defined by a manifest of low-code components that define source connector behavior""" - def __init__( + def __init__( # noqa: ANN204 self, source_config: ConnectionDefinition, *, config: Mapping[str, Any] | None = None, debug: bool = False, emit_connector_builder_messages: bool = False, - component_factory: Optional[ModelToComponentFactory] = None, + component_factory: ModelToComponentFactory | None = None, ): """ Args: @@ -94,7 +92,7 @@ def __init__( self._debug = debug self._emit_connector_builder_messages = emit_connector_builder_messages self._constructor = ( - component_factory + component_factory # noqa: FURB110 if component_factory else ModelToComponentFactory(emit_connector_builder_messages) ) @@ -121,17 +119,16 @@ def connection_checker(self) -> ConnectionChecker: check_stream = self._constructor.create_component( COMPONENTS_CHECKER_TYPE_MAPPING[check["type"]], check, - dict(), + dict(), # noqa: C408 emit_connector_builder_messages=self._emit_connector_builder_messages, ) if isinstance(check_stream, ConnectionChecker): return check_stream - else: - raise ValueError( - f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}" - ) + raise ValueError( + f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}" + ) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: self._emit_manifest_debug_message( extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)} ) @@ -154,8 +151,8 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: @staticmethod def _initialize_cache_for_parent_streams( - stream_configs: List[Dict[str, Any]], - ) -> List[Dict[str, Any]]: + stream_configs: list[dict[str, Any]], + ) -> list[dict[str, Any]]: parent_streams = set() def update_with_cache_parent_configs(parent_configs: list[dict[str, Any]]) -> None: @@ -204,10 +201,9 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: if spec: if "type" not in spec: spec["type"] = "Spec" - spec_component = self._constructor.create_component(SpecModel, spec, dict()) + spec_component = self._constructor.create_component(SpecModel, spec, dict()) # noqa: C408 return spec_component.generate_spec() - else: - return super().spec(logger) + return super().spec(logger) def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus: self._configure_logger_level(logger) @@ -218,7 +214,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: self._configure_logger_level(logger) yield from super().read(logger, config, catalog, state) @@ -247,7 +243,7 @@ def _validate_source(self) -> None: "Failed to read manifest component json schema required for validation" ) except FileNotFoundError as e: - raise FileNotFoundError( + raise FileNotFoundError( # noqa: B904 f"Failed to read manifest component json schema required for validation: {e}" ) @@ -314,9 +310,9 @@ def _parse_version( # No exception return parsed_version - def _stream_configs(self, manifest: Mapping[str, Any]) -> List[Dict[str, Any]]: + def _stream_configs(self, manifest: Mapping[str, Any]) -> list[dict[str, Any]]: # This has a warning flag for static, but after we finish part 4 we'll replace manifest with self._source_config - stream_configs: List[Dict[str, Any]] = manifest.get("streams", []) + stream_configs: list[dict[str, Any]] = manifest.get("streams", []) for s in stream_configs: if "type" not in s: s["type"] = "DeclarativeStream" @@ -324,10 +320,10 @@ def _stream_configs(self, manifest: Mapping[str, Any]) -> List[Dict[str, Any]]: def _dynamic_stream_configs( self, manifest: Mapping[str, Any], config: Mapping[str, Any] - ) -> List[Dict[str, Any]]: - dynamic_stream_definitions: List[Dict[str, Any]] = manifest.get("dynamic_streams", []) - dynamic_stream_configs: List[Dict[str, Any]] = [] - seen_dynamic_streams: Set[str] = set() + ) -> list[dict[str, Any]]: + dynamic_stream_definitions: list[dict[str, Any]] = manifest.get("dynamic_streams", []) + dynamic_stream_configs: list[dict[str, Any]] = [] + seen_dynamic_streams: set[str] = set() for dynamic_definition in dynamic_stream_definitions: components_resolver_config = dynamic_definition["components_resolver"] diff --git a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py index 830646fe9..03c3ed249 100644 --- a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py +++ b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration @@ -33,7 +34,7 @@ class LegacyToPerPartitionStateMigration(StateMigration): } """ - def __init__( + def __init__( # noqa: ANN204 self, partition_router: SubstreamPartitionRouter, cursor: CustomIncrementalSync | DatetimeBasedCursor, @@ -79,7 +80,7 @@ def should_migrate(self, stream_state: Mapping[str, Any]) -> bool: } """ if stream_state: - for key, value in stream_state.items(): + for key, value in stream_state.items(): # noqa: B007, PERF102 if isinstance(value, dict): keys = list(value.keys()) if len(keys) != 1: diff --git a/airbyte_cdk/sources/declarative/migrations/state_migration.py b/airbyte_cdk/sources/declarative/migrations/state_migration.py index 9cf7f3cfe..49bf5aa57 100644 --- a/airbyte_cdk/sources/declarative/migrations/state_migration.py +++ b/airbyte_cdk/sources/declarative/migrations/state_migration.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from abc import abstractmethod -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any class StateMigration: diff --git a/airbyte_cdk/sources/declarative/models/__init__.py b/airbyte_cdk/sources/declarative/models/__init__.py index 81f2e2f33..7fcd706df 100644 --- a/airbyte_cdk/sources/declarative/models/__init__.py +++ b/airbyte_cdk/sources/declarative/models/__init__.py @@ -1,2 +1,2 @@ # generated by bin/generate_component_manifest_files.py -from .declarative_component_schema import * +from .declarative_component_schema import * # noqa: F403 diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index fa4a00d18..38a272a69 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal from pydantic.v1 import BaseModel, Extra, Field @@ -22,13 +22,13 @@ class BasicHttpAuthenticator(BaseModel): examples=["{{ config['username'] }}", "{{ config['api_key'] }}"], title="Username", ) - password: Optional[str] = Field( + password: str | None = Field( "", description="The password that will be combined with the username, base64 encoded and used to make requests. Fill it in the user inputs.", examples=["{{ config['password'] }}", ""], title="Password", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class BearerAuthenticator(BaseModel): @@ -39,12 +39,12 @@ class BearerAuthenticator(BaseModel): examples=["{{ config['api_key'] }}", "{{ config['token'] }}"], title="Bearer Token", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CheckStream(BaseModel): type: Literal["CheckStream"] - stream_names: List[str] = Field( + stream_names: list[str] = Field( ..., description="Names of the streams to try reading from when running a check operation.", examples=[["users"], ["users", "contacts"]], @@ -62,31 +62,31 @@ class CheckDynamicStream(BaseModel): class ConcurrencyLevel(BaseModel): - type: Optional[Literal["ConcurrencyLevel"]] = None - default_concurrency: Union[int, str] = Field( + type: Literal["ConcurrencyLevel"] | None = None + default_concurrency: int | str = Field( ..., description="The amount of concurrency that will applied during a sync. This value can be hardcoded or user-defined in the config if different users have varying volume thresholds in the target API.", examples=[10, "{{ config['num_workers'] or 10 }}"], title="Default Concurrency", ) - max_concurrency: Optional[int] = Field( + max_concurrency: int | None = Field( None, description="The maximum level of concurrency that will be used during a sync. This becomes a required field when the default_concurrency derives from the config, because it serves as a safeguard against a user-defined threshold that is too high.", examples=[20, 100], title="Max Concurrency", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ConstantBackoffStrategy(BaseModel): type: Literal["ConstantBackoffStrategy"] - backoff_time_in_seconds: Union[float, str] = Field( + backoff_time_in_seconds: float | str = Field( ..., description="Backoff time in seconds.", examples=[30, 30.5, "{{ config['backoff_time'] }}"], title="Backoff Time", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CursorPagination(BaseModel): @@ -101,13 +101,13 @@ class CursorPagination(BaseModel): ], title="Cursor Value", ) - page_size: Optional[int] = Field( + page_size: int | None = Field( None, description="The number of records to include in each pages.", examples=[100], title="Page Size", ) - stop_condition: Optional[str] = Field( + stop_condition: str | None = Field( None, description="Template string evaluating when to stop paginating.", examples=[ @@ -116,7 +116,7 @@ class CursorPagination(BaseModel): ], title="Stop Condition", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomAuthenticator(BaseModel): @@ -130,7 +130,7 @@ class Config: examples=["source_railz.components.ShortLivedTokenAuthenticator"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomBackoffStrategy(BaseModel): @@ -144,7 +144,7 @@ class Config: examples=["source_railz.components.MyCustomBackoffStrategy"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomErrorHandler(BaseModel): @@ -158,7 +158,7 @@ class Config: examples=["source_railz.components.MyCustomErrorHandler"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomIncrementalSync(BaseModel): @@ -176,7 +176,7 @@ class Config: ..., description="The location of the value on a record that will be used as a bookmark during sync.", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomPaginationStrategy(BaseModel): @@ -190,7 +190,7 @@ class Config: examples=["source_railz.components.MyCustomPaginationStrategy"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRecordExtractor(BaseModel): @@ -204,7 +204,7 @@ class Config: examples=["source_railz.components.MyCustomRecordExtractor"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRecordFilter(BaseModel): @@ -218,7 +218,7 @@ class Config: examples=["source_railz.components.MyCustomCustomRecordFilter"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRequester(BaseModel): @@ -232,7 +232,7 @@ class Config: examples=["source_railz.components.MyCustomRecordExtractor"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomRetriever(BaseModel): @@ -246,7 +246,7 @@ class Config: examples=["source_railz.components.MyCustomRetriever"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomPartitionRouter(BaseModel): @@ -260,7 +260,7 @@ class Config: examples=["source_railz.components.MyCustomPartitionRouter"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomSchemaLoader(BaseModel): @@ -274,7 +274,7 @@ class Config: examples=["source_railz.components.MyCustomSchemaLoader"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomSchemaNormalization(BaseModel): @@ -290,7 +290,7 @@ class Config: ], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomStateMigration(BaseModel): @@ -304,7 +304,7 @@ class Config: examples=["source_railz.components.MyCustomStateMigration"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class CustomTransformation(BaseModel): @@ -318,14 +318,14 @@ class Config: examples=["source_railz.components.MyCustomTransformation"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class LegacyToPerPartitionStateMigration(BaseModel): class Config: extra = Extra.allow - type: Optional[Literal["LegacyToPerPartitionStateMigration"]] = None + type: Literal["LegacyToPerPartitionStateMigration"] | None = None class Algorithm(Enum): @@ -349,19 +349,19 @@ class JwtHeaders(BaseModel): class Config: extra = Extra.forbid - kid: Optional[str] = Field( + kid: str | None = Field( None, description="Private key ID for user account.", examples=["{{ config['kid'] }}"], title="Key Identifier", ) - typ: Optional[str] = Field( + typ: str | None = Field( "JWT", description="The media type of the complete JWT.", examples=["JWT"], title="Type", ) - cty: Optional[str] = Field( + cty: str | None = Field( None, description="Content type of JWT header.", examples=["JWT"], @@ -373,18 +373,18 @@ class JwtPayload(BaseModel): class Config: extra = Extra.forbid - iss: Optional[str] = Field( + iss: str | None = Field( None, description="The user/principal that issued the JWT. Commonly a value unique to the user.", examples=["{{ config['iss'] }}"], title="Issuer", ) - sub: Optional[str] = Field( + sub: str | None = Field( None, description="The subject of the JWT. Commonly defined by the API.", title="Subject", ) - aud: Optional[str] = Field( + aud: str | None = Field( None, description="The recipient that the JWT is intended for. Commonly defined by the API.", examples=["appstoreconnect-v1"], @@ -399,8 +399,8 @@ class JwtAuthenticator(BaseModel): description="Secret used to sign the JSON web token.", examples=["{{ config['secret_key'] }}"], ) - base64_encode_secret_key: Optional[bool] = Field( - False, + base64_encode_secret_key: bool | None = Field( + False, # noqa: FBT003 description='When set to true, the secret key will be base64 encoded prior to being encoded as part of the JWT. Only set to "true" when required by the API.', ) algorithm: Algorithm = Field( @@ -408,79 +408,79 @@ class JwtAuthenticator(BaseModel): description="Algorithm used to sign the JSON web token.", examples=["ES256", "HS256", "RS256", "{{ config['algorithm'] }}"], ) - token_duration: Optional[int] = Field( + token_duration: int | None = Field( 1200, description="The amount of time in seconds a JWT token can be valid after being issued.", examples=[1200, 3600], title="Token Duration", ) - header_prefix: Optional[str] = Field( + header_prefix: str | None = Field( None, description="The prefix to be used within the Authentication header.", examples=["Bearer", "Basic"], title="Header Prefix", ) - jwt_headers: Optional[JwtHeaders] = Field( + jwt_headers: JwtHeaders | None = Field( None, description="JWT headers used when signing JSON web token.", title="JWT Headers", ) - additional_jwt_headers: Optional[Dict[str, Any]] = Field( + additional_jwt_headers: dict[str, Any] | None = Field( None, description="Additional headers to be included with the JWT headers object.", title="Additional JWT Headers", ) - jwt_payload: Optional[JwtPayload] = Field( + jwt_payload: JwtPayload | None = Field( None, description="JWT Payload used when signing JSON web token.", title="JWT Payload", ) - additional_jwt_payload: Optional[Dict[str, Any]] = Field( + additional_jwt_payload: dict[str, Any] | None = Field( None, description="Additional properties to be added to the JWT payload.", title="Additional JWT Payload Properties", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class RefreshTokenUpdater(BaseModel): - refresh_token_name: Optional[str] = Field( + refresh_token_name: str | None = Field( "refresh_token", description="The name of the property which contains the updated refresh token in the response from the token refresh endpoint.", examples=["refresh_token"], title="Refresh Token Property Name", ) - access_token_config_path: Optional[List[str]] = Field( + access_token_config_path: list[str] | None = Field( ["credentials", "access_token"], description="Config path to the access token. Make sure the field actually exists in the config.", examples=[["credentials", "access_token"], ["access_token"]], title="Config Path To Access Token", ) - refresh_token_config_path: Optional[List[str]] = Field( + refresh_token_config_path: list[str] | None = Field( ["credentials", "refresh_token"], description="Config path to the access token. Make sure the field actually exists in the config.", examples=[["credentials", "refresh_token"], ["refresh_token"]], title="Config Path To Refresh Token", ) - token_expiry_date_config_path: Optional[List[str]] = Field( + token_expiry_date_config_path: list[str] | None = Field( ["credentials", "token_expiry_date"], description="Config path to the expiry date. Make sure actually exists in the config.", examples=[["credentials", "token_expiry_date"]], title="Config Path To Expiry Date", ) - refresh_token_error_status_codes: Optional[List[int]] = Field( + refresh_token_error_status_codes: list[int] | None = Field( [], description="Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error", examples=[[400, 500]], title="Refresh Token Error Status Codes", ) - refresh_token_error_key: Optional[str] = Field( + refresh_token_error_key: str | None = Field( "", description="Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).", examples=["error"], title="Refresh Token Error Key", ) - refresh_token_error_values: Optional[List[str]] = Field( + refresh_token_error_values: list[str] | None = Field( [], description='List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error', examples=[["invalid_grant", "invalid_permissions"]], @@ -490,7 +490,7 @@ class RefreshTokenUpdater(BaseModel): class OAuthAuthenticator(BaseModel): type: Literal["OAuthAuthenticator"] - client_id_name: Optional[str] = Field( + client_id_name: str | None = Field( "client_id", description="The name of the property to use to refresh the `access_token`.", examples=["custom_app_id"], @@ -502,7 +502,7 @@ class OAuthAuthenticator(BaseModel): examples=["{{ config['client_id }}", "{{ config['credentials']['client_id }}"], title="Client ID", ) - client_secret_name: Optional[str] = Field( + client_secret_name: str | None = Field( "client_secret", description="The name of the property to use to refresh the `access_token`.", examples=["custom_app_secret"], @@ -517,13 +517,13 @@ class OAuthAuthenticator(BaseModel): ], title="Client Secret", ) - refresh_token_name: Optional[str] = Field( + refresh_token_name: str | None = Field( "refresh_token", description="The name of the property to use to refresh the `access_token`.", examples=["custom_app_refresh_value"], title="Refresh Token Property Name", ) - refresh_token: Optional[str] = Field( + refresh_token: str | None = Field( None, description="Credential artifact used to get a new access token.", examples=[ @@ -532,43 +532,43 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Token", ) - token_refresh_endpoint: Optional[str] = Field( + token_refresh_endpoint: str | None = Field( None, description="The full URL to call to obtain a new access token.", examples=["https://connect.squareup.com/oauth2/token"], title="Token Refresh Endpoint", ) - access_token_name: Optional[str] = Field( + access_token_name: str | None = Field( "access_token", description="The name of the property which contains the access token in the response from the token refresh endpoint.", examples=["access_token"], title="Access Token Property Name", ) - access_token_value: Optional[str] = Field( + access_token_value: str | None = Field( None, description="The value of the access_token to bypass the token refreshing using `refresh_token`.", examples=["secret_access_token_value"], title="Access Token Value", ) - expires_in_name: Optional[str] = Field( + expires_in_name: str | None = Field( "expires_in", description="The name of the property which contains the expiry date in the response from the token refresh endpoint.", examples=["expires_in"], title="Token Expiry Property Name", ) - grant_type_name: Optional[str] = Field( + grant_type_name: str | None = Field( "grant_type", description="The name of the property to use to refresh the `access_token`.", examples=["custom_grant_type"], title="Grant Type Property Name", ) - grant_type: Optional[str] = Field( + grant_type: str | None = Field( "refresh_token", description="Specifies the OAuth2 grant type. If set to refresh_token, the refresh_token needs to be provided as well. For client_credentials, only client id and secret are required. Other grant types are not officially supported.", examples=["refresh_token", "client_credentials"], title="Grant Type", ) - refresh_request_body: Optional[Dict[str, Any]] = Field( + refresh_request_body: dict[str, Any] | None = Field( None, description="Body of the request sent to get a new access token.", examples=[ @@ -580,7 +580,7 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Request Body", ) - refresh_request_headers: Optional[Dict[str, Any]] = Field( + refresh_request_headers: dict[str, Any] | None = Field( None, description="Headers of the request sent to get a new access token.", examples=[ @@ -591,35 +591,35 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Request Headers", ) - scopes: Optional[List[str]] = Field( + scopes: list[str] | None = Field( None, description="List of scopes that should be granted to the access token.", examples=[["crm.list.read", "crm.objects.contacts.read", "crm.schema.contacts.read"]], title="Scopes", ) - token_expiry_date: Optional[str] = Field( + token_expiry_date: str | None = Field( None, description="The access token expiry date.", examples=["2023-04-06T07:12:10.421833+00:00", 1680842386], title="Token Expiry Date", ) - token_expiry_date_format: Optional[str] = Field( + token_expiry_date_format: str | None = Field( None, description="The format of the time to expiration datetime. Provide it if the time is returned as a date-time string instead of seconds.", examples=["%Y-%m-%d %H:%M:%S.%f+00:00"], title="Token Expiry Date Format", ) - refresh_token_updater: Optional[RefreshTokenUpdater] = Field( + refresh_token_updater: RefreshTokenUpdater | None = Field( None, description="When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.", title="Token Updater", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DpathExtractor(BaseModel): type: Literal["DpathExtractor"] - field_path: List[str] = Field( + field_path: list[str] = Field( ..., description='List of potentially nested fields describing the full path of the field to extract. Use "*" to extract all values from an array. See more info in the [docs](https://docs.airbyte.com/connector-development/config-based/understanding-the-yaml-file/record-selector).', examples=[ @@ -630,23 +630,23 @@ class DpathExtractor(BaseModel): ], title="Field Path", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ResponseToFileExtractor(BaseModel): type: Literal["ResponseToFileExtractor"] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ExponentialBackoffStrategy(BaseModel): type: Literal["ExponentialBackoffStrategy"] - factor: Optional[Union[float, str]] = Field( + factor: float | str | None = Field( 5, description="Multiplicative constant applied on each retry.", examples=[5, 5.5, "10"], title="Factor", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SessionTokenRequestBearerAuthenticator(BaseModel): @@ -674,37 +674,37 @@ class FailureType(Enum): class HttpResponseFilter(BaseModel): type: Literal["HttpResponseFilter"] - action: Optional[Action] = Field( + action: Action | None = Field( None, description="Action to execute if a response matches the filter.", examples=["SUCCESS", "FAIL", "RETRY", "IGNORE", "RATE_LIMITED"], title="Action", ) - failure_type: Optional[FailureType] = Field( + failure_type: FailureType | None = Field( None, description="Failure type of traced exception if a response matches the filter.", examples=["system_error", "config_error", "transient_error"], title="Failure Type", ) - error_message: Optional[str] = Field( + error_message: str | None = Field( None, description="Error Message to display if the response matches the filter.", title="Error Message", ) - error_message_contains: Optional[str] = Field( + error_message_contains: str | None = Field( None, description="Match the response if its error message contains the substring.", example=["This API operation is not enabled for this site"], title="Error Message Substring", ) - http_codes: Optional[List[int]] = Field( + http_codes: list[int] | None = Field( None, description="Match the response if its HTTP code is included in this list.", examples=[[420, 429], [500]], title="HTTP Codes", unique_items=True, ) - predicate: Optional[str] = Field( + predicate: str | None = Field( None, description="Match the response if the predicate evaluates to true.", examples=[ @@ -713,39 +713,39 @@ class HttpResponseFilter(BaseModel): ], title="Predicate", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class TypesMap(BaseModel): - target_type: Union[str, List[str]] - current_type: Union[str, List[str]] - condition: Optional[str] + target_type: str | list[str] + current_type: str | list[str] + condition: str | None class SchemaTypeIdentifier(BaseModel): - type: Optional[Literal["SchemaTypeIdentifier"]] = None - schema_pointer: Optional[List[str]] = Field( + type: Literal["SchemaTypeIdentifier"] | None = None + schema_pointer: list[str] | None = Field( [], description="List of nested fields defining the schema field path to extract. Defaults to [].", title="Schema Path", ) - key_pointer: List[str] = Field( + key_pointer: list[str] = Field( ..., description="List of potentially nested fields describing the full path of the field key to extract.", title="Key Path", ) - type_pointer: Optional[List[str]] = Field( + type_pointer: list[str] | None = Field( None, description="List of potentially nested fields describing the full path of the field type to extract.", title="Type Path", ) - types_mapping: Optional[List[TypesMap]] = None - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + types_mapping: list[TypesMap] | None = None + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class InlineSchemaLoader(BaseModel): type: Literal["InlineSchemaLoader"] - schema_: Optional[Dict[str, Any]] = Field( + schema_: dict[str, Any] | None = Field( None, alias="schema", description='Describes a streams\' schema. Refer to the Data Types documentation for more details on which types are valid.', @@ -755,13 +755,13 @@ class InlineSchemaLoader(BaseModel): class JsonFileSchemaLoader(BaseModel): type: Literal["JsonFileSchemaLoader"] - file_path: Optional[str] = Field( + file_path: str | None = Field( None, description="Path to the JSON file defining the schema. The path is relative to the connector module's root.", example=["./schemas/users.json"], title="File Path", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class JsonDecoder(BaseModel): @@ -774,27 +774,27 @@ class JsonlDecoder(BaseModel): class KeysToLower(BaseModel): type: Literal["KeysToLower"] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class KeysToSnakeCase(BaseModel): type: Literal["KeysToSnakeCase"] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class FlattenFields(BaseModel): type: Literal["FlattenFields"] - flatten_lists: Optional[bool] = Field( - True, + flatten_lists: bool | None = Field( + True, # noqa: FBT003 description="Whether to flatten lists or leave it as is. Default is True.", title="Flatten Lists", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DpathFlattenFields(BaseModel): type: Literal["DpathFlattenFields"] - field_path: List[str] = Field( + field_path: list[str] = Field( ..., description="A path to field that needs to be flattened.", examples=[ @@ -803,12 +803,12 @@ class DpathFlattenFields(BaseModel): ], title="Field Path", ) - delete_origin_value: Optional[bool] = Field( - False, + delete_origin_value: bool | None = Field( + False, # noqa: FBT003 description="Whether to delete the origin value or keep it. Default is False.", title="Delete Origin Value", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class KeysReplace(BaseModel): @@ -835,7 +835,7 @@ class KeysReplace(BaseModel): ], title="New value", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class IterableDecoder(BaseModel): @@ -857,7 +857,7 @@ class Config: examples=["source_amazon_ads.components.GzipJsonlDecoder"], title="Class Name", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class GzipJsonDecoder(BaseModel): @@ -865,8 +865,8 @@ class Config: extra = Extra.allow type: Literal["GzipJsonDecoder"] - encoding: Optional[str] = "utf-8" - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + encoding: str | None = "utf-8" + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class MinMaxDatetime(BaseModel): @@ -877,30 +877,30 @@ class MinMaxDatetime(BaseModel): examples=["2021-01-01", "2021-01-01T00:00:00Z", "{{ config['start_time'] }}"], title="Datetime", ) - datetime_format: Optional[str] = Field( + datetime_format: str | None = Field( "", description='Format of the datetime value. Defaults to "%Y-%m-%dT%H:%M:%S.%f%z" if left empty. Use placeholders starting with "%" to describe the format the API is using. The following placeholders are available:\n * **%s**: Epoch unix timestamp - `1686218963`\n * **%s_as_float**: Epoch unix timestamp in seconds as float with microsecond precision - `1686218963.123456`\n * **%ms**: Epoch unix timestamp - `1686218963123`\n * **%a**: Weekday (abbreviated) - `Sun`\n * **%A**: Weekday (full) - `Sunday`\n * **%w**: Weekday (decimal) - `0` (Sunday), `6` (Saturday)\n * **%d**: Day of the month (zero-padded) - `01`, `02`, ..., `31`\n * **%b**: Month (abbreviated) - `Jan`\n * **%B**: Month (full) - `January`\n * **%m**: Month (zero-padded) - `01`, `02`, ..., `12`\n * **%y**: Year (without century, zero-padded) - `00`, `01`, ..., `99`\n * **%Y**: Year (with century) - `0001`, `0002`, ..., `9999`\n * **%H**: Hour (24-hour, zero-padded) - `00`, `01`, ..., `23`\n * **%I**: Hour (12-hour, zero-padded) - `01`, `02`, ..., `12`\n * **%p**: AM/PM indicator\n * **%M**: Minute (zero-padded) - `00`, `01`, ..., `59`\n * **%S**: Second (zero-padded) - `00`, `01`, ..., `59`\n * **%f**: Microsecond (zero-padded to 6 digits) - `000000`, `000001`, ..., `999999`\n * **%z**: UTC offset - `(empty)`, `+0000`, `-04:00`\n * **%Z**: Time zone name - `(empty)`, `UTC`, `GMT`\n * **%j**: Day of the year (zero-padded) - `001`, `002`, ..., `366`\n * **%U**: Week number of the year (Sunday as first day) - `00`, `01`, ..., `53`\n * **%W**: Week number of the year (Monday as first day) - `00`, `01`, ..., `53`\n * **%c**: Date and time representation - `Tue Aug 16 21:30:00 1988`\n * **%x**: Date representation - `08/16/1988`\n * **%X**: Time representation - `21:30:00`\n * **%%**: Literal \'%\' character\n\n Some placeholders depend on the locale of the underlying system - in most cases this locale is configured as en/US. For more information see the [Python documentation](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes).\n', examples=["%Y-%m-%dT%H:%M:%S.%f%z", "%Y-%m-%d", "%s"], title="Datetime Format", ) - max_datetime: Optional[str] = Field( + max_datetime: str | None = Field( None, description="Ceiling applied on the datetime value. Must be formatted with the datetime_format field.", examples=["2021-01-01T00:00:00Z", "2021-01-01"], title="Max Datetime", ) - min_datetime: Optional[str] = Field( + min_datetime: str | None = Field( None, description="Floor applied on the datetime value. Must be formatted with the datetime_format field.", examples=["2010-01-01T00:00:00Z", "2010-01-01"], title="Min Datetime", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class NoAuth(BaseModel): type: Literal["NoAuth"] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class NoPagination(BaseModel): @@ -928,7 +928,7 @@ class Config: ], title="Consent URL", ) - scope: Optional[str] = Field( + scope: str | None = Field( None, description="The DeclarativeOAuth Specific string of the scopes needed to be grant for authenticated user.", examples=["user:read user:read_orders workspaces:read"], @@ -942,7 +942,7 @@ class Config: ], title="Access Token URL", ) - access_token_headers: Optional[Dict[str, Any]] = Field( + access_token_headers: dict[str, Any] | None = Field( None, description="The DeclarativeOAuth Specific optional headers to inject while exchanging the `auth_code` to `access_token` during `completeOAuthFlow` step.", examples=[ @@ -952,7 +952,7 @@ class Config: ], title="Access Token Headers", ) - access_token_params: Optional[Dict[str, Any]] = Field( + access_token_params: dict[str, Any] | None = Field( None, description="The DeclarativeOAuth Specific optional query parameters to inject while exchanging the `auth_code` to `access_token` during `completeOAuthFlow` step.\nWhen this property is provided, the query params will be encoded as `Json` and included in the outgoing API request.", examples=[ @@ -964,49 +964,49 @@ class Config: ], title="Access Token Query Params (Json Encoded)", ) - extract_output: Optional[List[str]] = Field( + extract_output: list[str] | None = Field( None, description="The DeclarativeOAuth Specific list of strings to indicate which keys should be extracted and returned back to the input config.", examples=[["access_token", "refresh_token", "other_field"]], title="Extract Output", ) - state: Optional[State] = Field( + state: State | None = Field( None, description="The DeclarativeOAuth Specific object to provide the criteria of how the `state` query param should be constructed,\nincluding length and complexity.", examples=[{"min": 7, "max": 128}], title="Configurable State Query Param", ) - client_id_key: Optional[str] = Field( + client_id_key: str | None = Field( None, description="The DeclarativeOAuth Specific optional override to provide the custom `client_id` key name, if required by data-provider.", examples=["my_custom_client_id_key_name"], title="Client ID Key Override", ) - client_secret_key: Optional[str] = Field( + client_secret_key: str | None = Field( None, description="The DeclarativeOAuth Specific optional override to provide the custom `client_secret` key name, if required by data-provider.", examples=["my_custom_client_secret_key_name"], title="Client Secret Key Override", ) - scope_key: Optional[str] = Field( + scope_key: str | None = Field( None, description="The DeclarativeOAuth Specific optional override to provide the custom `scope` key name, if required by data-provider.", examples=["my_custom_scope_key_key_name"], title="Scopes Key Override", ) - state_key: Optional[str] = Field( + state_key: str | None = Field( None, description="The DeclarativeOAuth Specific optional override to provide the custom `state` key name, if required by data-provider.", examples=["my_custom_state_key_key_name"], title="State Key Override", ) - auth_code_key: Optional[str] = Field( + auth_code_key: str | None = Field( None, description="The DeclarativeOAuth Specific optional override to provide the custom `code` key name to something like `auth_code` or `custom_auth_code`, if required by data-provider.", examples=["my_custom_auth_code_key_name"], title="Auth Code Key Override", ) - redirect_uri_key: Optional[str] = Field( + redirect_uri_key: str | None = Field( None, description="The DeclarativeOAuth Specific optional override to provide the custom `redirect_uri` key name to something like `callback_uri`, if required by data-provider.", examples=["my_custom_redirect_uri_key_name"], @@ -1018,7 +1018,7 @@ class OAuthConfigSpecification(BaseModel): class Config: extra = Extra.allow - oauth_user_input_from_connector_config_specification: Optional[Dict[str, Any]] = Field( + oauth_user_input_from_connector_config_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations used as input to OAuth.\nMust be a valid non-nested JSON that refers to properties from ConnectorSpecification.connectionSpecification\nusing special annotation 'path_in_connector_config'.\nThese are input values the user is entering through the UI to authenticate to the connector, that might also shared\nas inputs for syncing data via the connector.\nExamples:\nif no connector values is shared during oauth flow, oauth_user_input_from_connector_config_specification=[]\nif connector values such as 'app_id' inside the top level are used to generate the API url for the oauth flow,\n oauth_user_input_from_connector_config_specification={\n app_id: {\n type: string\n path_in_connector_config: ['app_id']\n }\n }\nif connector values such as 'info.app_id' nested inside another object are used to generate the API url for the oauth flow,\n oauth_user_input_from_connector_config_specification={\n app_id: {\n type: string\n path_in_connector_config: ['info', 'app_id']\n }\n }", examples=[ @@ -1032,12 +1032,12 @@ class Config: ], title="OAuth user input", ) - oauth_connector_input_specification: Optional[OauthConnectorInputSpecification] = Field( + oauth_connector_input_specification: OauthConnectorInputSpecification | None = Field( None, description='The DeclarativeOAuth specific blob.\nPertains to the fields defined by the connector relating to the OAuth flow.\n\nInterpolation capabilities:\n- The variables placeholders are declared as `{{my_var}}`.\n- The nested resolution variables like `{{ {{my_nested_var}} }}` is allowed as well.\n\n- The allowed interpolation context is:\n + base64Encoder - encode to `base64`, {{ {{my_var_a}}:{{my_var_b}} | base64Encoder }}\n + base64Decorer - decode from `base64` encoded string, {{ {{my_string_variable_or_string_value}} | base64Decoder }}\n + urlEncoder - encode the input string to URL-like format, {{ https://test.host.com/endpoint | urlEncoder}}\n + urlDecorer - decode the input url-encoded string into text format, {{ urlDecoder:https%3A%2F%2Fairbyte.io | urlDecoder}}\n + codeChallengeS256 - get the `codeChallenge` encoded value to provide additional data-provider specific authorisation values, {{ {{state_value}} | codeChallengeS256 }}\n\nExamples:\n - The TikTok Marketing DeclarativeOAuth spec:\n {\n "oauth_connector_input_specification": {\n "type": "object",\n "additionalProperties": false,\n "properties": {\n "consent_url": "https://ads.tiktok.com/marketing_api/auth?{{client_id_key}}={{client_id_value}}&{{redirect_uri_key}}={{ {{redirect_uri_value}} | urlEncoder}}&{{state_key}}={{state_value}}",\n "access_token_url": "https://business-api.tiktok.com/open_api/v1.3/oauth2/access_token/",\n "access_token_params": {\n "{{ auth_code_key }}": "{{ auth_code_value }}",\n "{{ client_id_key }}": "{{ client_id_value }}",\n "{{ client_secret_key }}": "{{ client_secret_value }}"\n },\n "access_token_headers": {\n "Content-Type": "application/json",\n "Accept": "application/json"\n },\n "extract_output": ["data.access_token"],\n "client_id_key": "app_id",\n "client_secret_key": "secret",\n "auth_code_key": "auth_code"\n }\n }\n }', title="DeclarativeOAuth Connector Specification", ) - complete_oauth_output_specification: Optional[Dict[str, Any]] = Field( + complete_oauth_output_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations produced by the OAuth flows as they are\nreturned by the distant OAuth APIs.\nMust be a valid JSON describing the fields to merge back to `ConnectorSpecification.connectionSpecification`.\nFor each field, a special annotation `path_in_connector_config` can be specified to determine where to merge it,\nExamples:\n complete_oauth_output_specification={\n refresh_token: {\n type: string,\n path_in_connector_config: ['credentials', 'refresh_token']\n }\n }", examples=[ @@ -1050,13 +1050,13 @@ class Config: ], title="OAuth output specification", ) - complete_oauth_server_input_specification: Optional[Dict[str, Any]] = Field( + complete_oauth_server_input_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations persisted as Airbyte Server configurations.\nMust be a valid non-nested JSON describing additional fields configured by the Airbyte Instance or Workspace Admins to be used by the\nserver when completing an OAuth flow (typically exchanging an auth code for refresh token).\nExamples:\n complete_oauth_server_input_specification={\n client_id: {\n type: string\n },\n client_secret: {\n type: string\n }\n }", examples=[{"client_id": {"type": "string"}, "client_secret": {"type": "string"}}], title="OAuth input specification", ) - complete_oauth_server_output_specification: Optional[Dict[str, Any]] = Field( + complete_oauth_server_output_specification: dict[str, Any] | None = Field( None, description="OAuth specific blob. This is a Json Schema used to validate Json configurations persisted as Airbyte Server configurations that\nalso need to be merged back into the connector configuration at runtime.\nThis is a subset configuration of `complete_oauth_server_input_specification` that filters fields out to retain only the ones that\nare necessary for the connector to function with OAuth. (some fields could be used during oauth flows but not needed afterwards, therefore\nthey would be listed in the `complete_oauth_server_input_specification` but not `complete_oauth_server_output_specification`)\nMust be a valid non-nested JSON describing additional fields configured by the Airbyte Instance or Workspace Admins to be used by the\nconnector when using OAuth flow APIs.\nThese fields are to be merged back to `ConnectorSpecification.connectionSpecification`.\nFor each field, a special annotation `path_in_connector_config` can be specified to determine where to merge it,\nExamples:\n complete_oauth_server_output_specification={\n client_id: {\n type: string,\n path_in_connector_config: ['credentials', 'client_id']\n },\n client_secret: {\n type: string,\n path_in_connector_config: ['credentials', 'client_secret']\n }\n }", examples=[ @@ -1077,44 +1077,44 @@ class Config: class OffsetIncrement(BaseModel): type: Literal["OffsetIncrement"] - page_size: Optional[Union[int, str]] = Field( + page_size: int | str | None = Field( None, description="The number of records to include in each pages.", examples=[100, "{{ config['page_size'] }}"], title="Limit", ) - inject_on_first_request: Optional[bool] = Field( - False, + inject_on_first_request: bool | None = Field( + False, # noqa: FBT003 description="Using the `offset` with value `0` during the first request", title="Inject Offset", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class PageIncrement(BaseModel): type: Literal["PageIncrement"] - page_size: Optional[Union[int, str]] = Field( + page_size: int | str | None = Field( None, description="The number of records to include in each pages.", examples=[100, "100", "{{ config['page_size'] }}"], title="Page Size", ) - start_from_page: Optional[int] = Field( + start_from_page: int | None = Field( 0, description="Index of the first page to request.", examples=[0, 1], title="Start From Page", ) - inject_on_first_request: Optional[bool] = Field( - False, + inject_on_first_request: bool | None = Field( + False, # noqa: FBT003 description="Using the `page number` with value defined by `start_from_page` during the first request", title="Inject Page Number", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class PrimaryKey(BaseModel): - __root__: Union[str, List[str], List[List[str]]] = Field( + __root__: str | list[str] | list[list[str]] = Field( ..., description="The stream field to be used to distinguish unique records. Can either be a single field, an array of fields representing a composite key, or an array of arrays representing a composite key where the fields are nested fields.", examples=["id", ["code", "type"]], @@ -1124,7 +1124,7 @@ class PrimaryKey(BaseModel): class RecordFilter(BaseModel): type: Literal["RecordFilter"] - condition: Optional[str] = Field( + condition: str | None = Field( "", description="The predicate to filter a record. Records will be removed if evaluated to False.", examples=[ @@ -1132,7 +1132,7 @@ class RecordFilter(BaseModel): "{{ record.status in ['active', 'expired'] }}", ], ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SchemaNormalization(Enum): @@ -1142,7 +1142,7 @@ class SchemaNormalization(Enum): class RemoveFields(BaseModel): type: Literal["RemoveFields"] - condition: Optional[str] = Field( + condition: str | None = Field( "", description="The predicate to filter a property by a property value. Property will be removed if it is empty OR expression is evaluated to True.,", examples=[ @@ -1152,7 +1152,7 @@ class RemoveFields(BaseModel): "{{ property == 'some_string_to_match' }}", ], ) - field_pointers: List[List[str]] = Field( + field_pointers: list[list[str]] = Field( ..., description="Array of paths defining the field to remove. Each item is an array whose field describe the path of a field to remove.", examples=[["tags"], [["content", "html"], ["content", "plain_text"]]], @@ -1208,7 +1208,7 @@ class LegacySessionTokenAuthenticator(BaseModel): examples=["session"], title="Login Path", ) - session_token: Optional[str] = Field( + session_token: str | None = Field( None, description="Session token to use if using a pre-defined token. Not needed if authenticating with username + password pair", example=["{{ config['session_token'] }}"], @@ -1220,13 +1220,13 @@ class LegacySessionTokenAuthenticator(BaseModel): examples=["id"], title="Response Token Response Key", ) - username: Optional[str] = Field( + username: str | None = Field( None, description="Username used to authenticate and obtain a session token", examples=[" {{ config['username'] }}"], title="Username", ) - password: Optional[str] = Field( + password: str | None = Field( "", description="Password used to authenticate and obtain a session token", examples=["{{ config['password'] }}", ""], @@ -1238,31 +1238,31 @@ class LegacySessionTokenAuthenticator(BaseModel): examples=["user/current"], title="Validate Session Path", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class JsonParser(BaseModel): type: Literal["JsonParser"] - encoding: Optional[str] = "utf-8" + encoding: str | None = "utf-8" class JsonLineParser(BaseModel): type: Literal["JsonLineParser"] - encoding: Optional[str] = "utf-8" + encoding: str | None = "utf-8" class CsvParser(BaseModel): type: Literal["CsvParser"] - encoding: Optional[str] = "utf-8" - delimiter: Optional[str] = "," + encoding: str | None = "utf-8" + delimiter: str | None = "," class AsyncJobStatusMap(BaseModel): - type: Optional[Literal["AsyncJobStatusMap"]] = None - running: List[str] - completed: List[str] - failed: List[str] - timeout: List[str] + type: Literal["AsyncJobStatusMap"] | None = None + running: list[str] + completed: list[str] + failed: list[str] + timeout: list[str] class ValueType(Enum): @@ -1280,19 +1280,19 @@ class WaitTimeFromHeader(BaseModel): examples=["Retry-After"], title="Response Header Name", ) - regex: Optional[str] = Field( + regex: str | None = Field( None, description="Optional regex to apply on the header to extract its value. The regex should define a capture group defining the wait time.", examples=["([-+]?\\d+)"], title="Extraction Regex", ) - max_waiting_time_in_seconds: Optional[float] = Field( + max_waiting_time_in_seconds: float | None = Field( None, description="Given the value extracted from the header is greater than this value, stop the stream.", examples=[3600], title="Max Waiting Time in Seconds", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class WaitUntilTimeFromHeader(BaseModel): @@ -1303,24 +1303,24 @@ class WaitUntilTimeFromHeader(BaseModel): examples=["wait_time"], title="Response Header", ) - min_wait: Optional[Union[float, str]] = Field( + min_wait: float | str | None = Field( None, description="Minimum time to wait before retrying.", examples=[10, "60"], title="Minimum Wait Time", ) - regex: Optional[str] = Field( + regex: str | None = Field( None, description="Optional regex to apply on the header to extract its value. The regex should define a capture group defining the wait time.", examples=["([-+]?\\d+)"], title="Extraction Regex", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ComponentMappingDefinition(BaseModel): type: Literal["ComponentMappingDefinition"] - field_path: List[str] = Field( + field_path: list[str] = Field( ..., description="A list of potentially nested fields indicating the full path where value will be added or updated.", examples=[ @@ -1345,35 +1345,35 @@ class ComponentMappingDefinition(BaseModel): ], title="Value", ) - value_type: Optional[ValueType] = Field( + value_type: ValueType | None = Field( None, description="The expected data type of the value. If omitted, the type will be inferred from the value provided.", title="Value Type", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class StreamConfig(BaseModel): type: Literal["StreamConfig"] - configs_pointer: List[str] = Field( + configs_pointer: list[str] = Field( ..., description="A list of potentially nested fields indicating the full path in source config file where streams configs located.", examples=[["data"], ["data", "streams"], ["data", "{{ parameters.name }}"]], title="Configs Pointer", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ConfigComponentsResolver(BaseModel): type: Literal["ConfigComponentsResolver"] stream_config: StreamConfig - components_mapping: List[ComponentMappingDefinition] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + components_mapping: list[ComponentMappingDefinition] + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AddedFieldDefinition(BaseModel): type: Literal["AddedFieldDefinition"] - path: List[str] = Field( + path: list[str] = Field( ..., description="List of strings defining the path where to add the value on the record.", examples=[["segment_id"], ["metadata", "segment_id"]], @@ -1389,39 +1389,39 @@ class AddedFieldDefinition(BaseModel): ], title="Value", ) - value_type: Optional[ValueType] = Field( + value_type: ValueType | None = Field( None, description="Type of the value. If not specified, the type will be inferred from the value.", title="Value Type", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AddFields(BaseModel): type: Literal["AddFields"] - fields: List[AddedFieldDefinition] = Field( + fields: list[AddedFieldDefinition] = Field( ..., description="List of transformations (path and corresponding value) that will be added to the record.", title="Fields", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ApiKeyAuthenticator(BaseModel): type: Literal["ApiKeyAuthenticator"] - api_token: Optional[str] = Field( + api_token: str | None = Field( None, description="The API key to inject in the request. Fill it in the user inputs.", examples=["{{ config['api_key'] }}", "Token token={{ config['api_key'] }}"], title="API Key", ) - header: Optional[str] = Field( + header: str | None = Field( None, description="The name of the HTTP header that will be set to the API key. This setting is deprecated, use inject_into instead. Header and inject_into can not be defined at the same time.", examples=["Authorization", "Api-Token", "X-Auth-Token"], title="Header Name", ) - inject_into: Optional[RequestOption] = Field( + inject_into: RequestOption | None = Field( None, description="Configure how the API Key will be sent in requests to the source API. Either inject_into or header has to be defined.", examples=[ @@ -1430,26 +1430,26 @@ class ApiKeyAuthenticator(BaseModel): ], title="Inject API Key Into Outgoing HTTP Request", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AuthFlow(BaseModel): - auth_flow_type: Optional[AuthFlowType] = Field( + auth_flow_type: AuthFlowType | None = Field( None, description="The type of auth to use", title="Auth flow type" ) - predicate_key: Optional[List[str]] = Field( + predicate_key: list[str] | None = Field( None, description="JSON path to a field in the connectorSpecification that should exist for the advanced auth to be applicable.", examples=[["credentials", "auth_type"]], title="Predicate key", ) - predicate_value: Optional[str] = Field( + predicate_value: str | None = Field( None, description="Value of the predicate_key fields for the advanced auth to be applicable.", examples=["Oauth"], title="Predicate value", ) - oauth_config_specification: Optional[OAuthConfigSpecification] = None + oauth_config_specification: OAuthConfigSpecification | None = None class DatetimeBasedCursor(BaseModel): @@ -1466,129 +1466,128 @@ class DatetimeBasedCursor(BaseModel): examples=["%Y-%m-%dT%H:%M:%S.%f%z", "%Y-%m-%d", "%s", "%ms", "%s_as_float"], title="Outgoing Datetime Format", ) - start_datetime: Union[str, MinMaxDatetime] = Field( + start_datetime: str | MinMaxDatetime = Field( ..., description="The datetime that determines the earliest record that should be synced.", examples=["2020-01-1T00:00:00Z", "{{ config['start_time'] }}"], title="Start Datetime", ) - cursor_datetime_formats: Optional[List[str]] = Field( + cursor_datetime_formats: list[str] | None = Field( None, description="The possible formats for the cursor field, in order of preference. The first format that matches the cursor field value will be used to parse it. If not provided, the `datetime_format` will be used.", title="Cursor Datetime Formats", ) - cursor_granularity: Optional[str] = Field( + cursor_granularity: str | None = Field( None, description="Smallest increment the datetime_format has (ISO 8601 duration) that is used to ensure the start of a slice does not overlap with the end of the previous one, e.g. for %Y-%m-%d the granularity should be P1D, for %Y-%m-%dT%H:%M:%SZ the granularity should be PT1S. Given this field is provided, `step` needs to be provided as well.", examples=["PT1S"], title="Cursor Granularity", ) - end_datetime: Optional[Union[str, MinMaxDatetime]] = Field( + end_datetime: str | MinMaxDatetime | None = Field( None, description="The datetime that determines the last record that should be synced. If not provided, `{{ now_utc() }}` will be used.", examples=["2021-01-1T00:00:00Z", "{{ now_utc() }}", "{{ day_delta(-1) }}"], title="End Datetime", ) - end_time_option: Optional[RequestOption] = Field( + end_time_option: RequestOption | None = Field( None, description="Optionally configures how the end datetime will be sent in requests to the source API.", title="Inject End Time Into Outgoing HTTP Request", ) - is_data_feed: Optional[bool] = Field( + is_data_feed: bool | None = Field( None, description="A data feed API is an API that does not allow filtering and paginates the content from the most recent to the least recent. Given this, the CDK needs to know when to stop paginating and this field will generate a stop condition for pagination.", title="Whether the target API is formatted as a data feed", ) - is_client_side_incremental: Optional[bool] = Field( + is_client_side_incremental: bool | None = Field( None, description="If the target API endpoint does not take cursor values to filter records and returns all records anyway, the connector with this cursor will filter out records locally, and only emit new records from the last sync, hence incremental. This means that all records would be read from the API, but only new records will be emitted to the destination.", title="Whether the target API does not support filtering and returns all data (the cursor filters records in the client instead of the API side)", ) - is_compare_strictly: Optional[bool] = Field( - False, + is_compare_strictly: bool | None = Field( + False, # noqa: FBT003 description="Set to True if the target API does not accept queries where the start time equal the end time.", title="Whether to skip requests if the start time equals the end time", ) - global_substream_cursor: Optional[bool] = Field( - False, + global_substream_cursor: bool | None = Field( + False, # noqa: FBT003 description="This setting optimizes performance when the parent stream has thousands of partitions by storing the cursor as a single value rather than per partition. Notably, the substream state is updated only at the end of the sync, which helps prevent data loss in case of a sync failure. See more info in the [docs](https://docs.airbyte.com/connector-development/config-based/understanding-the-yaml-file/incremental-syncs).", title="Whether to store cursor as one value instead of per partition", ) - lookback_window: Optional[str] = Field( + lookback_window: str | None = Field( None, description="Time interval before the start_datetime to read data for, e.g. P1M for looking back one month.", examples=["P1D", "P{{ config['lookback_days'] }}D"], title="Lookback Window", ) - partition_field_end: Optional[str] = Field( + partition_field_end: str | None = Field( None, description="Name of the partition start time field.", examples=["ending_time"], title="Partition Field End", ) - partition_field_start: Optional[str] = Field( + partition_field_start: str | None = Field( None, description="Name of the partition end time field.", examples=["starting_time"], title="Partition Field Start", ) - start_time_option: Optional[RequestOption] = Field( + start_time_option: RequestOption | None = Field( None, description="Optionally configures how the start datetime will be sent in requests to the source API.", title="Inject Start Time Into Outgoing HTTP Request", ) - step: Optional[str] = Field( + step: str | None = Field( None, description="The size of the time window (ISO8601 duration). Given this field is provided, `cursor_granularity` needs to be provided as well.", examples=["P1W", "{{ config['step_increment'] }}"], title="Step", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DefaultErrorHandler(BaseModel): type: Literal["DefaultErrorHandler"] - backoff_strategies: Optional[ - List[ - Union[ - ConstantBackoffStrategy, - CustomBackoffStrategy, - ExponentialBackoffStrategy, - WaitTimeFromHeader, - WaitUntilTimeFromHeader, - ] + backoff_strategies: ( + list[ + ConstantBackoffStrategy + | CustomBackoffStrategy + | ExponentialBackoffStrategy + | WaitTimeFromHeader + | WaitUntilTimeFromHeader ] - ] = Field( + | None + ) = Field( None, description="List of backoff strategies to use to determine how long to wait before retrying a retryable request.", title="Backoff Strategies", ) - max_retries: Optional[int] = Field( + max_retries: int | None = Field( 5, description="The maximum number of time to retry a retryable request before giving up and failing.", examples=[5, 0, 10], title="Max Retry Count", ) - response_filters: Optional[List[HttpResponseFilter]] = Field( + response_filters: list[HttpResponseFilter] | None = Field( None, description="List of response filters to iterate on when deciding how to handle an error. When using an array of multiple filters, the filters will be applied sequentially and the response will be selected if it matches any of the filter's predicate.", title="Response Filters", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DefaultPaginator(BaseModel): type: Literal["DefaultPaginator"] - pagination_strategy: Union[ - CursorPagination, CustomPaginationStrategy, OffsetIncrement, PageIncrement - ] = Field( + pagination_strategy: ( + CursorPagination | CustomPaginationStrategy | OffsetIncrement | PageIncrement + ) = Field( ..., description="Strategy defining how records are paginated.", title="Pagination Strategy", ) - page_size_option: Optional[RequestOption] = None - page_token_option: Optional[Union[RequestOption, RequestPath]] = None - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + page_size_option: RequestOption | None = None + page_token_option: RequestOption | RequestPath | None = None + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SessionTokenRequestApiKeyAuthenticator(BaseModel): @@ -1612,55 +1611,55 @@ class ListPartitionRouter(BaseModel): examples=["section", "{{ config['section_key'] }}"], title="Current Partition Value Identifier", ) - values: Union[str, List[str]] = Field( + values: str | list[str] = Field( ..., description="The list of attributes being iterated over and used as input for the requests made to the source API.", examples=[["section_a", "section_b", "section_c"], "{{ config['sections'] }}"], title="Partition Values", ) - request_option: Optional[RequestOption] = Field( + request_option: RequestOption | None = Field( None, description="A request option describing where the list value should be injected into and under what field name if applicable.", title="Inject Partition Value Into Outgoing HTTP Request", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class RecordSelector(BaseModel): type: Literal["RecordSelector"] - extractor: Union[CustomRecordExtractor, DpathExtractor] - record_filter: Optional[Union[CustomRecordFilter, RecordFilter]] = Field( + extractor: CustomRecordExtractor | DpathExtractor + record_filter: CustomRecordFilter | RecordFilter | None = Field( None, description="Responsible for filtering records to be emitted by the Source.", title="Record Filter", ) - schema_normalization: Optional[Union[SchemaNormalization, CustomSchemaNormalization]] = Field( + schema_normalization: SchemaNormalization | CustomSchemaNormalization | None = Field( SchemaNormalization.None_, description="Responsible for normalization according to the schema.", title="Schema Normalization", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class GzipParser(BaseModel): type: Literal["GzipParser"] - inner_parser: Union[JsonLineParser, CsvParser, JsonParser] + inner_parser: JsonLineParser | CsvParser | JsonParser class Spec(BaseModel): type: Literal["Spec"] - connection_specification: Dict[str, Any] = Field( + connection_specification: dict[str, Any] = Field( ..., description="A connection specification describing how a the connector can be configured.", title="Connection Specification", ) - documentation_url: Optional[str] = Field( + documentation_url: str | None = Field( None, description="URL of the connector's documentation page.", examples=["https://docs.airbyte.com/integrations/sources/dremio"], title="Documentation URL", ) - advanced_auth: Optional[AuthFlow] = Field( + advanced_auth: AuthFlow | None = Field( None, description="Advanced specification for configuring the authentication flow.", title="Advanced Auth", @@ -1669,12 +1668,12 @@ class Spec(BaseModel): class CompositeErrorHandler(BaseModel): type: Literal["CompositeErrorHandler"] - error_handlers: List[Union[CompositeErrorHandler, DefaultErrorHandler]] = Field( + error_handlers: list[CompositeErrorHandler | DefaultErrorHandler] = Field( ..., description="List of error handlers to iterate on to determine how to handle a failed response.", title="Error Handlers", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ZipfileDecoder(BaseModel): @@ -1682,7 +1681,7 @@ class Config: extra = Extra.allow type: Literal["ZipfileDecoder"] - parser: Union[GzipParser, JsonParser, JsonLineParser, CsvParser] = Field( + parser: GzipParser | JsonParser | JsonLineParser | CsvParser = Field( ..., description="Parser to parse the decompressed data from the zipfile(s).", title="Parser", @@ -1691,7 +1690,7 @@ class Config: class CompositeRawDecoder(BaseModel): type: Literal["CompositeRawDecoder"] - parser: Union[GzipParser, JsonParser, JsonLineParser, CsvParser] + parser: GzipParser | JsonParser | JsonLineParser | CsvParser class DeclarativeSource1(BaseModel): @@ -1699,22 +1698,22 @@ class Config: extra = Extra.forbid type: Literal["DeclarativeSource"] - check: Union[CheckStream, CheckDynamicStream] - streams: List[DeclarativeStream] - dynamic_streams: Optional[List[DynamicDeclarativeStream]] = None + check: CheckStream | CheckDynamicStream + streams: list[DeclarativeStream] + dynamic_streams: list[DynamicDeclarativeStream] | None = None version: str = Field( ..., description="The version of the Airbyte CDK used to build and test the source.", ) - schemas: Optional[Schemas] = None - definitions: Optional[Dict[str, Any]] = None - spec: Optional[Spec] = None - concurrency_level: Optional[ConcurrencyLevel] = None - metadata: Optional[Dict[str, Any]] = Field( + schemas: Schemas | None = None + definitions: dict[str, Any] | None = None + spec: Spec | None = None + concurrency_level: ConcurrencyLevel | None = None + metadata: dict[str, Any] | None = Field( None, description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.", ) - description: Optional[str] = Field( + description: str | None = Field( None, description="A description of the connector. It will be presented on the Source documentation page.", ) @@ -1725,22 +1724,22 @@ class Config: extra = Extra.forbid type: Literal["DeclarativeSource"] - check: Union[CheckStream, CheckDynamicStream] - streams: Optional[List[DeclarativeStream]] = None - dynamic_streams: List[DynamicDeclarativeStream] + check: CheckStream | CheckDynamicStream + streams: list[DeclarativeStream] | None = None + dynamic_streams: list[DynamicDeclarativeStream] version: str = Field( ..., description="The version of the Airbyte CDK used to build and test the source.", ) - schemas: Optional[Schemas] = None - definitions: Optional[Dict[str, Any]] = None - spec: Optional[Spec] = None - concurrency_level: Optional[ConcurrencyLevel] = None - metadata: Optional[Dict[str, Any]] = Field( + schemas: Schemas | None = None + definitions: dict[str, Any] | None = None + spec: Spec | None = None + concurrency_level: ConcurrencyLevel | None = None + metadata: dict[str, Any] | None = Field( None, description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.", ) - description: Optional[str] = Field( + description: str | None = Field( None, description="A description of the connector. It will be presented on the Source documentation page.", ) @@ -1750,7 +1749,7 @@ class DeclarativeSource(BaseModel): class Config: extra = Extra.forbid - __root__: Union[DeclarativeSource1, DeclarativeSource2] = Field( + __root__: DeclarativeSource1 | DeclarativeSource2 = Field( ..., description="An API source that extracts data according to its declarative components.", title="DeclarativeSource", @@ -1762,25 +1761,23 @@ class Config: extra = Extra.allow type: Literal["SelectiveAuthenticator"] - authenticator_selection_path: List[str] = Field( + authenticator_selection_path: list[str] = Field( ..., description="Path of the field in config with selected authenticator name", examples=[["auth"], ["auth", "type"]], title="Authenticator Selection Path", ) - authenticators: Dict[ + authenticators: dict[ str, - Union[ - ApiKeyAuthenticator, - BasicHttpAuthenticator, - BearerAuthenticator, - CustomAuthenticator, - OAuthAuthenticator, - JwtAuthenticator, - NoAuth, - SessionTokenAuthenticator, - LegacySessionTokenAuthenticator, - ], + ApiKeyAuthenticator + | BasicHttpAuthenticator + | BearerAuthenticator + | CustomAuthenticator + | OAuthAuthenticator + | JwtAuthenticator + | NoAuth + | SessionTokenAuthenticator + | LegacySessionTokenAuthenticator, ] = Field( ..., description="Authenticators to select from.", @@ -1795,7 +1792,7 @@ class Config: ], title="Authenticators", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DeclarativeStream(BaseModel): @@ -1803,58 +1800,52 @@ class Config: extra = Extra.allow type: Literal["DeclarativeStream"] - retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field( + retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field( ..., description="Component used to coordinate how records are extracted across stream slices and request pages.", title="Retriever", ) - incremental_sync: Optional[Union[CustomIncrementalSync, DatetimeBasedCursor]] = Field( + incremental_sync: CustomIncrementalSync | DatetimeBasedCursor | None = Field( None, description="Component used to fetch data incrementally based on a time field in the data.", title="Incremental Sync", ) - name: Optional[str] = Field("", description="The stream name.", example=["Users"], title="Name") - primary_key: Optional[PrimaryKey] = Field( + name: str | None = Field("", description="The stream name.", example=["Users"], title="Name") + primary_key: PrimaryKey | None = Field( "", description="The primary key of the stream.", title="Primary Key" ) - schema_loader: Optional[ - Union[ - DynamicSchemaLoader, - InlineSchemaLoader, - JsonFileSchemaLoader, - CustomSchemaLoader, - ] - ] = Field( + schema_loader: ( + DynamicSchemaLoader | InlineSchemaLoader | JsonFileSchemaLoader | CustomSchemaLoader | None + ) = Field( None, description="Component used to retrieve the schema for the current stream.", title="Schema Loader", ) - transformations: Optional[ - List[ - Union[ - AddFields, - CustomTransformation, - RemoveFields, - KeysToLower, - KeysToSnakeCase, - FlattenFields, - DpathFlattenFields, - KeysReplace, - ] + transformations: ( + list[ + AddFields + | CustomTransformation + | RemoveFields + | KeysToLower + | KeysToSnakeCase + | FlattenFields + | DpathFlattenFields + | KeysReplace ] - ] = Field( + | None + ) = Field( None, description="A list of transformations to be applied to each output record.", title="Transformations", ) - state_migrations: Optional[ - List[Union[LegacyToPerPartitionStateMigration, CustomStateMigration]] - ] = Field( - [], - description="Array of state migrations to be applied on the input state", - title="State Migrations", + state_migrations: list[LegacyToPerPartitionStateMigration | CustomStateMigration] | None = ( + Field( + [], + description="Array of state migrations to be applied on the input state", + title="State Migrations", + ) ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SessionTokenAuthenticator(BaseModel): @@ -1876,29 +1867,29 @@ class SessionTokenAuthenticator(BaseModel): ], title="Login Requester", ) - session_token_path: List[str] = Field( + session_token_path: list[str] = Field( ..., description="The path in the response body returned from the login requester to the session token.", examples=[["access_token"], ["result", "token"]], title="Session Token Path", ) - expiration_duration: Optional[str] = Field( + expiration_duration: str | None = Field( None, description="The duration in ISO 8601 duration notation after which the session token expires, starting from the time it was obtained. Omitting it will result in the session token being refreshed for every request.", examples=["PT1H", "P1D"], title="Expiration Duration", ) - request_authentication: Union[ - SessionTokenRequestApiKeyAuthenticator, SessionTokenRequestBearerAuthenticator - ] = Field( + request_authentication: ( + SessionTokenRequestApiKeyAuthenticator | SessionTokenRequestBearerAuthenticator + ) = Field( ..., description="Authentication method to use for requests sent to the API, specifying how to inject the session token.", title="Data Request Authentication", ) - decoder: Optional[Union[JsonDecoder, XmlDecoder, CompositeRawDecoder]] = Field( + decoder: JsonDecoder | XmlDecoder | CompositeRawDecoder | None = Field( None, description="Component used to decode the response.", title="Decoder" ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class HttpRequester(BaseModel): @@ -1922,38 +1913,35 @@ class HttpRequester(BaseModel): ], title="URL Path", ) - authenticator: Optional[ - Union[ - ApiKeyAuthenticator, - BasicHttpAuthenticator, - BearerAuthenticator, - CustomAuthenticator, - OAuthAuthenticator, - JwtAuthenticator, - NoAuth, - SessionTokenAuthenticator, - LegacySessionTokenAuthenticator, - SelectiveAuthenticator, - ] - ] = Field( + authenticator: ( + ApiKeyAuthenticator + | BasicHttpAuthenticator + | BearerAuthenticator + | CustomAuthenticator + | OAuthAuthenticator + | JwtAuthenticator + | NoAuth + | SessionTokenAuthenticator + | LegacySessionTokenAuthenticator + | SelectiveAuthenticator + | None + ) = Field( None, description="Authentication method to use for requests sent to the API.", title="Authenticator", ) - error_handler: Optional[ - Union[DefaultErrorHandler, CustomErrorHandler, CompositeErrorHandler] - ] = Field( + error_handler: DefaultErrorHandler | CustomErrorHandler | CompositeErrorHandler | None = Field( None, description="Error handler component that defines how to handle errors.", title="Error Handler", ) - http_method: Optional[HttpMethod] = Field( + http_method: HttpMethod | None = Field( HttpMethod.GET, description="The HTTP method used to fetch data from the source (can be GET or POST).", examples=["GET", "POST"], title="HTTP Method", ) - request_body_data: Optional[Union[str, Dict[str, str]]] = Field( + request_body_data: str | dict[str, str] | None = Field( None, description="Specifies how to populate the body of the request with a non-JSON payload. Plain text will be sent as is, whereas objects will be converted to a urlencoded form.", examples=[ @@ -1961,7 +1949,7 @@ class HttpRequester(BaseModel): ], title="Request Body Payload (Non-JSON)", ) - request_body_json: Optional[Union[str, Dict[str, Any]]] = Field( + request_body_json: str | dict[str, Any] | None = Field( None, description="Specifies how to populate the body of the request with a JSON payload. Can contain nested objects.", examples=[ @@ -1971,13 +1959,13 @@ class HttpRequester(BaseModel): ], title="Request Body JSON Payload", ) - request_headers: Optional[Union[str, Dict[str, str]]] = Field( + request_headers: str | dict[str, str] | None = Field( None, description="Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.", examples=[{"Output-Format": "JSON"}, {"Version": "{{ config['version'] }}"}], title="Request Headers", ) - request_parameters: Optional[Union[str, Dict[str, str]]] = Field( + request_parameters: str | dict[str, str] | None = Field( None, description="Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.", examples=[ @@ -1990,41 +1978,40 @@ class HttpRequester(BaseModel): ], title="Query Parameters", ) - use_cache: Optional[bool] = Field( - False, + use_cache: bool | None = Field( + False, # noqa: FBT003 description="Enables stream requests caching. This field is automatically set by the CDK.", title="Use Cache", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DynamicSchemaLoader(BaseModel): type: Literal["DynamicSchemaLoader"] - retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field( + retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field( ..., description="Component used to coordinate how records are extracted across stream slices and request pages.", title="Retriever", ) - schema_transformations: Optional[ - List[ - Union[ - AddFields, - CustomTransformation, - RemoveFields, - KeysToLower, - KeysToSnakeCase, - FlattenFields, - DpathFlattenFields, - KeysReplace, - ] + schema_transformations: ( + list[ + AddFields + | CustomTransformation + | RemoveFields + | KeysToLower + | KeysToSnakeCase + | FlattenFields + | DpathFlattenFields + | KeysReplace ] - ] = Field( + | None + ) = Field( None, description="A list of transformations to be applied to the schema.", title="Schema Transformations", ) schema_type_identifier: SchemaTypeIdentifier - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class ParentStreamConfig(BaseModel): @@ -2044,22 +2031,22 @@ class ParentStreamConfig(BaseModel): examples=["parent_id", "{{ config['parent_partition_field'] }}"], title="Current Parent Key Value Identifier", ) - request_option: Optional[RequestOption] = Field( + request_option: RequestOption | None = Field( None, description="A request option describing where the parent key value should be injected into and under what field name if applicable.", title="Request Option", ) - incremental_dependency: Optional[bool] = Field( - False, + incremental_dependency: bool | None = Field( + False, # noqa: FBT003 description="Indicates whether the parent stream should be read incrementally based on updates in the child stream.", title="Incremental Dependency", ) - extra_fields: Optional[List[List[str]]] = Field( + extra_fields: list[list[str]] | None = Field( None, description="Array of field paths to include as additional fields in the stream slice. Each path is an array of strings representing keys to access fields in the respective parent record. Accessible via `stream_slice.extra_fields`. Missing fields are set to `None`.", title="Extra Fields", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SimpleRetriever(BaseModel): @@ -2068,47 +2055,45 @@ class SimpleRetriever(BaseModel): ..., description="Component that describes how to extract records from a HTTP response.", ) - requester: Union[CustomRequester, HttpRequester] = Field( + requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API.", ) - paginator: Optional[Union[DefaultPaginator, NoPagination]] = Field( + paginator: DefaultPaginator | NoPagination | None = Field( None, description="Paginator component that describes how to navigate through the API's pages.", ) - ignore_stream_slicer_parameters_on_paginated_requests: Optional[bool] = Field( - False, + ignore_stream_slicer_parameters_on_paginated_requests: bool | None = Field( + False, # noqa: FBT003 description="If true, the partition router and incremental request options will be ignored when paginating requests. Request options set directly on the requester will not be ignored.", ) - partition_router: Optional[ - Union[ - CustomPartitionRouter, - ListPartitionRouter, - SubstreamPartitionRouter, - List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]], - ] - ] = Field( + partition_router: ( + CustomPartitionRouter + | ListPartitionRouter + | SubstreamPartitionRouter + | list[CustomPartitionRouter | ListPartitionRouter | SubstreamPartitionRouter] + | None + ) = Field( [], description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.", title="Partition Router", ) - decoder: Optional[ - Union[ - CustomDecoder, - JsonDecoder, - JsonlDecoder, - IterableDecoder, - XmlDecoder, - GzipJsonDecoder, - CompositeRawDecoder, - ZipfileDecoder, - ] - ] = Field( + decoder: ( + CustomDecoder + | JsonDecoder + | JsonlDecoder + | IterableDecoder + | XmlDecoder + | GzipJsonDecoder + | CompositeRawDecoder + | ZipfileDecoder + | None + ) = Field( None, description="Component decoding the response so records can be extracted.", title="Decoder", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class AsyncRetriever(BaseModel): @@ -2120,110 +2105,107 @@ class AsyncRetriever(BaseModel): status_mapping: AsyncJobStatusMap = Field( ..., description="Async Job Status to Airbyte CDK Async Job Status mapping." ) - status_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field( + status_extractor: CustomRecordExtractor | DpathExtractor = Field( ..., description="Responsible for fetching the actual status of the async job." ) - urls_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field( + urls_extractor: CustomRecordExtractor | DpathExtractor = Field( ..., description="Responsible for fetching the final result `urls` provided by the completed / finished / ready async job.", ) - download_extractor: Optional[ - Union[CustomRecordExtractor, DpathExtractor, ResponseToFileExtractor] - ] = Field(None, description="Responsible for fetching the records from provided urls.") - creation_requester: Union[CustomRequester, HttpRequester] = Field( + download_extractor: CustomRecordExtractor | DpathExtractor | ResponseToFileExtractor | None = ( + Field(None, description="Responsible for fetching the records from provided urls.") + ) + creation_requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API to create the async server-side job.", ) - polling_requester: Union[CustomRequester, HttpRequester] = Field( + polling_requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API to fetch the status of the running async job.", ) - url_requester: Optional[Union[CustomRequester, HttpRequester]] = Field( + url_requester: CustomRequester | HttpRequester | None = Field( None, description="Requester component that describes how to prepare HTTP requests to send to the source API to extract the url from polling response by the completed async job.", ) - download_requester: Union[CustomRequester, HttpRequester] = Field( + download_requester: CustomRequester | HttpRequester = Field( ..., description="Requester component that describes how to prepare HTTP requests to send to the source API to download the data provided by the completed async job.", ) - download_paginator: Optional[Union[DefaultPaginator, NoPagination]] = Field( + download_paginator: DefaultPaginator | NoPagination | None = Field( None, description="Paginator component that describes how to navigate through the API's pages during download.", ) - abort_requester: Optional[Union[CustomRequester, HttpRequester]] = Field( + abort_requester: CustomRequester | HttpRequester | None = Field( None, description="Requester component that describes how to prepare HTTP requests to send to the source API to abort a job once it is timed out from the source's perspective.", ) - delete_requester: Optional[Union[CustomRequester, HttpRequester]] = Field( + delete_requester: CustomRequester | HttpRequester | None = Field( None, description="Requester component that describes how to prepare HTTP requests to send to the source API to delete a job once the records are extracted.", ) - partition_router: Optional[ - Union[ - CustomPartitionRouter, - ListPartitionRouter, - SubstreamPartitionRouter, - List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]], - ] - ] = Field( + partition_router: ( + CustomPartitionRouter + | ListPartitionRouter + | SubstreamPartitionRouter + | list[CustomPartitionRouter | ListPartitionRouter | SubstreamPartitionRouter] + | None + ) = Field( [], description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.", title="Partition Router", ) - decoder: Optional[ - Union[ - CustomDecoder, - JsonDecoder, - JsonlDecoder, - IterableDecoder, - XmlDecoder, - GzipJsonDecoder, - CompositeRawDecoder, - ZipfileDecoder, - ] - ] = Field( + decoder: ( + CustomDecoder + | JsonDecoder + | JsonlDecoder + | IterableDecoder + | XmlDecoder + | GzipJsonDecoder + | CompositeRawDecoder + | ZipfileDecoder + | None + ) = Field( None, description="Component decoding the response so records can be extracted.", title="Decoder", ) - download_decoder: Optional[ - Union[ - CustomDecoder, - JsonDecoder, - JsonlDecoder, - IterableDecoder, - XmlDecoder, - GzipJsonDecoder, - CompositeRawDecoder, - ZipfileDecoder, - ] - ] = Field( + download_decoder: ( + CustomDecoder + | JsonDecoder + | JsonlDecoder + | IterableDecoder + | XmlDecoder + | GzipJsonDecoder + | CompositeRawDecoder + | ZipfileDecoder + | None + ) = Field( None, description="Component decoding the download response so records can be extracted.", title="Download Decoder", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class SubstreamPartitionRouter(BaseModel): type: Literal["SubstreamPartitionRouter"] - parent_stream_configs: List[ParentStreamConfig] = Field( + parent_stream_configs: list[ParentStreamConfig] = Field( ..., description="Specifies which parent streams are being iterated over and how parent records should be used to partition the child stream data set.", title="Parent Stream Configs", ) - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class HttpComponentsResolver(BaseModel): type: Literal["HttpComponentsResolver"] - retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field( + retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field( ..., description="Component used to coordinate how records are extracted across stream slices and request pages.", title="Retriever", ) - components_mapping: List[ComponentMappingDefinition] - parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + components_mapping: list[ComponentMappingDefinition] + parameters: dict[str, Any] | None = Field(None, alias="$parameters") class DynamicDeclarativeStream(BaseModel): @@ -2231,7 +2213,7 @@ class DynamicDeclarativeStream(BaseModel): stream_template: DeclarativeStream = Field( ..., description="Reference to the stream template.", title="Stream Template" ) - components_resolver: Union[HttpComponentsResolver, ConfigComponentsResolver] = Field( + components_resolver: HttpComponentsResolver | ConfigComponentsResolver = Field( ..., description="Component resolve and populates stream templates with components values.", title="Components Resolver", diff --git a/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py b/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py index 8a6638fad..2b805b50e 100644 --- a/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py +++ b/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py @@ -5,9 +5,8 @@ import sys from collections.abc import Mapping from types import ModuleType -from typing import Any, cast +from typing import Any, Literal, cast -from typing_extensions import Literal ChecksumType = Literal["md5", "sha256"] CHECKSUM_FUNCTIONS = { @@ -107,10 +106,10 @@ def get_registered_components_module( # Check for `components` or `source_declarative_manifest.components`. if SDM_COMPONENTS_MODULE_NAME in sys.modules: - return cast(ModuleType, sys.modules.get(SDM_COMPONENTS_MODULE_NAME)) + return cast(ModuleType, sys.modules.get(SDM_COMPONENTS_MODULE_NAME)) # noqa: TC006 if COMPONENTS_MODULE_NAME in sys.modules: - return cast(ModuleType, sys.modules.get(COMPONENTS_MODULE_NAME)) + return cast(ModuleType, sys.modules.get(COMPONENTS_MODULE_NAME)) # noqa: TC006 # Could not find module 'components' in `sys.modules` # and INJECTED_COMPONENTS_PY was not provided in config. diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py index f2719bb14..73ce6b526 100644 --- a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py +++ b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py @@ -3,8 +3,9 @@ # import copy -import typing -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any + PARAMETERS_STR = "$parameters" @@ -149,7 +150,7 @@ def propagate_types_and_parameters( ) if excluded_parameter: current_parameters[field_name] = excluded_parameter - elif isinstance(field_value, typing.List): + elif isinstance(field_value, list): # We exclude propagating a parameter that matches the current field name because that would result in an infinite cycle excluded_parameter = current_parameters.pop(field_name, None) for i, element in enumerate(field_value): diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py index 045ea9a2c..9fb80d0c4 100644 --- a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py +++ b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py @@ -3,13 +3,15 @@ # import re -from typing import Any, Mapping, Set, Tuple, Union +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.declarative.parsers.custom_exceptions import ( CircularReferenceException, UndefinedReferenceException, ) + REF_TAG = "$ref" @@ -106,7 +108,7 @@ def preprocess_manifest(self, manifest: Mapping[str, Any]) -> Mapping[str, Any]: """ return self._evaluate_node(manifest, manifest, set()) # type: ignore[no-any-return] - def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[Any]) -> Any: + def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: set[Any]) -> Any: # noqa: ANN401 if isinstance(node, dict): evaluated_dict = { k: self._evaluate_node(v, manifest, visited) @@ -118,24 +120,21 @@ def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[An evaluated_ref = self._evaluate_node(node[REF_TAG], manifest, visited) if not isinstance(evaluated_ref, dict): return evaluated_ref - else: - # The values defined on the component take precedence over the reference values - return evaluated_ref | evaluated_dict - else: - return evaluated_dict - elif isinstance(node, list): + # The values defined on the component take precedence over the reference values + return evaluated_ref | evaluated_dict + return evaluated_dict + if isinstance(node, list): return [self._evaluate_node(v, manifest, visited) for v in node] - elif self._is_ref(node): + if self._is_ref(node): if node in visited: raise CircularReferenceException(node) visited.add(node) ret = self._evaluate_node(self._lookup_ref_value(node, manifest), manifest, visited) visited.remove(node) return ret - else: - return node + return node - def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any: + def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any: # noqa: ANN401 ref_match = re.match(r"#/(.*)", ref) if not ref_match: raise ValueError(f"Invalid reference format {ref}") @@ -143,10 +142,10 @@ def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any: path = ref_match.groups()[0] return self._read_ref_value(path, manifest) except (AttributeError, KeyError, IndexError): - raise UndefinedReferenceException(path, ref) + raise UndefinedReferenceException(path, ref) # noqa: B904 @staticmethod - def _is_ref(node: Any) -> bool: + def _is_ref(node: Any) -> bool: # noqa: ANN401 return isinstance(node, str) and node.startswith("#/") @staticmethod @@ -154,7 +153,7 @@ def _is_ref_key(key: str) -> bool: return bool(key == REF_TAG) @staticmethod - def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any: + def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any: # noqa: ANN401 """ Read the value at the referenced location of the manifest. @@ -185,7 +184,7 @@ def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any: return manifest_node -def _parse_path(ref: str) -> Tuple[Union[str, int], str]: +def _parse_path(ref: str) -> tuple[str | int, str]: """ Return the next path component, together with the rest of the path. diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index c39ae0a68..dac53a066 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -8,31 +8,23 @@ import importlib import inspect import re -import sys +from collections.abc import Callable, Mapping, MutableMapping from functools import partial from typing import ( Any, - Callable, - Dict, - List, - Mapping, - MutableMapping, - Optional, - Type, - Union, get_args, get_origin, get_type_hints, ) from isodate import parse_duration -from pydantic.v1 import BaseModel +from pydantic.v1 import BaseModel # noqa: TC002 from airbyte_cdk.models import FailureType, Level -from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager +from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager # noqa: TC001 from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator from airbyte_cdk.sources.declarative.async_job.job_tracker import JobTracker -from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository +from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository # noqa: TC001 from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator from airbyte_cdk.sources.declarative.auth.declarative_authenticator import ( @@ -189,7 +181,7 @@ CustomRetriever as CustomRetrieverModel, ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( - CustomSchemaLoader as CustomSchemaLoader, + CustomSchemaLoader as CustomSchemaLoader, # noqa: PLC0414 ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( CustomSchemaNormalization as CustomSchemaNormalizationModel, @@ -364,10 +356,6 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ZipfileDecoder as ZipfileDecoderModel, ) -from airbyte_cdk.sources.declarative.parsers.custom_code_compiler import ( - COMPONENTS_MODULE_NAME, - SDM_COMPONENTS_MODULE_NAME, -) from airbyte_cdk.sources.declarative.partition_routers import ( CartesianProductStreamSlicer, ListPartitionRouter, @@ -435,7 +423,7 @@ TypesMap, ) from airbyte_cdk.sources.declarative.spec import Spec -from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer +from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer # noqa: TC001 from airbyte_cdk.sources.declarative.transformations import ( AddFields, RecordTransformation, @@ -468,9 +456,10 @@ DateTimeStreamStateConverter, ) from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction -from airbyte_cdk.sources.types import Config +from airbyte_cdk.sources.types import Config # noqa: TC001 from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer + ComponentDefinition = Mapping[str, Any] SCHEMA_TRANSFORMER_TYPE_MAPPING = { @@ -479,18 +468,18 @@ } -class ModelToComponentFactory: +class ModelToComponentFactory: # noqa: PLR0904 EPOCH_DATETIME_FORMAT = "%s" - def __init__( + def __init__( # noqa: ANN204 self, - limit_pages_fetched_per_slice: Optional[int] = None, - limit_slices_fetched: Optional[int] = None, - emit_connector_builder_messages: bool = False, - disable_retries: bool = False, - disable_cache: bool = False, - disable_resumable_full_refresh: bool = False, - message_repository: Optional[MessageRepository] = None, + limit_pages_fetched_per_slice: int | None = None, + limit_slices_fetched: int | None = None, + emit_connector_builder_messages: bool = False, # noqa: FBT001, FBT002 + disable_retries: bool = False, # noqa: FBT001, FBT002 + disable_cache: bool = False, # noqa: FBT001, FBT002 + disable_resumable_full_refresh: bool = False, # noqa: FBT001, FBT002 + message_repository: MessageRepository | None = None, ): self._init_mappings() self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice @@ -504,7 +493,7 @@ def __init__( ) def _init_mappings(self) -> None: - self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = { + self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[type[BaseModel], Callable[..., Any]] = { AddedFieldDefinitionModel: self.create_added_field_definition, AddFieldsModel: self.create_add_fields, ApiKeyAuthenticatorModel: self.create_api_key_authenticator, @@ -595,11 +584,11 @@ def _init_mappings(self) -> None: def create_component( self, - model_type: Type[BaseModel], + model_type: type[BaseModel], component_definition: ComponentDefinition, config: Config, - **kwargs: Any, - ) -> Any: + **kwargs: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 """ Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and subcomponents which will be used at runtime. This is done by first parsing the mapping into a Pydantic model and then creating @@ -620,7 +609,7 @@ def create_component( declarative_component_model = model_type.parse_obj(component_definition) if not isinstance(declarative_component_model, model_type): - raise ValueError( + raise ValueError( # noqa: TRY004 f"Expected {model_type.__name__} component, but received {declarative_component_model.__class__.__name__}" ) @@ -628,7 +617,7 @@ def create_component( model=declarative_component_model, config=config, **kwargs ) - def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any: + def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any: # noqa: ANN401 if model.__class__ not in self.PYDANTIC_MODEL_TO_CONSTRUCTOR: raise ValueError( f"{model.__class__} with attributes {model} is not a valid component type" @@ -640,7 +629,9 @@ def _create_component_from_model(self, model: BaseModel, config: Config, **kwarg @staticmethod def create_added_field_definition( - model: AddedFieldDefinitionModel, config: Config, **kwargs: Any + model: AddedFieldDefinitionModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> AddedFieldDefinition: interpolated_value = InterpolatedString.create( model.value, parameters=model.parameters or {} @@ -652,7 +643,7 @@ def create_added_field_definition( parameters=model.parameters or {}, ) - def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields: + def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields: # noqa: ANN401, ARG002 added_field_definitions = [ self._create_component_from_model( model=added_field_definition_model, @@ -666,33 +657,48 @@ def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any return AddFields(fields=added_field_definitions, parameters=model.parameters or {}) def create_keys_to_lower_transformation( - self, model: KeysToLowerModel, config: Config, **kwargs: Any + self, + model: KeysToLowerModel, # noqa: ARG002 + config: Config, # noqa: ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 ) -> KeysToLowerTransformation: return KeysToLowerTransformation() def create_keys_to_snake_transformation( - self, model: KeysToSnakeCaseModel, config: Config, **kwargs: Any + self, + model: KeysToSnakeCaseModel, # noqa: ARG002 + config: Config, # noqa: ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 ) -> KeysToSnakeCaseTransformation: return KeysToSnakeCaseTransformation() def create_keys_replace_transformation( - self, model: KeysReplaceModel, config: Config, **kwargs: Any + self, + model: KeysReplaceModel, + config: Config, # noqa: ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 ) -> KeysReplaceTransformation: return KeysReplaceTransformation( old=model.old, new=model.new, parameters=model.parameters or {} ) def create_flatten_fields( - self, model: FlattenFieldsModel, config: Config, **kwargs: Any + self, + model: FlattenFieldsModel, + config: Config, # noqa: ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 ) -> FlattenFields: return FlattenFields( flatten_lists=model.flatten_lists if model.flatten_lists is not None else True ) def create_dpath_flatten_fields( - self, model: DpathFlattenFieldsModel, config: Config, **kwargs: Any + self, + model: DpathFlattenFieldsModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DpathFlattenFields: - model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path] + model_field_path: list[InterpolatedString | str] = [x for x in model.field_path] return DpathFlattenFields( config=config, field_path=model_field_path, @@ -703,7 +709,7 @@ def create_dpath_flatten_fields( ) @staticmethod - def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[Type[Any]]: + def _json_schema_type_name_to_type(value_type: ValueType | None) -> type[Any] | None: if not value_type: return None names_to_types = { @@ -718,8 +724,8 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[ def create_api_key_authenticator( model: ApiKeyAuthenticatorModel, config: Config, - token_provider: Optional[TokenProvider] = None, - **kwargs: Any, + token_provider: TokenProvider | None = None, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> ApiKeyAuthenticator: if model.inject_into is None and model.header is None: raise ValueError( @@ -731,7 +737,7 @@ def create_api_key_authenticator( "inject_into and header cannot be set both for ApiKeyAuthenticator - remove the deprecated header option" ) - if token_provider is not None and model.api_token != "": + if token_provider is not None and model.api_token != "": # noqa: PLC1901 raise ValueError( "If token_provider is set, api_token is ignored and has to be set to empty string." ) @@ -766,20 +772,20 @@ def create_api_key_authenticator( def create_legacy_to_per_partition_state_migration( self, - model: LegacyToPerPartitionStateMigrationModel, + model: LegacyToPerPartitionStateMigrationModel, # noqa: ARG002 config: Mapping[str, Any], declarative_stream: DeclarativeStreamModel, ) -> LegacyToPerPartitionStateMigration: retriever = declarative_stream.retriever if not isinstance(retriever, SimpleRetrieverModel): - raise ValueError( + raise ValueError( # noqa: TRY004 f"LegacyToPerPartitionStateMigrations can only be applied on a DeclarativeStream with a SimpleRetriever. Got {type(retriever)}" ) partition_router = retriever.partition_router - if not isinstance( + if not isinstance( # noqa: UP038 partition_router, (SubstreamPartitionRouterModel, CustomPartitionRouterModel) ): - raise ValueError( + raise ValueError( # noqa: TRY004 f"LegacyToPerPartitionStateMigrations can only be applied on a SimpleRetriever with a Substream partition router. Got {type(partition_router)}" ) if not hasattr(partition_router, "parent_stream_configs"): @@ -800,8 +806,12 @@ def create_legacy_to_per_partition_state_migration( ) def create_session_token_authenticator( - self, model: SessionTokenAuthenticatorModel, config: Config, name: str, **kwargs: Any - ) -> Union[ApiKeyAuthenticator, BearerAuthenticator]: + self, + model: SessionTokenAuthenticatorModel, + config: Config, + name: str, + **kwargs: Any, # noqa: ANN401, ARG002 + ) -> ApiKeyAuthenticator | BearerAuthenticator: decoder = ( self._create_component_from_model(model=model.decoder, config=config) if model.decoder @@ -829,20 +839,21 @@ def create_session_token_authenticator( config, token_provider=token_provider, ) - else: - return ModelToComponentFactory.create_api_key_authenticator( - ApiKeyAuthenticatorModel( - type="ApiKeyAuthenticator", - api_token="", - inject_into=model.request_authentication.inject_into, - ), # type: ignore # $parameters and headers default to None - config=config, - token_provider=token_provider, - ) + return ModelToComponentFactory.create_api_key_authenticator( + ApiKeyAuthenticatorModel( + type="ApiKeyAuthenticator", + api_token="", + inject_into=model.request_authentication.inject_into, + ), # type: ignore # $parameters and headers default to None + config=config, + token_provider=token_provider, + ) @staticmethod def create_basic_http_authenticator( - model: BasicHttpAuthenticatorModel, config: Config, **kwargs: Any + model: BasicHttpAuthenticatorModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> BasicHttpAuthenticator: return BasicHttpAuthenticator( password=model.password or "", @@ -855,10 +866,10 @@ def create_basic_http_authenticator( def create_bearer_authenticator( model: BearerAuthenticatorModel, config: Config, - token_provider: Optional[TokenProvider] = None, - **kwargs: Any, + token_provider: TokenProvider | None = None, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> BearerAuthenticator: - if token_provider is not None and model.api_token != "": + if token_provider is not None and model.api_token != "": # noqa: PLC1901 raise ValueError( "If token_provider is set, api_token is ignored and has to be set to empty string." ) @@ -877,17 +888,22 @@ def create_bearer_authenticator( ) @staticmethod - def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream: + def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream: # noqa: ANN401, ARG004 return CheckStream(stream_names=model.stream_names, parameters={}) @staticmethod def create_check_dynamic_stream( - model: CheckDynamicStreamModel, config: Config, **kwargs: Any + model: CheckDynamicStreamModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> CheckDynamicStream: return CheckDynamicStream(stream_count=model.stream_count, parameters={}) def create_composite_error_handler( - self, model: CompositeErrorHandlerModel, config: Config, **kwargs: Any + self, + model: CompositeErrorHandlerModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> CompositeErrorHandler: error_handlers = [ self._create_component_from_model(model=error_handler_model, config=config) @@ -899,7 +915,9 @@ def create_composite_error_handler( @staticmethod def create_concurrency_level( - model: ConcurrencyLevelModel, config: Config, **kwargs: Any + model: ConcurrencyLevelModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> ConcurrencyLevel: return ConcurrencyLevel( default_concurrency=model.default_concurrency, @@ -908,16 +926,16 @@ def create_concurrency_level( parameters={}, ) - def create_concurrent_cursor_from_datetime_based_cursor( + def create_concurrent_cursor_from_datetime_based_cursor( # noqa: PLR0914 self, state_manager: ConnectorStateManager, - model_type: Type[BaseModel], + model_type: type[BaseModel], component_definition: ComponentDefinition, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, config: Config, stream_state: MutableMapping[str, Any], - **kwargs: Any, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> ConcurrentCursor: component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: @@ -928,7 +946,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( datetime_based_cursor_model = model_type.parse_obj(component_definition) if not isinstance(datetime_based_cursor_model, DatetimeBasedCursorModel): - raise ValueError( + raise ValueError( # noqa: TRY004 f"Expected {model_type.__name__} component, but received {datetime_based_cursor_model.__class__.__name__}" ) @@ -982,7 +1000,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( cursor_granularity=cursor_granularity, ) - start_date_runtime_value: Union[InterpolatedString, str, MinMaxDatetime] + start_date_runtime_value: InterpolatedString | str | MinMaxDatetime if isinstance(datetime_based_cursor_model.start_datetime, MinMaxDatetimeModel): start_date_runtime_value = self.create_min_max_datetime( model=datetime_based_cursor_model.start_datetime, config=config @@ -990,7 +1008,7 @@ def create_concurrent_cursor_from_datetime_based_cursor( else: start_date_runtime_value = datetime_based_cursor_model.start_datetime - end_date_runtime_value: Optional[Union[InterpolatedString, str, MinMaxDatetime]] + end_date_runtime_value: InterpolatedString | str | MinMaxDatetime | None if isinstance(datetime_based_cursor_model.end_datetime, MinMaxDatetimeModel): end_date_runtime_value = self.create_min_max_datetime( model=datetime_based_cursor_model.end_datetime, config=config @@ -1066,7 +1084,9 @@ def create_concurrent_cursor_from_datetime_based_cursor( @staticmethod def create_constant_backoff_strategy( - model: ConstantBackoffStrategyModel, config: Config, **kwargs: Any + model: ConstantBackoffStrategyModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> ConstantBackoffStrategy: return ConstantBackoffStrategy( backoff_time_in_seconds=model.backoff_time_in_seconds, @@ -1075,7 +1095,11 @@ def create_constant_backoff_strategy( ) def create_cursor_pagination( - self, model: CursorPaginationModel, config: Config, decoder: Decoder, **kwargs: Any + self, + model: CursorPaginationModel, + config: Config, + decoder: Decoder, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> CursorPaginationStrategy: if isinstance(decoder, PaginationDecoderDecorator): inner_decoder = decoder.decoder @@ -1099,7 +1123,7 @@ def create_cursor_pagination( parameters=model.parameters or {}, ) - def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any: + def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any: # noqa: ANN401 """ Generically creates a custom component based on the model type and a class_name reference to the custom Python class being instantiated. Only the model's additional properties that match the custom class definition are passed to the constructor @@ -1154,7 +1178,7 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> kwargs = { class_field: model_args[class_field] - for class_field in component_fields.keys() + for class_field in component_fields.keys() # noqa: SIM118 if class_field in model_args } return custom_component_class(**kwargs) @@ -1162,7 +1186,7 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> @staticmethod def _get_class_from_fully_qualified_class_name( full_qualified_class_name: str, - ) -> Any: + ) -> Any: # noqa: ANN401 """Get a class from its fully qualified name. If a custom components module is needed, we assume it is already registered - probably @@ -1194,7 +1218,7 @@ def _get_class_from_fully_qualified_class_name( ) from e @staticmethod - def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]: + def _derive_component_type_from_type_hints(field_type: Any) -> str | None: # noqa: ANN401 interface = field_type while True: origin = get_origin(interface) @@ -1211,22 +1235,25 @@ def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]: return None @staticmethod - def is_builtin_type(cls: Optional[Type[Any]]) -> bool: + def is_builtin_type(cls: type[Any] | None) -> bool: # noqa: PLW0211 if not cls: return False return cls.__module__ == "builtins" @staticmethod - def _extract_missing_parameters(error: TypeError) -> List[str]: + def _extract_missing_parameters(error: TypeError) -> list[str]: parameter_search = re.search(r"keyword-only.*:\s(.*)", str(error)) if parameter_search: return re.findall(r"\'(.+?)\'", parameter_search.group(1)) - else: - return [] + return [] def _create_nested_component( - self, model: Any, model_field: str, model_value: Any, config: Config - ) -> Any: + self, + model: Any, # noqa: ANN401 + model_field: str, # noqa: ARG002 + model_value: Any, # noqa: ANN401 + config: Config, + ) -> Any: # noqa: ANN401 type_name = model_value.get("type", None) if not type_name: # If no type is specified, we can assume this is a dictionary object which can be returned instead of a subcomponent @@ -1256,16 +1283,14 @@ def _create_nested_component( except TypeError as error: missing_parameters = self._extract_missing_parameters(error) if missing_parameters: - raise ValueError( + raise ValueError( # noqa: B904 f"Error creating component '{type_name}' with parent custom component {model.class_name}: Please provide " + ", ".join( - ( - f"{type_name}.$parameters.{parameter}" - for parameter in missing_parameters - ) + f"{type_name}.$parameters.{parameter}" + for parameter in missing_parameters ) ) - raise TypeError( + raise TypeError( # noqa: B904 f"Error creating component '{type_name}' with parent custom component {model.class_name}: {error}" ) else: @@ -1274,18 +1299,21 @@ def _create_nested_component( ) @staticmethod - def _is_component(model_value: Any) -> bool: + def _is_component(model_value: Any) -> bool: # noqa: ANN401 return isinstance(model_value, dict) and model_value.get("type") is not None def create_datetime_based_cursor( - self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any + self, + model: DatetimeBasedCursorModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DatetimeBasedCursor: - start_datetime: Union[str, MinMaxDatetime] = ( + start_datetime: str | MinMaxDatetime = ( model.start_datetime if isinstance(model.start_datetime, str) else self.create_min_max_datetime(model.start_datetime, config) ) - end_datetime: Union[str, MinMaxDatetime, None] = None + end_datetime: str | MinMaxDatetime | None = None if model.is_data_feed and model.end_datetime: raise ValueError("Data feed does not support end_datetime") if model.is_data_feed and model.is_client_side_incremental: @@ -1320,7 +1348,7 @@ def create_datetime_based_cursor( return DatetimeBasedCursor( cursor_field=model.cursor_field, - cursor_datetime_formats=model.cursor_datetime_formats + cursor_datetime_formats=model.cursor_datetime_formats # noqa: FURB110 if model.cursor_datetime_formats else [], cursor_granularity=model.cursor_granularity, @@ -1340,7 +1368,10 @@ def create_datetime_based_cursor( ) def create_declarative_stream( - self, model: DeclarativeStreamModel, config: Config, **kwargs: Any + self, + model: DeclarativeStreamModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DeclarativeStream: # When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field # components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the @@ -1375,7 +1406,7 @@ def create_declarative_stream( ), "substream_cursor": ( combined_slicers - if isinstance( + if isinstance( # noqa: UP038 combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) ) else None @@ -1418,7 +1449,7 @@ def create_declarative_stream( transformations = [] if model.transformations: for transformation_model in model.transformations: - transformations.append( + transformations.append( # noqa: PERF401 self._create_component_from_model(model=transformation_model, config=config) ) retriever = self._create_component_from_model( @@ -1465,9 +1496,9 @@ def create_declarative_stream( def _build_stream_slicer_from_partition_router( self, - model: Union[AsyncRetrieverModel, CustomRetrieverModel, SimpleRetrieverModel], + model: AsyncRetrieverModel | CustomRetrieverModel | SimpleRetrieverModel, config: Config, - ) -> Optional[PartitionRouter]: + ) -> PartitionRouter | None: if ( hasattr(model, "partition_router") and isinstance(model, SimpleRetrieverModel) @@ -1483,16 +1514,15 @@ def _build_stream_slicer_from_partition_router( ], parameters={}, ) - else: - return self._create_component_from_model(model=stream_slicer_model, config=config) # type: ignore[no-any-return] - # Will be created PartitionRouter as stream_slicer_model is model.partition_router + return self._create_component_from_model(model=stream_slicer_model, config=config) # type: ignore[no-any-return] + # Will be created PartitionRouter as stream_slicer_model is model.partition_router return None def _build_resumable_cursor_from_paginator( self, - model: Union[AsyncRetrieverModel, CustomRetrieverModel, SimpleRetrieverModel], - stream_slicer: Optional[StreamSlicer], - ) -> Optional[StreamSlicer]: + model: AsyncRetrieverModel | CustomRetrieverModel | SimpleRetrieverModel, + stream_slicer: StreamSlicer | None, + ) -> StreamSlicer | None: if hasattr(model, "paginator") and model.paginator and not stream_slicer: # For the regular Full-Refresh streams, we use the high level `ResumableFullRefreshCursor` return ResumableFullRefreshCursor(parameters={}) @@ -1500,7 +1530,7 @@ def _build_resumable_cursor_from_paginator( def _merge_stream_slicers( self, model: DeclarativeStreamModel, config: Config - ) -> Optional[StreamSlicer]: + ) -> StreamSlicer | None: stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) if model.incremental_sync and stream_slicer: @@ -1515,28 +1545,27 @@ def _merge_stream_slicers( return GlobalSubstreamCursor( stream_cursor=cursor_component, partition_router=stream_slicer ) - else: - cursor_component = self._create_component_from_model( - model=incremental_sync_model, config=config - ) - return PerPartitionWithGlobalCursor( - cursor_factory=CursorFactory( - lambda: self._create_component_from_model( - model=incremental_sync_model, config=config - ), + cursor_component = self._create_component_from_model( + model=incremental_sync_model, config=config + ) + return PerPartitionWithGlobalCursor( + cursor_factory=CursorFactory( + lambda: self._create_component_from_model( + model=incremental_sync_model, config=config ), - partition_router=stream_slicer, - stream_cursor=cursor_component, - ) - elif model.incremental_sync: + ), + partition_router=stream_slicer, + stream_cursor=cursor_component, + ) + if model.incremental_sync: return ( self._create_component_from_model(model=model.incremental_sync, config=config) if model.incremental_sync else None ) - elif self._disable_resumable_full_refresh: + if self._disable_resumable_full_refresh: return stream_slicer - elif stream_slicer: + if stream_slicer: # For the Full-Refresh sub-streams, we use the nested `ChildPartitionResumableFullRefreshCursor` return PerPartitionCursor( cursor_factory=CursorFactory( @@ -1547,19 +1576,22 @@ def _merge_stream_slicers( return self._build_resumable_cursor_from_paginator(model.retriever, stream_slicer) def create_default_error_handler( - self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any + self, + model: DefaultErrorHandlerModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DefaultErrorHandler: backoff_strategies = [] if model.backoff_strategies: for backoff_strategy_model in model.backoff_strategies: - backoff_strategies.append( + backoff_strategies.append( # noqa: PERF401 self._create_component_from_model(model=backoff_strategy_model, config=config) ) response_filters = [] if model.response_filters: for response_filter_model in model.response_filters: - response_filters.append( + response_filters.append( # noqa: PERF401 self._create_component_from_model(model=response_filter_model, config=config) ) response_filters.append( @@ -1580,9 +1612,9 @@ def create_default_paginator( config: Config, *, url_base: str, - decoder: Optional[Decoder] = None, - cursor_used_for_stop_condition: Optional[DeclarativeCursor] = None, - ) -> Union[DefaultPaginator, PaginatorTestReadDecorator]: + decoder: Decoder | None = None, + cursor_used_for_stop_condition: DeclarativeCursor | None = None, + ) -> DefaultPaginator | PaginatorTestReadDecorator: if decoder: if self._is_supported_decoder_for_pagination(decoder): decoder_to_use = PaginationDecoderDecorator(decoder=decoder) @@ -1624,14 +1656,14 @@ def create_dpath_extractor( self, model: DpathExtractorModel, config: Config, - decoder: Optional[Decoder] = None, - **kwargs: Any, + decoder: Decoder | None = None, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DpathExtractor: - if decoder: + if decoder: # noqa: SIM108 decoder_to_use = decoder else: decoder_to_use = JsonDecoder(parameters={}) - model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path] + model_field_path: list[InterpolatedString | str] = [x for x in model.field_path] return DpathExtractor( decoder=decoder_to_use, field_path=model_field_path, @@ -1642,7 +1674,7 @@ def create_dpath_extractor( def create_response_to_file_extractor( self, model: ResponseToFileExtractorModel, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> ResponseToFileExtractor: return ResponseToFileExtractor(parameters=model.parameters or {}) @@ -1658,7 +1690,7 @@ def create_http_requester( self, model: HttpRequesterModel, config: Config, - decoder: Decoder = JsonDecoder(parameters={}), + decoder: Decoder = JsonDecoder(parameters={}), # noqa: B008 *, name: str, ) -> HttpRequester: @@ -1717,9 +1749,11 @@ def create_http_requester( @staticmethod def create_http_response_filter( - model: HttpResponseFilterModel, config: Config, **kwargs: Any + model: HttpResponseFilterModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> HttpResponseFilter: - if model.action: + if model.action: # noqa: SIM108 action = ResponseAction(model.action.value) else: action = None @@ -1743,12 +1777,14 @@ def create_http_response_filter( @staticmethod def create_inline_schema_loader( - model: InlineSchemaLoaderModel, config: Config, **kwargs: Any + model: InlineSchemaLoaderModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> InlineSchemaLoader: return InlineSchemaLoader(schema=model.schema_ or {}, parameters={}) @staticmethod - def create_types_map(model: TypesMapModel, **kwargs: Any) -> TypesMap: + def create_types_map(model: TypesMapModel, **kwargs: Any) -> TypesMap: # noqa: ANN401, ARG004 return TypesMap( target_type=model.target_type, current_type=model.current_type, @@ -1756,21 +1792,22 @@ def create_types_map(model: TypesMapModel, **kwargs: Any) -> TypesMap: ) def create_schema_type_identifier( - self, model: SchemaTypeIdentifierModel, config: Config, **kwargs: Any + self, + model: SchemaTypeIdentifierModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> SchemaTypeIdentifier: types_mapping = [] if model.types_mapping: - types_mapping.extend( - [ - self._create_component_from_model(types_map, config=config) - for types_map in model.types_mapping - ] - ) - model_schema_pointer: List[Union[InterpolatedString, str]] = ( + types_mapping.extend([ + self._create_component_from_model(types_map, config=config) + for types_map in model.types_mapping + ]) + model_schema_pointer: list[InterpolatedString | str] = ( [x for x in model.schema_pointer] if model.schema_pointer else [] ) - model_key_pointer: List[Union[InterpolatedString, str]] = [x for x in model.key_pointer] - model_type_pointer: Optional[List[Union[InterpolatedString, str]]] = ( + model_key_pointer: list[InterpolatedString | str] = [x for x in model.key_pointer] + model_type_pointer: list[InterpolatedString | str] | None = ( [x for x in model.type_pointer] if model.type_pointer else None ) @@ -1783,7 +1820,10 @@ def create_schema_type_identifier( ) def create_dynamic_schema_loader( - self, model: DynamicSchemaLoaderModel, config: Config, **kwargs: Any + self, + model: DynamicSchemaLoaderModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DynamicSchemaLoader: stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) combined_slicers = self._build_resumable_cursor_from_paginator( @@ -1793,7 +1833,7 @@ def create_dynamic_schema_loader( schema_transformations = [] if model.schema_transformations: for transformation_model in model.schema_transformations: - schema_transformations.append( + schema_transformations.append( # noqa: PERF401 self._create_component_from_model(model=transformation_model, config=config) ) @@ -1817,67 +1857,86 @@ def create_dynamic_schema_loader( ) @staticmethod - def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) -> JsonDecoder: + def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) -> JsonDecoder: # noqa: ANN401, ARG004 return JsonDecoder(parameters={}) @staticmethod - def create_json_parser(model: JsonParserModel, config: Config, **kwargs: Any) -> JsonParser: - encoding = model.encoding if model.encoding else "utf-8" + def create_json_parser(model: JsonParserModel, config: Config, **kwargs: Any) -> JsonParser: # noqa: ANN401, ARG004 + encoding = model.encoding if model.encoding else "utf-8" # noqa: FURB110 return JsonParser(encoding=encoding) @staticmethod def create_jsonl_decoder( - model: JsonlDecoderModel, config: Config, **kwargs: Any + model: JsonlDecoderModel, # noqa: ARG004 + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> JsonlDecoder: return JsonlDecoder(parameters={}) @staticmethod def create_json_line_parser( - model: JsonLineParserModel, config: Config, **kwargs: Any + model: JsonLineParserModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> JsonLineParser: return JsonLineParser(encoding=model.encoding) @staticmethod def create_iterable_decoder( - model: IterableDecoderModel, config: Config, **kwargs: Any + model: IterableDecoderModel, # noqa: ARG004 + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> IterableDecoder: return IterableDecoder(parameters={}) @staticmethod - def create_xml_decoder(model: XmlDecoderModel, config: Config, **kwargs: Any) -> XmlDecoder: + def create_xml_decoder(model: XmlDecoderModel, config: Config, **kwargs: Any) -> XmlDecoder: # noqa: ANN401, ARG004 return XmlDecoder(parameters={}) @staticmethod def create_gzipjson_decoder( - model: GzipJsonDecoderModel, config: Config, **kwargs: Any + model: GzipJsonDecoderModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> GzipJsonDecoder: return GzipJsonDecoder(parameters={}, encoding=model.encoding) def create_zipfile_decoder( - self, model: ZipfileDecoderModel, config: Config, **kwargs: Any + self, + model: ZipfileDecoderModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> ZipfileDecoder: parser = self._create_component_from_model(model=model.parser, config=config) return ZipfileDecoder(parser=parser) def create_gzip_parser( - self, model: GzipParserModel, config: Config, **kwargs: Any + self, + model: GzipParserModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> GzipParser: inner_parser = self._create_component_from_model(model=model.inner_parser, config=config) return GzipParser(inner_parser=inner_parser) @staticmethod - def create_csv_parser(model: CsvParserModel, config: Config, **kwargs: Any) -> CsvParser: + def create_csv_parser(model: CsvParserModel, config: Config, **kwargs: Any) -> CsvParser: # noqa: ANN401, ARG004 return CsvParser(encoding=model.encoding, delimiter=model.delimiter) def create_composite_raw_decoder( - self, model: CompositeRawDecoderModel, config: Config, **kwargs: Any + self, + model: CompositeRawDecoderModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> CompositeRawDecoder: parser = self._create_component_from_model(model=model.parser, config=config) return CompositeRawDecoder(parser=parser) @staticmethod def create_json_file_schema_loader( - model: JsonFileSchemaLoaderModel, config: Config, **kwargs: Any + model: JsonFileSchemaLoaderModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> JsonFileSchemaLoader: return JsonFileSchemaLoader( file_path=model.file_path or "", config=config, parameters=model.parameters or {} @@ -1885,7 +1944,9 @@ def create_json_file_schema_loader( @staticmethod def create_jwt_authenticator( - model: JwtAuthenticatorModel, config: Config, **kwargs: Any + model: JwtAuthenticatorModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> JwtAuthenticator: jwt_headers = model.jwt_headers or JwtHeadersModel(kid=None, typ="JWT", cty=None) jwt_payload = model.jwt_payload or JwtPayloadModel(iss=None, sub=None, aud=None) @@ -1909,7 +1970,9 @@ def create_jwt_authenticator( @staticmethod def create_list_partition_router( - model: ListPartitionRouterModel, config: Config, **kwargs: Any + model: ListPartitionRouterModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> ListPartitionRouter: request_option = ( RequestOption( @@ -1930,7 +1993,9 @@ def create_list_partition_router( @staticmethod def create_min_max_datetime( - model: MinMaxDatetimeModel, config: Config, **kwargs: Any + model: MinMaxDatetimeModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> MinMaxDatetime: return MinMaxDatetime( datetime=model.datetime, @@ -1941,17 +2006,22 @@ def create_min_max_datetime( ) @staticmethod - def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth: + def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth: # noqa: ANN401, ARG004 return NoAuth(parameters=model.parameters or {}) @staticmethod def create_no_pagination( - model: NoPaginationModel, config: Config, **kwargs: Any + model: NoPaginationModel, # noqa: ARG004 + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> NoPagination: return NoPagination(parameters={}) def create_oauth_authenticator( - self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any + self, + model: OAuthAuthenticatorModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> DeclarativeOauth2Authenticator: if model.refresh_token_updater: # ignore type error because fixing it would have a lot of dependencies, revisit later @@ -2028,7 +2098,11 @@ def create_oauth_authenticator( ) def create_offset_increment( - self, model: OffsetIncrementModel, config: Config, decoder: Decoder, **kwargs: Any + self, + model: OffsetIncrementModel, + config: Config, + decoder: Decoder, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> OffsetIncrement: if isinstance(decoder, PaginationDecoderDecorator): inner_decoder = decoder.decoder @@ -2053,7 +2127,9 @@ def create_offset_increment( @staticmethod def create_page_increment( - model: PageIncrementModel, config: Config, **kwargs: Any + model: PageIncrementModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> PageIncrement: return PageIncrement( page_size=model.page_size, @@ -2064,7 +2140,10 @@ def create_page_increment( ) def create_parent_stream_config( - self, model: ParentStreamConfigModel, config: Config, **kwargs: Any + self, + model: ParentStreamConfigModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> ParentStreamConfig: declarative_stream = self._create_component_from_model(model.stream, config=config) request_option = ( @@ -2085,19 +2164,23 @@ def create_parent_stream_config( @staticmethod def create_record_filter( - model: RecordFilterModel, config: Config, **kwargs: Any + model: RecordFilterModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> RecordFilter: return RecordFilter( condition=model.condition or "", config=config, parameters=model.parameters or {} ) @staticmethod - def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath: + def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath: # noqa: ANN401, ARG004 return RequestPath(parameters={}) @staticmethod def create_request_option( - model: RequestOptionModel, config: Config, **kwargs: Any + model: RequestOptionModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> RequestOption: inject_into = RequestOptionType(model.inject_into.value) return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={}) @@ -2108,10 +2191,10 @@ def create_record_selector( config: Config, *, name: str, - transformations: List[RecordTransformation] | None = None, + transformations: list[RecordTransformation] | None = None, decoder: Decoder | None = None, - client_side_incremental_sync: Dict[str, Any] | None = None, - **kwargs: Any, + client_side_incremental_sync: dict[str, Any] | None = None, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> RecordSelector: extractor = self._create_component_from_model( model=model.extractor, decoder=decoder, config=config @@ -2148,14 +2231,19 @@ def create_record_selector( @staticmethod def create_remove_fields( - model: RemoveFieldsModel, config: Config, **kwargs: Any + model: RemoveFieldsModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> RemoveFields: return RemoveFields( field_pointers=model.field_pointers, condition=model.condition or "", parameters={} ) def create_selective_authenticator( - self, model: SelectiveAuthenticatorModel, config: Config, **kwargs: Any + self, + model: SelectiveAuthenticatorModel, + config: Config, + **kwargs: Any, # noqa: ANN401 ) -> DeclarativeAuthenticator: authenticators = { name: self._create_component_from_model(model=auth, config=config) @@ -2171,7 +2259,11 @@ def create_selective_authenticator( @staticmethod def create_legacy_session_token_authenticator( - model: LegacySessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs: Any + model: LegacySessionTokenAuthenticatorModel, + config: Config, + *, + url_base: str, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> LegacySessionTokenAuthenticator: return LegacySessionTokenAuthenticator( api_url=url_base, @@ -2186,18 +2278,18 @@ def create_legacy_session_token_authenticator( parameters=model.parameters or {}, ) - def create_simple_retriever( + def create_simple_retriever( # noqa: PLR0913 self, model: SimpleRetrieverModel, config: Config, *, name: str, - primary_key: Optional[Union[str, List[str], List[List[str]]]], - stream_slicer: Optional[StreamSlicer], - request_options_provider: Optional[RequestOptionsProvider] = None, + primary_key: str | list[str] | list[list[str]] | None, + stream_slicer: StreamSlicer | None, + request_options_provider: RequestOptionsProvider | None = None, stop_condition_on_cursor: bool = False, - client_side_incremental_sync: Optional[Dict[str, Any]] = None, - transformations: List[RecordTransformation], + client_side_incremental_sync: dict[str, Any] | None = None, + transformations: list[RecordTransformation], ) -> SimpleRetriever: decoder = ( self._create_component_from_model(model=model.decoder, config=config) @@ -2285,7 +2377,10 @@ def create_simple_retriever( ) def _create_async_job_status_mapping( - self, model: AsyncJobStatusMapModel, config: Config, **kwargs: Any + self, + model: AsyncJobStatusMapModel, + config: Config, # noqa: ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 ) -> Mapping[str, AsyncJobStatus]: api_status_to_cdk_status = {} for cdk_status, api_statuses in model.dict().items(): @@ -2314,19 +2409,20 @@ def _get_async_job_status(self, status: str) -> AsyncJobStatus: case _: raise ValueError(f"Unsupported CDK status {status}") - def create_async_retriever( + def create_async_retriever( # noqa: PLR0914 self, model: AsyncRetrieverModel, config: Config, *, name: str, - primary_key: Optional[ - Union[str, List[str], List[List[str]]] - ], # this seems to be needed to match create_simple_retriever - stream_slicer: Optional[StreamSlicer], - client_side_incremental_sync: Optional[Dict[str, Any]] = None, - transformations: List[RecordTransformation], - **kwargs: Any, + primary_key: str # noqa: ARG002 + | list[str] + | list[list[str]] + | None, # this seems to be needed to match create_simple_retriever + stream_slicer: StreamSlicer | None, + client_side_incremental_sync: dict[str, Any] | None = None, + transformations: list[RecordTransformation], + **kwargs: Any, # noqa: ANN401, ARG002 ) -> AsyncRetriever: decoder = ( self._create_component_from_model(model=model.decoder, config=config) @@ -2457,10 +2553,10 @@ def create_async_retriever( job_repository, stream_slices, JobTracker(1), - # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 + # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 # noqa: FIX001, TD001, TD004 self._message_repository, has_bulk_parent=False, - # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk + # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk # noqa: FIX001, TD001, TD004 ), stream_slicer=stream_slicer, config=config, @@ -2475,7 +2571,7 @@ def create_async_retriever( ) @staticmethod - def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec: + def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec: # noqa: ANN401, ARG004 return Spec( connection_specification=model.connection_specification, documentation_url=model.documentation_url, @@ -2484,18 +2580,19 @@ def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec: ) def create_substream_partition_router( - self, model: SubstreamPartitionRouterModel, config: Config, **kwargs: Any + self, + model: SubstreamPartitionRouterModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG002 ) -> SubstreamPartitionRouter: parent_stream_configs = [] if model.parent_stream_configs: - parent_stream_configs.extend( - [ - self._create_message_repository_substream_wrapper( - model=parent_stream_config, config=config - ) - for parent_stream_config in model.parent_stream_configs - ] - ) + parent_stream_configs.extend([ + self._create_message_repository_substream_wrapper( + model=parent_stream_config, config=config + ) + for parent_stream_config in model.parent_stream_configs + ]) return SubstreamPartitionRouter( parent_stream_configs=parent_stream_configs, @@ -2505,7 +2602,7 @@ def create_substream_partition_router( def _create_message_repository_substream_wrapper( self, model: ParentStreamConfigModel, config: Config - ) -> Any: + ) -> Any: # noqa: ANN401 substream_factory = ModelToComponentFactory( limit_pages_fetched_per_slice=self._limit_pages_fetched_per_slice, limit_slices_fetched=self._limit_slices_fetched, @@ -2518,11 +2615,13 @@ def _create_message_repository_substream_wrapper( self._evaluate_log_level(self._emit_connector_builder_messages), ), ) - return substream_factory._create_component_from_model(model=model, config=config) + return substream_factory._create_component_from_model(model=model, config=config) # noqa: SLF001 @staticmethod def create_wait_time_from_header( - model: WaitTimeFromHeaderModel, config: Config, **kwargs: Any + model: WaitTimeFromHeaderModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> WaitTimeFromHeaderBackoffStrategy: return WaitTimeFromHeaderBackoffStrategy( header=model.header, @@ -2536,7 +2635,9 @@ def create_wait_time_from_header( @staticmethod def create_wait_until_time_from_header( - model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs: Any + model: WaitUntilTimeFromHeaderModel, + config: Config, + **kwargs: Any, # noqa: ANN401, ARG004 ) -> WaitUntilTimeFromHeaderBackoffStrategy: return WaitUntilTimeFromHeaderBackoffStrategy( header=model.header, @@ -2549,12 +2650,14 @@ def create_wait_until_time_from_header( def get_message_repository(self) -> MessageRepository: return self._message_repository - def _evaluate_log_level(self, emit_connector_builder_messages: bool) -> Level: + def _evaluate_log_level(self, emit_connector_builder_messages: bool) -> Level: # noqa: FBT001 return Level.DEBUG if emit_connector_builder_messages else Level.INFO @staticmethod def create_components_mapping_definition( - model: ComponentMappingDefinitionModel, config: Config, **kwargs: Any + model: ComponentMappingDefinitionModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> ComponentMappingDefinition: interpolated_value = InterpolatedString.create( model.value, parameters=model.parameters or {} @@ -2572,7 +2675,7 @@ def create_components_mapping_definition( def create_http_components_resolver( self, model: HttpComponentsResolverModel, config: Config - ) -> Any: + ) -> Any: # noqa: ANN401 stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) combined_slicers = self._build_resumable_cursor_from_paginator( model.retriever, stream_slicer @@ -2583,7 +2686,7 @@ def create_http_components_resolver( config=config, name="", primary_key=None, - stream_slicer=stream_slicer if stream_slicer else combined_slicers, + stream_slicer=stream_slicer if stream_slicer else combined_slicers, # noqa: FURB110 transformations=[], ) @@ -2607,9 +2710,11 @@ def create_http_components_resolver( @staticmethod def create_stream_config( - model: StreamConfigModel, config: Config, **kwargs: Any + model: StreamConfigModel, + config: Config, # noqa: ARG004 + **kwargs: Any, # noqa: ANN401, ARG004 ) -> StreamConfig: - model_configs_pointer: List[Union[InterpolatedString, str]] = ( + model_configs_pointer: list[InterpolatedString | str] = ( [x for x in model.configs_pointer] if model.configs_pointer else [] ) @@ -2620,7 +2725,7 @@ def create_stream_config( def create_config_components_resolver( self, model: ConfigComponentsResolverModel, config: Config - ) -> Any: + ) -> Any: # noqa: ANN401 stream_config = self._create_component_from_model( model.stream_config, config=config, parameters=model.parameters or {} ) @@ -2650,17 +2755,15 @@ def create_config_components_resolver( ) def _is_supported_decoder_for_pagination(self, decoder: Decoder) -> bool: - if isinstance(decoder, (JsonDecoder, XmlDecoder)): + if isinstance(decoder, (JsonDecoder, XmlDecoder)): # noqa: UP038 return True - elif isinstance(decoder, CompositeRawDecoder): + if isinstance(decoder, CompositeRawDecoder): return self._is_supported_parser_for_pagination(decoder.parser) - else: - return False + return False def _is_supported_parser_for_pagination(self, parser: Parser) -> bool: if isinstance(parser, JsonParser): return True - elif isinstance(parser, GzipParser): + if isinstance(parser, GzipParser): return isinstance(parser.inner_parser, JsonParser) - else: - return False + return False diff --git a/airbyte_cdk/sources/declarative/partition_routers/__init__.py b/airbyte_cdk/sources/declarative/partition_routers/__init__.py index f35647402..5a6ecfcd0 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/__init__.py +++ b/airbyte_cdk/sources/declarative/partition_routers/__init__.py @@ -19,6 +19,7 @@ SubstreamPartitionRouter, ) + __all__ = [ "AsyncJobPartitionRouter", "CartesianProductStreamSlicer", diff --git a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py index 0f11820f7..2eccd8fe4 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from collections.abc import Callable, Iterable, Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Callable, Iterable, Mapping, Optional +from typing import Any from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.async_job.job_orchestrator import ( @@ -33,7 +34,7 @@ class AsyncJobPartitionRouter(StreamSlicer): def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._job_orchestrator_factory = self.job_orchestrator_factory - self._job_orchestrator: Optional[AsyncJobOrchestrator] = None + self._job_orchestrator: AsyncJobOrchestrator | None = None self._parameters = parameters def stream_slices(self) -> Iterable[StreamSlice]: diff --git a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py index 8718004bf..47bfc36b2 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py +++ b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py @@ -5,9 +5,9 @@ import itertools import logging from collections import ChainMap -from collections.abc import Callable +from collections.abc import Callable, Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( @@ -31,7 +31,7 @@ def check_for_substream_in_slicers( if isinstance(slicer, SubstreamPartitionRouter): log_warning("Parent state handling is not supported for CartesianProductStreamSlicer.") return - elif isinstance(slicer, CartesianProductStreamSlicer): + if isinstance(slicer, CartesianProductStreamSlicer): # Recursively check sub-slicers within CartesianProductStreamSlicer check_for_substream_in_slicers(slicer.stream_slicers, log_warning) @@ -57,7 +57,7 @@ class CartesianProductStreamSlicer(PartitionRouter): stream_slicers (List[PartitionRouter]): Underlying stream slicers. The RequestOptions (e.g: Request headers, parameters, etc..) returned by this slicer are the combination of the RequestOptions of its input slicers. If there are conflicts e.g: two slicers define the same header or request param, the conflict is resolved by taking the value from the first slicer, where ordering is determined by the order in which slicers were input to this composite slicer. """ - stream_slicers: List[PartitionRouter] + stream_slicers: list[PartitionRouter] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -66,81 +66,73 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( - ChainMap( - *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_params( - stream_state=stream_state, - stream_slice=stream_slice, - next_page_token=next_page_token, - ) - for s in self.stream_slicers - ] - ) + ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons + s.get_request_params( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + for s in self.stream_slicers + ]) ) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( - ChainMap( - *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_headers( - stream_state=stream_state, - stream_slice=stream_slice, - next_page_token=next_page_token, - ) - for s in self.stream_slicers - ] - ) + ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons + s.get_request_headers( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + for s in self.stream_slicers + ]) ) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( - ChainMap( - *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_body_data( - stream_state=stream_state, - stream_slice=stream_slice, - next_page_token=next_page_token, - ) - for s in self.stream_slicers - ] - ) + ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons + s.get_request_body_data( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + for s in self.stream_slicers + ]) ) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return dict( - ChainMap( - *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons - s.get_request_body_json( - stream_state=stream_state, - stream_slice=stream_slice, - next_page_token=next_page_token, - ) - for s in self.stream_slicers - ] - ) + ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons + s.get_request_body_json( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ) + for s in self.stream_slicers + ]) ) def stream_slices(self) -> Iterable[StreamSlice]: @@ -153,7 +145,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: raise ValueError( f"There should only be a single cursor slice. Found {cursor_slices}" ) - if cursor_slices: + if cursor_slices: # noqa: SIM108 cursor_slice = cursor_slices[0] else: cursor_slice = {} @@ -165,7 +157,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: """ pass - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + def get_stream_state(self) -> Mapping[str, StreamState] | None: """ Parent stream states are not supported for cartesian product stream slicer """ diff --git a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py index 29b700b04..ed5c127e1 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter @@ -27,11 +28,11 @@ class ListPartitionRouter(PartitionRouter): request_option (Optional[RequestOption]): The request option to configure the HTTP request """ - values: Union[str, List[str]] - cursor_field: Union[InterpolatedString, str] + values: str | list[str] + cursor_field: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - request_option: Optional[RequestOption] = None + request_option: RequestOption | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.values, str): @@ -48,36 +49,36 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.header, stream_slice) def get_request_body_data( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_json, stream_slice) @@ -91,7 +92,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: ] def _get_request_option( - self, request_option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, request_option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: if ( self.request_option @@ -101,10 +102,8 @@ def _get_request_option( slice_value = stream_slice.get(self._cursor_field.eval(self.config)) if slice_value: return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString - else: - return {} - else: return {} + return {} def set_initial_state(self, stream_state: StreamState) -> None: """ @@ -112,7 +111,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: """ pass - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + def get_stream_state(self) -> Mapping[str, StreamState] | None: """ ListPartitionRouter doesn't have parent streams """ diff --git a/airbyte_cdk/sources/declarative/partition_routers/partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/partition_router.py index 3a9bc3abf..d0a1c062a 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/partition_router.py @@ -3,8 +3,8 @@ # from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Mapping, Optional from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer from airbyte_cdk.sources.types import StreamState @@ -41,7 +41,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: """ @abstractmethod - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + def get_stream_state(self) -> Mapping[str, StreamState] | None: """ Get the state of the parent streams. diff --git a/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py index 32e6a353d..b6f5a3b7f 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional +from typing import Any from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import StreamSlice, StreamState @@ -17,33 +18,33 @@ class SinglePartitionRouter(PartitionRouter): def get_request_params( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_body_data( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_body_json( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} @@ -56,7 +57,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: """ pass - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + def get_stream_state(self) -> Mapping[str, StreamState] | None: """ SinglePartitionRouter doesn't have parent streams """ diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 1c7bb6961..ca0c5e61e 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -3,8 +3,9 @@ # import copy import logging +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any import dpath @@ -19,6 +20,7 @@ from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException + if TYPE_CHECKING: from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream @@ -37,14 +39,14 @@ class ParentStreamConfig: """ stream: "DeclarativeStream" # Parent streams must be DeclarativeStream because we can't know which part of the stream slice is a partition for regular Stream - parent_key: Union[InterpolatedString, str] - partition_field: Union[InterpolatedString, str] + parent_key: InterpolatedString | str + partition_field: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - extra_fields: Optional[Union[List[List[str]], List[List[InterpolatedString]]]] = ( + extra_fields: list[list[str]] | list[list[InterpolatedString]] | None = ( None # List of field paths (arrays of strings) ) - request_option: Optional[RequestOption] = None + request_option: RequestOption | None = None incremental_dependency: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -70,7 +72,7 @@ class SubstreamPartitionRouter(PartitionRouter): parent_stream_configs (List[ParentStreamConfig]): parent streams to iterate over and their config """ - parent_stream_configs: List[ParentStreamConfig] + parent_stream_configs: list[ParentStreamConfig] config: Config parameters: InitVar[Mapping[str, Any]] @@ -81,42 +83,42 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.header, stream_slice) def get_request_body_data( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_json, stream_slice) def _get_request_option( - self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: params = {} if stream_slice: @@ -128,13 +130,11 @@ def _get_request_option( key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string value = stream_slice.get(key) if value: - params.update( - { - parent_config.request_option.field_name.eval( # type: ignore [union-attr] - config=self.config - ): value - } - ) + params.update({ + parent_config.request_option.field_name.eval( # type: ignore [union-attr] + config=self.config + ): value + }) return params def stream_slices(self) -> Iterable[StreamSlice]: @@ -176,7 +176,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: f"Parent stream {parent_stream.name} returns records of type AirbyteMessage. This SubstreamPartitionRouter is not able to checkpoint incremental parent state." ) if parent_record.type == MessageType.RECORD: - parent_record = parent_record.record.data # type: ignore[union-attr, assignment] # record is always a Record + parent_record = parent_record.record.data # type: ignore[union-attr, assignment] # record is always a Record # noqa: PLW2901 else: continue elif isinstance(parent_record, Record): @@ -185,7 +185,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: if parent_record.associated_slice else {} ) - parent_record = parent_record.data + parent_record = parent_record.data # noqa: PLW2901 elif not isinstance(parent_record, Mapping): # The parent_record should only take the form of a Record, AirbyteMessage, or Mapping. Anything else is invalid raise AirbyteTracedException( @@ -214,7 +214,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: def _extract_extra_fields( self, parent_record: Mapping[str, Any] | AirbyteMessage, - extra_fields: Optional[List[List[str]]] = None, + extra_fields: list[list[str]] | None = None, ) -> Mapping[str, Any]: """ Extracts additional fields specified by their paths from the parent record. @@ -289,7 +289,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: # If `parent_state` doesn't exist and at least one parent stream has an incremental dependency, # copy the child state to parent streams with incremental dependencies. incremental_dependency = any( - [parent_config.incremental_dependency for parent_config in self.parent_stream_configs] + [parent_config.incremental_dependency for parent_config in self.parent_stream_configs] # noqa: C419 ) if not parent_state and not incremental_dependency: return @@ -313,7 +313,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: if parent_config.incremental_dependency: parent_config.stream.state = parent_state.get(parent_config.stream.name, {}) - def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + def get_stream_state(self) -> Mapping[str, StreamState] | None: """ Get the state of the parent streams. diff --git a/airbyte_cdk/sources/declarative/requesters/__init__.py b/airbyte_cdk/sources/declarative/requesters/__init__.py index e5266ea7c..ae6241428 100644 --- a/airbyte_cdk/sources/declarative/requesters/__init__.py +++ b/airbyte_cdk/sources/declarative/requesters/__init__.py @@ -6,4 +6,5 @@ from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption from airbyte_cdk.sources.declarative.requesters.requester import Requester + __all__ = ["HttpRequester", "RequestOption", "Requester"] diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py index 099aa4286..6cbeaec1a 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py @@ -16,6 +16,7 @@ HttpResponseFilter, ) + __all__ = [ "BackoffStrategy", "CompositeErrorHandler", diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py index 26ecafbde..ce3cc015d 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py @@ -15,6 +15,7 @@ WaitUntilTimeFromHeaderBackoffStrategy, ) + __all__ = [ "ConstantBackoffStrategy", "ExponentialBackoffStrategy", diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py index 26c7c7673..5c12ada8a 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests @@ -21,7 +22,7 @@ class ConstantBackoffStrategy(BackoffStrategy): backoff_time_in_seconds (float): time to backoff before retrying a retryable request. """ - backoff_time_in_seconds: Union[float, InterpolatedString, str] + backoff_time_in_seconds: float | InterpolatedString | str parameters: InitVar[Mapping[str, Any]] config: Config @@ -39,7 +40,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], - attempt_count: int, - ) -> Optional[float]: + response_or_exception: requests.Response | requests.RequestException | None, # noqa: ARG002 + attempt_count: int, # noqa: ARG002 + ) -> float | None: return self.backoff_time_in_seconds.eval(self.config) # type: ignore # backoff_time_in_seconds is always cast to an interpolated string diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py index cdd1fe650..475c23782 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests @@ -23,7 +24,7 @@ class ExponentialBackoffStrategy(BackoffStrategy): parameters: InitVar[Mapping[str, Any]] config: Config - factor: Union[float, InterpolatedString, str] = 5 + factor: float | InterpolatedString | str = 5 def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not isinstance(self.factor, InterpolatedString): @@ -39,7 +40,7 @@ def _retry_factor(self) -> float: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, # noqa: ARG002 attempt_count: int, - ) -> Optional[float]: + ) -> float | None: return self._retry_factor * 2**attempt_count # type: ignore # factor is always cast to an interpolated string diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py index 60103f343..2a2214793 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py @@ -4,14 +4,13 @@ import numbers from re import Pattern -from typing import Optional import requests def get_numeric_value_from_header( - response: requests.Response, header: str, regex: Optional[Pattern[str]] -) -> Optional[float]: + response: requests.Response, header: str, regex: Pattern[str] | None +) -> float | None: """ Extract a header value from the response as a float :param response: response the extract header value from @@ -28,13 +27,12 @@ def get_numeric_value_from_header( if match: header_value = match.group() return _as_float(header_value) - elif isinstance(header_value, numbers.Number): + if isinstance(header_value, numbers.Number): return float(header_value) # type: ignore[arg-type] - else: - return None + return None -def _as_float(s: str) -> Optional[float]: +def _as_float(s: str) -> float | None: try: return float(s) except ValueError: diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py index 5cda96a4d..7f42802c9 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py @@ -3,8 +3,9 @@ # import re +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests @@ -31,11 +32,11 @@ class WaitTimeFromHeaderBackoffStrategy(BackoffStrategy): max_waiting_time_in_seconds: (Optional[float]): given the value extracted from the header is greater than this value, stop the stream """ - header: Union[InterpolatedString, str] + header: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] config: Config - regex: Optional[Union[InterpolatedString, str]] = None - max_waiting_time_in_seconds: Optional[float] = None + regex: InterpolatedString | str | None = None + max_waiting_time_in_seconds: float | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.regex = ( @@ -45,9 +46,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], - attempt_count: int, - ) -> Optional[float]: + response_or_exception: requests.Response | requests.RequestException | None, + attempt_count: int, # noqa: ARG002 + ) -> float | None: header = self.header.eval(config=self.config) # type: ignore # header is always cast to an interpolated stream if self.regex: evaled_regex = self.regex.eval(self.config) # type: ignore # header is always cast to an interpolated string diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py index 1220e198f..2bbad2722 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py @@ -5,8 +5,9 @@ import numbers import re import time +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests @@ -32,11 +33,11 @@ class WaitUntilTimeFromHeaderBackoffStrategy(BackoffStrategy): regex (Optional[str]): optional regex to apply on the header to extract its value """ - header: Union[InterpolatedString, str] + header: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] config: Config - min_wait: Optional[Union[float, InterpolatedString, str]] = None - regex: Optional[Union[InterpolatedString, str]] = None + min_wait: float | InterpolatedString | str | None = None + regex: InterpolatedString | str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.header = InterpolatedString.create(self.header, parameters=parameters) @@ -48,9 +49,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], - attempt_count: int, - ) -> Optional[float]: + response_or_exception: requests.Response | requests.RequestException | None, + attempt_count: int, # noqa: ARG002 + ) -> float | None: now = time.time() header = self.header.eval(self.config) # type: ignore # header is always cast to an interpolated string if self.regex: @@ -72,6 +73,6 @@ def backoff_time( return float(min_wait) if min_wait: return float(max(wait_time, min_wait)) - elif wait_time < 0: + if wait_time < 0: return None return wait_time diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py index fc4219134..3737b3dd0 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, List, Mapping, Optional, Union +from typing import Any import requests @@ -40,7 +41,7 @@ class CompositeErrorHandler(ErrorHandler): error_handlers (List[ErrorHandler]): list of error handlers """ - error_handlers: List[ErrorHandler] + error_handlers: list[ErrorHandler] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -48,15 +49,15 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: raise ValueError("CompositeErrorHandler expects at least 1 underlying error handler") @property - def max_retries(self) -> Optional[int]: + def max_retries(self) -> int | None: return self.error_handlers[0].max_retries @property - def max_time(self) -> Optional[int]: - return max([error_handler.max_time or 0 for error_handler in self.error_handlers]) + def max_time(self) -> int | None: + return max([error_handler.max_time or 0 for error_handler in self.error_handlers]) # noqa: C419 def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] + self, response_or_exception: requests.Response | Exception | None ) -> ErrorResolution: matched_error_resolution = None for error_handler in self.error_handlers: @@ -69,7 +70,7 @@ def interpret_response( return matched_error_resolution if ( - matched_error_resolution.response_action == ResponseAction.RETRY + matched_error_resolution.response_action == ResponseAction.RETRY # noqa: PLR1714 or matched_error_resolution.response_action == ResponseAction.IGNORE ): return matched_error_resolution diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py index b70ceaaeb..3ebe3dc95 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, MutableMapping, Optional, Union +from typing import Any import requests @@ -95,12 +96,12 @@ class DefaultErrorHandler(ErrorHandler): parameters: InitVar[Mapping[str, Any]] config: Config - response_filters: Optional[List[HttpResponseFilter]] = None - max_retries: Optional[int] = 5 + response_filters: list[HttpResponseFilter] | None = None + max_retries: int | None = 5 max_time: int = 60 * 10 _max_retries: int = field(init=False, repr=False, default=5) _max_time: int = field(init=False, repr=False, default=60 * 10) - backoff_strategies: Optional[List[BackoffStrategy]] = None + backoff_strategies: list[BackoffStrategy] | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not self.response_filters: @@ -109,7 +110,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._last_request_to_attempt_count: MutableMapping[requests.PreparedRequest, int] = {} def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] + self, response_or_exception: requests.Response | Exception | None ) -> ErrorResolution: if self.response_filters: for response_filter in self.response_filters: @@ -118,7 +119,7 @@ def interpret_response( ) if matched_error_resolution: return matched_error_resolution - if isinstance(response_or_exception, requests.Response): + if isinstance(response_or_exception, requests.Response): # noqa: SIM102 if response_or_exception.ok: return SUCCESS_RESOLUTION @@ -126,16 +127,16 @@ def interpret_response( default_response_filter_resolution = default_reponse_filter.matches(response_or_exception) return ( - default_response_filter_resolution + default_response_filter_resolution # noqa: FURB110 if default_response_filter_resolution else create_fallback_error_resolution(response_or_exception) ) def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int = 0, - ) -> Optional[float]: + ) -> float | None: backoff = None if self.backoff_strategies: for backoff_strategy in self.backoff_strategies: diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py index 9943a0d6a..fbde89f9e 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py @@ -2,7 +2,6 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # -from typing import Optional, Union import requests @@ -20,12 +19,12 @@ class DefaultHttpResponseFilter(HttpResponseFilter): def matches( - self, response_or_exception: Optional[Union[requests.Response, Exception]] - ) -> Optional[ErrorResolution]: + self, response_or_exception: requests.Response | Exception | None + ) -> ErrorResolution | None: default_mapped_error_resolution = None - if isinstance(response_or_exception, (requests.Response, Exception)): - mapped_key: Union[int, type] = ( + if isinstance(response_or_exception, (requests.Response, Exception)): # noqa: UP038 + mapped_key: int | type = ( response_or_exception.status_code if isinstance(response_or_exception, requests.Response) else response_or_exception.__class__ @@ -34,7 +33,7 @@ def matches( default_mapped_error_resolution = DEFAULT_ERROR_MAPPING.get(mapped_key) return ( - default_mapped_error_resolution + default_mapped_error_resolution # noqa: FURB110 if default_mapped_error_resolution else create_fallback_error_resolution(response_or_exception) ) diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py index a2fc80007..34a9273ea 100644 --- a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py +++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Set, Union +from typing import Any import requests @@ -40,12 +41,12 @@ class HttpResponseFilter: config: Config parameters: InitVar[Mapping[str, Any]] - action: Optional[Union[ResponseAction, str]] = None - failure_type: Optional[Union[FailureType, str]] = None - http_codes: Optional[Set[int]] = None - error_message_contains: Optional[str] = None - predicate: Union[InterpolatedBoolean, str] = "" - error_message: Union[InterpolatedString, str] = "" + action: ResponseAction | str | None = None + failure_type: FailureType | str | None = None + http_codes: set[int] | None = None + error_message_contains: str | None = None + predicate: InterpolatedBoolean | str = "" + error_message: InterpolatedString | str = "" def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.action is not None: @@ -57,7 +58,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: raise ValueError( "HttpResponseFilter requires a filter condition if an action is specified" ) - elif isinstance(self.action, str): + if isinstance(self.action, str): self.action = ResponseAction[self.action] self.http_codes = self.http_codes or set() if isinstance(self.predicate, str): @@ -70,8 +71,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.failure_type = FailureType[self.failure_type] def matches( - self, response_or_exception: Optional[Union[requests.Response, Exception]] - ) -> Optional[ErrorResolution]: + self, response_or_exception: requests.Response | Exception | None + ) -> ErrorResolution | None: filter_action = self._matches_filter(response_or_exception) mapped_key = ( response_or_exception.status_code @@ -79,7 +80,7 @@ def matches( else response_or_exception.__class__ ) - if isinstance(mapped_key, (int, Exception)): + if isinstance(mapped_key, (int, Exception)): # noqa: UP038 default_mapped_error_resolution = self._match_default_error_mapping(mapped_key) else: default_mapped_error_resolution = None @@ -118,13 +119,13 @@ def matches( return None def _match_default_error_mapping( - self, mapped_key: Union[int, type[Exception]] - ) -> Optional[ErrorResolution]: + self, mapped_key: int | type[Exception] + ) -> ErrorResolution | None: return DEFAULT_ERROR_MAPPING.get(mapped_key) def _matches_filter( - self, response_or_exception: Optional[Union[requests.Response, Exception]] - ) -> Optional[ResponseAction]: + self, response_or_exception: requests.Response | Exception | None + ) -> ResponseAction | None: """ Apply the HTTP filter on the response and return the action to execute if it matches :param response: The HTTP response to evaluate @@ -145,7 +146,7 @@ def _safe_response_json(response: requests.Response) -> dict[str, Any]: except requests.exceptions.JSONDecodeError: return {} - def _create_error_message(self, response: requests.Response) -> Optional[str]: + def _create_error_message(self, response: requests.Response) -> str | None: """ Construct an error message based on the specified message template of the filter. :param response: The HTTP response which can be used during interpolation @@ -172,8 +173,5 @@ def _response_matches_predicate(self, response: requests.Response) -> bool: def _response_contains_error_message(self, response: requests.Response) -> bool: if not self.error_message_contains: return False - else: - error_message = self._error_message_parser.parse_response_error_message( - response=response - ) - return bool(error_message and self.error_message_contains in error_message) + error_message = self._error_message_parser.parse_response_error_message(response=response) + return bool(error_message and self.error_message_contains in error_message) diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index fce146fd8..cababa2d2 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -1,9 +1,10 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. import logging import uuid +from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Any import requests from requests import Response @@ -26,6 +27,7 @@ from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.utils import AirbyteTracedException + LOGGER = logging.getLogger("airbyte") @@ -38,23 +40,23 @@ class AsyncHttpJobRepository(AsyncJobRepository): creation_requester: Requester polling_requester: Requester download_retriever: SimpleRetriever - abort_requester: Optional[Requester] - delete_requester: Optional[Requester] + abort_requester: Requester | None + delete_requester: Requester | None status_extractor: DpathExtractor status_mapping: Mapping[str, AsyncJobStatus] urls_extractor: DpathExtractor - job_timeout: Optional[timedelta] = None + job_timeout: timedelta | None = None record_extractor: RecordExtractor = field( init=False, repr=False, default_factory=lambda: ResponseToFileExtractor({}) ) - url_requester: Optional[Requester] = ( + url_requester: Requester | None = ( None # use it in case polling_requester provides some and extra request is needed to obtain list of urls to download from ) def __post_init__(self) -> None: - self._create_job_response_by_id: Dict[str, Response] = {} - self._polling_job_response_by_id: Dict[str, Response] = {} + self._create_job_response_by_id: dict[str, Response] = {} + self._polling_job_response_by_id: dict[str, Response] = {} def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests.Response: """ @@ -70,7 +72,7 @@ def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests AirbyteTracedException: If the polling request returns an empty response. """ - polling_response: Optional[requests.Response] = self.polling_requester.send_request( + polling_response: requests.Response | None = self.polling_requester.send_request( stream_slice=stream_slice ) if polling_response is None: @@ -117,7 +119,7 @@ def _start_job_and_validate_response(self, stream_slice: StreamSlice) -> request AirbyteTracedException: If no response is received from the creation requester. """ - response: Optional[requests.Response] = self.creation_requester.send_request( + response: requests.Response | None = self.creation_requester.send_request( stream_slice=stream_slice ) if not response: @@ -168,13 +170,13 @@ def update_jobs_status(self, jobs: Iterable[AsyncJob]) -> None: lazy_log( LOGGER, logging.DEBUG, - lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}", + lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}", # noqa: B023 ) else: lazy_log( LOGGER, logging.DEBUG, - lambda: f"Status of job {job.api_job_id()} is still {job.status()}", + lambda: f"Status of job {job.api_job_id()} is still {job.status()}", # noqa: B023 ) job.update_status(job_status) @@ -206,7 +208,7 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: elif isinstance(message, AirbyteMessage): if message.type == Type.RECORD: yield message.record.data # type: ignore # message.record won't be None here as the message is a record - elif isinstance(message, (dict, Mapping)): + elif isinstance(message, (dict, Mapping)): # noqa: UP038 yield message else: raise TypeError(f"Unknown type `{type(message)}` for message") diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 35d4b0f11..49615b4c1 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -4,8 +4,9 @@ import logging import os +from collections.abc import Callable, Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Callable, Mapping, MutableMapping, Optional, Union +from typing import Any from urllib.parse import urljoin import requests @@ -47,16 +48,16 @@ class HttpRequester(Requester): """ name: str - url_base: Union[InterpolatedString, str] - path: Union[InterpolatedString, str] + url_base: InterpolatedString | str + path: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - authenticator: Optional[DeclarativeAuthenticator] = None - http_method: Union[str, HttpMethod] = HttpMethod.GET - request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None - error_handler: Optional[ErrorHandler] = None + authenticator: DeclarativeAuthenticator | None = None + http_method: str | HttpMethod = HttpMethod.GET + request_options_provider: InterpolatedRequestOptionsProvider | None = None + error_handler: ErrorHandler | None = None disable_retries: bool = False - message_repository: MessageRepository = NoopMessageRepository() + message_repository: MessageRepository = NoopMessageRepository() # noqa: RUF009 use_cache: bool = False _exit_on_rate_limit: bool = False stream_response: bool = False @@ -110,14 +111,14 @@ def get_authenticator(self) -> DeclarativeAuthenticator: return self._authenticator def get_url_base(self) -> str: - return os.path.join(self._url_base.eval(self.config), "") + return os.path.join(self._url_base.eval(self.config), "") # noqa: PTH118 def get_path( self, *, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, ) -> str: kwargs = { "stream_state": stream_state, @@ -133,9 +134,9 @@ def get_method(self) -> HttpMethod: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return self._request_options_provider.get_request_params( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -144,9 +145,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._request_options_provider.get_request_headers( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -156,10 +157,10 @@ def get_request_headers( def get_request_body_data( # type: ignore self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return ( self._request_options_provider.get_request_body_data( stream_state=stream_state, @@ -173,10 +174,10 @@ def get_request_body_data( # type: ignore def get_request_body_json( # type: ignore self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: return self._request_options_provider.get_request_body_json( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token ) @@ -187,36 +188,34 @@ def logger(self) -> logging.Logger: def _get_request_options( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - requester_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - auth_options_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - extra_options: Optional[Union[Mapping[str, Any], str]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + requester_method: Callable[..., Mapping[str, Any] | str | None], + auth_options_method: Callable[..., Mapping[str, Any] | str | None], + extra_options: Mapping[str, Any] | str | None = None, + ) -> Mapping[str, Any] | str: """ Get the request_option from the requester, the authenticator and extra_options passed in. Raise a ValueError if there's a key collision Returned merged mapping otherwise """ - return combine_mappings( - [ - requester_method( - stream_state=stream_state, - stream_slice=stream_slice, - next_page_token=next_page_token, - ), - auth_options_method(), - extra_options, - ] - ) + return combine_mappings([ + requester_method( + stream_state=stream_state, + stream_slice=stream_slice, + next_page_token=next_page_token, + ), + auth_options_method(), + extra_options, + ]) def _request_headers( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - extra_headers: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + extra_headers: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies request headers. @@ -231,15 +230,15 @@ def _request_headers( extra_headers, ) if isinstance(headers, str): - raise ValueError("Request headers cannot be a string") + raise ValueError("Request headers cannot be a string") # noqa: TRY004 return {str(k): str(v) for k, v in headers.items()} def _request_params( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - extra_params: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + extra_params: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. @@ -255,11 +254,11 @@ def _request_params( extra_params, ) if isinstance(options, str): - raise ValueError("Request params cannot be a string") + raise ValueError("Request params cannot be a string") # noqa: TRY004 for k, v in options.items(): - if isinstance(v, (dict,)): - raise ValueError( + if isinstance(v, (dict,)): # noqa: UP038 + raise ValueError( # noqa: TRY004 f"Invalid value for `{k}` parameter. The values of request params cannot be an object." ) @@ -267,11 +266,11 @@ def _request_params( def _request_body_data( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - extra_body_data: Optional[Union[Mapping[str, Any], str]] = None, - ) -> Optional[Union[Mapping[str, Any], str]]: + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + extra_body_data: Mapping[str, Any] | str | None = None, + ) -> Mapping[str, Any] | str | None: """ Specifies how to populate the body of the request with a non-JSON payload. @@ -293,11 +292,11 @@ def _request_body_data( def _request_body_json( self, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - extra_body_json: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + extra_body_json: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: """ Specifies how to populate the body of the request with a JSON payload. @@ -313,26 +312,26 @@ def _request_body_json( extra_body_json, ) if isinstance(options, str): - raise ValueError("Request body json cannot be a string") + raise ValueError("Request body json cannot be a string") # noqa: TRY004 return options @classmethod def _join_url(cls, url_base: str, path: str) -> str: return urljoin(url_base, path) - def send_request( + def send_request( # noqa: PLR0913, PLR0917 self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - path: Optional[str] = None, - request_headers: Optional[Mapping[str, Any]] = None, - request_params: Optional[Mapping[str, Any]] = None, - request_body_data: Optional[Union[Mapping[str, Any], str]] = None, - request_body_json: Optional[Mapping[str, Any]] = None, - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - ) -> Optional[requests.Response]: - request, response = self._http_client.send_request( + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + path: str | None = None, + request_headers: Mapping[str, Any] | None = None, + request_params: Mapping[str, Any] | None = None, + request_body_data: Mapping[str, Any] | str | None = None, + request_body_json: Mapping[str, Any] | None = None, + log_formatter: Callable[[requests.Response], Any] | None = None, + ) -> requests.Response | None: + request, response = self._http_client.send_request( # noqa: F841 http_method=self.get_method().value, url=self._join_url( self.get_url_base(), diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py b/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py index 3b077ec0c..9c5463752 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py @@ -12,6 +12,7 @@ PaginationStrategy, ) + __all__ = [ "DefaultPaginator", "NoPagination", diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py index 59255c75b..79fadee6a 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any import requests @@ -97,13 +98,13 @@ class DefaultPaginator(Paginator): pagination_strategy: PaginationStrategy config: Config - url_base: Union[InterpolatedString, str] + url_base: InterpolatedString | str parameters: InitVar[Mapping[str, Any]] decoder: Decoder = field( default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) ) - page_size_option: Optional[RequestOption] = None - page_token_option: Optional[Union[RequestPath, RequestOption]] = None + page_size_option: RequestOption | None = None + page_token_option: RequestPath | RequestOption | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.page_size_option and not self.pagination_strategy.get_page_size(): @@ -113,7 +114,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.url_base, str): self.url_base = InterpolatedString(string=self.url_base, parameters=parameters) - def get_initial_token(self) -> Optional[Any]: + def get_initial_token(self) -> Any | None: # noqa: ANN401 """ Return the page token that should be used for the first request of a stream @@ -126,9 +127,9 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any] = None, - ) -> Optional[Mapping[str, Any]]: + last_record: Record | None, + last_page_token_value: Any | None = None, # noqa: ANN401 + ) -> Mapping[str, Any] | None: next_page_token = self.pagination_strategy.next_page_token( response=response, last_page_size=last_page_size, @@ -137,55 +138,53 @@ def next_page_token( ) if next_page_token: return {"next_page_token": next_page_token} - else: - return None + return None - def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: + def path(self, next_page_token: Mapping[str, Any] | None) -> str | None: token = next_page_token.get("next_page_token") if next_page_token else None if token and self.page_token_option and isinstance(self.page_token_option, RequestPath): # Replace url base to only return the path return str(token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__ - else: - return None + return None def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter, next_page_token) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, str]: return self._get_request_options(RequestOptionType.header, next_page_token) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_data, next_page_token) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json, next_page_token) def _get_request_options( - self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]] + self, option_type: RequestOptionType, next_page_token: Mapping[str, Any] | None ) -> MutableMapping[str, Any]: options = {} @@ -228,7 +227,7 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No self._decorated = decorated self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL - def get_initial_token(self) -> Optional[Any]: + def get_initial_token(self) -> Any | None: # noqa: ANN401 self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL return self._decorated.get_initial_token() @@ -236,9 +235,9 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any] = None, - ) -> Optional[Mapping[str, Any]]: + last_record: Record | None, + last_page_token_value: Any | None = None, # noqa: ANN401 + ) -> Mapping[str, Any] | None: if self._page_count >= self._maximum_number_of_pages: return None @@ -247,15 +246,15 @@ def next_page_token( response, last_page_size, last_record, last_page_token_value ) - def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: + def path(self, next_page_token: Mapping[str, Any] | None) -> str | None: return self._decorated.path(next_page_token) def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._decorated.get_request_params( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -264,9 +263,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, str]: return self._decorated.get_request_headers( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token @@ -275,10 +274,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._decorated.get_request_body_data( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token ) @@ -286,9 +285,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._decorated.get_request_body_json( stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py index 7de91f5e9..35973b3ec 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any import requests @@ -19,53 +20,53 @@ class NoPagination(Paginator): parameters: InitVar[Mapping[str, Any]] - def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: + def path(self, next_page_token: Mapping[str, Any] | None) -> str | None: # noqa: ARG002 return None def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> MutableMapping[str, Any]: return {} def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, str]: return {} def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Mapping[str, Any] | str: return {} def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} - def get_initial_token(self) -> Optional[Any]: + def get_initial_token(self) -> Any | None: # noqa: ANN401 return None def next_page_token( self, - response: requests.Response, - last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Mapping[str, Any]]: + response: requests.Response, # noqa: ARG002 + last_page_size: int, # noqa: ARG002 + last_record: Record | None, # noqa: ARG002 + last_page_token_value: Any | None, # noqa: ANN401, ARG002 + ) -> Mapping[str, Any] | None: return {} diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py index 8b1fea69b..80b42d41d 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py @@ -3,8 +3,9 @@ # from abc import ABC, abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Mapping, Optional +from typing import Any import requests @@ -24,7 +25,7 @@ class Paginator(ABC, RequestOptionsProvider): """ @abstractmethod - def get_initial_token(self) -> Optional[Any]: + def get_initial_token(self) -> Any | None: # noqa: ANN401 """ Get the page token that should be included in the request to get the first page of records """ @@ -34,9 +35,9 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Mapping[str, Any]]: + last_record: Record | None, + last_page_token_value: Any | None, # noqa: ANN401 + ) -> Mapping[str, Any] | None: """ Returns the next_page_token to use to fetch the next page of records. @@ -49,7 +50,7 @@ def next_page_token( pass @abstractmethod - def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]: + def path(self, next_page_token: Mapping[str, Any] | None) -> str | None: """ Returns the URL path to hit to fetch the next page of records diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py index c1f9ff105..ae2663571 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py @@ -16,6 +16,7 @@ StopConditionPaginationStrategyDecorator, ) + __all__ = [ "CursorPaginationStrategy", "CursorStopCondition", diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py index e35c84c7c..34fd6b9d1 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any import requests @@ -33,11 +34,11 @@ class CursorPaginationStrategy(PaginationStrategy): decoder (Decoder): decoder to decode the response """ - cursor_value: Union[InterpolatedString, str] + cursor_value: InterpolatedString | str config: Config parameters: InitVar[Mapping[str, Any]] - page_size: Optional[int] = None - stop_condition: Optional[Union[InterpolatedBoolean, str]] = None + page_size: int | None = None + stop_condition: InterpolatedBoolean | str | None = None decoder: Decoder = field( default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) ) @@ -48,14 +49,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else: self._cursor_value = self.cursor_value if isinstance(self.stop_condition, str): - self._stop_condition: Optional[InterpolatedBoolean] = InterpolatedBoolean( + self._stop_condition: InterpolatedBoolean | None = InterpolatedBoolean( condition=self.stop_condition, parameters=parameters ) else: self._stop_condition = self.stop_condition @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: # noqa: ANN401 """ CursorPaginationStrategy does not have an initial value because the next cursor is typically included in the response of the first request. For Resumable Full Refresh streams that checkpoint the page @@ -67,14 +68,14 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any] = None, - ) -> Optional[Any]: + last_record: Record | None, + last_page_token_value: Any | None = None, # noqa: ANN401, ARG002 + ) -> Any | None: # noqa: ANN401 decoded_response = next(self.decoder.decode(response)) # The default way that link is presented in requests.Response is a string of various links (last, next, etc). This # is not indexable or useful for parsing the cursor, so we replace it with the link dictionary from response.links - headers: Dict[str, Any] = dict(response.headers) + headers: dict[str, Any] = dict(response.headers) headers["link"] = response.links if self._stop_condition: should_stop = self._stop_condition.eval( @@ -93,7 +94,7 @@ def next_page_token( last_record=last_record, last_page_size=last_page_size, ) - return token if token else None + return token if token else None # noqa: FURB110 - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: return self.page_size diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py index 512d8143c..b61755257 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Union +from typing import Any import requests @@ -44,7 +45,7 @@ class OffsetIncrement(PaginationStrategy): """ config: Config - page_size: Optional[Union[str, int]] + page_size: str | int | None parameters: InitVar[Mapping[str, Any]] decoder: Decoder = field( default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={})) @@ -54,14 +55,14 @@ class OffsetIncrement(PaginationStrategy): def __post_init__(self, parameters: Mapping[str, Any]) -> None: page_size = str(self.page_size) if isinstance(self.page_size, int) else self.page_size if page_size: - self._page_size: Optional[InterpolatedString] = InterpolatedString( + self._page_size: InterpolatedString | None = InterpolatedString( page_size, parameters=parameters ) else: self._page_size = None @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: # noqa: ANN401 if self.inject_on_first_request: return 0 return None @@ -70,9 +71,9 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any] = None, - ) -> Optional[Any]: + last_record: Record | None, # noqa: ARG002 + last_page_token_value: Any | None = None, # noqa: ANN401 + ) -> Any | None: # noqa: ANN401 decoded_response = next(self.decoder.decode(response)) # Stop paginating when there are fewer records than the page size or the current page has no records @@ -81,22 +82,20 @@ def next_page_token( and last_page_size < self._page_size.eval(self.config, response=decoded_response) ) or last_page_size == 0: return None - elif last_page_token_value is None: + if last_page_token_value is None: # If the OffsetIncrement strategy does not inject on the first request, the incoming last_page_token_value # will be None. For this case, we assume that None was the first page and progress to the next offset return 0 + last_page_size - elif not isinstance(last_page_token_value, int): - raise ValueError( + if not isinstance(last_page_token_value, int): + raise ValueError( # noqa: TRY004 f"Last page token value {last_page_token_value} for OffsetIncrement pagination strategy was not an integer" ) - else: - return last_page_token_value + last_page_size + return last_page_token_value + last_page_size - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: if self._page_size: page_size = self._page_size.eval(self.config) if not isinstance(page_size, int): - raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") + raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") # noqa: TRY002 return page_size - else: - return None + return None diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py index 2e1643b56..8e2762f42 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any import requests @@ -25,7 +26,7 @@ class PageIncrement(PaginationStrategy): """ config: Config - page_size: Optional[Union[str, int]] + page_size: str | int | None parameters: InitVar[Mapping[str, Any]] start_from_page: int = 0 inject_on_first_request: bool = False @@ -36,36 +37,35 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else: page_size = InterpolatedString(self.page_size, parameters=parameters).eval(self.config) if not isinstance(page_size, int): - raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") + raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") # noqa: TRY002 self._page_size = page_size @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: # noqa: ANN401 if self.inject_on_first_request: return self.start_from_page return None def next_page_token( self, - response: requests.Response, + response: requests.Response, # noqa: ARG002 last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Any]: + last_record: Record | None, # noqa: ARG002 + last_page_token_value: Any | None, # noqa: ANN401 + ) -> Any | None: # noqa: ANN401 # Stop paginating when there are fewer records than the page size or the current page has no records if (self._page_size and last_page_size < self._page_size) or last_page_size == 0: return None - elif last_page_token_value is None: + if last_page_token_value is None: # If the PageIncrement strategy does not inject on the first request, the incoming last_page_token_value # may be None. When this is the case, we assume we've already requested the first page specified by # start_from_page and must now get the next page return self.start_from_page + 1 - elif not isinstance(last_page_token_value, int): - raise ValueError( + if not isinstance(last_page_token_value, int): + raise ValueError( # noqa: TRY004 f"Last page token value {last_page_token_value} for PageIncrement pagination strategy was not an integer" ) - else: - return last_page_token_value + 1 + return last_page_token_value + 1 - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: return self._page_size diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py index dae02ba13..04945ea28 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py @@ -4,7 +4,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import requests @@ -19,7 +19,7 @@ class PaginationStrategy: @property @abstractmethod - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: # noqa: ANN401 """ Return the initial value of the token """ @@ -29,9 +29,9 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Any]: + last_record: Record | None, + last_page_token_value: Any | None, # noqa: ANN401 + ) -> Any | None: # noqa: ANN401 """ :param response: response to process :param last_page_size: the number of records read from the response @@ -42,7 +42,7 @@ def next_page_token( pass @abstractmethod - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: """ :return: page size: The number of records to fetch in a page. Returns None if unspecified """ diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py index 7c89ba552..65a241ab4 100644 --- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py +++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py @@ -3,7 +3,7 @@ # from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any import requests @@ -23,11 +23,11 @@ def is_met(self, record: Record) -> bool: :param record: a record used to evaluate the condition """ - raise NotImplementedError() + raise NotImplementedError class CursorStopCondition(PaginationStopCondition): - def __init__( + def __init__( # noqa: ANN204 self, cursor: DeclarativeCursor | ConcurrentCursor, # migrate to use both old and concurrent versions @@ -39,7 +39,7 @@ def is_met(self, record: Record) -> bool: class StopConditionPaginationStrategyDecorator(PaginationStrategy): - def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStopCondition): + def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStopCondition): # noqa: ANN204 self._delegate = _delegate self._stop_condition = stop_condition @@ -47,9 +47,9 @@ def next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any] = None, - ) -> Optional[Any]: + last_record: Record | None, + last_page_token_value: Any | None = None, # noqa: ANN401 + ) -> Any | None: # noqa: ANN401 # We evaluate in reverse order because the assumption is that most of the APIs using data feed structure # will return records in descending order. In terms of performance/memory, we return the records lazily if last_record and self._stop_condition.is_met(last_record): @@ -58,9 +58,9 @@ def next_page_token( response, last_page_size, last_record, last_page_token_value ) - def get_page_size(self) -> Optional[int]: + def get_page_size(self) -> int | None: return self._delegate.get_page_size() @property - def initial_token(self) -> Optional[Any]: + def initial_token(self) -> Any | None: # noqa: ANN401 return self._delegate.initial_token diff --git a/airbyte_cdk/sources/declarative/requesters/request_option.py b/airbyte_cdk/sources/declarative/requesters/request_option.py index d13d20566..9f8699a9e 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_option.py +++ b/airbyte_cdk/sources/declarative/requesters/request_option.py @@ -2,9 +2,10 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass from enum import Enum -from typing import Any, Mapping, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -30,7 +31,7 @@ class RequestOption: inject_into (RequestOptionType): Describes where in the HTTP request to inject the parameter """ - field_name: Union[InterpolatedString, str] + field_name: InterpolatedString | str inject_into: RequestOptionType parameters: InitVar[Mapping[str, Any]] diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py b/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py index a63705832..b77f51ee5 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py @@ -15,6 +15,7 @@ RequestOptionsProvider, ) + __all__ = [ "DatetimeBasedRequestOptionsProvider", "DefaultRequestOptionsProvider", diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py index 05e06db71..846b5af16 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py @@ -2,8 +2,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.request_option import ( @@ -25,10 +26,10 @@ class DatetimeBasedRequestOptionsProvider(RequestOptionsProvider): config: Config parameters: InitVar[Mapping[str, Any]] - start_time_option: Optional[RequestOption] = None - end_time_option: Optional[RequestOption] = None - partition_field_start: Optional[str] = None - partition_field_end: Optional[str] = None + start_time_option: RequestOption | None = None + end_time_option: RequestOption | None = None + partition_field_start: str | None = None + partition_field_end: str | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._partition_field_start = InterpolatedString.create( @@ -41,41 +42,41 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.request_parameter, stream_slice) def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.header, stream_slice) def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Mapping[str, Any] | str: return self._get_request_options(RequestOptionType.body_data, stream_slice) def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return self._get_request_options(RequestOptionType.body_json, stream_slice) def _get_request_options( - self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice] + self, option_type: RequestOptionType, stream_slice: StreamSlice | None ) -> Mapping[str, Any]: options: MutableMapping[str, Any] = {} if not stream_slice: diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py index 449da977f..f2d04ac0b 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py @@ -2,8 +2,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, @@ -26,35 +27,35 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Mapping[str, Any] | str: return {} def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: return {} diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py index 6403417c9..16cfa0fc8 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import ( InterpolatedNestedMapping, @@ -20,14 +21,12 @@ class InterpolatedNestedRequestInputProvider: """ parameters: InitVar[Mapping[str, Any]] - request_inputs: Optional[Union[str, NestedMapping]] = field(default=None) + request_inputs: str | NestedMapping | None = field(default=None) config: Config = field(default_factory=dict) - _interpolator: Optional[Union[InterpolatedString, InterpolatedNestedMapping]] = field( - init=False, repr=False, default=None - ) - _request_inputs: Optional[Union[str, NestedMapping]] = field( + _interpolator: InterpolatedString | InterpolatedNestedMapping | None = field( init=False, repr=False, default=None ) + _request_inputs: str | NestedMapping | None = field(init=False, repr=False, default=None) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._request_inputs = self.request_inputs or {} @@ -42,9 +41,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def eval_request_inputs( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Returns the request inputs to set on an outgoing HTTP request diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py index 0278df351..aabdbbaba 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Optional, Tuple, Type, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -17,14 +18,12 @@ class InterpolatedRequestInputProvider: """ parameters: InitVar[Mapping[str, Any]] - request_inputs: Optional[Union[str, Mapping[str, str]]] = field(default=None) + request_inputs: str | Mapping[str, str] | None = field(default=None) config: Config = field(default_factory=dict) - _interpolator: Optional[Union[InterpolatedString, InterpolatedMapping]] = field( - init=False, repr=False, default=None - ) - _request_inputs: Optional[Union[str, Mapping[str, str]]] = field( + _interpolator: InterpolatedString | InterpolatedMapping | None = field( init=False, repr=False, default=None ) + _request_inputs: str | Mapping[str, str] | None = field(init=False, repr=False, default=None) def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._request_inputs = self.request_inputs or {} @@ -37,11 +36,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def eval_request_inputs( self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - valid_key_types: Optional[Tuple[Type[Any]]] = None, - valid_value_types: Optional[Tuple[Type[Any], ...]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + valid_key_types: tuple[type[Any]] | None = None, + valid_value_types: tuple[type[Any], ...] | None = None, ) -> Mapping[str, Any]: """ Returns the request inputs to set on an outgoing HTTP request diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py index c327b83da..a284a542e 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping, MutableMapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any, Union from typing_extensions import deprecated @@ -20,7 +21,8 @@ from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.types import Config, StreamSlice, StreamState -RequestInput = Union[str, Mapping[str, str]] + +RequestInput = Union[str, Mapping[str, str]] # noqa: UP007 ValidRequestTypes = (str, list) @@ -39,10 +41,10 @@ class InterpolatedRequestOptionsProvider(RequestOptionsProvider): parameters: InitVar[Mapping[str, Any]] config: Config = field(default_factory=dict) - request_parameters: Optional[RequestInput] = None - request_headers: Optional[RequestInput] = None - request_body_data: Optional[RequestInput] = None - request_body_json: Optional[NestedMapping] = None + request_parameters: RequestInput | None = None + request_headers: RequestInput | None = None + request_body_data: RequestInput | None = None + request_body_json: NestedMapping | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: if self.request_parameters is None: @@ -75,9 +77,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: interpolated_value = self._parameter_interpolator.eval_request_inputs( stream_state, @@ -93,9 +95,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._headers_interpolator.eval_request_inputs( stream_state, stream_slice, next_page_token @@ -104,10 +106,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: return self._body_data_interpolator.eval_request_inputs( stream_state, stream_slice, @@ -119,9 +121,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: return self._body_json_interpolator.eval_request_inputs( stream_state, stream_slice, next_page_token @@ -147,18 +149,17 @@ def request_options_contain_stream_state(self) -> bool: @staticmethod def _check_if_interpolation_uses_stream_state( - request_input: Optional[Union[RequestInput, NestedMapping]], + request_input: RequestInput | NestedMapping | None, ) -> bool: if not request_input: return False - elif isinstance(request_input, str): + if isinstance(request_input, str): return "stream_state" in request_input - else: - for key, val in request_input.items(): - # Covers the case of RequestInput in the form of a string or Mapping[str, str]. It also covers the case - # of a NestedMapping where the value is a string. - # Note: Doesn't account for nested mappings for request_body_json, but I don't see stream_state used in that way - # in our code - if "stream_state" in key or (isinstance(val, str) and "stream_state" in val): - return True + for key, val in request_input.items(): + # Covers the case of RequestInput in the form of a string or Mapping[str, str]. It also covers the case + # of a NestedMapping where the value is a string. + # Note: Doesn't account for nested mappings for request_body_json, but I don't see stream_state used in that way + # in our code + if "stream_state" in key or (isinstance(val, str) and "stream_state" in val): + return True return False diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py index f0a94ecb9..f5147d481 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py @@ -3,8 +3,9 @@ # from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Mapping, Optional, Union +from typing import Any from airbyte_cdk.sources.types import StreamSlice, StreamState @@ -25,9 +26,9 @@ class RequestOptionsProvider: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. @@ -40,9 +41,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.""" @@ -50,10 +51,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: """ Specifies how to populate the body of the request with a non-JSON payload. @@ -68,9 +69,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies how to populate the body of the request with a JSON payload. diff --git a/airbyte_cdk/sources/declarative/requesters/request_path.py b/airbyte_cdk/sources/declarative/requesters/request_path.py index 378ea6220..d6b289f9d 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_path.py +++ b/airbyte_cdk/sources/declarative/requesters/request_path.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping +from typing import Any @dataclass diff --git a/airbyte_cdk/sources/declarative/requesters/requester.py b/airbyte_cdk/sources/declarative/requesters/requester.py index 604b2faba..030cb5155 100644 --- a/airbyte_cdk/sources/declarative/requesters/requester.py +++ b/airbyte_cdk/sources/declarative/requesters/requester.py @@ -3,8 +3,9 @@ # from abc import abstractmethod +from collections.abc import Callable, Mapping, MutableMapping from enum import Enum -from typing import Any, Callable, Mapping, MutableMapping, Optional, Union +from typing import Any import requests @@ -44,9 +45,9 @@ def get_url_base(self) -> str: def get_path( self, *, - stream_state: Optional[StreamState], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], + stream_state: StreamState | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, ) -> str: """ Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" @@ -62,9 +63,9 @@ def get_method(self) -> HttpMethod: def get_request_params( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> MutableMapping[str, Any]: """ Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. @@ -76,9 +77,9 @@ def get_request_params( def get_request_headers( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method. @@ -88,10 +89,10 @@ def get_request_headers( def get_request_body_data( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: """ Specifies how to populate the body of the request with a non-JSON payload. @@ -106,9 +107,9 @@ def get_request_body_data( def get_request_body_json( self, *, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies how to populate the body of the request with a JSON payload. @@ -117,18 +118,18 @@ def get_request_body_json( """ @abstractmethod - def send_request( + def send_request( # noqa: PLR0913, PLR0917 self, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - path: Optional[str] = None, - request_headers: Optional[Mapping[str, Any]] = None, - request_params: Optional[Mapping[str, Any]] = None, - request_body_data: Optional[Union[Mapping[str, Any], str]] = None, - request_body_json: Optional[Mapping[str, Any]] = None, - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - ) -> Optional[requests.Response]: + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + path: str | None = None, + request_headers: Mapping[str, Any] | None = None, + request_params: Mapping[str, Any] | None = None, + request_body_data: Mapping[str, Any] | str | None = None, + request_body_json: Mapping[str, Any] | None = None, + log_formatter: Callable[[requests.Response], Any] | None = None, + ) -> requests.Response | None: """ Sends a request and returns the response. Might return no response if the error handler chooses to ignore the response or throw an exception in case of an error. If path is set, the path configured on the requester itself is ignored. diff --git a/airbyte_cdk/sources/declarative/resolvers/__init__.py b/airbyte_cdk/sources/declarative/resolvers/__init__.py index dba2f60b8..ba2de5695 100644 --- a/airbyte_cdk/sources/declarative/resolvers/__init__.py +++ b/airbyte_cdk/sources/declarative/resolvers/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # -from typing import Mapping +from collections.abc import Mapping from pydantic.v1 import BaseModel @@ -25,6 +25,7 @@ HttpComponentsResolver, ) + COMPONENTS_RESOLVER_TYPE_MAPPING: Mapping[str, type[BaseModel]] = { "HttpComponentsResolver": HttpComponentsResolverModel, "ConfigComponentsResolver": ConfigComponentsResolverModel, diff --git a/airbyte_cdk/sources/declarative/resolvers/components_resolver.py b/airbyte_cdk/sources/declarative/resolvers/components_resolver.py index 5975b3082..6507b99d3 100644 --- a/airbyte_cdk/sources/declarative/resolvers/components_resolver.py +++ b/airbyte_cdk/sources/declarative/resolvers/components_resolver.py @@ -3,12 +3,13 @@ # from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, Union +from typing import Any, Union from typing_extensions import deprecated -from airbyte_cdk.sources.declarative.interpolation import InterpolatedString +from airbyte_cdk.sources.declarative.interpolation import InterpolatedString # noqa: TC001 from airbyte_cdk.sources.source import ExperimentalClassWarning @@ -18,9 +19,9 @@ class ComponentMappingDefinition: what field in the stream template should be updated with value, supporting dynamic interpolation and type enforcement.""" - field_path: List["InterpolatedString"] + field_path: list["InterpolatedString"] value: Union["InterpolatedString", str] - value_type: Optional[Type[Any]] + value_type: type[Any] | None parameters: InitVar[Mapping[str, Any]] @@ -30,9 +31,9 @@ class ResolvedComponentMappingDefinition: what field in the stream template should be updated with value, supporting dynamic interpolation and type enforcement.""" - field_path: List["InterpolatedString"] + field_path: list["InterpolatedString"] value: "InterpolatedString" - value_type: Optional[Type[Any]] + value_type: type[Any] | None parameters: InitVar[Mapping[str, Any]] @@ -45,8 +46,8 @@ class ComponentsResolver(ABC): @abstractmethod def resolve_components( - self, stream_template_config: Dict[str, Any] - ) -> Iterable[Dict[str, Any]]: + self, stream_template_config: dict[str, Any] + ) -> Iterable[dict[str, Any]]: """ Maps and populates values into a stream template configuration. :param stream_template_config: The stream template with placeholders for components. diff --git a/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py b/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py index 0308ea5da..c0b276ab3 100644 --- a/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py +++ b/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py @@ -2,9 +2,10 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping from copy import deepcopy from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, Iterable, List, Mapping, Union +from typing import Any import dpath from typing_extensions import deprecated @@ -26,7 +27,7 @@ class StreamConfig: Identifies stream config details for dynamic schema extraction and processing. """ - configs_pointer: List[Union[InterpolatedString, str]] + configs_pointer: list[InterpolatedString | str] parameters: InitVar[Mapping[str, Any]] def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -50,9 +51,9 @@ class ConfigComponentsResolver(ComponentsResolver): stream_config: StreamConfig config: Config - components_mapping: List[ComponentMappingDefinition] + components_mapping: list[ComponentMappingDefinition] parameters: InitVar[Mapping[str, Any]] - _resolved_components: List[ResolvedComponentMappingDefinition] = field( + _resolved_components: list[ResolvedComponentMappingDefinition] = field( init=False, repr=False, default_factory=list ) @@ -65,7 +66,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: """ for component_mapping in self.components_mapping: - if isinstance(component_mapping.value, (str, InterpolatedString)): + if isinstance(component_mapping.value, (str, InterpolatedString)): # noqa: UP038 interpolated_value = ( InterpolatedString.create(component_mapping.value, parameters=parameters) if isinstance(component_mapping.value, str) @@ -86,7 +87,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) ) else: - raise ValueError( + raise ValueError( # noqa: TRY004 f"Expected a string or InterpolatedString for value in mapping: {component_mapping}" ) @@ -104,8 +105,8 @@ def _stream_config(self) -> Iterable[Mapping[str, Any]]: return stream_config def resolve_components( - self, stream_template_config: Dict[str, Any] - ) -> Iterable[Dict[str, Any]]: + self, stream_template_config: dict[str, Any] + ) -> Iterable[dict[str, Any]]: """ Resolves components in the stream template configuration by populating values. diff --git a/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py b/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py index 6e85fc578..8213c9af0 100644 --- a/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py +++ b/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py @@ -2,9 +2,10 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Iterable, Mapping from copy import deepcopy from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, Iterable, List, Mapping +from typing import Any import dpath from typing_extensions import deprecated @@ -35,9 +36,9 @@ class HttpComponentsResolver(ComponentsResolver): retriever: Retriever config: Config - components_mapping: List[ComponentMappingDefinition] + components_mapping: list[ComponentMappingDefinition] parameters: InitVar[Mapping[str, Any]] - _resolved_components: List[ResolvedComponentMappingDefinition] = field( + _resolved_components: list[ResolvedComponentMappingDefinition] = field( init=False, repr=False, default_factory=list ) @@ -49,7 +50,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: parameters (Mapping[str, Any]): Parameters for interpolation. """ for component_mapping in self.components_mapping: - if isinstance(component_mapping.value, (str, InterpolatedString)): + if isinstance(component_mapping.value, (str, InterpolatedString)): # noqa: UP038 interpolated_value = ( InterpolatedString.create(component_mapping.value, parameters=parameters) if isinstance(component_mapping.value, str) @@ -70,13 +71,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: ) ) else: - raise ValueError( + raise ValueError( # noqa: TRY004 f"Expected a string or InterpolatedString for value in mapping: {component_mapping}" ) def resolve_components( - self, stream_template_config: Dict[str, Any] - ) -> Iterable[Dict[str, Any]]: + self, stream_template_config: dict[str, Any] + ) -> Iterable[dict[str, Any]]: """ Resolves components in the stream template configuration by populating values. diff --git a/airbyte_cdk/sources/declarative/retrievers/__init__.py b/airbyte_cdk/sources/declarative/retrievers/__init__.py index 177d141a3..ce6ab5842 100644 --- a/airbyte_cdk/sources/declarative/retrievers/__init__.py +++ b/airbyte_cdk/sources/declarative/retrievers/__init__.py @@ -9,4 +9,5 @@ SimpleRetrieverTestReadDecorator, ) + __all__ = ["Retriever", "SimpleRetriever", "SimpleRetrieverTestReadDecorator", "AsyncRetriever"] diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py index 1b8860289..f362ce606 100644 --- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py @@ -1,8 +1,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from collections.abc import Iterable, Mapping from dataclasses import InitVar, dataclass -from typing import Any, Iterable, Mapping, Optional +from typing import Any from typing_extensions import deprecated @@ -58,7 +59,7 @@ def _get_stream_state(self) -> StreamState: return self.state def _validate_and_get_stream_slice_partition( - self, stream_slice: Optional[StreamSlice] = None + self, stream_slice: StreamSlice | None = None ) -> AsyncPartition: """ Validates the stream_slice argument and returns the partition from it. @@ -80,13 +81,13 @@ def _validate_and_get_stream_slice_partition( ) return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: + def stream_slices(self) -> Iterable[StreamSlice | None]: return self.stream_slicer.stream_slices() def read_records( self, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[StreamData]: stream_state: StreamState = self._get_stream_state() partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice) diff --git a/airbyte_cdk/sources/declarative/retrievers/retriever.py b/airbyte_cdk/sources/declarative/retrievers/retriever.py index 155de5782..f4ba620b0 100644 --- a/airbyte_cdk/sources/declarative/retrievers/retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/retriever.py @@ -3,7 +3,8 @@ # from abc import abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice from airbyte_cdk.sources.streams.core import StreamData @@ -19,7 +20,7 @@ class Retriever: def read_records( self, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[StreamData]: """ Fetch a stream's records from an HTTP API source @@ -30,7 +31,7 @@ def read_records( """ @abstractmethod - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: + def stream_slices(self) -> Iterable[StreamSlice | None]: """Returns the stream slices""" @property diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index d167a84bc..1413faf23 100644 --- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -3,10 +3,11 @@ # import json +from collections.abc import Callable, Iterable, Mapping from dataclasses import InitVar, dataclass, field from functools import partial from itertools import islice -from typing import Any, Callable, Iterable, List, Mapping, Optional, Set, Tuple, Union +from typing import Any import requests @@ -32,6 +33,7 @@ from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils.mapping_helpers import combine_mappings + FULL_REFRESH_SYNC_COMPLETE_KEY = "__ab_full_refresh_sync_complete" @@ -64,17 +66,17 @@ class SimpleRetriever(Retriever): config: Config parameters: InitVar[Mapping[str, Any]] name: str - _name: Union[InterpolatedString, str] = field(init=False, repr=False, default="") - primary_key: Optional[Union[str, List[str], List[List[str]]]] + _name: InterpolatedString | str = field(init=False, repr=False, default="") + primary_key: str | list[str] | list[list[str]] | None _primary_key: str = field(init=False, repr=False, default="") - paginator: Optional[Paginator] = None + paginator: Paginator | None = None stream_slicer: StreamSlicer = field( default_factory=lambda: SinglePartitionRouter(parameters={}) ) request_option_provider: RequestOptionsProvider = field( default_factory=lambda: DefaultRequestOptionsProvider(parameters={}) ) - cursor: Optional[DeclarativeCursor] = None + cursor: DeclarativeCursor | None = None ignore_stream_slicer_parameters_on_paginated_requests: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: @@ -103,8 +105,10 @@ def name(self, value: str) -> None: self._name = value def _get_mapping( - self, method: Callable[..., Optional[Union[Mapping[str, Any], str]]], **kwargs: Any - ) -> Tuple[Union[Mapping[str, Any], str], Set[str]]: + self, + method: Callable[..., Mapping[str, Any] | str | None], + **kwargs: Any, # noqa: ANN401 + ) -> tuple[Mapping[str, Any] | str, set[str]]: """ Get mapping from the provided method, and get the keys of the mapping. If the method returns a string, it will return the string and an empty set. @@ -116,18 +120,18 @@ def _get_mapping( def _get_request_options( self, - stream_state: Optional[StreamData], - stream_slice: Optional[StreamSlice], - next_page_token: Optional[Mapping[str, Any]], - paginator_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - stream_slicer_method: Callable[..., Optional[Union[Mapping[str, Any], str]]], - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamData | None, + stream_slice: StreamSlice | None, + next_page_token: Mapping[str, Any] | None, + paginator_method: Callable[..., Mapping[str, Any] | str | None], + stream_slicer_method: Callable[..., Mapping[str, Any] | str | None], + ) -> Mapping[str, Any] | str: """ Get the request_option from the paginator and the stream slicer. Raise a ValueError if there's a key collision Returned merged mapping otherwise """ - # FIXME we should eventually remove the usage of stream_state as part of the interpolation + # FIXME we should eventually remove the usage of stream_state as part of the interpolation # noqa: FIX001, TD001, TD004 mappings = [ paginator_method( stream_state=stream_state, @@ -147,9 +151,9 @@ def _get_request_options( def _request_headers( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies request headers. @@ -163,14 +167,14 @@ def _request_headers( self.stream_slicer.get_request_headers, ) if isinstance(headers, str): - raise ValueError("Request headers cannot be a string") + raise ValueError("Request headers cannot be a string") # noqa: TRY004 return {str(k): str(v) for k, v in headers.items()} def _request_params( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Mapping[str, Any]: """ Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. @@ -185,15 +189,15 @@ def _request_params( self.request_option_provider.get_request_params, ) if isinstance(params, str): - raise ValueError("Request params cannot be a string") + raise ValueError("Request params cannot be a string") # noqa: TRY004 return params def _request_body_data( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Union[Mapping[str, Any], str]: + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | str: """ Specifies how to populate the body of the request with a non-JSON payload. @@ -213,10 +217,10 @@ def _request_body_data( def _request_body_json( self, - stream_state: Optional[StreamData] = None, - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: + stream_state: StreamData | None = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> Mapping[str, Any] | None: """ Specifies how to populate the body of the request with a JSON payload. @@ -230,10 +234,10 @@ def _request_body_json( self.request_option_provider.get_request_body_json, ) if isinstance(body_json, str): - raise ValueError("Request body json cannot be a string") + raise ValueError("Request body json cannot be a string") # noqa: TRY004 return body_json - def _paginator_path(self, next_page_token: Optional[Mapping[str, Any]] = None) -> Optional[str]: + def _paginator_path(self, next_page_token: Mapping[str, Any] | None = None) -> str | None: """ If the paginator points to a path, follow it, else return nothing so the requester is used. :param next_page_token: @@ -243,11 +247,11 @@ def _paginator_path(self, next_page_token: Optional[Mapping[str, Any]] = None) - def _parse_response( self, - response: Optional[requests.Response], + response: requests.Response | None, stream_state: StreamState, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: StreamSlice | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Record]: if not response: yield from [] @@ -261,7 +265,7 @@ def _parse_response( ) @property # type: ignore - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: """The stream's primary key""" return self._primary_key @@ -274,9 +278,9 @@ def _next_page_token( self, response: requests.Response, last_page_size: int, - last_record: Optional[Record], - last_page_token_value: Optional[Any], - ) -> Optional[Mapping[str, Any]]: + last_record: Record | None, + last_page_token_value: Any | None, # noqa: ANN401 + ) -> Mapping[str, Any] | None: """ Specifies a pagination strategy. @@ -295,8 +299,8 @@ def _fetch_next_page( self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[requests.Response]: + next_page_token: Mapping[str, Any] | None = None, + ) -> requests.Response | None: return self.requester.send_request( path=self._paginator_path(next_page_token=next_page_token), stream_state=stream_state, @@ -327,20 +331,20 @@ def _fetch_next_page( # This logic is similar to _read_pages in the HttpStream class. When making changes here, consider making changes there as well. def _read_pages( self, - records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]], + records_generator_fn: Callable[[requests.Response | None], Iterable[Record]], stream_state: Mapping[str, Any], stream_slice: StreamSlice, ) -> Iterable[Record]: pagination_complete = False initial_token = self._paginator.get_initial_token() - next_page_token: Optional[Mapping[str, Any]] = ( + next_page_token: Mapping[str, Any] | None = ( {"next_page_token": initial_token} if initial_token else None ) while not pagination_complete: response = self._fetch_next_page(stream_state, stream_slice, next_page_token) last_page_size = 0 - last_record: Optional[Record] = None + last_record: Record | None = None for record in records_generator_fn(response): last_page_size += 1 last_record = record @@ -366,21 +370,21 @@ def _read_pages( def _read_single_page( self, - records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]], + records_generator_fn: Callable[[requests.Response | None], Iterable[Record]], stream_state: Mapping[str, Any], stream_slice: StreamSlice, ) -> Iterable[StreamData]: initial_token = stream_state.get("next_page_token") if initial_token is None: initial_token = self._paginator.get_initial_token() - next_page_token: Optional[Mapping[str, Any]] = ( + next_page_token: Mapping[str, Any] | None = ( {"next_page_token": initial_token} if initial_token else None ) response = self._fetch_next_page(stream_state, stream_slice, next_page_token) last_page_size = 0 - last_record: Optional[Record] = None + last_record: Record | None = None for record in records_generator_fn(response): last_page_size += 1 last_record = record @@ -410,7 +414,7 @@ def _read_single_page( def read_records( self, records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice] = None, + stream_slice: StreamSlice | None = None, ) -> Iterable[StreamData]: """ Fetch a stream's records from an HTTP API source @@ -419,13 +423,16 @@ def read_records( :param stream_slice: The stream slice to read data for :return: The records read from the API source """ - _slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check + slice_ = stream_slice or StreamSlice( + partition={}, + cursor_slice={}, + ) most_recent_record_from_slice = None record_generator = partial( self._parse_records, stream_state=self.state or {}, - stream_slice=_slice, + stream_slice=slice_, records_schema=records_schema, ) @@ -438,46 +445,42 @@ def read_records( if stream_state.get(FULL_REFRESH_SYNC_COMPLETE_KEY): return - yield from self._read_single_page(record_generator, stream_state, _slice) + yield from self._read_single_page(record_generator, stream_state, slice_) else: - for stream_data in self._read_pages(record_generator, self.state, _slice): - current_record = self._extract_record(stream_data, _slice) + for stream_data in self._read_pages(record_generator, self.state, slice_): + current_record = self._extract_record(stream_data, slice_) if self.cursor and current_record: - self.cursor.observe(_slice, current_record) + self.cursor.observe(slice_, current_record) # Latest record read, not necessarily within slice boundaries. - # TODO Remove once all custom components implement `observe` method. + # TODO Remove once all custom components implement `observe` method. # noqa: TD004 # https://github.com/airbytehq/airbyte-internal-issues/issues/6955 most_recent_record_from_slice = self._get_most_recent_record( - most_recent_record_from_slice, current_record, _slice + most_recent_record_from_slice, current_record, slice_ ) yield stream_data if self.cursor: - self.cursor.close_slice(_slice, most_recent_record_from_slice) + self.cursor.close_slice(slice_, most_recent_record_from_slice) return def _get_most_recent_record( self, - current_most_recent: Optional[Record], - current_record: Optional[Record], - stream_slice: StreamSlice, - ) -> Optional[Record]: + current_most_recent: Record | None, + current_record: Record | None, + stream_slice: StreamSlice, # noqa: ARG002 + ) -> Record | None: if self.cursor and current_record: if not current_most_recent: return current_record - else: - return ( - current_most_recent - if self.cursor.is_greater_than_or_equal(current_most_recent, current_record) - else current_record - ) - else: - return None + return ( + current_most_recent + if self.cursor.is_greater_than_or_equal(current_most_recent, current_record) + else current_record + ) + return None - def _extract_record( - self, stream_data: StreamData, stream_slice: StreamSlice - ) -> Optional[Record]: + def _extract_record(self, stream_data: StreamData, stream_slice: StreamSlice) -> Record | None: """ As we allow the output of _read_pages to be StreamData, it can be multiple things. Therefore, we need to filter out and normalize to data to streamline the rest of the process. @@ -485,11 +488,11 @@ def _extract_record( if isinstance(stream_data, Record): # Record is not part of `StreamData` but is the most common implementation of `Mapping[str, Any]` which is part of `StreamData` return stream_data - elif isinstance(stream_data, (dict, Mapping)): + if isinstance(stream_data, (dict, Mapping)): # noqa: UP038 return Record( data=dict(stream_data), associated_slice=stream_slice, stream_name=self.name ) - elif isinstance(stream_data, AirbyteMessage) and stream_data.record: + if isinstance(stream_data, AirbyteMessage) and stream_data.record: return Record( data=stream_data.record.data, # type:ignore # AirbyteMessage always has record.data associated_slice=stream_slice, @@ -498,7 +501,7 @@ def _extract_record( return None # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore + def stream_slices(self) -> Iterable[StreamSlice | None]: # type: ignore """ Specifies the slices for this stream. See the stream slicing section of the docs for more information. @@ -521,10 +524,10 @@ def state(self, value: StreamState) -> None: def _parse_records( self, - response: Optional[requests.Response], + response: requests.Response | None, stream_state: Mapping[str, Any], records_schema: Mapping[str, Any], - stream_slice: Optional[StreamSlice], + stream_slice: StreamSlice | None, ) -> Iterable[Record]: yield from self._parse_response( response, @@ -537,7 +540,7 @@ def must_deduplicate_query_params(self) -> bool: return True @staticmethod - def _to_partition_key(to_serialize: Any) -> str: + def _to_partition_key(to_serialize: Any) -> str: # noqa: ANN401 # separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True) @@ -559,15 +562,15 @@ def __post_init__(self, options: Mapping[str, Any]) -> None: ) # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore + def stream_slices(self) -> Iterable[StreamSlice | None]: # type: ignore return islice(super().stream_slices(), self.maximum_number_of_slices) def _fetch_next_page( self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[requests.Response]: + next_page_token: Mapping[str, Any] | None = None, + ) -> requests.Response | None: return self.requester.send_request( path=self._paginator_path(next_page_token=next_page_token), stream_state=stream_state, diff --git a/airbyte_cdk/sources/declarative/schema/__init__.py b/airbyte_cdk/sources/declarative/schema/__init__.py index b5b6a7d31..e41407845 100644 --- a/airbyte_cdk/sources/declarative/schema/__init__.py +++ b/airbyte_cdk/sources/declarative/schema/__init__.py @@ -12,6 +12,7 @@ from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import JsonFileSchemaLoader from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader + __all__ = [ "JsonFileSchemaLoader", "DefaultSchemaLoader", diff --git a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py index a9b625e7d..7f3ba5611 100644 --- a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py @@ -3,8 +3,9 @@ # import logging +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping +from typing import Any from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import JsonFileSchemaLoader from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader @@ -41,7 +42,7 @@ def get_json_schema(self) -> Mapping[str, Any]: # A slight hack since we don't directly have the stream name. However, when building the default filepath we assume the # runtime options stores stream name 'name' so we'll do the same here stream_name = self._parameters.get("name", "") - logging.info( + logging.info( # noqa: LOG015 f"Could not find schema for stream {stream_name}, defaulting to the empty schema" ) return {} diff --git a/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py b/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py index d65890b70..68cb2efaf 100644 --- a/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py @@ -3,9 +3,10 @@ # +from collections.abc import Mapping, MutableMapping from copy import deepcopy from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, MutableMapping, Optional, Union +from typing import Any import dpath from typing_extensions import deprecated @@ -18,6 +19,7 @@ from airbyte_cdk.sources.source import ExperimentalClassWarning from airbyte_cdk.sources.types import Config, StreamSlice, StreamState + AIRBYTE_DATA_TYPES: Mapping[str, Mapping[str, Any]] = { "string": {"type": ["null", "string"]}, "boolean": {"type": ["null", "boolean"]}, @@ -52,9 +54,9 @@ class TypesMap: Represents a mapping between a current type and its corresponding target type. """ - target_type: Union[List[str], str] - current_type: Union[List[str], str] - condition: Optional[str] + target_type: list[str] | str + current_type: list[str] | str + condition: str | None @deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning) @@ -64,11 +66,11 @@ class SchemaTypeIdentifier: Identifies schema details for dynamic schema extraction and processing. """ - key_pointer: List[Union[InterpolatedString, str]] + key_pointer: list[InterpolatedString | str] parameters: InitVar[Mapping[str, Any]] - type_pointer: Optional[List[Union[InterpolatedString, str]]] = None - types_mapping: Optional[List[TypesMap]] = None - schema_pointer: Optional[List[Union[InterpolatedString, str]]] = None + type_pointer: list[InterpolatedString | str] | None = None + types_mapping: list[TypesMap] | None = None + schema_pointer: list[InterpolatedString | str] | None = None def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.schema_pointer = ( @@ -81,8 +83,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: @staticmethod def _update_pointer( - pointer: Optional[List[Union[InterpolatedString, str]]], parameters: Mapping[str, Any] - ) -> Optional[List[Union[InterpolatedString, str]]]: + pointer: list[InterpolatedString | str] | None, parameters: Mapping[str, Any] + ) -> list[InterpolatedString | str] | None: return ( [ InterpolatedString.create(path, parameters=parameters) @@ -106,7 +108,7 @@ class DynamicSchemaLoader(SchemaLoader): config: Config parameters: InitVar[Mapping[str, Any]] schema_type_identifier: SchemaTypeIdentifier - schema_transformations: List[RecordTransformation] = field(default_factory=lambda: []) + schema_transformations: list[RecordTransformation] = field(default_factory=list) def get_json_schema(self) -> Mapping[str, Any]: """ @@ -143,8 +145,8 @@ def get_json_schema(self) -> Mapping[str, Any]: def _transform( self, properties: Mapping[str, Any], - stream_state: StreamState, - stream_slice: Optional[StreamSlice] = None, + stream_state: StreamState, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: for transformation in self.schema_transformations: transformation.transform( @@ -156,21 +158,21 @@ def _transform( def _get_key( self, raw_schema: MutableMapping[str, Any], - field_key_path: List[Union[InterpolatedString, str]], + field_key_path: list[InterpolatedString | str], ) -> str: """ Extracts the key field from the schema using the specified path. """ field_key = self._extract_data(raw_schema, field_key_path) if not isinstance(field_key, str): - raise ValueError(f"Expected key to be a string. Got {field_key}") + raise ValueError(f"Expected key to be a string. Got {field_key}") # noqa: TRY004 return field_key def _get_type( self, raw_schema: MutableMapping[str, Any], - field_type_path: Optional[List[Union[InterpolatedString, str]]], - ) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]: + field_type_path: list[InterpolatedString | str] | None, + ) -> Mapping[str, Any] | list[Mapping[str, Any]]: """ Determines the JSON Schema type for a field, supporting nullable and combined types. """ @@ -182,24 +184,23 @@ def _get_type( mapped_field_type = self._replace_type_if_not_valid(raw_field_type, raw_schema) if ( isinstance(mapped_field_type, list) - and len(mapped_field_type) == 2 + and len(mapped_field_type) == 2 # noqa: PLR2004 and all(isinstance(item, str) for item in mapped_field_type) ): first_type = self._get_airbyte_type(mapped_field_type[0]) second_type = self._get_airbyte_type(mapped_field_type[1]) return {"oneOf": [first_type, second_type]} - elif isinstance(mapped_field_type, str): + if isinstance(mapped_field_type, str): return self._get_airbyte_type(mapped_field_type) - else: - raise ValueError( - f"Invalid data type. Available string or two items list of string. Got {mapped_field_type}." - ) + raise ValueError( + f"Invalid data type. Available string or two items list of string. Got {mapped_field_type}." + ) def _replace_type_if_not_valid( self, - field_type: Union[List[str], str], + field_type: list[str] | str, raw_schema: MutableMapping[str, Any], - ) -> Union[List[str], str]: + ) -> list[str] | str: """ Replaces a field type if it matches a type mapping in `types_map`. """ @@ -228,9 +229,9 @@ def _get_airbyte_type(field_type: str) -> Mapping[str, Any]: def _extract_data( self, body: Mapping[str, Any], - extraction_path: Optional[List[Union[InterpolatedString, str]]] = None, - default: Any = None, - ) -> Any: + extraction_path: list[InterpolatedString | str] | None = None, + default: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 """ Extracts data from the body based on the provided extraction path. """ diff --git a/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py b/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py index 72a46b7e5..6675badde 100644 --- a/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, Mapping +from typing import Any from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader @@ -12,7 +13,7 @@ class InlineSchemaLoader(SchemaLoader): """Describes a stream's schema""" - schema: Dict[str, Any] + schema: dict[str, Any] parameters: InitVar[Mapping[str, Any]] def get_json_schema(self) -> Mapping[str, Any]: diff --git a/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py b/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py index af51fe5db..486223800 100644 --- a/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py @@ -5,8 +5,9 @@ import json import pkgutil import sys +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Tuple, Union +from typing import Any from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader @@ -43,7 +44,7 @@ class JsonFileSchemaLoader(ResourceSchemaLoader, SchemaLoader): config: Config parameters: InitVar[Mapping[str, Any]] - file_path: Union[InterpolatedString, str] = field(default="") + file_path: InterpolatedString | str = field(default="") def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not self.file_path: @@ -51,14 +52,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.file_path = InterpolatedString.create(self.file_path, parameters=parameters) def get_json_schema(self) -> Mapping[str, Any]: - # todo: It is worth revisiting if we can replace file_path with just file_name if every schema is in the /schemas directory + # TODO: It is worth revisiting if we can replace file_path with just file_name if every schema is in the /schemas directory # this would require that we find a creative solution to store or retrieve source_name in here since the files are mounted there json_schema_path = self._get_json_filepath() resource, schema_path = self.extract_resource_and_schema_path(json_schema_path) raw_json_file = pkgutil.get_data(resource, schema_path) if not raw_json_file: - raise IOError(f"Cannot find file {json_schema_path}") + raise OSError(f"Cannot find file {json_schema_path}") try: raw_schema = json.loads(raw_json_file) except ValueError as err: @@ -66,11 +67,11 @@ def get_json_schema(self) -> Mapping[str, Any]: self.package_name = resource return self._resolve_schema_references(raw_schema) - def _get_json_filepath(self) -> Any: + def _get_json_filepath(self) -> Any: # noqa: ANN401 return self.file_path.eval(self.config) # type: ignore # file_path is always cast to an interpolated string @staticmethod - def extract_resource_and_schema_path(json_schema_path: str) -> Tuple[str, str]: + def extract_resource_and_schema_path(json_schema_path: str) -> tuple[str, str]: """ When the connector is running on a docker container, package_data is accessible from the resource (source_), so we extract the resource from the first part of the schema path and the remaining path is used to find the schema file. This is a slight @@ -80,7 +81,7 @@ def extract_resource_and_schema_path(json_schema_path: str) -> Tuple[str, str]: """ split_path = json_schema_path.split("/") - if split_path[0] == "" or split_path[0] == ".": + if split_path[0] == "" or split_path[0] == ".": # noqa: PLC1901 split_path = split_path[1:] if len(split_path) == 0: diff --git a/airbyte_cdk/sources/declarative/schema/schema_loader.py b/airbyte_cdk/sources/declarative/schema/schema_loader.py index a6beb70ae..fb7f45cb6 100644 --- a/airbyte_cdk/sources/declarative/schema/schema_loader.py +++ b/airbyte_cdk/sources/declarative/schema/schema_loader.py @@ -3,8 +3,9 @@ # from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Mapping +from typing import Any @dataclass diff --git a/airbyte_cdk/sources/declarative/spec/__init__.py b/airbyte_cdk/sources/declarative/spec/__init__.py index 1c13ed67c..4302b8749 100644 --- a/airbyte_cdk/sources/declarative/spec/__init__.py +++ b/airbyte_cdk/sources/declarative/spec/__init__.py @@ -4,4 +4,5 @@ from airbyte_cdk.sources.declarative.spec.spec import Spec + __all__ = ["Spec"] diff --git a/airbyte_cdk/sources/declarative/spec/spec.py b/airbyte_cdk/sources/declarative/spec/spec.py index 914e99e93..7892e3c45 100644 --- a/airbyte_cdk/sources/declarative/spec/spec.py +++ b/airbyte_cdk/sources/declarative/spec/spec.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Mapping, Optional +from typing import Any from airbyte_cdk.models import ( AdvancedAuth, @@ -25,8 +26,8 @@ class Spec: connection_specification: Mapping[str, Any] parameters: InitVar[Mapping[str, Any]] - documentation_url: Optional[str] = None - advanced_auth: Optional[AuthFlow] = None + documentation_url: str | None = None + advanced_auth: AuthFlow | None = None def generate_spec(self) -> ConnectorSpecification: """ diff --git a/airbyte_cdk/sources/declarative/stream_slicers/__init__.py b/airbyte_cdk/sources/declarative/stream_slicers/__init__.py index 7bacc3ca8..d2df979f5 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/__init__.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/__init__.py @@ -4,4 +4,5 @@ from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer + __all__ = ["StreamSlicer"] diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 91ce28e7a..d6e5dd1b1 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.declarative.retrievers import Retriever from airbyte_cdk.sources.message import MessageRepository @@ -40,7 +41,7 @@ def create(self, stream_slice: StreamSlice) -> Partition: class DeclarativePartition(Partition): - def __init__( + def __init__( # noqa: ANN204 self, stream_name: str, json_schema: Mapping[str, Any], @@ -66,7 +67,7 @@ def read(self) -> Iterable[Record]: else: self._message_repository.emit_message(stream_data) - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: return self._stream_slice def stream_name(self) -> str: diff --git a/airbyte_cdk/sources/declarative/transformations/__init__.py b/airbyte_cdk/sources/declarative/transformations/__init__.py index e18712a01..9bd1e758e 100644 --- a/airbyte_cdk/sources/declarative/transformations/__init__.py +++ b/airbyte_cdk/sources/declarative/transformations/__init__.py @@ -10,8 +10,10 @@ # so we add the split directive below to tell isort to sort imports while keeping RecordTransformation as the first import from .transformation import RecordTransformation + # isort: split from .add_fields import AddFields from .remove_fields import RemoveFields + __all__ = ["AddFields", "RecordTransformation", "RemoveFields"] diff --git a/airbyte_cdk/sources/declarative/transformations/add_fields.py b/airbyte_cdk/sources/declarative/transformations/add_fields.py index 4c9d5366c..b4c13c7de 100644 --- a/airbyte_cdk/sources/declarative/transformations/add_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/add_fields.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, List, Mapping, Optional, Type, Union +from typing import Any import dpath @@ -17,8 +18,8 @@ class AddedFieldDefinition: """Defines the field to add on a record""" path: FieldPointer - value: Union[InterpolatedString, str] - value_type: Optional[Type[Any]] + value: InterpolatedString | str + value_type: type[Any] | None parameters: InitVar[Mapping[str, Any]] @@ -28,12 +29,12 @@ class ParsedAddFieldDefinition: path: FieldPointer value: InterpolatedString - value_type: Optional[Type[Any]] + value_type: type[Any] | None parameters: InitVar[Mapping[str, Any]] @dataclass -class AddFields(RecordTransformation): +class AddFields(RecordTransformation): # noqa: PLW1641 """ Transformation which adds field to an output record. The path of the added field can be nested. Adding nested fields will create all necessary parent objects (like mkdir -p). Adding fields to an array will extend the array to that index (filling intermediate @@ -84,9 +85,9 @@ class AddFields(RecordTransformation): fields (List[AddedFieldDefinition]): A list of transformations (path and corresponding value) that will be added to the record """ - fields: List[AddedFieldDefinition] + fields: list[AddedFieldDefinition] parameters: InitVar[Mapping[str, Any]] - _parsed_fields: List[ParsedAddFieldDefinition] = field( + _parsed_fields: list[ParsedAddFieldDefinition] = field( init=False, repr=False, default_factory=list ) @@ -100,15 +101,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not isinstance(add_field.value, InterpolatedString): if not isinstance(add_field.value, str): raise f"Expected a string value for the AddFields transformation: {add_field}" - else: - self._parsed_fields.append( - ParsedAddFieldDefinition( - add_field.path, - InterpolatedString.create(add_field.value, parameters=parameters), - value_type=add_field.value_type, - parameters=parameters, - ) + self._parsed_fields.append( + ParsedAddFieldDefinition( + add_field.path, + InterpolatedString.create(add_field.value, parameters=parameters), + value_type=add_field.value_type, + parameters=parameters, ) + ) else: self._parsed_fields.append( ParsedAddFieldDefinition( @@ -121,10 +121,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: if config is None: config = {} @@ -134,5 +134,5 @@ def transform( value = parsed_field.value.eval(config, valid_types=valid_types, **kwargs) dpath.new(record, parsed_field.path, value) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return bool(self.__dict__ == other.__dict__) diff --git a/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py b/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py index 73162d848..4fd5efc04 100644 --- a/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any import dpath @@ -19,7 +20,7 @@ class DpathFlattenFields(RecordTransformation): """ config: Config - field_path: List[Union[InterpolatedString, str]] + field_path: list[InterpolatedString | str] parameters: InitVar[Mapping[str, Any]] delete_origin_value: bool = False @@ -35,10 +36,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, # noqa: ARG002 + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 ) -> None: path = [path.eval(self.config) for path in self._field_path] if "*" in path: diff --git a/airbyte_cdk/sources/declarative/transformations/flatten_fields.py b/airbyte_cdk/sources/declarative/transformations/flatten_fields.py index 24bfba660..4976918c1 100644 --- a/airbyte_cdk/sources/declarative/transformations/flatten_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/flatten_fields.py @@ -3,7 +3,7 @@ # from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.sources.declarative.transformations import RecordTransformation from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -15,18 +15,18 @@ class FlattenFields(RecordTransformation): def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, # noqa: ARG002 + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 ) -> None: transformed_record = self.flatten_record(record) record.clear() record.update(transformed_record) - def flatten_record(self, record: Dict[str, Any]) -> Dict[str, Any]: + def flatten_record(self, record: dict[str, Any]) -> dict[str, Any]: stack = [(record, "_")] - transformed_record: Dict[str, Any] = {} + transformed_record: dict[str, Any] = {} force_with_parent_name = False while stack: diff --git a/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py index 8fe0bbffb..00deb5130 100644 --- a/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py +++ b/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py @@ -2,8 +2,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, Mapping, Optional +from typing import Any from airbyte_cdk import InterpolatedString from airbyte_cdk.sources.declarative.transformations import RecordTransformation @@ -34,10 +35,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: if config is None: config = {} @@ -46,7 +47,7 @@ def transform( old_key = str(self._old.eval(config, **kwargs)) new_key = str(self._new.eval(config, **kwargs)) - def _transform(data: Dict[str, Any]) -> Dict[str, Any]: + def _transform(data: dict[str, Any]) -> dict[str, Any]: result = {} for key, value in data.items(): updated_key = key.replace(old_key, new_key) diff --git a/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py index 53db3d49a..cabc6761d 100644 --- a/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py +++ b/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py @@ -3,7 +3,7 @@ # from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.sources.declarative.transformations import RecordTransformation from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -13,10 +13,10 @@ class KeysToLowerTransformation(RecordTransformation): def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, # noqa: ARG002 + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 ) -> None: for key in set(record.keys()): record[key.lower()] = record.pop(key) diff --git a/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py index 86e25c399..c6e645706 100644 --- a/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py +++ b/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py @@ -4,7 +4,7 @@ import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any import unidecode @@ -20,16 +20,16 @@ class KeysToSnakeCaseTransformation(RecordTransformation): def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, # noqa: ARG002 + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 ) -> None: transformed_record = self._transform_record(record) record.clear() record.update(transformed_record) - def _transform_record(self, record: Dict[str, Any]) -> Dict[str, Any]: + def _transform_record(self, record: dict[str, Any]) -> dict[str, Any]: transformed_record = {} for key, value in record.items(): transformed_key = self.process_key(key) @@ -50,19 +50,19 @@ def process_key(self, key: str) -> str: def normalize_key(self, key: str) -> str: return unidecode.unidecode(key) - def tokenize_key(self, key: str) -> List[str]: + def tokenize_key(self, key: str) -> list[str]: tokens = [] for match in self.token_pattern.finditer(key): token = match.group(0) if match.group("NoToken") is None else "" tokens.append(token) return tokens - def filter_tokens(self, tokens: List[str]) -> List[str]: - if len(tokens) >= 3: + def filter_tokens(self, tokens: list[str]) -> list[str]: + if len(tokens) >= 3: # noqa: PLR2004 tokens = tokens[:1] + [t for t in tokens[1:-1] if t] + tokens[-1:] if tokens and tokens[0].isdigit(): tokens.insert(0, "") return tokens - def tokens_to_snake_case(self, tokens: List[str]) -> str: + def tokens_to_snake_case(self, tokens: list[str]) -> str: return "_".join(token.lower() for token in tokens) diff --git a/airbyte_cdk/sources/declarative/transformations/remove_fields.py b/airbyte_cdk/sources/declarative/transformations/remove_fields.py index f5d8164df..1aee83bb8 100644 --- a/airbyte_cdk/sources/declarative/transformations/remove_fields.py +++ b/airbyte_cdk/sources/declarative/transformations/remove_fields.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from dataclasses import InitVar, dataclass -from typing import Any, Dict, List, Mapping, Optional +from typing import Any import dpath import dpath.exceptions @@ -40,7 +41,7 @@ class RemoveFields(RecordTransformation): field_pointers (List[FieldPointer]): pointers to the fields that should be removed """ - field_pointers: List[FieldPointer] + field_pointers: list[FieldPointer] parameters: InitVar[Mapping[str, Any]] condition: str = "" @@ -51,10 +52,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, # noqa: ARG002 ) -> None: """ :param record: The record to be transformed @@ -62,7 +63,7 @@ def transform( """ for pointer in self.field_pointers: # the dpath library by default doesn't delete fields from arrays - try: + try: # noqa: SIM105 dpath.delete( record, pointer, diff --git a/airbyte_cdk/sources/declarative/transformations/transformation.py b/airbyte_cdk/sources/declarative/transformations/transformation.py index f5b226429..15b67b0b9 100644 --- a/airbyte_cdk/sources/declarative/transformations/transformation.py +++ b/airbyte_cdk/sources/declarative/transformations/transformation.py @@ -4,13 +4,13 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @dataclass -class RecordTransformation: +class RecordTransformation: # noqa: PLW1641 """ Implementations of this class define transformations that can be applied to records of a stream. """ @@ -18,10 +18,10 @@ class RecordTransformation: @abstractmethod def transform( self, - record: Dict[str, Any], - config: Optional[Config] = None, - stream_state: Optional[StreamState] = None, - stream_slice: Optional[StreamSlice] = None, + record: dict[str, Any], + config: Config | None = None, + stream_state: StreamState | None = None, + stream_slice: StreamSlice | None = None, ) -> None: """ Transform a record by adding, deleting, or mutating fields directly from the record reference passed in argument. diff --git a/airbyte_cdk/sources/declarative/types.py b/airbyte_cdk/sources/declarative/types.py index a4d0aeb1d..80f297e7f 100644 --- a/airbyte_cdk/sources/declarative/types.py +++ b/airbyte_cdk/sources/declarative/types.py @@ -1,4 +1,4 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # @@ -13,13 +13,14 @@ StreamState, ) + # Note: This package originally contained class definitions for low-code CDK types, but we promoted them into the Python CDK. # We've migrated connectors in the repository to reference the new location, but these assignments are used to retain backwards # compatibility for sources created by OSS customers or on forks. This can be removed when we start bumping major versions. -FieldPointer = FieldPointer -Config = Config -ConnectionDefinition = ConnectionDefinition -StreamState = StreamState -Record = Record -StreamSlice = StreamSlice +FieldPointer = FieldPointer # noqa: PLW0127 +Config = Config # noqa: PLW0127 +ConnectionDefinition = ConnectionDefinition # noqa: PLW0127 +StreamState = StreamState # noqa: PLW0127 +Record = Record # noqa: PLW0127 +StreamSlice = StreamSlice # noqa: PLW0127 diff --git a/airbyte_cdk/sources/declarative/yaml_declarative_source.py b/airbyte_cdk/sources/declarative/yaml_declarative_source.py index 04ccda4cf..b94b387a3 100644 --- a/airbyte_cdk/sources/declarative/yaml_declarative_source.py +++ b/airbyte_cdk/sources/declarative/yaml_declarative_source.py @@ -3,7 +3,8 @@ # import pkgutil -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any import yaml @@ -14,16 +15,16 @@ from airbyte_cdk.sources.types import ConnectionDefinition -class YamlDeclarativeSource(ConcurrentDeclarativeSource[List[AirbyteStateMessage]]): +class YamlDeclarativeSource(ConcurrentDeclarativeSource[list[AirbyteStateMessage]]): """Declarative source defined by a yaml file""" def __init__( self, path_to_yaml: str, - debug: bool = False, - catalog: Optional[ConfiguredAirbyteCatalog] = None, - config: Optional[Mapping[str, Any]] = None, - state: Optional[List[AirbyteStateMessage]] = None, + debug: bool = False, # noqa: FBT001, FBT002, ARG002 + catalog: ConfiguredAirbyteCatalog | None = None, + config: Mapping[str, Any] | None = None, + state: list[AirbyteStateMessage] | None = None, ) -> None: """ :param path_to_yaml: Path to the yaml file describing the source @@ -45,8 +46,7 @@ def _read_and_parse_yaml_file(self, path_to_yaml_file: str) -> ConnectionDefinit if yaml_config: decoded_yaml = yaml_config.decode() return self._parse(decoded_yaml) - else: - return {} + return {} def _emit_manifest_debug_message(self, extra_args: dict[str, Any]) -> None: extra_args["path_to_yaml"] = self._path_to_yaml diff --git a/airbyte_cdk/sources/embedded/base_integration.py b/airbyte_cdk/sources/embedded/base_integration.py index 77917b0a1..f16a03ab2 100644 --- a/airbyte_cdk/sources/embedded/base_integration.py +++ b/airbyte_cdk/sources/embedded/base_integration.py @@ -3,7 +3,8 @@ # from abc import ABC, abstractmethod -from typing import Generic, Iterable, Optional, TypeVar +from collections.abc import Iterable +from typing import Generic, TypeVar from airbyte_cdk.connector import TConfig from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type @@ -16,27 +17,28 @@ from airbyte_cdk.sources.embedded.tools import get_defined_id from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit + TOutput = TypeVar("TOutput") class BaseEmbeddedIntegration(ABC, Generic[TConfig, TOutput]): - def __init__(self, runner: SourceRunner[TConfig], config: TConfig): + def __init__(self, runner: SourceRunner[TConfig], config: TConfig): # noqa: ANN204 check_config_against_spec_or_exit(config, runner.spec()) self.source = runner self.config = config - self.last_state: Optional[AirbyteStateMessage] = None + self.last_state: AirbyteStateMessage | None = None @abstractmethod - def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Optional[TOutput]: + def _handle_record(self, record: AirbyteRecordMessage, id: str | None) -> TOutput | None: # noqa: A002 """ Turn an Airbyte record into the appropriate output type for the integration. """ pass def _load_data( - self, stream_name: str, state: Optional[AirbyteStateMessage] = None + self, stream_name: str, state: AirbyteStateMessage | None = None ) -> Iterable[TOutput]: catalog = self.source.discover(self.config) stream = get_stream(catalog, stream_name) diff --git a/airbyte_cdk/sources/embedded/catalog.py b/airbyte_cdk/sources/embedded/catalog.py index 62c7a623d..935bfca97 100644 --- a/airbyte_cdk/sources/embedded/catalog.py +++ b/airbyte_cdk/sources/embedded/catalog.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import List, Optional from airbyte_cdk.models import ( AirbyteCatalog, @@ -15,11 +14,11 @@ from airbyte_cdk.sources.embedded.tools import get_first -def get_stream(catalog: AirbyteCatalog, stream_name: str) -> Optional[AirbyteStream]: +def get_stream(catalog: AirbyteCatalog, stream_name: str) -> AirbyteStream | None: return get_first(catalog.streams, lambda s: s.name == stream_name) -def get_stream_names(catalog: AirbyteCatalog) -> List[str]: +def get_stream_names(catalog: AirbyteCatalog) -> list[str]: return [stream.name for stream in catalog.streams] @@ -27,8 +26,8 @@ def to_configured_stream( stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh, destination_sync_mode: DestinationSyncMode = DestinationSyncMode.append, - cursor_field: Optional[List[str]] = None, - primary_key: Optional[List[List[str]]] = None, + cursor_field: list[str] | None = None, + primary_key: list[list[str]] | None = None, ) -> ConfiguredAirbyteStream: return ConfiguredAirbyteStream( stream=stream, @@ -40,7 +39,7 @@ def to_configured_stream( def to_configured_catalog( - configured_streams: List[ConfiguredAirbyteStream], + configured_streams: list[ConfiguredAirbyteStream], ) -> ConfiguredAirbyteCatalog: return ConfiguredAirbyteCatalog(streams=configured_streams) diff --git a/airbyte_cdk/sources/embedded/runner.py b/airbyte_cdk/sources/embedded/runner.py index 43217f156..da45b6e21 100644 --- a/airbyte_cdk/sources/embedded/runner.py +++ b/airbyte_cdk/sources/embedded/runner.py @@ -5,7 +5,8 @@ import logging from abc import ABC, abstractmethod -from typing import Generic, Iterable, Optional +from collections.abc import Iterable +from typing import Generic from airbyte_cdk.connector import TConfig from airbyte_cdk.models import ( @@ -32,13 +33,13 @@ def read( self, config: TConfig, catalog: ConfiguredAirbyteCatalog, - state: Optional[AirbyteStateMessage], + state: AirbyteStateMessage | None, ) -> Iterable[AirbyteMessage]: pass class CDKRunner(SourceRunner[TConfig]): - def __init__(self, source: Source, name: str): + def __init__(self, source: Source, name: str): # noqa: ANN204 self._source = source self._logger = logging.getLogger(name) @@ -52,6 +53,6 @@ def read( self, config: TConfig, catalog: ConfiguredAirbyteCatalog, - state: Optional[AirbyteStateMessage], + state: AirbyteStateMessage | None, ) -> Iterable[AirbyteMessage]: return self._source.read(self._logger, config, catalog, state=[state] if state else []) diff --git a/airbyte_cdk/sources/embedded/tools.py b/airbyte_cdk/sources/embedded/tools.py index 1ddb29b3a..dbfd1de07 100644 --- a/airbyte_cdk/sources/embedded/tools.py +++ b/airbyte_cdk/sources/embedded/tools.py @@ -2,7 +2,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Callable, Dict, Iterable, Optional +from collections.abc import Callable, Iterable +from typing import Any import dpath @@ -10,12 +11,13 @@ def get_first( - iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True -) -> Optional[Any]: + iterable: Iterable[Any], + predicate: Callable[[Any], bool] = lambda m: True, # noqa: ARG005 +) -> Any | None: # noqa: ANN401 return next(filter(predicate, iterable), None) -def get_defined_id(stream: AirbyteStream, data: Dict[str, Any]) -> Optional[str]: +def get_defined_id(stream: AirbyteStream, data: dict[str, Any]) -> str | None: if not stream.source_defined_primary_key: return None primary_key = [] diff --git a/airbyte_cdk/sources/file_based/__init__.py b/airbyte_cdk/sources/file_based/__init__.py index 6ea0ca31e..81bd9d8ec 100644 --- a/airbyte_cdk/sources/file_based/__init__.py +++ b/airbyte_cdk/sources/file_based/__init__.py @@ -3,11 +3,12 @@ from .config.file_based_stream_config import FileBasedStreamConfig from .config.jsonl_format import JsonlFormat from .exceptions import CustomFileBasedException, ErrorListingFiles, FileBasedSourceError -from .file_based_source import DEFAULT_CONCURRENCY, FileBasedSource +from .file_based_source import DEFAULT_CONCURRENCY, FileBasedSource # noqa: F401 from .file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode from .remote_file import RemoteFile from .stream.cursor import DefaultFileBasedCursor + __all__ = [ "AbstractFileBasedSpec", "AbstractFileBasedStreamReader", diff --git a/airbyte_cdk/sources/file_based/availability_strategy/__init__.py b/airbyte_cdk/sources/file_based/availability_strategy/__init__.py index 8134a89e0..2356a9a2d 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/__init__.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/__init__.py @@ -4,6 +4,7 @@ ) from .default_file_based_availability_strategy import DefaultFileBasedAvailabilityStrategy + __all__ = [ "AbstractFileBasedAvailabilityStrategy", "AbstractFileBasedAvailabilityStrategyWrapper", diff --git a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py index 12e1740b6..ee83a23d0 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py @@ -6,9 +6,9 @@ import logging from abc import abstractmethod -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING -from airbyte_cdk.sources import Source +from airbyte_cdk.sources import Source # noqa: TC001 from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( AbstractAvailabilityStrategy, @@ -16,7 +16,8 @@ StreamAvailable, StreamUnavailable, ) -from airbyte_cdk.sources.streams.core import Stream +from airbyte_cdk.sources.streams.core import Stream # noqa: TC001 + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream @@ -28,8 +29,8 @@ def check_availability( # type: ignore[override] # Signature doesn't match bas self, stream: Stream, logger: logging.Logger, - _: Optional[Source], - ) -> Tuple[bool, Optional[str]]: + _: Source | None, + ) -> tuple[bool, str | None]: """ Perform a connection check for the stream. @@ -42,8 +43,8 @@ def check_availability_and_parsability( self, stream: AbstractFileBasedStream, logger: logging.Logger, - _: Optional[Source], - ) -> Tuple[bool, Optional[str]]: + _: Source | None, + ) -> tuple[bool, str | None]: """ Performs a connection check for the stream, as well as additional checks that verify that the connection is working as expected. @@ -65,9 +66,7 @@ def check_availability(self, logger: logging.Logger) -> StreamAvailability: return StreamAvailable() return StreamUnavailable(reason or "") - def check_availability_and_parsability( - self, logger: logging.Logger - ) -> Tuple[bool, Optional[str]]: + def check_availability_and_parsability(self, logger: logging.Logger) -> tuple[bool, str | None]: return self.stream.availability_strategy.check_availability_and_parsability( self.stream, logger, None ) diff --git a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py index c9d416a72..b6304a50a 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py @@ -2,14 +2,14 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from __future__ import annotations +from __future__ import annotations # noqa: I001 import logging import traceback -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING from airbyte_cdk import AirbyteTracedException -from airbyte_cdk.sources import Source +from airbyte_cdk.sources import Source # noqa: TC001 from airbyte_cdk.sources.file_based.availability_strategy import ( AbstractFileBasedAvailabilityStrategy, ) @@ -18,10 +18,11 @@ CustomFileBasedException, FileBasedSourceError, ) -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader -from airbyte_cdk.sources.file_based.remote_file import RemoteFile +from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader # noqa: TC001 +from airbyte_cdk.sources.file_based.remote_file import RemoteFile # noqa: TC001 from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream @@ -33,9 +34,9 @@ def __init__(self, stream_reader: AbstractFileBasedStreamReader) -> None: def check_availability( # type: ignore[override] # Signature doesn't match base class self, stream: AbstractFileBasedStream, - logger: logging.Logger, - _: Optional[Source], - ) -> Tuple[bool, Optional[str]]: + logger: logging.Logger, # noqa: ARG002 + _: Source | None, + ) -> tuple[bool, str | None]: """ Perform a connection check for the stream (verify that we can list files from the stream). @@ -52,8 +53,8 @@ def check_availability_and_parsability( self, stream: AbstractFileBasedStream, logger: logging.Logger, - _: Optional[Source], - ) -> Tuple[bool, Optional[str]]: + _: Source | None, + ) -> tuple[bool, str | None]: """ Perform a connection check for the stream. @@ -77,14 +78,14 @@ def check_availability_and_parsability( return False, config_check_error_message try: file = self._check_list_files(stream) - if not parser.parser_max_n_files_for_parsability == 0: + if parser.parser_max_n_files_for_parsability != 0: self._check_parse_record(stream, file, logger) else: # If the parser is set to not check parsability, we still want to check that we can open the file. handle = stream.stream_reader.open_file(file, parser.file_read_mode, None, logger) handle.close() except AirbyteTracedException as ate: - raise ate + raise ate # noqa: TRY201 except CheckAvailabilityError: return False, "".join(traceback.format_exc()) @@ -99,7 +100,7 @@ def _check_list_files(self, stream: AbstractFileBasedStream) -> RemoteFile: try: file = next(iter(stream.get_files())) except StopIteration: - raise CheckAvailabilityError(FileBasedSourceError.EMPTY_STREAM, stream=stream.name) + raise CheckAvailabilityError(FileBasedSourceError.EMPTY_STREAM, stream=stream.name) # noqa: B904 except CustomFileBasedException as exc: raise CheckAvailabilityError(str(exc), stream=stream.name) from exc except Exception as exc: @@ -131,14 +132,14 @@ def _check_parse_record( # we skip the schema validation check. return except AirbyteTracedException as ate: - raise ate + raise ate # noqa: TRY201 except Exception as exc: raise CheckAvailabilityError( FileBasedSourceError.ERROR_READING_FILE, stream=stream.name, file=file.uri ) from exc schema = stream.catalog_schema or stream.config.input_schema - if schema and stream.validation_policy.validate_schema_before_sync: + if schema and stream.validation_policy.validate_schema_before_sync: # noqa: SIM102 if not conforms_to_schema(record, schema): # type: ignore raise CheckAvailabilityError( FileBasedSourceError.ERROR_VALIDATING_RECORD, @@ -146,4 +147,4 @@ def _check_parse_record( file=file.uri, ) - return None + return diff --git a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py index 626d50fef..26fd8da14 100644 --- a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py +++ b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py @@ -4,7 +4,7 @@ import copy from abc import abstractmethod -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal import dpath from pydantic.v1 import AnyUrl, BaseModel, Field @@ -49,7 +49,7 @@ class AbstractFileBasedSpec(BaseModel): that are needed when users configure a file-based source. """ - start_date: Optional[str] = Field( + start_date: str | None = Field( title="Start Date", description="UTC date and time in the format 2017-01-25T00:00:00.000000Z. Any file modified before this date will not be replicated.", examples=["2021-01-01T00:00:00.000000Z"], @@ -59,13 +59,13 @@ class AbstractFileBasedSpec(BaseModel): order=1, ) - streams: List[FileBasedStreamConfig] = Field( + streams: list[FileBasedStreamConfig] = Field( title="The list of streams to sync", description='Each instance of this configuration defines a stream. Use this to define which files belong in the stream, their format, and how they should be parsed and validated. When sending data to warehouse destination such as Snowflake or BigQuery, each stream is a separate table.', order=10, ) - delivery_method: Union[DeliverRecords, DeliverRawFiles] = Field( + delivery_method: DeliverRecords | DeliverRawFiles = Field( title="Delivery Method", discriminator="delivery_type", type="object", @@ -84,12 +84,12 @@ def documentation_url(cls) -> AnyUrl: """ @classmethod - def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401 """ Generates the mapping comprised of the config fields """ schema = super().schema(*args, **kwargs) - transformed_schema: Dict[str, Any] = copy.deepcopy(schema) + transformed_schema: dict[str, Any] = copy.deepcopy(schema) schema_helpers.expand_refs(transformed_schema) cls.replace_enum_allOf_and_anyOf(transformed_schema) cls.remove_discriminator(transformed_schema) @@ -97,12 +97,12 @@ def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: return transformed_schema @staticmethod - def remove_discriminator(schema: Dict[str, Any]) -> None: + def remove_discriminator(schema: dict[str, Any]) -> None: """pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references""" dpath.delete(schema, "properties/**/discriminator") @staticmethod - def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: + def replace_enum_allOf_and_anyOf(schema: dict[str, Any]) -> dict[str, Any]: # noqa: N802 """ allOfs are not supported by the UI, but pydantic is automatically writing them for enums. Unpacks the enums under allOf and moves them up a level under the enum key @@ -112,7 +112,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: objects_to_check = schema["properties"]["streams"]["items"]["properties"]["format"] objects_to_check["type"] = "object" objects_to_check["oneOf"] = objects_to_check.pop("anyOf", []) - for format in objects_to_check["oneOf"]: + for format in objects_to_check["oneOf"]: # noqa: A001 for key in format["properties"]: object_property = format["properties"][key] AbstractFileBasedSpec.move_enum_to_root(object_property) @@ -133,7 +133,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: csv_format_schemas = list( filter( - lambda format: format["properties"]["filetype"]["default"] == "csv", + lambda format: format["properties"]["filetype"]["default"] == "csv", # noqa: A006 schema["properties"]["streams"]["items"]["properties"]["format"]["oneOf"], ) ) @@ -146,7 +146,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]: return schema @staticmethod - def move_enum_to_root(object_property: Dict[str, Any]) -> None: + def move_enum_to_root(object_property: dict[str, Any]) -> None: if "allOf" in object_property and "enum" in object_property["allOf"][0]: object_property["enum"] = object_property["allOf"][0]["enum"] object_property.pop("allOf") diff --git a/airbyte_cdk/sources/file_based/config/csv_format.py b/airbyte_cdk/sources/file_based/config/csv_format.py index 1441d8411..9ce100892 100644 --- a/airbyte_cdk/sources/file_based/config/csv_format.py +++ b/airbyte_cdk/sources/file_based/config/csv_format.py @@ -4,7 +4,7 @@ import codecs from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any from pydantic.v1 import BaseModel, Field, root_validator, validator from pydantic.v1.error_wrappers import ValidationError @@ -60,7 +60,7 @@ class Config(OneOfOptionConfig): CsvHeaderDefinitionType.USER_PROVIDED.value, const=True, ) - column_names: List[str] = Field( + column_names: list[str] = Field( title="Column Names", description="The column names that will be used while emitting the CSV records", ) @@ -69,7 +69,7 @@ def has_header_row(self) -> bool: return False @validator("column_names") - def validate_column_names(cls, v: List[str]) -> List[str]: + def validate_column_names(cls, v: list[str]) -> list[str]: if not v: raise ValueError( "At least one column name needs to be provided when using user provided headers" @@ -100,12 +100,12 @@ class Config(OneOfOptionConfig): default='"', description="The character used for quoting CSV values. To disallow quoting, make this field blank.", ) - escape_char: Optional[str] = Field( + escape_char: str | None = Field( title="Escape Character", default=None, description="The character used for escaping special characters. To disallow escaping, leave this field blank.", ) - encoding: Optional[str] = Field( + encoding: str | None = Field( default="utf8", description='The character encoding of the CSV data. Leave blank to default to UTF8. See list of python encodings for allowable options.', ) @@ -114,7 +114,7 @@ class Config(OneOfOptionConfig): default=True, description="Whether two quotes in a quoted CSV value denote a single quote in the data.", ) - null_values: Set[str] = Field( + null_values: set[str] = Field( title="Null Values", default=[], description="A set of case-sensitive strings that should be interpreted as null values. For example, if the value 'NA' should be interpreted as null, enter 'NA' in this field.", @@ -134,19 +134,17 @@ class Config(OneOfOptionConfig): default=0, description="The number of rows to skip after the header row.", ) - header_definition: Union[CsvHeaderFromCsv, CsvHeaderAutogenerated, CsvHeaderUserProvided] = ( - Field( - title="CSV Header Definition", - default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value), - description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", - ) + header_definition: CsvHeaderFromCsv | CsvHeaderAutogenerated | CsvHeaderUserProvided = Field( + title="CSV Header Definition", + default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value), + description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.", ) - true_values: Set[str] = Field( + true_values: set[str] = Field( title="True Values", default=DEFAULT_TRUE_VALUES, description="A set of case-sensitive strings that should be interpreted as true values.", ) - false_values: Set[str] = Field( + false_values: set[str] = Field( title="False Values", default=DEFAULT_FALSE_VALUES, description="A set of case-sensitive strings that should be interpreted as false values.", @@ -190,11 +188,11 @@ def validate_encoding(cls, v: str) -> str: try: codecs.lookup(v) except LookupError: - raise ValueError(f"invalid encoding format: {v}") + raise ValueError(f"invalid encoding format: {v}") # noqa: B904 return v @root_validator - def validate_optional_args(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def validate_optional_args(cls, values: dict[str, Any]) -> dict[str, Any]: definition_type = values.get("header_definition_type") column_names = values.get("user_provided_column_names") if definition_type == CsvHeaderDefinitionType.USER_PROVIDED and not column_names: diff --git a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py index eb592a4aa..7b481c9df 100644 --- a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py +++ b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py @@ -2,8 +2,9 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # +from collections.abc import Mapping from enum import Enum -from typing import Any, List, Mapping, Optional, Union +from typing import Any from pydantic.v1 import BaseModel, Field, validator @@ -16,7 +17,8 @@ from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError from airbyte_cdk.sources.file_based.schema_helpers import type_mapping_to_jsonschema -PrimaryKeyType = Optional[Union[str, List[str]]] + +PrimaryKeyType = str | list[str] | None class ValidationPolicy(Enum): @@ -27,13 +29,13 @@ class ValidationPolicy(Enum): class FileBasedStreamConfig(BaseModel): name: str = Field(title="Name", description="The name of the stream.") - globs: Optional[List[str]] = Field( + globs: list[str] | None = Field( default=["**"], title="Globs", description='The pattern used to specify which files should be selected from the file system. For more information on glob pattern matching look here.', order=1, ) - legacy_prefix: Optional[str] = Field( + legacy_prefix: str | None = Field( title="Legacy Prefix", description="The path prefix configured in v3 versions of the S3 connector. This option is deprecated in favor of a single glob.", airbyte_hidden=True, @@ -43,11 +45,11 @@ class FileBasedStreamConfig(BaseModel): description="The name of the validation policy that dictates sync behavior when a record does not adhere to the stream schema.", default=ValidationPolicy.emit_record, ) - input_schema: Optional[str] = Field( + input_schema: str | None = Field( title="Input Schema", description="The schema that will be used to validate records extracted from the file. This will override the stream schema that is auto-detected from incoming files.", ) - primary_key: Optional[str] = Field( + primary_key: str | None = Field( title="Primary Key", description="The column or columns (for a composite key) that serves as the unique identifier of a record. If empty, the primary key will default to the parser's default primary key.", airbyte_hidden=True, # Users can create/modify primary keys in the connection configuration so we shouldn't duplicate it here. @@ -57,9 +59,9 @@ class FileBasedStreamConfig(BaseModel): description="When the state history of the file store is full, syncs will only read files that were last modified in the provided day range.", default=3, ) - format: Union[ - AvroFormat, CsvFormat, JsonlFormat, ParquetFormat, UnstructuredFormat, ExcelFormat - ] = Field( + format: ( + AvroFormat | CsvFormat | JsonlFormat | ParquetFormat | UnstructuredFormat | ExcelFormat + ) = Field( title="Format", description="The configuration options that are used to alter how to read incoming files that deviate from the standard formatting.", ) @@ -68,7 +70,7 @@ class FileBasedStreamConfig(BaseModel): description="When enabled, syncs will not validate or structure records against the stream's schema.", default=False, ) - recent_n_files_to_read_for_schema_discovery: Optional[int] = Field( + recent_n_files_to_read_for_schema_discovery: int | None = Field( title="Files To Read For Schema Discover", description="The number of resent files which will be used to discover the schema for this stream.", default=None, @@ -76,15 +78,14 @@ class FileBasedStreamConfig(BaseModel): ) @validator("input_schema", pre=True) - def validate_input_schema(cls, v: Optional[str]) -> Optional[str]: + def validate_input_schema(cls, v: str | None) -> str | None: if v: if type_mapping_to_jsonschema(v): return v - else: - raise ConfigValidationError(FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA) + raise ConfigValidationError(FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA) return None - def get_input_schema(self) -> Optional[Mapping[str, Any]]: + def get_input_schema(self) -> Mapping[str, Any] | None: """ User defined input_schema is defined as a string in the config. This method takes the string representation and converts it into a Mapping[str, Any] which is used by file-based CDK components. diff --git a/airbyte_cdk/sources/file_based/config/unstructured_format.py b/airbyte_cdk/sources/file_based/config/unstructured_format.py index c03540ce6..7f8ae9369 100644 --- a/airbyte_cdk/sources/file_based/config/unstructured_format.py +++ b/airbyte_cdk/sources/file_based/config/unstructured_format.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import List, Literal, Optional, Union +from typing import Literal from pydantic.v1 import BaseModel, Field @@ -50,7 +50,7 @@ class APIProcessingConfigModel(BaseModel): examples=["https://api.unstructured.com"], ) - parameters: Optional[List[APIParameterConfigModel]] = Field( + parameters: list[APIParameterConfigModel] | None = Field( default=[], always_show=True, title="Additional URL Parameters", @@ -90,10 +90,7 @@ class Config(OneOfOptionConfig): description="The strategy used to parse documents. `fast` extracts text directly from the document which doesn't work for all files. `ocr_only` is more reliable, but slower. `hi_res` is the most reliable, but requires an API key and a hosted instance of unstructured and can't be used with local mode. See the unstructured.io documentation for more details: https://unstructured-io.github.io/unstructured/core/partition.html#partition-pdf", ) - processing: Union[ - LocalProcessingConfigModel, - APIProcessingConfigModel, - ] = Field( + processing: LocalProcessingConfigModel | APIProcessingConfigModel = Field( default=LocalProcessingConfigModel(mode="local"), title="Processing", description="Processing configuration", diff --git a/airbyte_cdk/sources/file_based/discovery_policy/__init__.py b/airbyte_cdk/sources/file_based/discovery_policy/__init__.py index 6d0f231a3..7d5bdd887 100644 --- a/airbyte_cdk/sources/file_based/discovery_policy/__init__.py +++ b/airbyte_cdk/sources/file_based/discovery_policy/__init__.py @@ -5,4 +5,5 @@ DefaultDiscoveryPolicy, ) + __all__ = ["AbstractDiscoveryPolicy", "DefaultDiscoveryPolicy"] diff --git a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py index f651c2ce1..24a0d9158 100644 --- a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py +++ b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py @@ -7,6 +7,7 @@ ) from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser + DEFAULT_N_CONCURRENT_REQUESTS = 10 DEFAULT_MAX_N_FILES_FOR_STREAM_SCHEMA_INFERENCE = 10 diff --git a/airbyte_cdk/sources/file_based/exceptions.py b/airbyte_cdk/sources/file_based/exceptions.py index b0d38947f..f2af7c2a8 100644 --- a/airbyte_cdk/sources/file_based/exceptions.py +++ b/airbyte_cdk/sources/file_based/exceptions.py @@ -3,7 +3,7 @@ # from enum import Enum -from typing import Any, List, Union +from typing import Any from airbyte_cdk.models import AirbyteMessage, FailureType from airbyte_cdk.utils import AirbyteTracedException @@ -43,9 +43,9 @@ class FileBasedErrorsCollector: The placeholder for all errors collected. """ - errors: List[AirbyteMessage] = [] + errors: list[AirbyteMessage] = [] # noqa: RUF012 - def yield_and_raise_collected(self) -> Any: + def yield_and_raise_collected(self) -> Any: # noqa: ANN401 if self.errors: # emit collected logged messages yield from self.errors @@ -63,7 +63,7 @@ def collect(self, logged_error: AirbyteMessage) -> None: class BaseFileBasedSourceError(Exception): - def __init__(self, error: Union[FileBasedSourceError, str], **kwargs): # type: ignore # noqa + def __init__(self, error: FileBasedSourceError | str, **kwargs) -> None: # type: ignore if isinstance(error, FileBasedSourceError): error = FileBasedSourceError(error).value super().__init__( @@ -112,7 +112,7 @@ class ErrorListingFiles(BaseFileBasedSourceError): class DuplicatedFilesError(BaseFileBasedSourceError): - def __init__(self, duplicated_files_names: List[dict[str, List[str]]], **kwargs: Any): + def __init__(self, duplicated_files_names: list[dict[str, list[str]]], **kwargs) -> None: self._duplicated_files_names = duplicated_files_names self._stream_name: str = kwargs["stream"] super().__init__(self._format_duplicate_files_error_message(), **kwargs) diff --git a/airbyte_cdk/sources/file_based/file_based_source.py b/airbyte_cdk/sources/file_based/file_based_source.py index 0eb90ac24..99fa3bcc1 100644 --- a/airbyte_cdk/sources/file_based/file_based_source.py +++ b/airbyte_cdk/sources/file_based/file_based_source.py @@ -6,7 +6,8 @@ import traceback from abc import ABC from collections import Counter -from typing import Any, Iterator, List, Mapping, Optional, Tuple, Type, Union +from collections.abc import Iterator, Mapping +from typing import Any from pydantic.v1.error_wrappers import ValidationError @@ -63,6 +64,7 @@ from airbyte_cdk.utils.analytics_message import create_analytics_message from airbyte_cdk.utils.traced_exception import AirbyteTracedException + DEFAULT_CONCURRENCY = 100 MAX_CONCURRENCY = 100 INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2 @@ -72,21 +74,21 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC): # We make each source override the concurrency level to give control over when they are upgraded. _concurrency_level = None - def __init__( + def __init__( # noqa: ANN204, PLR0913, PLR0917 self, stream_reader: AbstractFileBasedStreamReader, - spec_class: Type[AbstractFileBasedSpec], - catalog: Optional[ConfiguredAirbyteCatalog], - config: Optional[Mapping[str, Any]], - state: Optional[List[AirbyteStateMessage]], - availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None, - discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(), - parsers: Mapping[Type[Any], FileTypeParser] = default_parsers, + spec_class: type[AbstractFileBasedSpec], + catalog: ConfiguredAirbyteCatalog | None, + config: Mapping[str, Any] | None, + state: list[AirbyteStateMessage] | None, + availability_strategy: AbstractFileBasedAvailabilityStrategy | None = None, + discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(), # noqa: B008 + parsers: Mapping[type[Any], FileTypeParser] = default_parsers, validation_policies: Mapping[ ValidationPolicy, AbstractSchemaValidationPolicy ] = DEFAULT_SCHEMA_VALIDATION_POLICIES, - cursor_cls: Type[ - Union[AbstractConcurrentFileBasedCursor, AbstractFileBasedCursor] + cursor_cls: type[ + AbstractConcurrentFileBasedCursor | AbstractFileBasedCursor ] = FileBasedConcurrentCursor, ): self.stream_reader = stream_reader @@ -106,7 +108,7 @@ def __init__( self.cursor_cls = cursor_cls self.logger = init_logger(f"airbyte.{self.name}") self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector() - self._message_repository: Optional[MessageRepository] = None + self._message_repository: MessageRepository | None = None concurrent_source = ConcurrentSource.create( MAX_CONCURRENCY, INITIAL_N_PARTITIONS, @@ -127,7 +129,7 @@ def message_repository(self) -> MessageRepository: def check_connection( self, logger: logging.Logger, config: Mapping[str, Any] - ) -> Tuple[bool, Optional[Any]]: + ) -> tuple[bool, Any | None]: """ Check that the source can be accessed using the user-provided configuration. @@ -140,7 +142,7 @@ def check_connection( try: streams = self.streams(config) except Exception as config_exception: - raise AirbyteTracedException( + raise AirbyteTracedException( # noqa: B904 internal_message="Please check the logged errors for more information.", message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value, exception=AirbyteTracedException(exception=config_exception), @@ -158,7 +160,7 @@ def check_connection( tracebacks = [] for stream in streams: if not isinstance(stream, AbstractFileBasedStream): - raise ValueError(f"Stream {stream} is not a file-based stream.") + raise ValueError(f"Stream {stream} is not a file-based stream.") # noqa: TRY004 try: parsed_config = self._get_parsed_config(config) availability_method = ( @@ -191,7 +193,7 @@ def check_connection( message=f"{errors[0]}", failure_type=FailureType.config_error, ) - elif len(errors) > 1: + if len(errors) > 1: raise AirbyteTracedException( internal_message="\n".join(tracebacks), message=f"{len(errors)} streams with errors: {', '.join(error for error in errors)}", @@ -200,12 +202,12 @@ def check_connection( return not bool(errors), (errors or None) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> list[Stream]: """ Return a list of this source's streams. """ - if self.catalog: + if self.catalog: # noqa: SIM108 state_manager = ConnectorStateManager(state=self.state) else: # During `check` operations we don't have a catalog so cannot create a state manager. @@ -215,7 +217,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: try: parsed_config = self._get_parsed_config(config) self.stream_reader.config = parsed_config - streams: List[Stream] = [] + streams: list[Stream] = [] for stream_config in parsed_config.streams: # Like state_manager, `catalog_stream` may be None during `check` catalog_stream = self._get_stream_from_catalog(stream_config) @@ -256,9 +258,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: and hasattr(self, "_concurrency_level") and self._concurrency_level is not None ): - assert ( - state_manager is not None - ), "No ConnectorStateManager was created, but it is required for incremental syncs. This is unexpected. Please contact Support." + assert state_manager is not None, ( + "No ConnectorStateManager was created, but it is required for incremental syncs. This is unexpected. Please contact Support." + ) cursor = self.cursor_cls( stream_config, @@ -289,7 +291,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: ) streams.append(stream) - return streams + return streams # noqa: TRY300 except ValidationError as exc: raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc @@ -297,7 +299,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: def _make_default_stream( self, stream_config: FileBasedStreamConfig, - cursor: Optional[AbstractFileBasedCursor], + cursor: AbstractFileBasedCursor | None, parsed_config: AbstractFileBasedSpec, ) -> AbstractFileBasedStream: return DefaultFileBasedStream( @@ -316,14 +318,14 @@ def _make_default_stream( def _get_stream_from_catalog( self, stream_config: FileBasedStreamConfig - ) -> Optional[AirbyteStream]: + ) -> AirbyteStream | None: if self.catalog: for stream in self.catalog.streams or []: if stream.stream.name == stream_config.name: return stream.stream return None - def _get_sync_mode_from_catalog(self, stream_name: str) -> Optional[SyncMode]: + def _get_sync_mode_from_catalog(self, stream_name: str) -> SyncMode | None: if self.catalog: for catalog_stream in self.catalog.streams: if stream_name == catalog_stream.stream.name: @@ -336,7 +338,7 @@ def read( logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, + state: list[AirbyteStateMessage] | None = None, ) -> Iterator[AirbyteMessage]: yield from super().read(logger, config, catalog, state) # emit all the errors collected @@ -348,7 +350,7 @@ def read( ).items(): yield create_analytics_message(f"file-cdk-{parser}-stream-count", count) - def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: + def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: # noqa: ANN401, ARG002 """ Returns the specification describing what fields can be configured by a user when setting up a file-based source. """ diff --git a/airbyte_cdk/sources/file_based/file_based_stream_reader.py b/airbyte_cdk/sources/file_based/file_based_stream_reader.py index 065125621..fcb6e0a31 100644 --- a/airbyte_cdk/sources/file_based/file_based_stream_reader.py +++ b/airbyte_cdk/sources/file_based/file_based_stream_reader.py @@ -4,11 +4,12 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Iterable from datetime import datetime from enum import Enum from io import IOBase from os import makedirs, path -from typing import Any, Dict, Iterable, List, Optional, Set +from typing import Any from wcmatch.glob import GLOBSTAR, globmatch @@ -28,7 +29,7 @@ def __init__(self) -> None: self._config = None @property - def config(self) -> Optional[AbstractFileBasedSpec]: + def config(self) -> AbstractFileBasedSpec | None: return self._config @config.setter @@ -47,7 +48,7 @@ def config(self, value: AbstractFileBasedSpec) -> None: @abstractmethod def open_file( - self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger + self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger ) -> IOBase: """ Return a file handle for reading. @@ -63,8 +64,8 @@ def open_file( @abstractmethod def get_matching_files( self, - globs: List[str], - prefix: Optional[str], + globs: list[str], + prefix: str | None, logger: logging.Logger, ) -> Iterable[RemoteFile]: """ @@ -84,7 +85,7 @@ def get_matching_files( ... def filter_files_by_globs_and_start_date( - self, files: List[RemoteFile], globs: List[str] + self, files: list[RemoteFile], globs: list[str] ) -> Iterable[RemoteFile]: """ Utility method for filtering files based on globs. @@ -97,7 +98,7 @@ def filter_files_by_globs_and_start_date( seen = set() for file in files: - if self.file_matches_globs(file, globs): + if self.file_matches_globs(file, globs): # noqa: SIM102 if file.uri not in seen and (not start_date or file.last_modified >= start_date): seen.add(file.uri) yield file @@ -113,13 +114,13 @@ def file_size(self, file: RemoteFile) -> int: ... @staticmethod - def file_matches_globs(file: RemoteFile, globs: List[str]) -> bool: + def file_matches_globs(file: RemoteFile, globs: list[str]) -> bool: # Use the GLOBSTAR flag to enable recursive ** matching # (https://facelessuser.github.io/wcmatch/wcmatch/#globstar) return any(globmatch(file.uri, g, flags=GLOBSTAR) for g in globs) @staticmethod - def get_prefixes_from_globs(globs: List[str]) -> Set[str]: + def get_prefixes_from_globs(globs: list[str]) -> set[str]: """ Utility method for extracting prefixes from the globs. """ @@ -149,7 +150,7 @@ def preserve_directory_structure(self) -> bool: @abstractmethod def get_file( self, file: RemoteFile, local_directory: str, logger: logging.Logger - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ This is required for connectors that will support writing to files. It will handle the logic to download,get,read,acquire or @@ -170,16 +171,16 @@ def get_file( """ ... - def _get_file_transfer_paths(self, file: RemoteFile, local_directory: str) -> List[str]: + def _get_file_transfer_paths(self, file: RemoteFile, local_directory: str) -> list[str]: preserve_directory_structure = self.preserve_directory_structure() if preserve_directory_structure: # Remove left slashes from source path format to make relative path for writing locally file_relative_path = file.uri.lstrip("/") else: - file_relative_path = path.basename(file.uri) - local_file_path = path.join(local_directory, file_relative_path) + file_relative_path = path.basename(file.uri) # noqa: PTH119 + local_file_path = path.join(local_directory, file_relative_path) # noqa: PTH118 # Ensure the local directory exists - makedirs(path.dirname(local_file_path), exist_ok=True) - absolute_file_path = path.abspath(local_file_path) + makedirs(path.dirname(local_file_path), exist_ok=True) # noqa: PTH103, PTH120 + absolute_file_path = path.abspath(local_file_path) # noqa: PTH100 return [file_relative_path, local_file_path, absolute_file_path] diff --git a/airbyte_cdk/sources/file_based/file_types/__init__.py b/airbyte_cdk/sources/file_based/file_types/__init__.py index b9d8f1d52..50446f893 100644 --- a/airbyte_cdk/sources/file_based/file_types/__init__.py +++ b/airbyte_cdk/sources/file_based/file_types/__init__.py @@ -1,11 +1,5 @@ -from typing import Any, Mapping, Type - -from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat -from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat -from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat -from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat -from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat -from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat +from collections.abc import Mapping +from typing import Any, Type # noqa: F401, UP035 from .avro_parser import AvroParser from .csv_parser import CsvParser @@ -15,8 +9,15 @@ from .jsonl_parser import JsonlParser from .parquet_parser import ParquetParser from .unstructured_parser import UnstructuredParser +from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat +from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat +from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat +from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat +from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat +from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat + -default_parsers: Mapping[Type[Any], FileTypeParser] = { +default_parsers: Mapping[type[Any], FileTypeParser] = { AvroFormat: AvroParser(), CsvFormat: CsvParser(), ExcelFormat: ExcelParser(), diff --git a/airbyte_cdk/sources/file_based/file_types/avro_parser.py b/airbyte_cdk/sources/file_based/file_types/avro_parser.py index e1aa2c4cb..6da06900b 100644 --- a/airbyte_cdk/sources/file_based/file_types/avro_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/avro_parser.py @@ -3,7 +3,8 @@ # import logging -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, cast +from collections.abc import Iterable, Mapping +from typing import Any, cast import fastavro @@ -18,6 +19,7 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType + AVRO_TYPE_TO_JSON_TYPE = { "null": "null", "boolean": "boolean", @@ -46,7 +48,7 @@ class AvroParser(FileTypeParser): ENCODING = None - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002 """ AvroParser does not require config checks, implicit pydantic validation is enough. """ @@ -61,12 +63,12 @@ async def infer_schema( ) -> SchemaType: avro_format = config.format if not isinstance(avro_format, AvroFormat): - raise ValueError(f"Expected ParquetFormat, got {avro_format}") + raise ValueError(f"Expected ParquetFormat, got {avro_format}") # noqa: TRY004 with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: avro_reader = fastavro.reader(fp) # type: ignore [arg-type] avro_schema = avro_reader.writer_schema - if not avro_schema["type"] == "record": # type: ignore [index, call-overload] + if avro_schema["type"] != "record": # type: ignore [index, call-overload] unsupported_type = avro_schema["type"] # type: ignore [index, call-overload] raise ValueError( f"Only record based avro files are supported. Found {unsupported_type}" @@ -82,7 +84,7 @@ async def infer_schema( return json_schema @classmethod - def _convert_avro_type_to_json( + def _convert_avro_type_to_json( # noqa: PLR0911, PLR0912 cls, avro_format: AvroFormat, field_name: str, avro_field: str ) -> Mapping[str, Any]: if isinstance(avro_field, str) and avro_field in AVRO_TYPE_TO_JSON_TYPE: @@ -101,7 +103,7 @@ def _convert_avro_type_to_json( for object_field in avro_field["fields"] }, } - elif avro_field["type"] == "array": + if avro_field["type"] == "array": if "items" not in avro_field: raise ValueError( f"{field_name} array type does not have a required field items" @@ -112,7 +114,7 @@ def _convert_avro_type_to_json( avro_format, "", avro_field["items"] ), } - elif avro_field["type"] == "enum": + if avro_field["type"] == "enum": if "symbols" not in avro_field: raise ValueError( f"{field_name} enum type does not have a required field symbols" @@ -120,7 +122,7 @@ def _convert_avro_type_to_json( if "name" not in avro_field: raise ValueError(f"{field_name} enum type does not have a required field name") return {"type": "string", "enum": avro_field["symbols"]} - elif avro_field["type"] == "map": + if avro_field["type"] == "map": if "values" not in avro_field: raise ValueError(f"{field_name} map type does not have a required field values") return { @@ -129,7 +131,7 @@ def _convert_avro_type_to_json( avro_format, "", avro_field["values"] ), } - elif avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration": + if avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration": if "size" not in avro_field: raise ValueError(f"{field_name} fixed type does not have a required field size") if not isinstance(avro_field["size"], int): @@ -138,7 +140,7 @@ def _convert_avro_type_to_json( "type": "string", "pattern": f"^[0-9A-Fa-f]{{{avro_field['size'] * 2}}}$", } - elif avro_field.get("logicalType") == "decimal": + if avro_field.get("logicalType") == "decimal": if "precision" not in avro_field: raise ValueError( f"{field_name} decimal type does not have a required field precision" @@ -154,18 +156,16 @@ def _convert_avro_type_to_json( # For example: ^-?\d{1,5}(?:\.\d{1,3})?$ would accept 12345.123 and 123456.12345 would be rejected return { "type": "string", - "pattern": f"^-?\\d{{{1,max_whole_number_range}}}(?:\\.\\d{1,decimal_range})?$", + "pattern": f"^-?\\d{{{1, max_whole_number_range}}}(?:\\.\\d{1, decimal_range})?$", } - elif "logicalType" in avro_field: + if "logicalType" in avro_field: if avro_field["logicalType"] not in AVRO_LOGICAL_TYPE_TO_JSON: raise ValueError( f"{avro_field['logicalType']} is not a valid Avro logical type" ) return AVRO_LOGICAL_TYPE_TO_JSON[avro_field["logicalType"]] - else: - raise ValueError(f"Unsupported avro type: {avro_field}") - else: raise ValueError(f"Unsupported avro type: {avro_field}") + raise ValueError(f"Unsupported avro type: {avro_field}") def parse_records( self, @@ -173,11 +173,11 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002 + ) -> Iterable[dict[str, Any]]: avro_format = config.format or AvroFormat(filetype="avro") if not isinstance(avro_format, AvroFormat): - raise ValueError(f"Expected ParquetFormat, got {avro_format}") + raise ValueError(f"Expected ParquetFormat, got {avro_format}") # noqa: TRY004 line_no = 0 try: @@ -185,7 +185,7 @@ def parse_records( avro_reader = fastavro.reader(fp) # type: ignore [arg-type] schema = avro_reader.writer_schema schema_field_name_to_type = { - field["name"]: cast(dict[str, Any], field["type"]) # type: ignore [index] + field["name"]: cast(dict[str, Any], field["type"]) # type: ignore [index] # noqa: TC006 for field in schema["fields"] # type: ignore [index, call-overload] # If schema is not dict, it is not subscriptable by strings } for record in avro_reader: @@ -193,7 +193,7 @@ def parse_records( yield { record_field: self._to_output_value( avro_format, - schema_field_name_to_type[record_field], # type: ignore [index] # Any not subscriptable + schema_field_name_to_type[record_field], # type: ignore [index] # Any not subscriptable # noqa: PLR1733 record[record_field], # type: ignore [index] # Any not subscriptable ) for record_field, record_value in schema_field_name_to_type.items() @@ -208,26 +208,27 @@ def file_read_mode(self) -> FileReadMode: return FileReadMode.READ_BINARY @staticmethod - def _to_output_value( - avro_format: AvroFormat, record_type: Mapping[str, Any], record_value: Any - ) -> Any: + def _to_output_value( # noqa: PLR0911 + avro_format: AvroFormat, + record_type: Mapping[str, Any], + record_value: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 if isinstance(record_value, bytes): return record_value.decode() - elif not isinstance(record_type, Mapping): + if not isinstance(record_type, Mapping): if record_type == "double" and avro_format.double_as_string: return str(record_value) return record_value if record_type.get("logicalType") in ("decimal", "uuid"): return str(record_value) - elif record_type.get("logicalType") == "date": + if record_type.get("logicalType") == "date": return record_value.isoformat() - elif record_type.get("logicalType") == "timestamp-millis": + if record_type.get("logicalType") == "timestamp-millis": return record_value.isoformat(sep="T", timespec="milliseconds") - elif record_type.get("logicalType") == "timestamp-micros": + if record_type.get("logicalType") == "timestamp-micros": return record_value.isoformat(sep="T", timespec="microseconds") - elif record_type.get("logicalType") == "local-timestamp-millis": + if record_type.get("logicalType") == "local-timestamp-millis": return record_value.isoformat(sep="T", timespec="milliseconds") - elif record_type.get("logicalType") == "local-timestamp-micros": + if record_type.get("logicalType") == "local-timestamp-micros": return record_value.isoformat(sep="T", timespec="microseconds") - else: - return record_value + return record_value diff --git a/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte_cdk/sources/file_based/file_types/csv_parser.py index e3010690e..cc56e29ee 100644 --- a/airbyte_cdk/sources/file_based/file_types/csv_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/csv_parser.py @@ -7,9 +7,10 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping from functools import partial from io import IOBase -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple +from typing import Any from uuid import uuid4 import orjson @@ -32,6 +33,7 @@ from airbyte_cdk.sources.file_based.schema_helpers import TYPE_PYTHON_MAPPING, SchemaType from airbyte_cdk.utils.traced_exception import AirbyteTracedException + DIALECT_NAME = "_config_dialect" @@ -43,7 +45,7 @@ def read_data( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, file_read_mode: FileReadMode, - ) -> Generator[Dict[str, Any], None, None]: + ) -> Generator[dict[str, Any], None, None]: config_format = _extract_format(config) lineno = 0 @@ -52,7 +54,7 @@ def read_data( # Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up # with a race condition where a thread attempts to use a dialect before a separate thread has finished # registering it. - dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}" + dialect_name = f"{config.name}_{uuid4()!s}_{DIALECT_NAME}" csv.register_dialect( dialect_name, delimiter=config_format.delimiter, @@ -65,7 +67,7 @@ def read_data( try: headers = self._get_headers(fp, config_format, dialect_name) except UnicodeError: - raise AirbyteTracedException( + raise AirbyteTracedException( # noqa: B904 message=f"{FileBasedSourceError.ENCODING_ERROR.value} Expected encoding: {config_format.encoding}", ) @@ -111,7 +113,7 @@ def read_data( # due to RecordParseError or GeneratorExit csv.unregister_dialect(dialect_name) - def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]: + def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> list[str]: """ Assumes the fp is pointing to the beginning of the files and will reset it as such """ @@ -133,7 +135,7 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) fp.seek(0) return headers - def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> List[str]: + def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> list[str]: """ Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True. See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html @@ -154,14 +156,14 @@ def _skip_rows(fp: IOBase, rows_to_skip: int) -> None: class CsvParser(FileTypeParser): _MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000 - def __init__(self, csv_reader: Optional[_CsvReader] = None, csv_field_max_bytes: int = 2**31): + def __init__(self, csv_reader: _CsvReader | None = None, csv_field_max_bytes: int = 2**31): # noqa: ANN204 # Increase the maximum length of data that can be parsed in a single CSV field. The default is 128k, which is typically sufficient # but given the use of Airbyte in loading a large variety of data it is best to allow for a larger maximum field size to avoid # skipping data on load. https://stackoverflow.com/questions/15063936/csv-error-field-larger-than-field-limit-131072 csv.field_size_limit(csv_field_max_bytes) - self._csv_reader = csv_reader if csv_reader else _CsvReader() + self._csv_reader = csv_reader if csv_reader else _CsvReader() # noqa: FURB110 - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002 """ CsvParser does not require config checks, implicit pydantic validation is enough. """ @@ -178,10 +180,10 @@ async def infer_schema( if input_schema: return input_schema - # todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual + # TODO: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual # sources will likely require one. Rather than modify the interface now we can wait until the real use case config_format = _extract_format(config) - type_inferrer_by_field: Dict[str, _TypeInferrer] = defaultdict( + type_inferrer_by_field: dict[str, _TypeInferrer] = defaultdict( lambda: _JsonTypeInferrer( config_format.true_values, config_format.false_values, config_format.null_values ) @@ -221,8 +223,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, + ) -> Iterable[dict[str, Any]]: line_no = 0 try: config_format = _extract_format(config) @@ -263,7 +265,7 @@ def _get_cast_function( deduped_property_types: Mapping[str, str], config_format: CsvFormat, logger: logging.Logger, - schemaless: bool, + schemaless: bool, # noqa: FBT001 ) -> Callable[[Mapping[str, str]], Mapping[str, str]]: # Only cast values if the schema is provided if deduped_property_types and not schemaless: @@ -273,17 +275,16 @@ def _get_cast_function( config_format=config_format, logger=logger, ) - else: - # If no schema is provided, yield the rows as they are - return _no_cast + # If no schema is provided, yield the rows as they are + return _no_cast @staticmethod def _to_nullable( row: Mapping[str, str], deduped_property_types: Mapping[str, str], - null_values: Set[str], - strings_can_be_null: bool, - ) -> Dict[str, Optional[str]]: + null_values: set[str], + strings_can_be_null: bool, # noqa: FBT001 + ) -> dict[str, str | None]: nullable = { k: None if CsvParser._value_is_none( @@ -296,15 +297,15 @@ def _to_nullable( @staticmethod def _value_is_none( - value: Any, - deduped_property_type: Optional[str], - null_values: Set[str], - strings_can_be_null: bool, + value: Any, # noqa: ANN401 + deduped_property_type: str | None, + null_values: set[str], + strings_can_be_null: bool, # noqa: FBT001 ) -> bool: return value in null_values and (strings_can_be_null or deduped_property_type != "string") @staticmethod - def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str, str]: + def _pre_propcess_property_types(property_types: dict[str, Any]) -> Mapping[str, str]: """ Transform the property types to be non-nullable and remove duplicate types if any. Sample input: @@ -335,11 +336,11 @@ def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str, @staticmethod def _cast_types( - row: Dict[str, str], + row: dict[str, str], deduped_property_types: Mapping[str, str], config_format: CsvFormat, logger: logging.Logger, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Casts the values in the input 'row' dictionary according to the types defined in the JSON schema. @@ -358,7 +359,7 @@ def _cast_types( _, python_type = TYPE_PYTHON_MAPPING[prop_type] if python_type is None: - if value == "": + if value == "": # noqa: PLC1901 cast_value = None else: warnings.append(_format_warning(key, value, prop_type)) @@ -401,7 +402,7 @@ def _cast_types( class _TypeInferrer(ABC): @abstractmethod - def add_value(self, value: Any) -> None: + def add_value(self, value: Any) -> None: # noqa: ANN401 pass @abstractmethod @@ -410,7 +411,7 @@ def infer(self) -> str: class _DisabledTypeInferrer(_TypeInferrer): - def add_value(self, value: Any) -> None: + def add_value(self, value: Any) -> None: # noqa: ANN401 pass def infer(self) -> str: @@ -425,14 +426,14 @@ class _JsonTypeInferrer(_TypeInferrer): _STRING_TYPE = "string" def __init__( - self, boolean_trues: Set[str], boolean_falses: Set[str], null_values: Set[str] + self, boolean_trues: set[str], boolean_falses: set[str], null_values: set[str] ) -> None: self._boolean_trues = boolean_trues self._boolean_falses = boolean_falses self._null_values = null_values - self._values: Set[str] = set() + self._values: set[str] = set() - def add_value(self, value: Any) -> None: + def add_value(self, value: Any) -> None: # noqa: ANN401 self._values.add(value) def infer(self) -> str: @@ -447,13 +448,13 @@ def infer(self) -> str: types = set.intersection(*types_excluding_null_values) if self._BOOLEAN_TYPE in types: return self._BOOLEAN_TYPE - elif self._INTEGER_TYPE in types: + if self._INTEGER_TYPE in types: return self._INTEGER_TYPE - elif self._NUMBER_TYPE in types: + if self._NUMBER_TYPE in types: return self._NUMBER_TYPE return self._STRING_TYPE - def _infer_type(self, value: str) -> Set[str]: + def _infer_type(self, value: str) -> set[str]: inferred_types = set() if value in self._null_values: @@ -472,7 +473,7 @@ def _infer_type(self, value: str) -> Set[str]: def _is_boolean(self, value: str) -> bool: try: _value_to_bool(value, self._boolean_trues, self._boolean_falses) - return True + return True # noqa: TRY300 except ValueError: return False @@ -480,7 +481,7 @@ def _is_boolean(self, value: str) -> bool: def _is_integer(value: str) -> bool: try: _value_to_python_type(value, int) - return True + return True # noqa: TRY300 except ValueError: return False @@ -488,12 +489,12 @@ def _is_integer(value: str) -> bool: def _is_number(value: str) -> bool: try: _value_to_python_type(value, float) - return True + return True # noqa: TRY300 except ValueError: return False -def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> bool: +def _value_to_bool(value: str, true_values: set[str], false_values: set[str]) -> bool: if value in true_values: return True if value in false_values: @@ -501,18 +502,18 @@ def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> raise ValueError(f"Value {value} is not a valid boolean value") -def _value_to_list(value: str) -> List[Any]: +def _value_to_list(value: str) -> list[Any]: parsed_value = json.loads(value) if isinstance(parsed_value, list): return parsed_value raise ValueError(f"Value {parsed_value} is not a valid list value") -def _value_to_python_type(value: str, python_type: type) -> Any: +def _value_to_python_type(value: str, python_type: type) -> Any: # noqa: ANN401 return python_type(value) -def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str: +def _format_warning(key: str, value: str, expected_type: Any | None) -> str: # noqa: ANN401 return f"{key}: value={value},expected_type={expected_type}" @@ -523,5 +524,5 @@ def _no_cast(row: Mapping[str, str]) -> Mapping[str, str]: def _extract_format(config: FileBasedStreamConfig) -> CsvFormat: config_format = config.format if not isinstance(config_format, CsvFormat): - raise ValueError(f"Invalid format config: {config_format}") + raise ValueError(f"Invalid format config: {config_format}") # noqa: TRY004 return config_format diff --git a/airbyte_cdk/sources/file_based/file_types/excel_parser.py b/airbyte_cdk/sources/file_based/file_types/excel_parser.py index 5a0332171..dd3117b7c 100644 --- a/airbyte_cdk/sources/file_based/file_types/excel_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/excel_parser.py @@ -3,9 +3,10 @@ # import logging +from collections.abc import Iterable, Mapping from io import IOBase from pathlib import Path -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import Any import orjson import pandas as pd @@ -34,7 +35,7 @@ class ExcelParser(FileTypeParser): ENCODING = None - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002 """ ExcelParser does not require config checks, implicit pydantic validation is enough. """ @@ -63,7 +64,7 @@ async def infer_schema( # Validate the format of the config self.validate_format(config.format, logger) - fields: Dict[str, str] = {} + fields: dict[str, str] = {} with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: df = self.open_and_parse_file(fp) @@ -91,8 +92,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]] = None, - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None = None, # noqa: ARG002 + ) -> Iterable[dict[str, Any]]: """ Parses records from an Excel file based on the provided configuration. @@ -140,7 +141,7 @@ def file_read_mode(self) -> FileReadMode: @staticmethod def dtype_to_json_type( - current_type: Optional[str], + current_type: str | None, dtype: dtype_, # type: ignore [type-arg] ) -> str: """ @@ -183,7 +184,7 @@ def validate_format(excel_format: BaseModel, logger: logging.Logger) -> None: raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) @staticmethod - def open_and_parse_file(fp: Union[IOBase, str, Path]) -> pd.DataFrame: + def open_and_parse_file(fp: IOBase | str | Path) -> pd.DataFrame: """ Opens and parses the Excel file. diff --git a/airbyte_cdk/sources/file_based/file_types/file_transfer.py b/airbyte_cdk/sources/file_based/file_types/file_transfer.py index 154b6ff44..e8ca27e4a 100644 --- a/airbyte_cdk/sources/file_based/file_types/file_transfer.py +++ b/airbyte_cdk/sources/file_based/file_types/file_transfer.py @@ -3,12 +3,14 @@ # import logging import os -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader from airbyte_cdk.sources.file_based.remote_file import RemoteFile + AIRBYTE_STAGING_DIRECTORY = os.getenv("AIRBYTE_STAGING_DIRECTORY", "/staging/files") DEFAULT_LOCAL_DIRECTORY = "/tmp/airbyte-file-transfer" @@ -17,21 +19,21 @@ class FileTransfer: def __init__(self) -> None: self._local_directory = ( AIRBYTE_STAGING_DIRECTORY - if os.path.exists(AIRBYTE_STAGING_DIRECTORY) + if os.path.exists(AIRBYTE_STAGING_DIRECTORY) # noqa: PTH110 else DEFAULT_LOCAL_DIRECTORY ) def get_file( self, - config: FileBasedStreamConfig, + config: FileBasedStreamConfig, # noqa: ARG002 file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - ) -> Iterable[Dict[str, Any]]: + ) -> Iterable[dict[str, Any]]: try: yield stream_reader.get_file( file=file, local_directory=self._local_directory, logger=logger ) except Exception as ex: logger.error("An error has occurred while getting file: %s", str(ex)) - raise ex + raise ex # noqa: TRY201 diff --git a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py index e6a9c5cb1..e27e27870 100644 --- a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py @@ -4,7 +4,8 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.file_based_stream_reader import ( @@ -14,7 +15,8 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import SchemaType -Record = Dict[str, Any] + +Record = dict[str, Any] class FileTypeParser(ABC): @@ -24,27 +26,27 @@ class FileTypeParser(ABC): """ @property - def parser_max_n_files_for_schema_inference(self) -> Optional[int]: + def parser_max_n_files_for_schema_inference(self) -> int | None: """ The discovery policy decides how many files are loaded for schema inference. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used. """ return None @property - def parser_max_n_files_for_parsability(self) -> Optional[int]: + def parser_max_n_files_for_parsability(self) -> int | None: """ The availability policy decides how many files are loaded for checking whether parsing works correctly. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used. """ return None - def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> Optional[str]: + def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> str | None: # noqa: ARG002 """ The parser can define a primary key. If no user-defined primary key is provided, this will be used. """ return None @abstractmethod - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: """ Check whether the config is valid for this file type. If it is, return True and None. If it's not, return False and an error message explaining why it's invalid. """ @@ -70,7 +72,7 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], + discovered_schema: Mapping[str, SchemaType] | None, ) -> Iterable[Record]: """ Parse and emit each record. diff --git a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py index 722ad329b..c7c5f7b4b 100644 --- a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py @@ -4,7 +4,8 @@ import json import logging -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any import orjson @@ -27,7 +28,7 @@ class JsonlParser(FileTypeParser): MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000 ENCODING = "utf8" - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002 """ JsonlParser does not require config checks, implicit pydantic validation is enough. """ @@ -35,7 +36,7 @@ def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[st async def infer_schema( self, - config: FileBasedStreamConfig, + config: FileBasedStreamConfig, # noqa: ARG002 file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, @@ -54,12 +55,12 @@ async def infer_schema( def parse_records( self, - config: FileBasedStreamConfig, + config: FileBasedStreamConfig, # noqa: ARG002 file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002 + ) -> Iterable[dict[str, Any]]: """ This code supports parsing json objects over multiple lines even though this does not align with the JSONL format. This is for backward compatibility reasons i.e. the previous source-s3 parser did support this. The drawback is: @@ -73,7 +74,7 @@ def parse_records( yield from self._parse_jsonl_entries(file, stream_reader, logger) @classmethod - def _infer_schema_for_record(cls, record: Dict[str, Any]) -> Dict[str, Any]: + def _infer_schema_for_record(cls, record: dict[str, Any]) -> dict[str, Any]: record_schema = {} for key, value in record.items(): if value is None: @@ -92,8 +93,8 @@ def _parse_jsonl_entries( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - read_limit: bool = False, - ) -> Iterable[Dict[str, Any]]: + read_limit: bool = False, # noqa: FBT001, FBT002 + ) -> Iterable[dict[str, Any]]: with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: read_bytes = 0 @@ -137,9 +138,9 @@ def _parse_jsonl_entries( FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line ) - @staticmethod - def _instantiate_accumulator(line: Union[bytes, str]) -> Union[bytes, str]: + @staticmethod # noqa: RET503 + def _instantiate_accumulator(line: bytes | str) -> bytes | str: if isinstance(line, bytes): return bytes("", json.detect_encoding(line)) - elif isinstance(line, str): + if isinstance(line, str): # noqa: RET503, RUF100 return "" diff --git a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py index 28cfb14c9..b3391d866 100644 --- a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py @@ -5,7 +5,8 @@ import json import logging import os -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping +from typing import Any from urllib.parse import unquote import pyarrow as pa @@ -33,7 +34,7 @@ class ParquetParser(FileTypeParser): ENCODING = None - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002 """ ParquetParser does not require config checks, implicit pydantic validation is enough. """ @@ -48,7 +49,7 @@ async def infer_schema( ) -> SchemaType: parquet_format = config.format if not isinstance(parquet_format, ParquetFormat): - raise ValueError(f"Expected ParquetFormat, got {parquet_format}") + raise ValueError(f"Expected ParquetFormat, got {parquet_format}") # noqa: TRY004 with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp: parquet_file = pq.ParquetFile(fp) @@ -74,8 +75,8 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: + discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002 + ) -> Iterable[dict[str, Any]]: parquet_format = config.format if not isinstance(parquet_format, ParquetFormat): logger.info(f"Expected ParquetFormat, got {parquet_format}") @@ -109,8 +110,8 @@ def parse_records( ) from exc @staticmethod - def _extract_partitions(filepath: str) -> List[str]: - return [unquote(partition) for partition in filepath.split(os.sep) if "=" in partition] + def _extract_partitions(filepath: str) -> list[str]: + return [unquote(partition) for partition in filepath.split(os.sep) if "=" in partition] # noqa: PTH206 @property def file_read_mode(self) -> FileReadMode: @@ -118,18 +119,17 @@ def file_read_mode(self) -> FileReadMode: @staticmethod def _to_output_value( - parquet_value: Union[Scalar, DictionaryArray], parquet_format: ParquetFormat - ) -> Any: + parquet_value: Scalar | DictionaryArray, parquet_format: ParquetFormat + ) -> Any: # noqa: ANN401 """ Convert an entry in a pyarrow table to a value that can be output by the source. """ if isinstance(parquet_value, DictionaryArray): return ParquetParser._dictionary_array_to_python_value(parquet_value) - else: - return ParquetParser._scalar_to_python_value(parquet_value, parquet_format) + return ParquetParser._scalar_to_python_value(parquet_value, parquet_format) @staticmethod - def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat) -> Any: + def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat) -> Any: # noqa: ANN401, PLR0911 """ Convert a pyarrow scalar to a value that can be output by the source. """ @@ -155,8 +155,7 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat if pa.types.is_decimal(parquet_value.type): if parquet_format.decimal_as_float: return float(parquet_value.as_py()) - else: - return str(parquet_value.as_py()) + return str(parquet_value.as_py()) if pa.types.is_map(parquet_value.type): return {k: v for k, v in parquet_value.as_py()} @@ -170,19 +169,17 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat duration_seconds = duration.total_seconds() if parquet_value.type.unit == "s": return duration_seconds - elif parquet_value.type.unit == "ms": + if parquet_value.type.unit == "ms": return duration_seconds * 1000 - elif parquet_value.type.unit == "us": + if parquet_value.type.unit == "us": return duration_seconds * 1_000_000 - elif parquet_value.type.unit == "ns": + if parquet_value.type.unit == "ns": return duration_seconds * 1_000_000_000 + duration.nanoseconds - else: - raise ValueError(f"Unknown duration unit: {parquet_value.type.unit}") - else: - return parquet_value.as_py() + raise ValueError(f"Unknown duration unit: {parquet_value.type.unit}") + return parquet_value.as_py() @staticmethod - def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[str, Any]: + def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> dict[str, Any]: """ Convert a pyarrow dictionary array to a value that can be output by the source. @@ -196,7 +193,7 @@ def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[st } @staticmethod - def parquet_type_to_schema_type( + def parquet_type_to_schema_type( # noqa: PLR0911 parquet_type: pa.DataType, parquet_format: ParquetFormat ) -> Mapping[str, str]: """ @@ -206,24 +203,23 @@ def parquet_type_to_schema_type( if pa.types.is_timestamp(parquet_type): return {"type": "string", "format": "date-time"} - elif pa.types.is_date(parquet_type): + if pa.types.is_date(parquet_type): return {"type": "string", "format": "date"} - elif ParquetParser._is_string(parquet_type, parquet_format): + if ParquetParser._is_string(parquet_type, parquet_format): return {"type": "string"} - elif pa.types.is_boolean(parquet_type): + if pa.types.is_boolean(parquet_type): return {"type": "boolean"} - elif ParquetParser._is_integer(parquet_type): + if ParquetParser._is_integer(parquet_type): return {"type": "integer"} - elif ParquetParser._is_float(parquet_type, parquet_format): + if ParquetParser._is_float(parquet_type, parquet_format): return {"type": "number"} - elif ParquetParser._is_object(parquet_type): + if ParquetParser._is_object(parquet_type): return {"type": "object"} - elif ParquetParser._is_list(parquet_type): + if ParquetParser._is_list(parquet_type): return {"type": "array"} - elif pa.types.is_null(parquet_type): + if pa.types.is_null(parquet_type): return {"type": "null"} - else: - raise ValueError(f"Unsupported parquet type: {parquet_type}") + raise ValueError(f"Unsupported parquet type: {parquet_type}") @staticmethod def _is_binary(parquet_type: pa.DataType) -> bool: @@ -241,22 +237,20 @@ def _is_integer(parquet_type: pa.DataType) -> bool: def _is_float(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool: if pa.types.is_decimal(parquet_type): return parquet_format.decimal_as_float - else: - return bool(pa.types.is_floating(parquet_type)) + return bool(pa.types.is_floating(parquet_type)) @staticmethod def _is_string(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool: if pa.types.is_decimal(parquet_type): return not parquet_format.decimal_as_float - else: - return bool( - pa.types.is_time(parquet_type) - or pa.types.is_string(parquet_type) - or pa.types.is_large_string(parquet_type) - or ParquetParser._is_binary( - parquet_type - ) # Best we can do is return as a string since we do not support binary - ) + return bool( + pa.types.is_time(parquet_type) + or pa.types.is_string(parquet_type) + or pa.types.is_large_string(parquet_type) + or ParquetParser._is_binary( + parquet_type + ) # Best we can do is return as a string since we do not support binary + ) @staticmethod def _is_object(parquet_type: pa.DataType) -> bool: diff --git a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py index f55675e0a..f1ad81a09 100644 --- a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py @@ -4,9 +4,10 @@ import logging import os import traceback +from collections.abc import Iterable, Mapping from datetime import datetime from io import BytesIO, IOBase -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any import backoff import dpath @@ -39,6 +40,7 @@ from airbyte_cdk.utils import is_cloud_environment from airbyte_cdk.utils.traced_exception import AirbyteTracedException + unstructured_partition_pdf = None unstructured_partition_docx = None unstructured_partition_pptx = None @@ -54,10 +56,10 @@ def get_nltk_temp_folder() -> str: """ try: nltk_data_dir = AIRBYTE_NLTK_DATA_DIR - os.makedirs(nltk_data_dir, exist_ok=True) + os.makedirs(nltk_data_dir, exist_ok=True) # noqa: PTH103 except OSError: nltk_data_dir = TMP_NLTK_DATA_DIR - os.makedirs(nltk_data_dir, exist_ok=True) + os.makedirs(nltk_data_dir, exist_ok=True) # noqa: PTH103 return nltk_data_dir @@ -73,7 +75,7 @@ def get_nltk_temp_folder() -> str: nltk.download("averaged_perceptron_tagger_eng", download_dir=nltk_data_dir, quiet=True) -def optional_decode(contents: Union[str, bytes]) -> str: +def optional_decode(contents: str | bytes) -> str: if isinstance(contents, bytes): return contents.decode("utf-8") return contents @@ -81,12 +83,12 @@ def optional_decode(contents: Union[str, bytes]) -> str: def _import_unstructured() -> None: """Dynamically imported as needed, due to slow import speed.""" - global unstructured_partition_pdf + global unstructured_partition_pdf # noqa: FURB154 global unstructured_partition_docx global unstructured_partition_pptx - from unstructured.partition.docx import partition_docx - from unstructured.partition.pdf import partition_pdf - from unstructured.partition.pptx import partition_pptx + from unstructured.partition.docx import partition_docx # noqa: PLC0415 + from unstructured.partition.pdf import partition_pdf # noqa: PLC0415 + from unstructured.partition.pptx import partition_pptx # noqa: PLC0415 # separate global variables to properly propagate typing unstructured_partition_pdf = partition_pdf @@ -102,7 +104,7 @@ def user_error(e: Exception) -> bool: return False if not isinstance(e, requests.exceptions.RequestException): return False - return bool(e.response and 400 <= e.response.status_code < 500) + return bool(e.response and 400 <= e.response.status_code < 500) # noqa: PLR2004 CLOUD_DEPLOYMENT_MODE = "cloud" @@ -110,20 +112,20 @@ def user_error(e: Exception) -> bool: class UnstructuredParser(FileTypeParser): @property - def parser_max_n_files_for_schema_inference(self) -> Optional[int]: + def parser_max_n_files_for_schema_inference(self) -> int | None: """ Just check one file as the schema is static """ return 1 @property - def parser_max_n_files_for_parsability(self) -> Optional[int]: + def parser_max_n_files_for_parsability(self) -> int | None: """ Do not check any files for parsability because it might be an expensive operation and doesn't give much confidence whether the sync will succeed. """ return 0 - def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> Optional[str]: + def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> str | None: # noqa: ARG002 """ Return the document_key field as the primary key. @@ -138,7 +140,7 @@ async def infer_schema( stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, ) -> SchemaType: - format = _extract_format(config) + format = _extract_format(config) # noqa: A001 with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle: filetype = self._get_filetype(file_handle, file) if filetype not in self._supported_file_types() and not format.skip_unprocessable_files: @@ -168,9 +170,9 @@ def parse_records( file: RemoteFile, stream_reader: AbstractFileBasedStreamReader, logger: logging.Logger, - discovered_schema: Optional[Mapping[str, SchemaType]], - ) -> Iterable[Dict[str, Any]]: - format = _extract_format(config) + discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002 + ) -> Iterable[dict[str, Any]]: + format = _extract_format(config) # noqa: A001 with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle: try: markdown = self._read_file(file_handle, file, format, logger) @@ -193,18 +195,18 @@ def parse_records( } logger.warn(f"File {file.uri} cannot be parsed. Skipping it.") else: - raise e + raise e # noqa: TRY201 except Exception as e: exception_str = str(e) logger.error(f"File {file.uri} caused an error during parsing: {exception_str}.") - raise e + raise e # noqa: TRY201 - def _read_file( + def _read_file( # noqa: RET503 self, file_handle: IOBase, remote_file: RemoteFile, - format: UnstructuredFormat, - logger: logging.Logger, + format: UnstructuredFormat, # noqa: A002 + logger: logging.Logger, # noqa: ARG002 ) -> str: _import_unstructured() if ( @@ -213,7 +215,7 @@ def _read_file( or (not unstructured_partition_pptx) ): # check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point) - raise Exception("unstructured library is not available") + raise Exception("unstructured library is not available") # noqa: TRY002 filetype: FileType | None = self._get_filetype(file_handle, remote_file) @@ -233,7 +235,7 @@ def _read_file( format.strategy, remote_file, ) - elif format.processing.mode == "api": + if format.processing.mode == "api": # noqa: RET503, RUF100 try: result: str = self._read_file_remotely_with_retries( file_handle, @@ -248,17 +250,17 @@ def _read_file( # For other exceptions, re-throw as config error so the sync is stopped as problems with the external API need to be resolved by the user and are not considered part of the SLA. # Once this parser leaves experimental stage, we should consider making this a system error instead for issues that might be transient. if isinstance(e, RecordParseError): - raise e - raise AirbyteTracedException.from_exception( + raise e # noqa: TRY201 + raise AirbyteTracedException.from_exception( # noqa: B904 e, failure_type=FailureType.config_error ) return result def _params_to_dict( - self, params: Optional[List[APIParameterConfigModel]], strategy: str - ) -> Dict[str, Union[str, List[str]]]: - result_dict: Dict[str, Union[str, List[str]]] = {"strategy": strategy} + self, params: list[APIParameterConfigModel] | None, strategy: str + ) -> dict[str, str | list[str]]: + result_dict: dict[str, str | list[str]] = {"strategy": strategy} if params is None: return result_dict for item in params: @@ -277,7 +279,7 @@ def _params_to_dict( return result_dict - def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]: + def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: """ Perform a connection check for the parser config: - Verify that encryption is enabled if the API is hosted on a cloud instance. @@ -313,7 +315,7 @@ def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[st def _read_file_remotely_with_retries( self, file_handle: IOBase, - format: APIProcessingConfigModel, + format: APIProcessingConfigModel, # noqa: A002 filetype: FileType, strategy: str, remote_file: RemoteFile, @@ -326,7 +328,7 @@ def _read_file_remotely_with_retries( def _read_file_remotely( self, file_handle: IOBase, - format: APIProcessingConfigModel, + format: APIProcessingConfigModel, # noqa: A002 filetype: FileType, strategy: str, remote_file: RemoteFile, @@ -341,12 +343,11 @@ def _read_file_remotely( f"{format.api_url}/general/v0/general", headers=headers, data=data, files=file_data ) - if response.status_code == 422: + if response.status_code == 422: # noqa: PLR2004 # 422 means the file couldn't be processed, but the API is working. Treat this as a parsing error (passing an error record to the destination). raise self._create_parse_error(remote_file, response.json()) - else: - # Other error statuses are raised as requests exceptions (retry everything except user errors) - response.raise_for_status() + # Other error statuses are raised as requests exceptions (retry everything except user errors) + response.raise_for_status() json_response = response.json() @@ -362,7 +363,7 @@ def _read_file_locally( or (not unstructured_partition_pptx) ): # check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point) - raise Exception("unstructured library is not available") + raise Exception("unstructured library is not available") # noqa: TRY002 file: Any = file_handle @@ -383,7 +384,7 @@ def _read_file_locally( elif filetype == FileType.PPTX: elements = unstructured_partition_pptx(file=file) except Exception as e: - raise self._create_parse_error(remote_file, str(e)) + raise self._create_parse_error(remote_file, str(e)) # noqa: B904 return self._render_markdown([element.to_dict() for element in elements]) @@ -396,7 +397,7 @@ def _create_parse_error( FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri, message=message ) - def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileType]: + def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> FileType | None: """ Detect the file type based on the file name and the file content. @@ -416,7 +417,7 @@ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileT # if possible, try to leverage the file name to detect the file type # if the file name is not available, use the file content file_type: FileType | None = None - try: + try: # noqa: SIM105 file_type = detect_filetype( filename=remote_file.uri, ) @@ -439,34 +440,33 @@ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileT return None - def _supported_file_types(self) -> List[Any]: + def _supported_file_types(self) -> list[Any]: return [FileType.MD, FileType.PDF, FileType.DOCX, FileType.PPTX, FileType.TXT] def _get_file_type_error_message( self, file_type: FileType | None, ) -> str: - supported_file_types = ", ".join([str(type) for type in self._supported_file_types()]) + supported_file_types = ", ".join([str(type) for type in self._supported_file_types()]) # noqa: A001 return f"File type {file_type or 'None'!s} is not supported. Supported file types are {supported_file_types}" - def _render_markdown(self, elements: List[Any]) -> str: - return "\n\n".join((self._convert_to_markdown(el) for el in elements)) + def _render_markdown(self, elements: list[Any]) -> str: + return "\n\n".join(self._convert_to_markdown(el) for el in elements) - def _convert_to_markdown(self, el: Dict[str, Any]) -> str: + def _convert_to_markdown(self, el: dict[str, Any]) -> str: if dpath.get(el, "type") == "Title": category_depth = dpath.get(el, "metadata/category_depth", default=1) or 1 if not isinstance(category_depth, int): category_depth = ( - int(category_depth) if isinstance(category_depth, (str, float)) else 1 + int(category_depth) if isinstance(category_depth, (str, float)) else 1 # noqa: UP038 ) heading_str = "#" * category_depth return f"{heading_str} {dpath.get(el, 'text')}" - elif dpath.get(el, "type") == "ListItem": + if dpath.get(el, "type") == "ListItem": return f"- {dpath.get(el, 'text')}" - elif dpath.get(el, "type") == "Formula": + if dpath.get(el, "type") == "Formula": return f"```\n{dpath.get(el, 'text')}\n```" - else: - return str(dpath.get(el, "text", default="")) + return str(dpath.get(el, "text", default="")) @property def file_read_mode(self) -> FileReadMode: @@ -476,5 +476,5 @@ def file_read_mode(self) -> FileReadMode: def _extract_format(config: FileBasedStreamConfig) -> UnstructuredFormat: config_format = config.format if not isinstance(config_format, UnstructuredFormat): - raise ValueError(f"Invalid format config: {config_format}") + raise ValueError(f"Invalid format config: {config_format}") # noqa: TRY004 return config_format diff --git a/airbyte_cdk/sources/file_based/remote_file.py b/airbyte_cdk/sources/file_based/remote_file.py index 0197a35fd..48d4e2513 100644 --- a/airbyte_cdk/sources/file_based/remote_file.py +++ b/airbyte_cdk/sources/file_based/remote_file.py @@ -3,7 +3,6 @@ # from datetime import datetime -from typing import Optional from pydantic.v1 import BaseModel @@ -15,4 +14,4 @@ class RemoteFile(BaseModel): uri: str last_modified: datetime - mime_type: Optional[str] = None + mime_type: str | None = None diff --git a/airbyte_cdk/sources/file_based/schema_helpers.py b/airbyte_cdk/sources/file_based/schema_helpers.py index 1b653db67..98ff86ac8 100644 --- a/airbyte_cdk/sources/file_based/schema_helpers.py +++ b/airbyte_cdk/sources/file_based/schema_helpers.py @@ -3,10 +3,11 @@ # import json +from collections.abc import Mapping from copy import deepcopy from enum import Enum from functools import total_ordering -from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union +from typing import Any, Literal, Union from airbyte_cdk.sources.file_based.exceptions import ( ConfigValidationError, @@ -14,7 +15,8 @@ SchemaInferenceError, ) -JsonSchemaSupportedType = Union[List[str], Literal["string"], str] + +JsonSchemaSupportedType = Union[list[str], Literal["string"], str] # noqa: PYI051, UP007 SchemaType = Mapping[str, Mapping[str, JsonSchemaSupportedType]] schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}} @@ -33,14 +35,13 @@ class ComparableType(Enum): STRING = 4 OBJECT = 5 - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: Any) -> bool: # noqa: ANN401 if self.__class__ is other.__class__: return self.value < other.value # type: ignore - else: - return NotImplemented + return NotImplemented -TYPE_PYTHON_MAPPING: Mapping[str, Tuple[str, Optional[Type[Any]]]] = { +TYPE_PYTHON_MAPPING: Mapping[str, tuple[str, type[Any] | None]] = { "null": ("null", None), "array": ("array", list), "boolean": ("boolean", bool), @@ -53,7 +54,7 @@ def __lt__(self, other: Any) -> bool: PYTHON_TYPE_MAPPING = {t: k for k, (_, t) in TYPE_PYTHON_MAPPING.items()} -def get_comparable_type(value: Any) -> Optional[ComparableType]: +def get_comparable_type(value: Any) -> ComparableType | None: # noqa: ANN401, PLR0911 if value == "null": return ComparableType.NULL if value == "boolean": @@ -66,11 +67,10 @@ def get_comparable_type(value: Any) -> Optional[ComparableType]: return ComparableType.STRING if value == "object": return ComparableType.OBJECT - else: - return None + return None -def get_inferred_type(value: Any) -> Optional[ComparableType]: +def get_inferred_type(value: Any) -> ComparableType | None: # noqa: ANN401, PLR0911 if value is None: return ComparableType.NULL if isinstance(value, bool): @@ -83,8 +83,7 @@ def get_inferred_type(value: Any) -> Optional[ComparableType]: return ComparableType.STRING if isinstance(value, dict): return ComparableType.OBJECT - else: - return None + return None def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType: @@ -107,7 +106,7 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType: if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]): raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t) - merged_schema: Dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping + merged_schema: dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping for k2, t2 in schema2.items(): t1 = merged_schema.get(k2) if t1 is None: @@ -136,7 +135,7 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) - detected_types=f"{t1},{t2}", ) # Schemas can still be merged if a key contains a null value in either t1 or t2, but it is still an object - elif ( + if ( (t1_type == "object" or t2_type == "object") and t1_type != "null" and t2_type != "null" @@ -148,24 +147,23 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) - key=key, detected_types=f"{t1},{t2}", ) - else: - comparable_t1 = get_comparable_type( - TYPE_PYTHON_MAPPING[t1_type][0] - ) # accessing the type_mapping value - comparable_t2 = get_comparable_type( - TYPE_PYTHON_MAPPING[t2_type][0] - ) # accessing the type_mapping value - if not comparable_t1 and comparable_t2: - raise SchemaInferenceError( - FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}" - ) - return max( - [t1, t2], - key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])), - ) # accessing the type_mapping value + comparable_t1 = get_comparable_type( + TYPE_PYTHON_MAPPING[t1_type][0] + ) # accessing the type_mapping value + comparable_t2 = get_comparable_type( + TYPE_PYTHON_MAPPING[t2_type][0] + ) # accessing the type_mapping value + if not comparable_t1 and comparable_t2: + raise SchemaInferenceError( + FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}" + ) + return max( + [t1, t2], + key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])), + ) # accessing the type_mapping value -def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool: +def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool: # noqa: ANN401 if isinstance(value, list): # We do not compare lists directly; the individual items are compared. # If we hit this condition, it means that the expected type is not @@ -180,7 +178,7 @@ def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool: return ComparableType(inferred_type) <= ComparableType(get_comparable_type(expected_type)) -def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool: +def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool: # noqa: PLR0911 """ Return true iff the record conforms to the supplied schema. @@ -202,9 +200,9 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> if value is not None: if isinstance(expected_type, list): return any(is_equal_or_narrower_type(value, e) for e in expected_type) - elif expected_type == "object": + if expected_type == "object": return isinstance(value, dict) - elif expected_type == "array": + if expected_type == "array": if not isinstance(value, list): return False array_type = definition.get("items", {}).get("type") @@ -216,7 +214,7 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> return True -def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[Mapping[str, str]]: +def _parse_json_input(input_schema: str | Mapping[str, str]) -> Mapping[str, str] | None: try: if isinstance(input_schema, str): schema: Mapping[str, str] = json.loads(input_schema) @@ -235,8 +233,8 @@ def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[M def type_mapping_to_jsonschema( - input_schema: Optional[Union[str, Mapping[str, str]]], -) -> Optional[Mapping[str, Any]]: + input_schema: str | Mapping[str, str] | None, +) -> Mapping[str, Any] | None: """ Return the user input schema (type mapping), transformed to JSON Schema format. @@ -252,14 +250,14 @@ def type_mapping_to_jsonschema( json_mapping = _parse_json_input(input_schema) or {} for col_name, type_name in json_mapping.items(): - col_name, type_name = col_name.strip(), type_name.strip() + col_name, type_name = col_name.strip(), type_name.strip() # noqa: PLW2901 if not (col_name and type_name): raise ConfigValidationError( FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, details=f"Invalid input schema; expected mapping in the format column_name: type, got {input_schema}.", ) - _json_schema_type = TYPE_PYTHON_MAPPING.get(type_name.casefold()) + _json_schema_type = TYPE_PYTHON_MAPPING.get(type_name.casefold()) # noqa: RUF052 if not _json_schema_type: raise ConfigValidationError( diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py b/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py index e687bd5b3..e044bd709 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py @@ -8,6 +8,7 @@ WaitForDiscoverPolicy, ) + __all__ = [ "DEFAULT_SCHEMA_VALIDATION_POLICIES", "AbstractSchemaValidationPolicy", diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py index 139511a98..be3d47005 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py @@ -3,7 +3,8 @@ # from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any class AbstractSchemaValidationPolicy(ABC): @@ -12,9 +13,9 @@ class AbstractSchemaValidationPolicy(ABC): @abstractmethod def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: """ Return True if the record passes the user's validation policy. """ - raise NotImplementedError() + raise NotImplementedError diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py index 261b0fabd..0e4821e29 100644 --- a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py +++ b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py @@ -2,7 +2,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import ValidationPolicy from airbyte_cdk.sources.file_based.exceptions import ( @@ -17,7 +18,9 @@ class EmitRecordPolicy(AbstractSchemaValidationPolicy): name = "emit_record" def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, + record: Mapping[str, Any], # noqa: ARG002 + schema: Mapping[str, Any] | None, # noqa: ARG002 ) -> bool: return True @@ -26,7 +29,7 @@ class SkipRecordPolicy(AbstractSchemaValidationPolicy): name = "skip_record" def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: return schema is not None and conforms_to_schema(record, schema) @@ -36,7 +39,7 @@ class WaitForDiscoverPolicy(AbstractSchemaValidationPolicy): validate_schema_before_sync = True def record_passes_validation_policy( - self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]] + self, record: Mapping[str, Any], schema: Mapping[str, Any] | None ) -> bool: if schema is None or not conforms_to_schema(record, schema): raise StopSyncPerValidationPolicy( diff --git a/airbyte_cdk/sources/file_based/stream/__init__.py b/airbyte_cdk/sources/file_based/stream/__init__.py index 4b5c4bc2e..07e5544f1 100644 --- a/airbyte_cdk/sources/file_based/stream/__init__.py +++ b/airbyte_cdk/sources/file_based/stream/__init__.py @@ -1,4 +1,7 @@ -from airbyte_cdk.sources.file_based.stream.abstract_file_based_stream import AbstractFileBasedStream +from airbyte_cdk.sources.file_based.stream.abstract_file_based_stream import ( + AbstractFileBasedStream, +) from airbyte_cdk.sources.file_based.stream.default_file_based_stream import DefaultFileBasedStream + __all__ = ["AbstractFileBasedStream", "DefaultFileBasedStream"] diff --git a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py index ef258b34d..b69461531 100644 --- a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py +++ b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py @@ -3,8 +3,9 @@ # from abc import abstractmethod -from functools import cache, cached_property, lru_cache -from typing import Any, Dict, Iterable, List, Mapping, Optional, Type +from collections.abc import Iterable, Mapping +from functools import cache, cached_property +from typing import Any from typing_extensions import deprecated @@ -50,14 +51,14 @@ class AbstractFileBasedStream(Stream): by the stream. """ - def __init__( + def __init__( # noqa: ANN204, PLR0913, PLR0917 self, config: FileBasedStreamConfig, - catalog_schema: Optional[Mapping[str, Any]], + catalog_schema: Mapping[str, Any] | None, stream_reader: AbstractFileBasedStreamReader, availability_strategy: AbstractFileBasedAvailabilityStrategy, discovery_policy: AbstractDiscoveryPolicy, - parsers: Dict[Type[Any], FileTypeParser], + parsers: dict[type[Any], FileTypeParser], validation_policy: AbstractSchemaValidationPolicy, errors_collector: FileBasedErrorsCollector, cursor: AbstractFileBasedCursor, @@ -77,8 +78,8 @@ def __init__( @abstractmethod def primary_key(self) -> PrimaryKeyType: ... - @cache - def list_files(self) -> List[RemoteFile]: + @cache # noqa: B019 + def list_files(self) -> list[RemoteFile]: """ List all files that belong to the stream. @@ -97,10 +98,10 @@ def get_files(self) -> Iterable[RemoteFile]: def read_records( self, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[StreamSlice] = None, - stream_state: Optional[Mapping[str, Any]] = None, + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_slice: StreamSlice | None = None, + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Iterable[Mapping[str, Any] | AirbyteMessage]: """ Yield all records from all remote files in `list_files_for_this_sync`. @@ -123,10 +124,10 @@ def read_records_from_slice( def stream_slices( self, *, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Iterable[Mapping[str, Any] | None]: """ This method acts as an adapter between the generic Stream interface and the file-based's stream since file-based streams manage their own states. @@ -134,7 +135,7 @@ def stream_slices( return self.compute_slices() @abstractmethod - def compute_slices(self) -> Iterable[Optional[StreamSlice]]: + def compute_slices(self) -> Iterable[StreamSlice | None]: """ Return a list of slices that will be used to read files in the current sync. :return: The slices to use for the current sync. @@ -142,7 +143,7 @@ def compute_slices(self) -> Iterable[Optional[StreamSlice]]: ... @abstractmethod - @lru_cache(maxsize=None) + @cache # noqa: B019 def get_json_schema(self) -> Mapping[str, Any]: """ Return the JSON Schema for a stream. @@ -150,7 +151,7 @@ def get_json_schema(self) -> Mapping[str, Any]: ... @abstractmethod - def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: """ Infer the schema for files in the stream. """ @@ -160,7 +161,7 @@ def get_parser(self) -> FileTypeParser: try: return self._parsers[type(self.config.format)] except KeyError: - raise UndefinedParserError( + raise UndefinedParserError( # noqa: B904 FileBasedSourceError.UNDEFINED_PARSER, stream=self.name, format=type(self.config.format), @@ -171,12 +172,11 @@ def record_passes_validation_policy(self, record: Mapping[str, Any]) -> bool: return self.validation_policy.record_passes_validation_policy( record=record, schema=self.catalog_schema ) - else: - raise RecordParseError( - FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, - stream=self.name, - validation_policy=self.config.validation_policy, - ) + raise RecordParseError( + FileBasedSourceError.UNDEFINED_VALIDATION_POLICY, + stream=self.name, + validation_policy=self.config.validation_policy, + ) @cached_property @deprecated("Deprecated as of CDK version 3.7.0.") @@ -187,7 +187,7 @@ def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy: def name(self) -> str: return self.config.name - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: """ This is a temporary hack. Because file-based, declarative, and concurrent have _slightly_ different cursor implementations the file-based cursor isn't compatible with the cursor-based iteration flow in core.py top-level CDK. By setting this to diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py index fb0efc82c..1f18e3314 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py @@ -4,8 +4,9 @@ import copy import logging -from functools import cache, lru_cache -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from collections.abc import Iterable, Mapping, MutableMapping +from functools import cache +from typing import TYPE_CHECKING, Any from typing_extensions import deprecated @@ -46,6 +47,7 @@ from airbyte_cdk.sources.utils.schema_helpers import InternalConfig from airbyte_cdk.sources.utils.slice_logger import SliceLogger + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.cursor import ( AbstractConcurrentFileBasedCursor, @@ -67,7 +69,7 @@ def create_from_stream( stream: AbstractFileBasedStream, source: AbstractSource, logger: logging.Logger, - state: Optional[MutableMapping[str, Any]], + state: MutableMapping[str, Any] | None, cursor: "AbstractConcurrentFileBasedCursor", ) -> "FileBasedStreamFacade": """ @@ -75,7 +77,7 @@ def create_from_stream( """ pk = get_primary_key_from_stream(stream.primary_key) cursor_field = get_cursor_field_from_stream(stream) - stream._cursor = cursor + stream._cursor = cursor # noqa: SLF001 if not source.message_repository: raise ValueError( @@ -107,10 +109,10 @@ def create_from_stream( stream, cursor, logger=logger, - slice_logger=source._slice_logger, + slice_logger=source._slice_logger, # noqa: SLF001 ) - def __init__( + def __init__( # noqa: ANN204 self, stream: DefaultStream, legacy_stream: AbstractFileBasedStream, @@ -131,11 +133,10 @@ def __init__( self.validation_policy = legacy_stream.validation_policy @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: if self._abstract_stream.cursor_field is None: return [] - else: - return self._abstract_stream.cursor_field + return self._abstract_stream.cursor_field @property def name(self) -> str: @@ -150,7 +151,7 @@ def supports_incremental(self) -> bool: def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy: return self._legacy_stream.availability_strategy - @lru_cache(maxsize=None) + @cache # noqa: B019 def get_json_schema(self) -> Mapping[str, Any]: return self._abstract_stream.get_json_schema() @@ -170,10 +171,10 @@ def get_files(self) -> Iterable[RemoteFile]: def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any]]: yield from self._legacy_stream.read_records_from_slice(stream_slice) # type: ignore[misc] # Only Mapping[str, Any] is expected for legacy streams, not AirbyteMessage - def compute_slices(self) -> Iterable[Optional[StreamSlice]]: + def compute_slices(self) -> Iterable[StreamSlice | None]: return self._legacy_stream.compute_slices() - def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: return self._legacy_stream.infer_schema(files) def get_underlying_stream(self) -> DefaultStream: @@ -181,21 +182,21 @@ def get_underlying_stream(self) -> DefaultStream: def read( self, - configured_stream: ConfiguredAirbyteStream, - logger: logging.Logger, - slice_logger: SliceLogger, - stream_state: MutableMapping[str, Any], - state_manager: ConnectorStateManager, - internal_config: InternalConfig, + configured_stream: ConfiguredAirbyteStream, # noqa: ARG002 + logger: logging.Logger, # noqa: ARG002 + slice_logger: SliceLogger, # noqa: ARG002 + stream_state: MutableMapping[str, Any], # noqa: ARG002 + state_manager: ConnectorStateManager, # noqa: ARG002 + internal_config: InternalConfig, # noqa: ARG002 ) -> Iterable[StreamData]: yield from self._read_records() def read_records( self, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Iterable[StreamData]: try: yield from self._read_records() @@ -211,7 +212,7 @@ def read_records( level=Level.ERROR, message=f"Cursor State at time of exception: {state}" ), ) - raise exc + raise exc # noqa: TRY201 def _read_records(self) -> Iterable[StreamData]: for partition in self._abstract_stream.generate_partitions(): @@ -222,14 +223,14 @@ def _read_records(self) -> Iterable[StreamData]: class FileBasedStreamPartition(Partition): - def __init__( + def __init__( # noqa: ANN204 self, stream: AbstractFileBasedStream, - _slice: Optional[Mapping[str, Any]], + _slice: Mapping[str, Any] | None, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, ): self._stream = stream self._slice = _slice @@ -265,7 +266,7 @@ def read(self) -> Iterable[Record]: else record_data.record.data ) if not record_message_data: - raise ExceptionWithDisplayMessage("A record without data was found") + raise ExceptionWithDisplayMessage("A record without data was found") # noqa: TRY301 else: yield Record( data=record_message_data, @@ -279,14 +280,14 @@ def read(self) -> Iterable[Record]: if display_message: raise ExceptionWithDisplayMessage(display_message) from e else: - raise e + raise e # noqa: TRY201 - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: if self._slice is None: return None - assert ( - len(self._slice["files"]) == 1 - ), f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}" + assert len(self._slice["files"]) == 1, ( + f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}" + ) file = self._slice["files"][0] return {"files": [file]} @@ -297,16 +298,14 @@ def __hash__(self) -> int: raise ValueError( f"Slices for file-based streams should be of length 1, but got {len(self._slice['files'])}. This is unexpected. Please contact Support." ) - else: - s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}" + s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}" return hash((self._stream.name, s)) - else: - return hash(self._stream.name) + return hash(self._stream.name) def stream_name(self) -> str: return self._stream.name - @cache + @cache # noqa: B019 def _use_file_transfer(self) -> bool: return hasattr(self._stream, "use_file_transfer") and self._stream.use_file_transfer @@ -315,13 +314,13 @@ def __repr__(self) -> str: class FileBasedStreamPartitionGenerator(PartitionGenerator): - def __init__( + def __init__( # noqa: ANN204 self, stream: AbstractFileBasedStream, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, cursor: "AbstractConcurrentFileBasedCursor", ): self._stream = stream @@ -338,7 +337,7 @@ def generate(self) -> Iterable[FileBasedStreamPartition]: ): if _slice is not None: for file in _slice.get("files", []): - pending_partitions.append( + pending_partitions.append( # noqa: PERF401 FileBasedStreamPartition( self._stream, {"files": [copy.deepcopy(file)]}, diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py index 089cae0ad..6498e5c8c 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py @@ -2,6 +2,7 @@ from .file_based_concurrent_cursor import FileBasedConcurrentCursor from .file_based_final_state_cursor import FileBasedFinalStateCursor + __all__ = [ "AbstractConcurrentFileBasedCursor", "FileBasedConcurrentCursor", diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py index 5c30fda4a..988134600 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py @@ -4,8 +4,9 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, MutableMapping from datetime import datetime -from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping +from typing import TYPE_CHECKING, Any from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor @@ -14,12 +15,13 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.types import Record + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition class AbstractConcurrentFileBasedCursor(Cursor, AbstractFileBasedCursor, ABC): - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 pass @property @@ -33,7 +35,7 @@ def observe(self, record: Record) -> None: ... def close_partition(self, partition: Partition) -> None: ... @abstractmethod - def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: ... + def set_pending_partitions(self, partitions: list["FileBasedStreamPartition"]) -> None: ... @abstractmethod def add_file(self, file: RemoteFile) -> None: ... diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py index a70169197..e143dda89 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py @@ -3,9 +3,10 @@ # import logging +from collections.abc import Iterable, MutableMapping from datetime import datetime, timedelta from threading import RLock -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import TYPE_CHECKING, Any from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager @@ -21,6 +22,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.types import Record + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition @@ -34,14 +36,14 @@ class FileBasedConcurrentCursor(AbstractConcurrentFileBasedCursor): ) DEFAULT_MAX_HISTORY_SIZE = 10_000 DATE_TIME_FORMAT = DefaultFileBasedCursor.DATE_TIME_FORMAT - zero_value = datetime.min + zero_value = datetime.min # noqa: DTZ901 zero_cursor_value = f"0001-01-01T00:00:00.000000Z_{_NULL_FILE}" def __init__( self, stream_config: FileBasedStreamConfig, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, stream_state: MutableMapping[str, Any], message_repository: MessageRepository, connector_state_manager: ConnectorStateManager, @@ -60,7 +62,7 @@ def __init__( ) self._state_lock = RLock() self._pending_files_lock = RLock() - self._pending_files: Optional[Dict[str, RemoteFile]] = None + self._pending_files: dict[str, RemoteFile] | None = None self._file_to_datetime_history = stream_state.get("history", {}) if stream_state else {} self._prev_cursor_value = self._compute_prev_sync_cursor(stream_state) self._sync_start = self._compute_start_time() @@ -72,28 +74,28 @@ def state(self) -> MutableMapping[str, Any]: def observe(self, record: Record) -> None: pass - def close_partition(self, partition: Partition) -> None: + def close_partition(self, partition: Partition) -> None: # noqa: ARG002 with self._pending_files_lock: if self._pending_files is None: raise RuntimeError( "Expected pending partitions to be set but it was not. This is unexpected. Please contact Support." ) - def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: + def set_pending_partitions(self, partitions: list["FileBasedStreamPartition"]) -> None: with self._pending_files_lock: self._pending_files = {} for partition in partitions: - _slice = partition.to_slice() + _slice = partition.to_slice() # noqa: RUF052 if _slice is None: continue for file in _slice["files"]: - if file.uri in self._pending_files.keys(): + if file.uri in self._pending_files.keys(): # noqa: SIM118 raise RuntimeError( f"Already found file {_slice} in pending files. This is unexpected. Please contact Support." ) self._pending_files.update({file.uri: file}) - def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datetime, str]: + def _compute_prev_sync_cursor(self, value: StreamState | None) -> tuple[datetime, str]: if not value: return self.zero_value, "" prev_cursor_str = value.get(self._cursor_field.cursor_field_key) or self.zero_cursor_value @@ -112,23 +114,23 @@ def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datet cursor_dt, cursor_uri = cursor_str.split("_", 1) return datetime.strptime(cursor_dt, self.DATE_TIME_FORMAT), cursor_uri - def _get_cursor_key_from_file(self, file: Optional[RemoteFile]) -> str: + def _get_cursor_key_from_file(self, file: RemoteFile | None) -> str: if file: return f"{datetime.strftime(file.last_modified, self.DATE_TIME_FORMAT)}_{file.uri}" return self.zero_cursor_value - def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: + def _compute_earliest_file_in_history(self) -> RemoteFile | None: with self._state_lock: if self._file_to_datetime_history: filename, last_modified = min( - self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) + self._file_to_datetime_history.items(), + key=lambda f: (f[1], f[0]), # noqa: FURB118 ) return RemoteFile( uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT), ) - else: - return None + return None def add_file(self, file: RemoteFile) -> None: """ @@ -139,7 +141,7 @@ def add_file(self, file: RemoteFile) -> None: raise RuntimeError( "Expected pending partitions to be set but it was not. This is unexpected. Please contact Support." ) - with self._pending_files_lock: + with self._pending_files_lock: # noqa: SIM117 with self._state_lock: if file.uri not in self._pending_files: self._message_repository.emit_message( @@ -162,7 +164,7 @@ def add_file(self, file: RemoteFile) -> None: if oldest_file: del self._file_to_datetime_history[oldest_file.uri] else: - raise Exception( + raise Exception( # noqa: TRY002 "The history is full but there is no files in the history. This should never happen and might be indicative of a bug in the CDK." ) self.emit_state_message() @@ -181,7 +183,7 @@ def emit_state_message(self) -> None: self._message_repository.emit_message(state_message) def _get_new_cursor_value(self) -> str: - with self._pending_files_lock: + with self._pending_files_lock: # noqa: SIM117 with self._state_lock: if self._pending_files: # If there are partitions that haven't been synced, we don't know whether the files that have been synced @@ -189,31 +191,29 @@ def _get_new_cursor_value(self) -> str: # To avoid missing files, we only increment the cursor up to the oldest pending file, because we know # that all older files have been synced. return self._get_cursor_key_from_file(self._compute_earliest_pending_file()) - elif self._file_to_datetime_history: + if self._file_to_datetime_history: # If all partitions have been synced, we know that the sync is up-to-date and so can advance # the cursor to the newest file in history. return self._get_cursor_key_from_file(self._compute_latest_file_in_history()) - else: - return f"{self.zero_value.strftime(self.DATE_TIME_FORMAT)}_" + return f"{self.zero_value.strftime(self.DATE_TIME_FORMAT)}_" - def _compute_earliest_pending_file(self) -> Optional[RemoteFile]: + def _compute_earliest_pending_file(self) -> RemoteFile | None: if self._pending_files: return min(self._pending_files.values(), key=lambda x: x.last_modified) - else: - return None + return None - def _compute_latest_file_in_history(self) -> Optional[RemoteFile]: + def _compute_latest_file_in_history(self) -> RemoteFile | None: with self._state_lock: if self._file_to_datetime_history: filename, last_modified = max( - self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) + self._file_to_datetime_history.items(), + key=lambda f: (f[1], f[0]), # noqa: FURB118 ) return RemoteFile( uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT), ) - else: - return None + return None def get_files_to_sync( self, all_files: Iterable[RemoteFile], logger: logging.Logger @@ -235,7 +235,7 @@ def get_files_to_sync( if self._should_sync_file(f, logger): yield f - def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: + def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # noqa: ARG002 with self._state_lock: if file.uri in self._file_to_datetime_history: # If the file's uri is in the history, we should sync the file if it has been modified since it was synced @@ -253,23 +253,20 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: ) ) return False - else: - return file.last_modified > updated_at_from_history + return file.last_modified > updated_at_from_history prev_cursor_timestamp, prev_cursor_uri = self._prev_cursor_value if self._is_history_full(): if file.last_modified > prev_cursor_timestamp: # If the history is partial and the file's datetime is strictly greater than the cursor, we should sync it return True - elif file.last_modified == prev_cursor_timestamp: + if file.last_modified == prev_cursor_timestamp: # If the history is partial and the file's datetime is equal to the earliest file in the history, # we should sync it if its uri is greater than or equal to the cursor value. return file.uri > prev_cursor_uri - else: - return file.last_modified >= self._sync_start - else: - # The file is not in the history and the history is complete. We know we need to sync the file - return True + return file.last_modified >= self._sync_start + # The file is not in the history and the history is complete. We know we need to sync the file + return True def _is_history_full(self) -> bool: """ @@ -284,14 +281,13 @@ def _is_history_full(self) -> bool: def _compute_start_time(self) -> datetime: if not self._file_to_datetime_history: - return datetime.min - else: - earliest = min(self._file_to_datetime_history.values()) - earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) - if self._is_history_full(): - time_window = datetime.now() - self._time_window_if_history_is_full - earliest_dt = min(earliest_dt, time_window) - return earliest_dt + return datetime.min # noqa: DTZ901 + earliest = min(self._file_to_datetime_history.values()) + earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) + if self._is_history_full(): + time_window = datetime.now() - self._time_window_if_history_is_full + earliest_dt = min(earliest_dt, time_window) + return earliest_dt def get_start_time(self) -> datetime: return self._sync_start diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py index e219292d1..04c308b3e 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py @@ -3,8 +3,9 @@ # import logging +from collections.abc import Iterable, MutableMapping from datetime import datetime -from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping, Optional +from typing import TYPE_CHECKING, Any from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig @@ -18,6 +19,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.types import Record + if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition @@ -25,12 +27,12 @@ class FileBasedFinalStateCursor(AbstractConcurrentFileBasedCursor): """Cursor that is used to guarantee at least one state message is emitted for a concurrent file-based stream.""" - def __init__( + def __init__( # noqa: ANN204 self, stream_config: FileBasedStreamConfig, message_repository: MessageRepository, - stream_namespace: Optional[str], - **kwargs: Any, + stream_namespace: str | None, + **kwargs: Any, # noqa: ANN401, ARG002 ): self._stream_name = stream_config.name self._stream_namespace = stream_namespace @@ -50,25 +52,27 @@ def observe(self, record: Record) -> None: def close_partition(self, partition: Partition) -> None: pass - def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: + def set_pending_partitions(self, partitions: list["FileBasedStreamPartition"]) -> None: pass def add_file(self, file: RemoteFile) -> None: pass def get_files_to_sync( - self, all_files: Iterable[RemoteFile], logger: logging.Logger + self, + all_files: Iterable[RemoteFile], + logger: logging.Logger, # noqa: ARG002 ) -> Iterable[RemoteFile]: return all_files def get_state(self) -> MutableMapping[str, Any]: return {} - def set_initial_state(self, value: StreamState) -> None: + def set_initial_state(self, value: StreamState) -> None: # noqa: ARG002 return None def get_start_time(self) -> datetime: - return datetime.min + return datetime.min # noqa: DTZ901 def emit_state_message(self) -> None: pass diff --git a/airbyte_cdk/sources/file_based/stream/cursor/__init__.py b/airbyte_cdk/sources/file_based/stream/cursor/__init__.py index c1bf15a5d..34d4899fc 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/__init__.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/__init__.py @@ -1,4 +1,5 @@ from .abstract_file_based_cursor import AbstractFileBasedCursor from .default_file_based_cursor import DefaultFileBasedCursor + __all__ = ["AbstractFileBasedCursor", "DefaultFileBasedCursor"] diff --git a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py index 4a5eadb4e..e1094df5b 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py @@ -4,8 +4,9 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, MutableMapping from datetime import datetime -from typing import Any, Iterable, MutableMapping +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile @@ -18,7 +19,7 @@ class AbstractFileBasedCursor(ABC): """ @abstractmethod - def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any): + def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any): # noqa: ANN204, ANN401 """ Common interface for all cursors. """ diff --git a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py index 08ad8c3ae..fa876a24a 100644 --- a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py +++ b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py @@ -3,8 +3,9 @@ # import logging +from collections.abc import Iterable, MutableMapping from datetime import datetime, timedelta -from typing import Any, Iterable, MutableMapping, Optional +from typing import Any from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig from airbyte_cdk.sources.file_based.remote_file import RemoteFile @@ -20,7 +21,7 @@ class DefaultFileBasedCursor(AbstractFileBasedCursor): DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" CURSOR_FIELD = "_ab_source_file_last_modified" - def __init__(self, stream_config: FileBasedStreamConfig, **_: Any): + def __init__(self, stream_config: FileBasedStreamConfig, **_: Any): # noqa: ANN204, ANN401 super().__init__(stream_config) # type: ignore [safe-super] self._file_to_datetime_history: MutableMapping[str, str] = {} self._time_window_if_history_is_full = timedelta( @@ -34,7 +35,7 @@ def __init__(self, stream_config: FileBasedStreamConfig, **_: Any): ) self._start_time = self._compute_start_time() - self._initial_earliest_file_in_history: Optional[RemoteFile] = None + self._initial_earliest_file_in_history: RemoteFile | None = None def set_initial_state(self, value: StreamState) -> None: self._file_to_datetime_history = value.get("history", {}) @@ -51,7 +52,7 @@ def add_file(self, file: RemoteFile) -> None: if oldest_file: del self._file_to_datetime_history[oldest_file.uri] else: - raise Exception( + raise Exception( # noqa: TRY002 "The history is full but there is no files in the history. This should never happen and might be indicative of a bug in the CDK." ) @@ -59,7 +60,7 @@ def get_state(self) -> StreamState: state = {"history": self._file_to_datetime_history, self.CURSOR_FIELD: self._get_cursor()} return state - def _get_cursor(self) -> Optional[str]: + def _get_cursor(self) -> str | None: """ Returns the cursor value. @@ -68,7 +69,8 @@ def _get_cursor(self) -> Optional[str]: """ if self._file_to_datetime_history.items(): filename, timestamp = max( - self._file_to_datetime_history.items(), key=lambda x: (x[1], x[0]) + self._file_to_datetime_history.items(), + key=lambda x: (x[1], x[0]), # noqa: FURB118 ) return f"{timestamp}_{filename}" return None @@ -79,7 +81,7 @@ def _is_history_full(self) -> bool: """ return len(self._file_to_datetime_history) >= self.DEFAULT_MAX_HISTORY_SIZE - def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: + def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # noqa: PLR0911 if file.uri in self._file_to_datetime_history: # If the file's uri is in the history, we should sync the file if it has been modified since it was synced updated_at_from_history = datetime.strptime( @@ -99,16 +101,14 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # If the history is partial and the file's datetime is strictly greater than the earliest file in the history, # we should sync it return True - elif file.last_modified == self._initial_earliest_file_in_history.last_modified: + if file.last_modified == self._initial_earliest_file_in_history.last_modified: # If the history is partial and the file's datetime is equal to the earliest file in the history, # we should sync it if its uri is strictly greater than the earliest file in the history return file.uri > self._initial_earliest_file_in_history.uri - else: - # Otherwise, only sync the file if it has been modified since the start of the time window - return file.last_modified >= self.get_start_time() - else: - # The file is not in the history and the history is complete. We know we need to sync the file - return True + # Otherwise, only sync the file if it has been modified since the start of the time window + return file.last_modified >= self.get_start_time() + # The file is not in the history and the history is complete. We know we need to sync the file + return True def get_files_to_sync( self, all_files: Iterable[RemoteFile], logger: logging.Logger @@ -126,24 +126,23 @@ def get_files_to_sync( def get_start_time(self) -> datetime: return self._start_time - def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]: + def _compute_earliest_file_in_history(self) -> RemoteFile | None: if self._file_to_datetime_history: filename, last_modified = min( - self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0]) + self._file_to_datetime_history.items(), + key=lambda f: (f[1], f[0]), # noqa: FURB118 ) return RemoteFile( uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT) ) - else: - return None + return None def _compute_start_time(self) -> datetime: if not self._file_to_datetime_history: - return datetime.min - else: - earliest = min(self._file_to_datetime_history.values()) - earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) - if self._is_history_full(): - time_window = datetime.now() - self._time_window_if_history_is_full - earliest_dt = min(earliest_dt, time_window) - return earliest_dt + return datetime.min # noqa: DTZ901 + earliest = min(self._file_to_datetime_history.values()) + earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT) + if self._is_history_full(): + time_window = datetime.now() - self._time_window_if_history_is_full + earliest_dt = min(earliest_dt, time_window) + return earliest_dt diff --git a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py index 604322549..6d414d5ad 100644 --- a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +++ b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py @@ -6,10 +6,11 @@ import itertools import traceback from collections import defaultdict +from collections.abc import Iterable, Mapping, MutableMapping from copy import deepcopy from functools import cache from os import path -from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union +from typing import Any from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, FailureType, Level from airbyte_cdk.models import Type as MessageType @@ -53,11 +54,11 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin): ab_file_name_col = "_ab_source_file_url" modified = "modified" source_file_url = "source_file_url" - airbyte_columns = [ab_last_mod_col, ab_file_name_col] + airbyte_columns = [ab_last_mod_col, ab_file_name_col] # noqa: RUF012 use_file_transfer = False preserve_directory_structure = True - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any): # noqa: ANN204, ANN401 if self.FILE_TRANSFER_KW in kwargs: self.use_file_transfer = kwargs.pop(self.FILE_TRANSFER_KW, False) if self.PRESERVE_DIRECTORY_STRUCTURE_KW in kwargs: @@ -76,7 +77,7 @@ def state(self, value: MutableMapping[str, Any]) -> None: self._cursor.set_initial_state(value) @property # type: ignore # mypy complains wrong type, but AbstractFileBasedCursor is parent of file-based cursors - def cursor(self) -> Optional[AbstractFileBasedCursor]: + def cursor(self) -> AbstractFileBasedCursor | None: return self._cursor @cursor.setter @@ -94,8 +95,8 @@ def primary_key(self) -> PrimaryKeyType: ) def _filter_schema_invalid_properties( - self, configured_catalog_json_schema: Dict[str, Any] - ) -> Dict[str, Any]: + self, configured_catalog_json_schema: dict[str, Any] + ) -> dict[str, Any]: if self.use_file_transfer: return { "type": "object", @@ -105,22 +106,21 @@ def _filter_schema_invalid_properties( self.ab_file_name_col: {"type": "string"}, }, } - else: - return super()._filter_schema_invalid_properties(configured_catalog_json_schema) + return super()._filter_schema_invalid_properties(configured_catalog_json_schema) def _duplicated_files_names( - self, slices: List[dict[str, List[RemoteFile]]] - ) -> List[dict[str, List[str]]]: - seen_file_names: Dict[str, List[str]] = defaultdict(list) + self, slices: list[dict[str, list[RemoteFile]]] + ) -> list[dict[str, list[str]]]: + seen_file_names: dict[str, list[str]] = defaultdict(list) for file_slice in slices: for file_found in file_slice[self.FILES_KEY]: - file_name = path.basename(file_found.uri) + file_name = path.basename(file_found.uri) # noqa: PTH119 seen_file_names[file_name].append(file_found.uri) return [ {file_name: paths} for file_name, paths in seen_file_names.items() if len(paths) > 1 ] - def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: + def compute_slices(self) -> Iterable[Mapping[str, Any] | None]: # Sort files by last_modified, uri and return them grouped by last_modified all_files = self.list_files() files_to_read = self._cursor.get_files_to_sync(all_files, self.logger) @@ -174,7 +174,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte try: if self.use_file_transfer: self.logger.info(f"{self.name}: {file} file-based syncing") - # todo: complete here the code to not rely on local parser + # TODO: complete here the code to not rely on local parser file_transfer = FileTransfer() for record in file_transfer.get_file( self.config, file, self.stream_reader, self.logger @@ -183,7 +183,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte if not self.record_passes_validation_policy(record): n_skipped += 1 continue - record = self.transform_record_for_file_transfer(record, file) + record = self.transform_record_for_file_transfer(record, file) # noqa: PLW2901 yield stream_data_to_airbyte_message( self.name, record, is_file_transfer_message=True ) @@ -193,11 +193,11 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte ): line_no += 1 if self.config.schemaless: - record = {"data": record} + record = {"data": record} # noqa: PLW2901 elif not self.record_passes_validation_policy(record): n_skipped += 1 continue - record = self.transform_record(record, file, file_datetime_string) + record = self.transform_record(record, file, file_datetime_string) # noqa: PLW2901 yield stream_data_to_airbyte_message(self.name, record) self._cursor.add_file(file) @@ -227,7 +227,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte except AirbyteTracedException as exc: # Re-raise the exception to stop the whole sync immediately as this is a fatal error - raise exc + raise exc # noqa: TRY201 except Exception: yield AirbyteMessage( @@ -250,14 +250,14 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte ) @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: """ Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. """ return self.ab_last_mod_col - @cache + @cache # noqa: B019 def get_json_schema(self) -> JsonSchema: extra_fields = { self.ab_last_mod_col: {"type": "string"}, @@ -266,14 +266,14 @@ def get_json_schema(self) -> JsonSchema: try: schema = self._get_raw_json_schema() except InvalidSchemaError as config_exception: - raise AirbyteTracedException( + raise AirbyteTracedException( # noqa: B904 internal_message="Please check the logged errors for more information.", message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exception=AirbyteTracedException(exception=config_exception), failure_type=FailureType.config_error, ) except AirbyteTracedException as ate: - raise ate + raise ate # noqa: TRY201 except Exception as exc: raise SchemaInferenceError( FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name @@ -284,22 +284,21 @@ def get_json_schema(self) -> JsonSchema: def _get_raw_json_schema(self) -> JsonSchema: if self.use_file_transfer: return file_transfer_schema - elif self.config.input_schema: + if self.config.input_schema: return self.config.get_input_schema() # type: ignore - elif self.config.schemaless: + if self.config.schemaless: return schemaless_schema - else: - files = self.list_files() - first_n_files = len(files) - - if self.config.recent_n_files_to_read_for_schema_discovery: - self.logger.info( - msg=( - f"Only first {self.config.recent_n_files_to_read_for_schema_discovery} files will be used to infer schema " - f"for stream {self.name} due to limitation in config." - ) + files = self.list_files() + first_n_files = len(files) + + if self.config.recent_n_files_to_read_for_schema_discovery: + self.logger.info( + msg=( + f"Only first {self.config.recent_n_files_to_read_for_schema_discovery} files will be used to infer schema " + f"for stream {self.name} due to limitation in config." ) - first_n_files = self.config.recent_n_files_to_read_for_schema_discovery + ) + first_n_files = self.config.recent_n_files_to_read_for_schema_discovery if first_n_files == 0: self.logger.warning( @@ -341,7 +340,7 @@ def get_files(self) -> Iterable[RemoteFile]: self.config.globs or [], self.config.legacy_prefix, self.logger ) - def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: loop = asyncio.get_event_loop() schema = loop.run_until_complete(self._infer_schema(files)) # as infer schema returns a Mapping that is assumed to be immutable, we need to create a deepcopy to avoid modifying the reference @@ -354,7 +353,7 @@ def _fill_nulls(schema: Mapping[str, Any]) -> Mapping[str, Any]: if k == "type": if isinstance(v, list): if "null" not in v: - schema[k] = ["null"] + v + schema[k] = ["null"] + v # noqa: RUF005 elif v != "null": schema[k] = ["null", v] else: @@ -364,7 +363,7 @@ def _fill_nulls(schema: Mapping[str, Any]) -> Mapping[str, Any]: DefaultFileBasedStream._fill_nulls(item) return schema - async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: + async def _infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]: """ Infer the schema for a stream. @@ -372,7 +371,7 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: Dispatch on file type. """ base_schema: SchemaType = {} - pending_tasks: Set[asyncio.tasks.Task[SchemaType]] = set() + pending_tasks: set[asyncio.tasks.Task[SchemaType]] = set() n_started, n_files = 0, len(files) files_iterator = iter(files) @@ -391,7 +390,7 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]: try: base_schema = merge_schemas(base_schema, task.result()) except AirbyteTracedException as ate: - raise ate + raise ate # noqa: TRY201 except Exception as exc: self.logger.error( f"An error occurred inferring the schema. \n {traceback.format_exc()}", @@ -406,7 +405,7 @@ async def _infer_file_schema(self, file: RemoteFile) -> SchemaType: self.config, file, self.stream_reader, self.logger ) except AirbyteTracedException as ate: - raise ate + raise ate # noqa: TRY201 except Exception as exc: raise SchemaInferenceError( FileBasedSourceError.SCHEMA_INFERENCE_ERROR, diff --git a/airbyte_cdk/sources/file_based/types.py b/airbyte_cdk/sources/file_based/types.py index b83bf37a3..fa05978f7 100644 --- a/airbyte_cdk/sources/file_based/types.py +++ b/airbyte_cdk/sources/file_based/types.py @@ -1,10 +1,12 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # from __future__ import annotations -from typing import Any, Mapping, MutableMapping +from collections.abc import Mapping, MutableMapping +from typing import Any + StreamSlice = Mapping[str, Any] StreamState = MutableMapping[str, Any] diff --git a/airbyte_cdk/sources/http_logger.py b/airbyte_cdk/sources/http_logger.py index 33ccc68ac..91b6078cf 100644 --- a/airbyte_cdk/sources/http_logger.py +++ b/airbyte_cdk/sources/http_logger.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Optional, Union import requests @@ -13,8 +12,8 @@ def format_http_message( response: requests.Response, title: str, description: str, - stream_name: Optional[str], - is_auxiliary: bool | None = None, + stream_name: str | None, + is_auxiliary: bool | None = None, # noqa: FBT001 ) -> LogMessage: request = response.request log_message = { @@ -48,5 +47,5 @@ def format_http_message( return log_message # type: ignore [return-value] # got "dict[str, object]", expected "dict[str, JsonType]" -def _normalize_body_string(body_str: Optional[Union[str, bytes]]) -> Optional[str]: - return body_str.decode() if isinstance(body_str, (bytes, bytearray)) else body_str +def _normalize_body_string(body_str: str | bytes | None) -> str | None: + return body_str.decode() if isinstance(body_str, (bytes, bytearray)) else body_str # noqa: UP038 diff --git a/airbyte_cdk/sources/message/__init__.py b/airbyte_cdk/sources/message/__init__.py index 31c484ab5..fc43482f4 100644 --- a/airbyte_cdk/sources/message/__init__.py +++ b/airbyte_cdk/sources/message/__init__.py @@ -10,6 +10,7 @@ NoopMessageRepository, ) + __all__ = [ "InMemoryMessageRepository", "LogAppenderMessageRepositoryDecorator", diff --git a/airbyte_cdk/sources/message/repository.py b/airbyte_cdk/sources/message/repository.py index 2fc156e8c..952b0b882 100644 --- a/airbyte_cdk/sources/message/repository.py +++ b/airbyte_cdk/sources/message/repository.py @@ -6,12 +6,13 @@ import logging from abc import ABC, abstractmethod from collections import deque -from typing import Callable, Deque, Iterable, List, Optional +from collections.abc import Callable, Iterable from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.utils.types import JsonType from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets + _LOGGER = logging.getLogger("MessageRepository") _SUPPORTED_MESSAGE_TYPES = {Type.CONTROL, Type.LOG} LogMessage = dict[str, JsonType] @@ -45,7 +46,7 @@ def _is_severe_enough(threshold: Level, level: Level) -> bool: class MessageRepository(ABC): @abstractmethod def emit_message(self, message: AirbyteMessage) -> None: - raise NotImplementedError() + raise NotImplementedError @abstractmethod def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: @@ -53,11 +54,11 @@ def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) Computing messages can be resource consuming. This method is specialized for logging because we want to allow for lazy evaluation if the log level is less severe than what is configured """ - raise NotImplementedError() + raise NotImplementedError @abstractmethod def consume_queue(self) -> Iterable[AirbyteMessage]: - raise NotImplementedError() + raise NotImplementedError class NoopMessageRepository(MessageRepository): @@ -73,7 +74,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: class InMemoryMessageRepository(MessageRepository): def __init__(self, log_level: Level = Level.INFO) -> None: - self._message_queue: Deque[AirbyteMessage] = deque() + self._message_queue: deque[AirbyteMessage] = deque() self._log_level = log_level def emit_message(self, message: AirbyteMessage) -> None: @@ -96,7 +97,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: class LogAppenderMessageRepositoryDecorator(MessageRepository): - def __init__( + def __init__( # noqa: ANN204 self, dict_to_append: LogMessage, decorated: MessageRepository, @@ -119,7 +120,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: return self._decorated.consume_queue() def _append_second_to_first( - self, first: LogMessage, second: LogMessage, path: Optional[List[str]] = None + self, first: LogMessage, second: LogMessage, path: list[str] | None = None ) -> LogMessage: if path is None: path = [] @@ -127,10 +128,10 @@ def _append_second_to_first( for key in second: if key in first: if isinstance(first[key], dict) and isinstance(second[key], dict): - self._append_second_to_first(first[key], second[key], path + [str(key)]) # type: ignore # type is verified above + self._append_second_to_first(first[key], second[key], path + [str(key)]) # type: ignore # type is verified above # noqa: RUF005 else: if first[key] != second[key]: - _LOGGER.warning("Conflict at %s" % ".".join(path + [str(key)])) + _LOGGER.warning("Conflict at %s" % ".".join(path + [str(key)])) # noqa: UP031, RUF005 first[key] = second[key] else: first[key] = second[key] diff --git a/airbyte_cdk/sources/source.py b/airbyte_cdk/sources/source.py index 2958d82ca..8336b7d1b 100644 --- a/airbyte_cdk/sources/source.py +++ b/airbyte_cdk/sources/source.py @@ -5,7 +5,8 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Generic, Iterable, List, Mapping, Optional, TypeVar +from collections.abc import Iterable, Mapping +from typing import Any, Generic, TypeVar from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig from airbyte_cdk.models import ( @@ -17,6 +18,7 @@ ConfiguredAirbyteCatalogSerializer, ) + TState = TypeVar("TState") TCatalog = TypeVar("TCatalog") @@ -38,7 +40,7 @@ def read( logger: logging.Logger, config: TConfig, catalog: TCatalog, - state: Optional[TState] = None, + state: TState | None = None, ) -> Iterable[AirbyteMessage]: """ Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state. @@ -54,12 +56,12 @@ def discover(self, logger: logging.Logger, config: TConfig) -> AirbyteCatalog: class Source( DefaultConnectorMixin, - BaseSource[Mapping[str, Any], List[AirbyteStateMessage], ConfiguredAirbyteCatalog], + BaseSource[Mapping[str, Any], list[AirbyteStateMessage], ConfiguredAirbyteCatalog], ABC, ): # can be overridden to change an input state. @classmethod - def read_state(cls, state_path: str) -> List[AirbyteStateMessage]: + def read_state(cls, state_path: str) -> list[AirbyteStateMessage]: """ Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the @@ -69,7 +71,7 @@ def read_state(cls, state_path: str) -> List[AirbyteStateMessage]: """ parsed_state_messages = [] if state_path: - state_obj = BaseConnector._read_json_file(state_path) + state_obj = BaseConnector._read_json_file(state_path) # noqa: SLF001 if state_obj: for state in state_obj: # type: ignore # `isinstance(state_obj, List)` ensures that this is a list parsed_message = AirbyteStateMessageSerializer.load(state) diff --git a/airbyte_cdk/sources/streams/__init__.py b/airbyte_cdk/sources/streams/__init__.py index dc735b617..eec23592e 100644 --- a/airbyte_cdk/sources/streams/__init__.py +++ b/airbyte_cdk/sources/streams/__init__.py @@ -5,4 +5,5 @@ # Initialize Streams Package from .core import NO_CURSOR_STATE_KEY, CheckpointMixin, IncrementalMixin, Stream + __all__ = ["NO_CURSOR_STATE_KEY", "IncrementalMixin", "CheckpointMixin", "Stream"] diff --git a/airbyte_cdk/sources/streams/availability_strategy.py b/airbyte_cdk/sources/streams/availability_strategy.py index 312ddae19..efdeb5837 100644 --- a/airbyte_cdk/sources/streams/availability_strategy.py +++ b/airbyte_cdk/sources/streams/availability_strategy.py @@ -5,11 +5,13 @@ import logging import typing from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any, Optional from airbyte_cdk.models import SyncMode from airbyte_cdk.sources.streams.core import Stream, StreamData + if typing.TYPE_CHECKING: from airbyte_cdk.sources import Source @@ -22,7 +24,7 @@ class AvailabilityStrategy(ABC): @abstractmethod def check_availability( self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """ Checks stream availability. @@ -36,7 +38,7 @@ def check_availability( """ @staticmethod - def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]: + def get_first_stream_slice(stream: Stream) -> Mapping[str, Any] | None: """ Gets the first stream_slice from a given stream's stream_slices. :param stream: stream @@ -55,7 +57,7 @@ def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]: @staticmethod def get_first_record_for_slice( - stream: Stream, stream_slice: Optional[Mapping[str, Any]] + stream: Stream, stream_slice: Mapping[str, Any] | None ) -> StreamData: """ Gets the first record for a stream_slice of a stream. diff --git a/airbyte_cdk/sources/streams/call_rate.py b/airbyte_cdk/sources/streams/call_rate.py index 81ebac78e..bc1c62a7a 100644 --- a/airbyte_cdk/sources/streams/call_rate.py +++ b/airbyte_cdk/sources/streams/call_rate.py @@ -7,9 +7,10 @@ import datetime import logging import time +from collections.abc import Mapping from datetime import timedelta from threading import RLock -from typing import TYPE_CHECKING, Any, Mapping, Optional +from typing import TYPE_CHECKING, Any from urllib import parse import requests @@ -18,6 +19,7 @@ from pyrate_limiter import Rate as PyRateRate from pyrate_limiter.exceptions import BucketFullException + # prevents mypy from complaining about missing session attributes in LimiterMixin if TYPE_CHECKING: MIXIN_BASE = requests.Session @@ -36,7 +38,7 @@ class Rate: class CallRateLimitHit(Exception): - def __init__(self, error: str, item: Any, weight: int, rate: str, time_to_wait: timedelta): + def __init__(self, error: str, item: Any, weight: int, rate: str, time_to_wait: timedelta): # noqa: ANN204, ANN401 """Constructor :param error: error message @@ -58,7 +60,7 @@ class AbstractCallRatePolicy(abc.ABC): """ @abc.abstractmethod - def matches(self, request: Any) -> bool: + def matches(self, request: Any) -> bool: # noqa: ANN401 """Tells if this policy matches specific request and should apply to it :param request: @@ -66,7 +68,7 @@ def matches(self, request: Any) -> bool: """ @abc.abstractmethod - def try_acquire(self, request: Any, weight: int) -> None: + def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401 """Try to acquire request :param request: a request object representing a single call to API @@ -75,9 +77,7 @@ def try_acquire(self, request: Any, weight: int) -> None: """ @abc.abstractmethod - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Update call rate counting with current values :param available_calls: @@ -89,7 +89,7 @@ class RequestMatcher(abc.ABC): """Callable that help to match a request object with call rate policies.""" @abc.abstractmethod - def __call__(self, request: Any) -> bool: + def __call__(self, request: Any) -> bool: # noqa: ANN401 """ :param request: @@ -100,12 +100,12 @@ def __call__(self, request: Any) -> bool: class HttpRequestMatcher(RequestMatcher): """Simple implementation of RequestMatcher for http requests case""" - def __init__( + def __init__( # noqa: ANN204 self, - method: Optional[str] = None, - url: Optional[str] = None, - params: Optional[Mapping[str, Any]] = None, - headers: Optional[Mapping[str, Any]] = None, + method: str | None = None, + url: str | None = None, + params: Mapping[str, Any] | None = None, + headers: Mapping[str, Any] | None = None, ): """Constructor @@ -129,7 +129,7 @@ def _match_dict(obj: Mapping[str, Any], pattern: Mapping[str, Any]) -> bool: """ return pattern.items() <= obj.items() - def __call__(self, request: Any) -> bool: + def __call__(self, request: Any) -> bool: # noqa: ANN401 """ :param request: @@ -142,7 +142,7 @@ def __call__(self, request: Any) -> bool: else: return False - if self._method is not None: + if self._method is not None: # noqa: SIM102 if prepared_request.method != self._method: return False if self._url is not None and prepared_request.url is not None: @@ -154,17 +154,17 @@ def __call__(self, request: Any) -> bool: params = dict(parse.parse_qsl(str(parsed_url.query))) if not self._match_dict(params, self._params): return False - if self._headers is not None: + if self._headers is not None: # noqa: SIM102 if not self._match_dict(prepared_request.headers, self._headers): return False return True class BaseCallRatePolicy(AbstractCallRatePolicy, abc.ABC): - def __init__(self, matchers: list[RequestMatcher]): + def __init__(self, matchers: list[RequestMatcher]): # noqa: ANN204 self._matchers = matchers - def matches(self, request: Any) -> bool: + def matches(self, request: Any) -> bool: # noqa: ANN401 """Tell if this policy matches specific request and should apply to it :param request: @@ -200,17 +200,15 @@ class UnlimitedCallRatePolicy(BaseCallRatePolicy): The code above will limit all calls to /some/method except calls that have header sandbox=True """ - def try_acquire(self, request: Any, weight: int) -> None: + def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401 """Do nothing""" - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Do nothing""" class FixedWindowCallRatePolicy(BaseCallRatePolicy): - def __init__( + def __init__( # noqa: ANN204 self, next_reset_ts: datetime.datetime, period: timedelta, @@ -232,7 +230,7 @@ def __init__( self._lock = RLock() super().__init__(matchers=matchers) - def try_acquire(self, request: Any, weight: int) -> None: + def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401 if weight > self._call_limit: raise ValueError("Weight can not exceed the call limit") if not self.matches(request): @@ -257,9 +255,7 @@ def try_acquire(self, request: Any, weight: int) -> None: self._calls_num += weight - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Update call rate counters, by default, only reacts to decreasing updates of available_calls and changes to call_reset_ts. We ignore updates with available_calls > current_available_calls to support call rate limits that are lower than API limits. @@ -290,7 +286,7 @@ def _update_current_window(self) -> None: now = datetime.datetime.now() if now > self._next_reset_ts: logger.debug("started new window, %s calls available now", self._call_limit) - self._next_reset_ts = self._next_reset_ts + self._offset + self._next_reset_ts = self._next_reset_ts + self._offset # noqa: PLR6104 self._calls_num = 0 @@ -302,7 +298,7 @@ class MovingWindowCallRatePolicy(BaseCallRatePolicy): This strategy requires saving of timestamps of all requests within a window. """ - def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]): + def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]): # noqa: ANN204 """Constructor :param rates: list of rates, the order is important and must be ascending @@ -319,7 +315,7 @@ def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]): self._limiter = Limiter(self._bucket) super().__init__(matchers=matchers) - def try_acquire(self, request: Any, weight: int) -> None: + def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401 if not self.matches(request): raise ValueError("Request does not match the policy") @@ -333,7 +329,7 @@ def try_acquire(self, request: Any, weight: int) -> None: time_to_wait = self._bucket.waiting(item) assert isinstance(time_to_wait, int) - raise CallRateLimitHit( + raise CallRateLimitHit( # noqa: B904 error=str(exc.meta_info["error"]), item=request, weight=int(exc.meta_info["weight"]), @@ -341,16 +337,14 @@ def try_acquire(self, request: Any, weight: int) -> None: time_to_wait=timedelta(milliseconds=time_to_wait), ) - def update( - self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] - ) -> None: + def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None: """Adjust call bucket to reflect the state of the API server :param available_calls: :param call_reset_ts: :return: """ - if ( + if ( # noqa: SIM102 available_calls is not None and call_reset_ts is None ): # we do our best to sync buckets with API if available_calls == 0: @@ -376,7 +370,11 @@ class AbstractAPIBudget(abc.ABC): @abc.abstractmethod def acquire_call( - self, request: Any, block: bool = True, timeout: Optional[float] = None + self, + request: Any, # noqa: ANN401 + *, + block: bool = True, + timeout: float | None = None, ) -> None: """Try to get a call from budget, will block by default @@ -387,11 +385,11 @@ def acquire_call( """ @abc.abstractmethod - def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]: + def get_matching_policy(self, request: Any) -> AbstractCallRatePolicy | None: # noqa: ANN401 """Find matching call rate policy for specific request""" @abc.abstractmethod - def update_from_response(self, request: Any, response: Any) -> None: + def update_from_response(self, request: Any, response: Any) -> None: # noqa: ANN401 """Update budget information based on response from API :param request: the initial request that triggered this response @@ -415,14 +413,18 @@ def __init__( self._policies = policies self._maximum_attempts_to_acquire = maximum_attempts_to_acquire - def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]: + def get_matching_policy(self, request: Any) -> AbstractCallRatePolicy | None: # noqa: ANN401 for policy in self._policies: if policy.matches(request): return policy return None def acquire_call( - self, request: Any, block: bool = True, timeout: Optional[float] = None + self, + request: Any, # noqa: ANN401 + *, + block: bool = True, + timeout: float | None = None, ) -> None: """Try to get a call from budget, will block by default. Matchers will be called sequentially in the same order they were added. @@ -440,7 +442,7 @@ def acquire_call( elif self._policies: logger.info("no policies matched with requests, allow call by default") - def update_from_response(self, request: Any, response: Any) -> None: + def update_from_response(self, request: Any, response: Any) -> None: # noqa: ANN401 """Update budget information based on response from API :param request: the initial request that triggered this response @@ -449,7 +451,12 @@ def update_from_response(self, request: Any, response: Any) -> None: pass def _do_acquire( - self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: Optional[float] + self, + request: Any, # noqa: ANN401 + policy: AbstractCallRatePolicy, + *, + block: bool, + timeout: float | None, ) -> None: """Internal method to try to acquire a call credit @@ -460,10 +467,10 @@ def _do_acquire( """ last_exception = None # sometimes we spend all budget before a second attempt, so we have few more here - for attempt in range(1, self._maximum_attempts_to_acquire): + for attempt in range(1, self._maximum_attempts_to_acquire): # noqa: B007 try: policy.try_acquire(request, weight=1) - return + return # noqa: TRY300 except CallRateLimitHit as exc: last_exception = exc if block: @@ -492,12 +499,12 @@ def _do_acquire( class HttpAPIBudget(APIBudget): """Implementation of AbstractAPIBudget for HTTP""" - def __init__( + def __init__( # noqa: ANN204 self, ratelimit_reset_header: str = "ratelimit-reset", ratelimit_remaining_header: str = "ratelimit-remaining", status_codes_for_ratelimit_hit: tuple[int] = (429,), - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ): """Constructor @@ -510,7 +517,7 @@ def __init__( self._status_codes_for_ratelimit_hit = status_codes_for_ratelimit_hit super().__init__(**kwargs) - def update_from_response(self, request: Any, response: Any) -> None: + def update_from_response(self, request: Any, response: Any) -> None: # noqa: ANN401 policy = self.get_matching_policy(request) if not policy: return @@ -520,16 +527,14 @@ def update_from_response(self, request: Any, response: Any) -> None: reset_ts = self.get_reset_ts_from_response(response) policy.update(available_calls=available_calls, call_reset_ts=reset_ts) - def get_reset_ts_from_response( - self, response: requests.Response - ) -> Optional[datetime.datetime]: + def get_reset_ts_from_response(self, response: requests.Response) -> datetime.datetime | None: if response.headers.get(self._ratelimit_reset_header): - return datetime.datetime.fromtimestamp( + return datetime.datetime.fromtimestamp( # noqa: DTZ006 int(response.headers[self._ratelimit_reset_header]) ) return None - def get_calls_left_from_response(self, response: requests.Response) -> Optional[int]: + def get_calls_left_from_response(self, response: requests.Response) -> int | None: if response.headers.get(self._ratelimit_remaining_header): return int(response.headers[self._ratelimit_remaining_header]) @@ -542,15 +547,15 @@ def get_calls_left_from_response(self, response: requests.Response) -> Optional[ class LimiterMixin(MIXIN_BASE): """Mixin class that adds rate-limiting behavior to requests.""" - def __init__( + def __init__( # noqa: ANN204 self, api_budget: AbstractAPIBudget, - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ): self._api_budget = api_budget super().__init__(**kwargs) # type: ignore # Base Session doesn't take any kwargs - def send(self, request: requests.PreparedRequest, **kwargs: Any) -> requests.Response: + def send(self, request: requests.PreparedRequest, **kwargs: Any) -> requests.Response: # noqa: ANN401 """Send a request with rate-limiting.""" self._api_budget.acquire_call(request) response = super().send(request, **kwargs) diff --git a/airbyte_cdk/sources/streams/checkpoint/__init__.py b/airbyte_cdk/sources/streams/checkpoint/__init__.py index ae4e0e46f..866df6875 100644 --- a/airbyte_cdk/sources/streams/checkpoint/__init__.py +++ b/airbyte_cdk/sources/streams/checkpoint/__init__.py @@ -13,6 +13,7 @@ from .cursor import Cursor from .resumable_full_refresh_cursor import ResumableFullRefreshCursor + __all__ = [ "CheckpointMode", "CheckpointReader", diff --git a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py index 6e4ef98d7..20ff03708 100644 --- a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py +++ b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py @@ -1,12 +1,12 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping from enum import Enum -from typing import Any, Iterable, Mapping, Optional - -from airbyte_cdk.sources.types import StreamSlice +from typing import Any from .cursor import Cursor +from airbyte_cdk.sources.types import StreamSlice class CheckpointMode(Enum): @@ -25,7 +25,7 @@ class CheckpointReader(ABC): """ @abstractmethod - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: """ Returns the next slice that will be used to fetch the next group of records. Returning None indicates that the reader has finished iterating over all slices. @@ -41,7 +41,7 @@ def observe(self, new_state: Mapping[str, Any]) -> None: """ @abstractmethod - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: """ Retrieves the current state value of the stream. The connector does not emit state messages if the checkpoint value is None. """ @@ -53,18 +53,18 @@ class IncrementalCheckpointReader(CheckpointReader): before syncing data. """ - def __init__( - self, stream_state: Mapping[str, Any], stream_slices: Iterable[Optional[Mapping[str, Any]]] + def __init__( # noqa: ANN204 + self, stream_state: Mapping[str, Any], stream_slices: Iterable[Mapping[str, Any] | None] ): - self._state: Optional[Mapping[str, Any]] = stream_state + self._state: Mapping[str, Any] | None = stream_state self._stream_slices = iter(stream_slices) self._has_slices = False - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: next_slice = next(self._stream_slices) self._has_slices = True - return next_slice + return next_slice # noqa: TRY300 except StopIteration: # This is used to avoid sending a duplicate state message at the end of a sync since the stream has already # emitted state at the end of each slice. If we want to avoid this extra complexity, we can also just accept @@ -76,7 +76,7 @@ def next(self) -> Optional[Mapping[str, Any]]: def observe(self, new_state: Mapping[str, Any]) -> None: self._state = new_state - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: return self._state @@ -89,25 +89,25 @@ class CursorBasedCheckpointReader(CheckpointReader): that belongs to the Concurrent CDK. """ - def __init__( + def __init__( # noqa: ANN204 self, cursor: Cursor, - stream_slices: Iterable[Optional[Mapping[str, Any]]], - read_state_from_cursor: bool = False, + stream_slices: Iterable[Mapping[str, Any] | None], + read_state_from_cursor: bool = False, # noqa: FBT001, FBT002 ): self._cursor = cursor self._stream_slices = iter(stream_slices) # read_state_from_cursor is used to delineate that partitions should determine when to stop syncing dynamically according # to the value of the state at runtime. This currently only applies to streams that use resumable full refresh. self._read_state_from_cursor = read_state_from_cursor - self._current_slice: Optional[StreamSlice] = None + self._current_slice: StreamSlice | None = None self._finished_sync = False - self._previous_state: Optional[Mapping[str, Any]] = None + self._previous_state: Mapping[str, Any] | None = None - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: self.current_slice = self._find_next_slice() - return self.current_slice + return self.current_slice # noqa: TRY300 except StopIteration: self._finished_sync = True return None @@ -117,14 +117,13 @@ def observe(self, new_state: Mapping[str, Any]) -> None: # while processing records pass - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: # This is used to avoid sending a duplicate state messages new_state = self._cursor.get_stream_state() if new_state != self._previous_state: self._previous_state = new_state return new_state - else: - return None + return None def _find_next_slice(self) -> StreamSlice: """ @@ -165,36 +164,34 @@ def _find_next_slice(self) -> StreamSlice: partition=next_slice.partition, extra_fields=next_slice.extra_fields, ) - else: - state_for_slice = self._cursor.select_state(self.current_slice) - if state_for_slice == FULL_REFRESH_COMPLETE_STATE: - # If the current slice is is complete, move to the next slice and skip the next slices that already - # have the terminal complete value indicating that a previous attempt was successfully read. - # Dummy initialization for mypy since we'll iterate at least once to get the next slice - next_candidate_slice = StreamSlice(cursor_slice={}, partition={}) - has_more = True - while has_more: - next_candidate_slice = self.read_and_convert_slice() - state_for_slice = self._cursor.select_state(next_candidate_slice) - has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE - return StreamSlice( - cursor_slice=state_for_slice or {}, - partition=next_candidate_slice.partition, - extra_fields=next_candidate_slice.extra_fields, - ) - # The reader continues to process the current partition if it's state is still in progress + state_for_slice = self._cursor.select_state(self.current_slice) + if state_for_slice == FULL_REFRESH_COMPLETE_STATE: + # If the current slice is is complete, move to the next slice and skip the next slices that already + # have the terminal complete value indicating that a previous attempt was successfully read. + # Dummy initialization for mypy since we'll iterate at least once to get the next slice + next_candidate_slice = StreamSlice(cursor_slice={}, partition={}) + has_more = True + while has_more: + next_candidate_slice = self.read_and_convert_slice() + state_for_slice = self._cursor.select_state(next_candidate_slice) + has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE return StreamSlice( cursor_slice=state_for_slice or {}, - partition=self.current_slice.partition, - extra_fields=self.current_slice.extra_fields, + partition=next_candidate_slice.partition, + extra_fields=next_candidate_slice.extra_fields, ) - else: - # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate - # on a fixed set of slices determined before reading records. They just iterate to the next slice - return self.read_and_convert_slice() + # The reader continues to process the current partition if it's state is still in progress + return StreamSlice( + cursor_slice=state_for_slice or {}, + partition=self.current_slice.partition, + extra_fields=self.current_slice.extra_fields, + ) + # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate + # on a fixed set of slices determined before reading records. They just iterate to the next slice + return self.read_and_convert_slice() @property - def current_slice(self) -> Optional[StreamSlice]: + def current_slice(self) -> StreamSlice | None: return self._current_slice @current_slice.setter @@ -204,7 +201,7 @@ def current_slice(self, value: StreamSlice) -> None: def read_and_convert_slice(self) -> StreamSlice: next_slice = next(self._stream_slices) if not isinstance(next_slice, StreamSlice): - raise ValueError( + raise ValueError( # noqa: TRY004 f"{self.current_slice} should be of type StreamSlice. This is likely a bug in the CDK, please contact Airbyte support" ) return next_slice @@ -231,11 +228,11 @@ class LegacyCursorBasedCheckpointReader(CursorBasedCheckpointReader): } """ - def __init__( + def __init__( # noqa: ANN204 self, cursor: Cursor, - stream_slices: Iterable[Optional[Mapping[str, Any]]], - read_state_from_cursor: bool = False, + stream_slices: Iterable[Mapping[str, Any] | None], + read_state_from_cursor: bool = False, # noqa: FBT001, FBT002 ): super().__init__( cursor=cursor, @@ -243,13 +240,13 @@ def __init__( read_state_from_cursor=read_state_from_cursor, ) - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: self.current_slice = self._find_next_slice() if "partition" in dict(self.current_slice): raise ValueError("Stream is configured to use invalid stream slice key 'partition'") - elif "cursor_slice" in dict(self.current_slice): + if "cursor_slice" in dict(self.current_slice): raise ValueError( "Stream is configured to use invalid stream slice key 'cursor_slice'" ) @@ -268,7 +265,7 @@ def next(self) -> Optional[Mapping[str, Any]]: def read_and_convert_slice(self) -> StreamSlice: next_mapping_slice = next(self._stream_slices) if not isinstance(next_mapping_slice, Mapping): - raise ValueError( + raise ValueError( # noqa: TRY004 f"{self.current_slice} should be of type Mapping. This is likely a bug in the CDK, please contact Airbyte support" ) @@ -287,25 +284,24 @@ class ResumableFullRefreshCheckpointReader(CheckpointReader): fetching more pages or stopping the sync. """ - def __init__(self, stream_state: Mapping[str, Any]): + def __init__(self, stream_state: Mapping[str, Any]): # noqa: ANN204 # The first attempt of an RFR stream has an empty {} incoming state, but should still make a first attempt to read records # from the first page in next(). self._first_page = bool(stream_state == {}) self._state: Mapping[str, Any] = stream_state - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: if self._first_page: self._first_page = False return self._state - elif self._state == FULL_REFRESH_COMPLETE_STATE: + if self._state == FULL_REFRESH_COMPLETE_STATE: return None - else: - return self._state + return self._state def observe(self, new_state: Mapping[str, Any]) -> None: self._state = new_state - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: return self._state or {} @@ -315,11 +311,11 @@ class FullRefreshCheckpointReader(CheckpointReader): is not capable of managing state. At the end of a sync, a final state message is emitted to signal completion. """ - def __init__(self, stream_slices: Iterable[Optional[Mapping[str, Any]]]): + def __init__(self, stream_slices: Iterable[Mapping[str, Any] | None]): # noqa: ANN204 self._stream_slices = iter(stream_slices) self._final_checkpoint = False - def next(self) -> Optional[Mapping[str, Any]]: + def next(self) -> Mapping[str, Any] | None: try: return next(self._stream_slices) except StopIteration: @@ -329,7 +325,7 @@ def next(self) -> Optional[Mapping[str, Any]]: def observe(self, new_state: Mapping[str, Any]) -> None: pass - def get_checkpoint(self) -> Optional[Mapping[str, Any]]: + def get_checkpoint(self) -> Mapping[str, Any] | None: if self._final_checkpoint: return {"__ab_no_cursor_state_message": True} return None diff --git a/airbyte_cdk/sources/streams/checkpoint/cursor.py b/airbyte_cdk/sources/streams/checkpoint/cursor.py index 6d758bf4e..1f57b0482 100644 --- a/airbyte_cdk/sources/streams/checkpoint/cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/cursor.py @@ -3,7 +3,7 @@ # from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from airbyte_cdk.sources.types import Record, StreamSlice, StreamState @@ -23,7 +23,7 @@ def set_initial_state(self, stream_state: StreamState) -> None: :param stream_state: The state of the stream as returned by get_stream_state """ - def observe(self, stream_slice: StreamSlice, record: Record) -> None: + def observe(self, stream_slice: StreamSlice, record: Record) -> None: # noqa: B027 """ Register a record with the cursor; the cursor instance can then use it to manage the state of the in-progress stream read. @@ -34,7 +34,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: pass @abstractmethod - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401 """ Update state based on the stream slice. Note that `stream_slice.cursor_slice` and `most_recent_record.associated_slice` are expected to be the same but we make it explicit here that `stream_slice` should be leveraged to update the state. We do not pass in the @@ -69,7 +69,7 @@ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: """ @abstractmethod - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: """ Get the state value of a specific stream_slice. For incremental or resumable full refresh cursors which only manage state in a single dimension this is the entire state object. For per-partition cursors used by substreams, this returns the state of diff --git a/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py b/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py index e0dee4a92..3acfee945 100644 --- a/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py +++ b/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. import json -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any class PerPartitionKeySerializer: @@ -13,10 +14,10 @@ class PerPartitionKeySerializer: """ @staticmethod - def to_partition_key(to_serialize: Any) -> str: + def to_partition_key(to_serialize: Any) -> str: # noqa: ANN401 # separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True) @staticmethod - def to_partition(to_deserialize: Any) -> Mapping[str, Any]: + def to_partition(to_deserialize: Any) -> Mapping[str, Any]: # noqa: ANN401 return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any diff --git a/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py index 86abd253f..23fa9f31a 100644 --- a/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from airbyte_cdk.sources.streams.checkpoint import Cursor from airbyte_cdk.sources.types import Record, StreamSlice, StreamState @@ -30,22 +30,22 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: """ pass - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002 self._cursor = stream_slice.cursor_slice - def should_be_synced(self, record: Record) -> bool: + def should_be_synced(self, record: Record) -> bool: # noqa: ARG002 """ Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages that don't have filterable bounds. We should always return them. """ return True - def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: + def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: # noqa: ARG002 """ RFR record don't have ordering to be compared between one another. """ return False - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002 # A top-level RFR cursor only manages the state of a single partition return self._cursor diff --git a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py index 9966959f0..75bf976dd 100644 --- a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py +++ b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass -from typing import Any, Mapping, MutableMapping, Optional +from typing import Any from airbyte_cdk.models import FailureType from airbyte_cdk.sources.streams.checkpoint import Cursor @@ -11,6 +12,7 @@ from airbyte_cdk.sources.types import Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException + FULL_REFRESH_COMPLETE_STATE: Mapping[str, Any] = {"__ab_full_refresh_sync_complete": True} @@ -76,26 +78,26 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None: """ pass - def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: + def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002 self._per_partition_state[self._to_partition_key(stream_slice.partition)] = { "partition": stream_slice.partition, "cursor": FULL_REFRESH_COMPLETE_STATE, } - def should_be_synced(self, record: Record) -> bool: + def should_be_synced(self, record: Record) -> bool: # noqa: ARG002 """ Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages that don't have filterable bounds. We should always return them. """ return True - def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: + def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: # noqa: ARG002 """ RFR record don't have ordering to be compared between one another. """ return False - def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: if not stream_slice: raise ValueError("A partition needs to be provided in order to extract a state") diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py index 26e6f09d4..4e5d321d8 100644 --- a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py @@ -3,7 +3,8 @@ # from abc import ABC, abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from typing_extensions import deprecated @@ -58,7 +59,7 @@ def name(self) -> str: @property @abstractmethod - def cursor_field(self) -> Optional[str]: + def cursor_field(self) -> str | None: """ Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. Nested cursor fields are not supported. diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py index 18cacbc50..e67a6a461 100644 --- a/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py +++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py @@ -1,14 +1,15 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage + StreamType = TypeVar("StreamType") -class AbstractStreamFacade(Generic[StreamType], ABC): +class AbstractStreamFacade(Generic[StreamType], ABC): # noqa: PYI059 @abstractmethod def get_underlying_stream(self) -> StreamType: """ @@ -21,7 +22,7 @@ def source_defined_cursor(self) -> bool: # Streams must be aware of their cursor at instantiation time return True - def get_error_display_message(self, exception: BaseException) -> Optional[str]: + def get_error_display_message(self, exception: BaseException) -> str | None: """ Retrieves the user-friendly display message that corresponds to an exception. This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. @@ -33,5 +34,4 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: """ if isinstance(exception, ExceptionWithDisplayMessage): return exception.display_message - else: - return None + return None diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index f304bfb21..135d29f86 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -5,8 +5,9 @@ import copy import json import logging -from functools import lru_cache -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from collections.abc import Iterable, Mapping, MutableMapping +from functools import cache +from typing import Any, Optional from typing_extensions import deprecated @@ -45,6 +46,7 @@ from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.utils.slice_hasher import SliceHasher + """ This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream """ @@ -68,7 +70,7 @@ def create_from_stream( stream: Stream, source: AbstractSource, logger: logging.Logger, - state: Optional[MutableMapping[str, Any]], + state: MutableMapping[str, Any] | None, cursor: Cursor, ) -> Stream: """ @@ -109,7 +111,7 @@ def create_from_stream( ), stream, cursor, - slice_logger=source._slice_logger, + slice_logger=source._slice_logger, # noqa: SLF001 logger=logger, ) @@ -124,7 +126,7 @@ def state(self, value: Mapping[str, Any]) -> None: if "state" in dir(self._legacy_stream): self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above - def __init__( + def __init__( # noqa: ANN204 self, stream: DefaultStream, legacy_stream: Stream, @@ -143,21 +145,21 @@ def __init__( def read( self, - configured_stream: ConfiguredAirbyteStream, - logger: logging.Logger, - slice_logger: SliceLogger, - stream_state: MutableMapping[str, Any], - state_manager: ConnectorStateManager, - internal_config: InternalConfig, + configured_stream: ConfiguredAirbyteStream, # noqa: ARG002 + logger: logging.Logger, # noqa: ARG002 + slice_logger: SliceLogger, # noqa: ARG002 + stream_state: MutableMapping[str, Any], # noqa: ARG002 + state_manager: ConnectorStateManager, # noqa: ARG002 + internal_config: InternalConfig, # noqa: ARG002 ) -> Iterable[StreamData]: yield from self._read_records() def read_records( self, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Iterable[StreamData]: try: yield from self._read_records() @@ -173,7 +175,7 @@ def read_records( level=Level.ERROR, message=f"Cursor State at time of exception: {state}" ), ) - raise exc + raise exc # noqa: TRY201 def _read_records(self) -> Iterable[StreamData]: for partition in self._abstract_stream.generate_partitions(): @@ -187,22 +189,21 @@ def name(self) -> str: return self._abstract_stream.name @property - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: # This method is not expected to be called directly. It is only implemented for backward compatibility with the old interface return self.as_airbyte_stream().source_defined_primary_key # type: ignore # source_defined_primary_key is known to be an Optional[List[List[str]]] @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: if self._abstract_stream.cursor_field is None: return [] - else: - return self._abstract_stream.cursor_field + return self._abstract_stream.cursor_field @property - def cursor(self) -> Optional[Cursor]: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor + def cursor(self) -> Cursor | None: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor return self._cursor - @lru_cache(maxsize=None) + @cache # noqa: B019 def get_json_schema(self) -> Mapping[str, Any]: return self._abstract_stream.get_json_schema() @@ -211,8 +212,10 @@ def supports_incremental(self) -> bool: return self._legacy_stream.supports_incremental def check_availability( - self, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: + self, + logger: logging.Logger, # noqa: ARG002 + source: Optional["Source"] = None, # noqa: ARG002 + ) -> tuple[bool, str | None]: """ Verifies the stream is available. Delegates to the underlying AbstractStream and ignores the parameters :param logger: (ignored) @@ -233,7 +236,7 @@ def get_underlying_stream(self) -> DefaultStream: class SliceEncoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: + def default(self, obj: Any) -> Any: # noqa: ANN401 if hasattr(obj, "__json_serializable__"): return obj.__json_serializable__() @@ -251,14 +254,14 @@ class StreamPartition(Partition): In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time. """ - def __init__( + def __init__( # noqa: ANN204 self, stream: Stream, - _slice: Optional[Mapping[str, Any]], + _slice: Mapping[str, Any] | None, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, ): """ :param stream: The stream to delegate to @@ -309,9 +312,9 @@ def read(self) -> Iterable[Record]: if display_message: raise ExceptionWithDisplayMessage(display_message) from e else: - raise e + raise e # noqa: TRY201 - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: return self._slice def __hash__(self) -> int: @@ -332,13 +335,13 @@ class StreamPartitionGenerator(PartitionGenerator): In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time. """ - def __init__( + def __init__( # noqa: ANN204 self, stream: Stream, message_repository: MessageRepository, sync_mode: SyncMode, - cursor_field: Optional[List[str]], - state: Optional[MutableMapping[str, Any]], + cursor_field: list[str] | None, + state: MutableMapping[str, Any] | None, ): """ :param stream: The stream to delegate to @@ -369,12 +372,15 @@ def generate(self) -> Iterable[Partition]: category=ExperimentalClassWarning, ) class AvailabilityStrategyFacade(AvailabilityStrategy): - def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy): + def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy): # noqa: ANN204 self._abstract_availability_strategy = abstract_availability_strategy def check_availability( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: + self, + stream: Stream, # noqa: ARG002 + logger: logging.Logger, + source: Optional["Source"] = None, # noqa: ARG002 + ) -> tuple[bool, str | None]: """ Checks stream availability. diff --git a/airbyte_cdk/sources/streams/concurrent/availability_strategy.py b/airbyte_cdk/sources/streams/concurrent/availability_strategy.py index 118a7d0bb..9722a1b27 100644 --- a/airbyte_cdk/sources/streams/concurrent/availability_strategy.py +++ b/airbyte_cdk/sources/streams/concurrent/availability_strategy.py @@ -4,7 +4,6 @@ import logging from abc import ABC, abstractmethod -from typing import Optional from typing_extensions import deprecated @@ -19,7 +18,7 @@ def is_available(self) -> bool: """ @abstractmethod - def message(self) -> Optional[str]: + def message(self) -> str | None: """ :return: A message describing why the stream is not available. If the stream is available, this should return None. """ @@ -29,18 +28,18 @@ class StreamAvailable(StreamAvailability): def is_available(self) -> bool: return True - def message(self) -> Optional[str]: + def message(self) -> str | None: return None class StreamUnavailable(StreamAvailability): - def __init__(self, message: str): + def __init__(self, message: str): # noqa: ANN204 self._message = message def is_available(self) -> bool: return False - def message(self) -> Optional[str]: + def message(self) -> str | None: return self._message @@ -84,7 +83,7 @@ class AlwaysAvailableAvailabilityStrategy(AbstractAvailabilityStrategy): without disrupting existing functionality. """ - def check_availability(self, logger: logging.Logger) -> StreamAvailability: + def check_availability(self, logger: logging.Logger) -> StreamAvailability: # noqa: ARG002 """ Checks stream availability. diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index cbce82a94..3900388e2 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -5,17 +5,10 @@ import functools import logging from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Mapping, MutableMapping from typing import ( Any, - Callable, - Iterable, - List, - Mapping, - MutableMapping, - Optional, Protocol, - Tuple, - Union, ) from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager @@ -28,10 +21,11 @@ ) from airbyte_cdk.sources.types import Record, StreamSlice + LOGGER = logging.getLogger("airbyte") -def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any: +def _extract_value(mapping: Mapping[str, Any], path: list[str]) -> Any: # noqa: ANN401 return functools.reduce(lambda a, b: a[b], path, mapping) @@ -86,14 +80,14 @@ def observe(self, record: Record) -> None: """ Indicate to the cursor that the record has been emitted """ - raise NotImplementedError() + raise NotImplementedError @abstractmethod def close_partition(self, partition: Partition) -> None: """ Indicate to the cursor that the partition has been successfully processed """ - raise NotImplementedError() + raise NotImplementedError @abstractmethod def ensure_at_least_one_state_emitted(self) -> None: @@ -101,7 +95,7 @@ def ensure_at_least_one_state_emitted(self) -> None: State messages are emitted when a partition is closed. However, the platform expects at least one state to be emitted per sync per stream. Hence, if no partitions are generated, this method needs to be called. """ - raise NotImplementedError() + raise NotImplementedError def stream_slices(self) -> Iterable[StreamSlice]: """ @@ -117,7 +111,7 @@ class FinalStateCursor(Cursor): def __init__( self, stream_name: str, - stream_namespace: Optional[str], + stream_namespace: str | None, message_repository: MessageRepository, ) -> None: self._stream_name = stream_name @@ -157,21 +151,21 @@ class ConcurrentCursor(Cursor): _START_BOUNDARY = 0 _END_BOUNDARY = 1 - def __init__( + def __init__( # noqa: PLR0913, PLR0917 self, stream_name: str, - stream_namespace: Optional[str], - stream_state: Any, + stream_namespace: str | None, + stream_state: Any, # noqa: ANN401 message_repository: MessageRepository, connector_state_manager: ConnectorStateManager, connector_state_converter: AbstractStreamStateConverter, cursor_field: CursorField, - slice_boundary_fields: Optional[Tuple[str, str]], - start: Optional[CursorValueType], + slice_boundary_fields: tuple[str, str] | None, + start: CursorValueType | None, end_provider: Callable[[], CursorValueType], - lookback_window: Optional[GapType] = None, - slice_range: Optional[GapType] = None, - cursor_granularity: Optional[GapType] = None, + lookback_window: GapType | None = None, + slice_range: GapType | None = None, + cursor_granularity: GapType | None = None, ) -> None: self._stream_name = stream_name self._stream_namespace = stream_namespace @@ -187,7 +181,7 @@ def __init__( self._lookback_window = lookback_window self._slice_range = slice_range self._most_recent_cursor_value_per_partition: MutableMapping[ - Union[StreamSlice, Mapping[str, Any], None], Any + StreamSlice | Mapping[str, Any] | None, Any ] = {} self._has_closed_at_least_one_slice = False self._cursor_granularity = cursor_granularity @@ -203,9 +197,9 @@ def cursor_field(self) -> CursorField: return self._cursor_field @property - def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]: + def _slice_boundary_fields_wrapper(self) -> tuple[str, str]: return ( - self._slice_boundary_fields + self._slice_boundary_fields # noqa: FURB110 if self._slice_boundary_fields else ( self._connector_state_converter.START_KEY, @@ -215,7 +209,7 @@ def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]: def _get_concurrent_state( self, state: MutableMapping[str, Any] - ) -> Tuple[CursorValueType, MutableMapping[str, Any]]: + ) -> tuple[CursorValueType, MutableMapping[str, Any]]: if self._connector_state_converter.is_state_message_compatible(state): return ( self._start or self._connector_state_converter.zero_value, @@ -237,7 +231,7 @@ def observe(self, record: Record) -> None: except ValueError: self._log_for_record_without_cursor_value() - def _extract_cursor_value(self, record: Record) -> Any: + def _extract_cursor_value(self, record: Record) -> Any: # noqa: ANN401 return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record)) def close_partition(self, partition: Partition) -> None: @@ -260,17 +254,15 @@ def _add_slice_to_state(self, partition: Partition) -> None: raise RuntimeError( f"The state for stream {self._stream_name} should have at least one slice to delineate the sync start time, but no slices are present. This is unexpected. Please contact Support." ) - self.state["slices"].append( - { - self._connector_state_converter.START_KEY: self._extract_from_slice( - partition, self._slice_boundary_fields[self._START_BOUNDARY] - ), - self._connector_state_converter.END_KEY: self._extract_from_slice( - partition, self._slice_boundary_fields[self._END_BOUNDARY] - ), - self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value, - } - ) + self.state["slices"].append({ + self._connector_state_converter.START_KEY: self._extract_from_slice( + partition, self._slice_boundary_fields[self._START_BOUNDARY] + ), + self._connector_state_converter.END_KEY: self._extract_from_slice( + partition, self._slice_boundary_fields[self._END_BOUNDARY] + ), + self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value, + }) elif most_recent_cursor_value: if self._has_closed_at_least_one_slice: # If we track state value using records cursor field, we can only do that if there is one partition. This is because we save @@ -288,13 +280,11 @@ def _add_slice_to_state(self, partition: Partition) -> None: "expected. Please contact the Airbyte team." ) - self.state["slices"].append( - { - self._connector_state_converter.START_KEY: self.start, - self._connector_state_converter.END_KEY: most_recent_cursor_value, - self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value, - } - ) + self.state["slices"].append({ + self._connector_state_converter.START_KEY: self.start, + self._connector_state_converter.END_KEY: most_recent_cursor_value, + self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value, + }) def _emit_state_message(self) -> None: self._connector_state_manager.update_state_for_stream( @@ -314,9 +304,9 @@ def _merge_partitions(self) -> None: def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType: try: - _slice = partition.to_slice() + _slice = partition.to_slice() # noqa: RUF052 if not _slice: - raise KeyError(f"Could not find key `{key}` in empty slice") + raise KeyError(f"Could not find key `{key}` in empty slice") # noqa: TRY301 return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType except KeyError as exception: raise KeyError( @@ -348,7 +338,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: yield from self._split_per_slice_range( self._start, self.state["slices"][0][self._connector_state_converter.START_KEY], - False, + False, # noqa: FBT003 ) if len(self.state["slices"]) == 1: @@ -357,7 +347,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: self.state["slices"][0][self._connector_state_converter.END_KEY] ), self._end_provider(), - True, + True, # noqa: FBT003 ) elif len(self.state["slices"]) > 1: for i in range(len(self.state["slices"]) - 1): @@ -366,20 +356,20 @@ def stream_slices(self) -> Iterable[StreamSlice]: self.state["slices"][i][self._connector_state_converter.END_KEY] + self._cursor_granularity, self.state["slices"][i + 1][self._connector_state_converter.START_KEY], - False, + False, # noqa: FBT003 ) else: yield from self._split_per_slice_range( self.state["slices"][i][self._connector_state_converter.END_KEY], self.state["slices"][i + 1][self._connector_state_converter.START_KEY], - False, + False, # noqa: FBT003 ) yield from self._split_per_slice_range( self._calculate_lower_boundary_of_last_slice( self.state["slices"][-1][self._connector_state_converter.END_KEY] ), self._end_provider(), - True, + True, # noqa: FBT003 ) else: raise ValueError("Expected at least one slice") @@ -398,7 +388,10 @@ def _calculate_lower_boundary_of_last_slice( return lower_boundary def _split_per_slice_range( - self, lower: CursorValueType, upper: CursorValueType, upper_is_end: bool + self, + lower: CursorValueType, + upper: CursorValueType, + upper_is_end: bool, # noqa: FBT001 ) -> Iterable[StreamSlice]: if lower >= upper: return diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py index 7679a1eb6..79aa79404 100644 --- a/airbyte_cdk/sources/streams/concurrent/default_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py @@ -2,9 +2,10 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from functools import lru_cache +from collections.abc import Iterable, Mapping +from functools import cache from logging import Logger -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any from airbyte_cdk.models import AirbyteStream, SyncMode from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -18,17 +19,17 @@ class DefaultStream(AbstractStream): - def __init__( + def __init__( # noqa: PLR0913, PLR0917 self, partition_generator: PartitionGenerator, name: str, json_schema: Mapping[str, Any], availability_strategy: AbstractAvailabilityStrategy, - primary_key: List[str], - cursor_field: Optional[str], + primary_key: list[str], + cursor_field: str | None, logger: Logger, cursor: Cursor, - namespace: Optional[str] = None, + namespace: str | None = None, ) -> None: self._stream_partition_generator = partition_generator self._name = name @@ -48,17 +49,17 @@ def name(self) -> str: return self._name @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: return self._namespace def check_availability(self) -> StreamAvailability: return self._availability_strategy.check_availability(self._logger) @property - def cursor_field(self) -> Optional[str]: + def cursor_field(self) -> str | None: return self._cursor_field - @lru_cache(maxsize=None) + @cache # noqa: B019 def get_json_schema(self) -> Mapping[str, Any]: return self._json_schema diff --git a/airbyte_cdk/sources/streams/concurrent/exceptions.py b/airbyte_cdk/sources/streams/concurrent/exceptions.py index a0cf699a4..691ba077b 100644 --- a/airbyte_cdk/sources/streams/concurrent/exceptions.py +++ b/airbyte_cdk/sources/streams/concurrent/exceptions.py @@ -10,7 +10,7 @@ class ExceptionWithDisplayMessage(Exception): Exception that can be used to display a custom message to the user. """ - def __init__(self, display_message: str, **kwargs: Any): + def __init__(self, display_message: str, **kwargs: Any): # noqa: ANN204, ANN401 super().__init__(**kwargs) self.display_message = display_message diff --git a/airbyte_cdk/sources/streams/concurrent/helpers.py b/airbyte_cdk/sources/streams/concurrent/helpers.py index 5e2edf055..8ea0e3201 100644 --- a/airbyte_cdk/sources/streams/concurrent/helpers.py +++ b/airbyte_cdk/sources/streams/concurrent/helpers.py @@ -1,18 +1,17 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -from typing import List, Optional, Union from airbyte_cdk.sources.streams import Stream def get_primary_key_from_stream( - stream_primary_key: Optional[Union[str, List[str], List[List[str]]]], -) -> List[str]: + stream_primary_key: str | list[str] | list[list[str]] | None, +) -> list[str]: if stream_primary_key is None: return [] - elif isinstance(stream_primary_key, str): + if isinstance(stream_primary_key, str): return [stream_primary_key] - elif isinstance(stream_primary_key, list): + if isinstance(stream_primary_key, list): are_all_elements_str = all(isinstance(k, str) for k in stream_primary_key) are_all_elements_list_of_size_one = all( isinstance(k, list) and len(k) == 1 for k in stream_primary_key @@ -20,23 +19,19 @@ def get_primary_key_from_stream( if are_all_elements_str: return stream_primary_key # type: ignore # We verified all items in the list are strings - elif are_all_elements_list_of_size_one: - return list(map(lambda x: x[0], stream_primary_key)) - else: - raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") - else: - raise ValueError(f"Invalid type for primary key: {stream_primary_key}") + if are_all_elements_list_of_size_one: + return list(map(lambda x: x[0], stream_primary_key)) # noqa: C417, FURB118 + raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}") + raise ValueError(f"Invalid type for primary key: {stream_primary_key}") -def get_cursor_field_from_stream(stream: Stream) -> Optional[str]: +def get_cursor_field_from_stream(stream: Stream) -> str | None: if isinstance(stream.cursor_field, list): if len(stream.cursor_field) > 1: raise ValueError( f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}" ) - elif len(stream.cursor_field) == 0: + if len(stream.cursor_field) == 0: return None - else: - return stream.cursor_field[0] - else: - return stream.cursor_field + return stream.cursor_field[0] + return stream.cursor_field diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py index 8391a5a2b..b3f1cda50 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py @@ -3,7 +3,8 @@ # from abc import ABC, abstractmethod -from typing import Any, Iterable, Mapping, Optional +from collections.abc import Iterable, Mapping +from typing import Any from airbyte_cdk.sources.types import Record @@ -22,7 +23,7 @@ def read(self) -> Iterable[Record]: pass @abstractmethod - def to_slice(self) -> Optional[Mapping[str, Any]]: + def to_slice(self) -> Mapping[str, Any] | None: """ Converts the partition to a slice that can be serialized and deserialized. diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py index eff978564..f4bd77bd9 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py @@ -3,7 +3,7 @@ # from abc import ABC, abstractmethod -from typing import Iterable +from collections.abc import Iterable from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py b/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py index 98ac04ed7..e10fada49 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from abc import ABC, abstractmethod -from typing import Iterable +from collections.abc import Iterable from airbyte_cdk.sources.types import StreamSlice diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py index 77644c6b9..c645f3b46 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -1,8 +1,8 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Union +from typing import Union from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( PartitionGenerationCompletedSentinel, @@ -11,20 +11,20 @@ from airbyte_cdk.sources.types import Record -class PartitionCompleteSentinel: +class PartitionCompleteSentinel: # noqa: PLW1641 """ A sentinel object indicating all records for a partition were produced. Includes a pointer to the partition that was processed. """ - def __init__(self, partition: Partition, is_successful: bool = True): + def __init__(self, partition: Partition, is_successful: bool = True): # noqa: ANN204, FBT001, FBT002 """ :param partition: The partition that was processed """ self.partition = partition self.is_successful = is_successful - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, PartitionCompleteSentinel): return self.partition == other.partition return False @@ -33,6 +33,6 @@ def __eq__(self, other: Any) -> bool: """ Typedef representing the items that can be added to the ThreadBasedConcurrentStream """ -QueueItem = Union[ +QueueItem = Union[ # noqa: UP007 Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception ] diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py index 987915317..78787964b 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py @@ -3,8 +3,10 @@ # from abc import ABC, abstractmethod +from collections.abc import MutableMapping from enum import Enum -from typing import TYPE_CHECKING, Any, List, MutableMapping, Optional, Tuple +from typing import TYPE_CHECKING, Any + if TYPE_CHECKING: from airbyte_cdk.sources.streams.concurrent.cursor import CursorField @@ -20,14 +22,14 @@ class AbstractStreamStateConverter(ABC): MOST_RECENT_RECORD_KEY = "most_recent_cursor_value" @abstractmethod - def _from_state_message(self, value: Any) -> Any: + def _from_state_message(self, value: Any) -> Any: # noqa: ANN401 pass @abstractmethod - def _to_state_message(self, value: Any) -> Any: + def _to_state_message(self, value: Any) -> Any: # noqa: ANN401 pass - def __init__(self, is_sequential_state: bool = True): + def __init__(self, is_sequential_state: bool = True): # noqa: ANN204, FBT001, FBT002 self._is_sequential_state = is_sequential_state def convert_to_state_message( @@ -43,14 +45,13 @@ def convert_to_state_message( legacy_state = stream_state.get("legacy", {}) latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", [])) if latest_complete_time is not None: - legacy_state.update( - {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)} - ) + legacy_state.update({ + cursor_field.cursor_field_key: self._to_state_message(latest_complete_time) + }) return legacy_state or {} - else: - return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range) + return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range) - def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Any: + def _get_latest_complete_time(self, slices: list[MutableMapping[str, Any]]) -> Any: # noqa: ANN401 """ Get the latest time before which all records have been processed. """ @@ -102,8 +103,8 @@ def convert_from_sequential_state( self, cursor_field: "CursorField", # to deprecate as it is only needed for sequential state stream_state: MutableMapping[str, Any], - start: Optional[Any], - ) -> Tuple[Any, MutableMapping[str, Any]]: + start: Any | None, # noqa: ANN401 + ) -> tuple[Any, MutableMapping[str, Any]]: """ Convert the state message to the format required by the ConcurrentCursor. @@ -118,22 +119,22 @@ def convert_from_sequential_state( ... @abstractmethod - def increment(self, value: Any) -> Any: + def increment(self, value: Any) -> Any: # noqa: ANN401 """ Increment a timestamp by a single unit. """ ... @abstractmethod - def output_format(self, value: Any) -> Any: + def output_format(self, value: Any) -> Any: # noqa: ANN401 """ Convert the cursor value type to a JSON valid type. """ ... def merge_intervals( - self, intervals: List[MutableMapping[str, Any]] - ) -> List[MutableMapping[str, Any]]: + self, intervals: list[MutableMapping[str, Any]] + ) -> list[MutableMapping[str, Any]]: """ Compute and return a list of merged intervals. @@ -144,7 +145,8 @@ def merge_intervals( return [] sorted_intervals = sorted( - intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY]) + intervals, + key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY]), # noqa: FURB118 ) merged_intervals = [sorted_intervals[0]] @@ -170,7 +172,7 @@ def merge_intervals( return merged_intervals @abstractmethod - def parse_value(self, value: Any) -> Any: + def parse_value(self, value: Any) -> Any: # noqa: ANN401 """ Parse the value of the cursor field into a comparable value. """ @@ -178,4 +180,4 @@ def parse_value(self, value: Any) -> Any: @property @abstractmethod - def zero_value(self) -> Any: ... + def zero_value(self) -> Any: ... # noqa: ANN401 diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py index 3f53a9234..d3a47d89f 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py @@ -3,13 +3,14 @@ # from abc import abstractmethod +from collections.abc import Callable, MutableMapping from datetime import datetime, timedelta, timezone -from typing import Any, Callable, List, MutableMapping, Optional, Tuple +from typing import Any import pendulum from pendulum.datetime import DateTime -# FIXME We would eventually like the Concurrent package do be agnostic of the declarative package. However, this is a breaking change and +# FIXME We would eventually like the Concurrent package do be agnostic of the declarative package. However, this is a breaking change and # noqa: FIX001, TD001, TD004 # the goal in the short term is only to fix the issue we are seeing for source-declarative-manifest. from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser from airbyte_cdk.sources.streams.concurrent.cursor import CursorField @@ -20,15 +21,15 @@ class DateTimeStreamStateConverter(AbstractStreamStateConverter): - def _from_state_message(self, value: Any) -> Any: + def _from_state_message(self, value: Any) -> Any: # noqa: ANN401 return self.parse_timestamp(value) - def _to_state_message(self, value: Any) -> Any: + def _to_state_message(self, value: Any) -> Any: # noqa: ANN401 return self.output_format(value) @property @abstractmethod - def _zero_value(self) -> Any: ... + def _zero_value(self) -> Any: ... # noqa: ANN401 @property def zero_value(self) -> datetime: @@ -42,26 +43,26 @@ def get_end_provider(cls) -> Callable[[], datetime]: def increment(self, timestamp: datetime) -> datetime: ... @abstractmethod - def parse_timestamp(self, timestamp: Any) -> datetime: ... + def parse_timestamp(self, timestamp: Any) -> datetime: ... # noqa: ANN401 @abstractmethod - def output_format(self, timestamp: datetime) -> Any: ... + def output_format(self, timestamp: datetime) -> Any: ... # noqa: ANN401 - def parse_value(self, value: Any) -> Any: + def parse_value(self, value: Any) -> Any: # noqa: ANN401 """ Parse the value of the cursor field into a comparable value. """ return self.parse_timestamp(value) - def _compare_intervals(self, end_time: Any, start_time: Any) -> bool: + def _compare_intervals(self, end_time: Any, start_time: Any) -> bool: # noqa: ANN401 return bool(self.increment(end_time) >= start_time) def convert_from_sequential_state( self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], - start: Optional[datetime], - ) -> Tuple[datetime, MutableMapping[str, Any]]: + start: datetime | None, + ) -> tuple[datetime, MutableMapping[str, Any]]: """ Convert the state message to the format required by the ConcurrentCursor. @@ -99,7 +100,7 @@ def _get_sync_start( self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], - start: Optional[datetime], + start: datetime | None, ) -> datetime: sync_start = start if start is not None else self.zero_value prev_sync_low_water_mark = ( @@ -109,8 +110,7 @@ def _get_sync_start( ) if prev_sync_low_water_mark and prev_sync_low_water_mark >= sync_start: return prev_sync_low_water_mark - else: - return sync_start + return sync_start class EpochValueConcurrentStreamStateConverter(DateTimeStreamStateConverter): @@ -138,7 +138,7 @@ def output_format(self, timestamp: datetime) -> int: def parse_timestamp(self, timestamp: int) -> datetime: dt_object = pendulum.from_timestamp(timestamp) if not isinstance(dt_object, DateTime): - raise ValueError( + raise ValueError( # noqa: TRY004 f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})" ) return dt_object @@ -160,8 +160,11 @@ class IsoMillisConcurrentStreamStateConverter(DateTimeStreamStateConverter): _zero_value = "0001-01-01T00:00:00.000Z" - def __init__( - self, is_sequential_state: bool = True, cursor_granularity: Optional[timedelta] = None + def __init__( # noqa: ANN204 + self, + *, + is_sequential_state: bool = True, + cursor_granularity: timedelta | None = None, ): super().__init__(is_sequential_state=is_sequential_state) self._cursor_granularity = cursor_granularity or timedelta(milliseconds=1) @@ -169,13 +172,13 @@ def __init__( def increment(self, timestamp: datetime) -> datetime: return timestamp + self._cursor_granularity - def output_format(self, timestamp: datetime) -> Any: + def output_format(self, timestamp: datetime) -> Any: # noqa: ANN401 return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" def parse_timestamp(self, timestamp: str) -> datetime: dt_object = pendulum.parse(timestamp) if not isinstance(dt_object, DateTime): - raise ValueError( + raise ValueError( # noqa: TRY004 f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})" ) return dt_object @@ -187,18 +190,18 @@ class CustomFormatConcurrentStreamStateConverter(IsoMillisConcurrentStreamStateC incoming state in any valid datetime format via Pendulum. """ - def __init__( + def __init__( # noqa: ANN204 self, datetime_format: str, - input_datetime_formats: Optional[List[str]] = None, - is_sequential_state: bool = True, - cursor_granularity: Optional[timedelta] = None, + input_datetime_formats: list[str] | None = None, + is_sequential_state: bool = True, # noqa: FBT001, FBT002 + cursor_granularity: timedelta | None = None, ): super().__init__( is_sequential_state=is_sequential_state, cursor_granularity=cursor_granularity ) self._datetime_format = datetime_format - self._input_datetime_formats = input_datetime_formats if input_datetime_formats else [] + self._input_datetime_formats = input_datetime_formats if input_datetime_formats else [] # noqa: FURB110 self._input_datetime_formats += [self._datetime_format] self._parser = DatetimeParser() diff --git a/airbyte_cdk/sources/streams/core.py b/airbyte_cdk/sources/streams/core.py index a9aa8550a..dc814a2a7 100644 --- a/airbyte_cdk/sources/streams/core.py +++ b/airbyte_cdk/sources/streams/core.py @@ -6,13 +6,13 @@ import itertools import logging from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator, Mapping, MutableMapping from dataclasses import dataclass -from functools import cached_property, lru_cache -from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Union +from functools import cache, cached_property +from typing import Any, Union from typing_extensions import deprecated -import airbyte_cdk.sources.utils.casing as casing from airbyte_cdk.models import ( AirbyteMessage, AirbyteStream, @@ -32,16 +32,18 @@ ResumableFullRefreshCheckpointReader, ) from airbyte_cdk.sources.types import StreamSlice +from airbyte_cdk.sources.utils import casing # list of all possible HTTP methods which can be used for sending of request bodies from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, ResourceSchemaLoader from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer + # A stream's read method can return one of the following types: # Mapping[str, Any]: The content of an AirbyteRecordMessage # AirbyteMessage: An AirbyteMessage. Could be of any type -StreamData = Union[Mapping[str, Any], AirbyteMessage] +StreamData = Union[Mapping[str, Any], AirbyteMessage] # noqa: UP007 JsonSchema = Mapping[str, Any] @@ -53,8 +55,7 @@ def package_name_from_class(cls: object) -> str: module = inspect.getmodule(cls) if module is not None: return module.__name__.split(".")[0] - else: - raise ValueError(f"Could not find package name for class {cls}") + raise ValueError(f"Could not find package name for class {cls}") class CheckpointMixin(ABC): @@ -115,12 +116,12 @@ class StreamClassification: has_multiple_slices: bool -class Stream(ABC): +class Stream(ABC): # noqa: PLR0904 """ Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol. """ - _configured_json_schema: Optional[Dict[str, Any]] = None + _configured_json_schema: dict[str, Any] | None = None _exit_on_rate_limit: bool = False # Use self.logger in subclasses to log any messages @@ -131,7 +132,7 @@ def logger(self) -> logging.Logger: # TypeTransformer object to perform output data transformation transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform) - cursor: Optional[Cursor] = None + cursor: Cursor | None = None has_multiple_slices = False @@ -142,7 +143,7 @@ def name(self) -> str: """ return casing.camel_to_snake(self.__class__.__name__) - def get_error_display_message(self, exception: BaseException) -> Optional[str]: + def get_error_display_message(self, exception: BaseException) -> str | None: # noqa: ARG002 """ Retrieves the user-friendly display message that corresponds to an exception. This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. @@ -160,7 +161,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o logger: logging.Logger, slice_logger: SliceLogger, stream_state: MutableMapping[str, Any], - state_manager, + state_manager, # noqa: ANN001 internal_config: InternalConfig, ) -> Iterable[StreamData]: sync_mode = configured_stream.sync_mode @@ -171,7 +172,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o # opposed to the incoming stream_state value. Because some connectors like ones using the file-based CDK modify # state before setting the value on the Stream attribute, the most up-to-date state is derived from Stream.state # instead of the stream_state parameter. This does not apply to legacy connectors using get_updated_state(). - try: + try: # noqa: SIM105 stream_state = self.state # type: ignore # we know the field might not exist... except AttributeError: pass @@ -188,7 +189,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o if slice_logger.should_log_slice_message(logger): yield slice_logger.create_slice_log_message(next_slice) records = self.read_records( - sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior + sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior stream_slice=next_slice, stream_state=stream_state, cursor_field=cursor_field or None, @@ -252,7 +253,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o airbyte_state_message = self._checkpoint_state(checkpoint, state_manager=state_manager) yield airbyte_state_message - def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterable[StreamData]: + def read_only_records(self, state: Mapping[str, Any] | None = None) -> Iterable[StreamData]: """ Helper method that performs a read on a stream with an optional state and emits records. If the parent stream supports incremental, this operation does not update the stream's internal state (if it uses the modern state setter/getter) @@ -284,15 +285,15 @@ def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterab def read_records( self, sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + cursor_field: list[str] | None = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: """ This method should be overridden by subclasses to read records based on the inputs """ - @lru_cache(maxsize=None) + @cache # noqa: B019 def get_json_schema(self) -> Mapping[str, Any]: """ :return: A dict of the JSON schema representing this stream. @@ -300,7 +301,7 @@ def get_json_schema(self) -> Mapping[str, Any]: The default implementation of this method looks for a JSONSchema file with the same name as this stream's "name" property. Override as needed. """ - # TODO show an example of using pydantic to define the JSON schema, or reading an OpenAPI spec + # TODO show an example of using pydantic to define the JSON schema, or reading an OpenAPI spec # noqa: TD004 return ResourceSchemaLoader(package_name_from_class(self.__class__)).get_schema(self.name) def as_airbyte_stream(self) -> AirbyteStream: @@ -348,19 +349,18 @@ def is_resumable(self) -> bool: # to structure stream state in a very specific way. We also can't check for issubclass(HttpSubStream) because # not all substreams implement the interface and it would be a circular dependency so we use parent as a surrogate return False - elif hasattr(type(self), "state") and getattr(type(self), "state").fset is not None: + if hasattr(type(self), "state") and type(self).state.fset is not None: # Modern case where a stream manages state using getter/setter return True - else: - # Legacy case where the CDK manages state via the get_updated_state() method. This is determined by checking if - # the stream's get_updated_state() differs from the Stream class and therefore has been overridden - return type(self).get_updated_state != Stream.get_updated_state + # Legacy case where the CDK manages state via the get_updated_state() method. This is determined by checking if + # the stream's get_updated_state() differs from the Stream class and therefore has been overridden + return type(self).get_updated_state != Stream.get_updated_state - def _wrapped_cursor_field(self) -> List[str]: + def _wrapped_cursor_field(self) -> list[str]: return [self.cursor_field] if isinstance(self.cursor_field, str) else self.cursor_field @property - def cursor_field(self) -> Union[str, List[str]]: + def cursor_field(self) -> str | list[str]: """ Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. @@ -368,7 +368,7 @@ def cursor_field(self) -> Union[str, List[str]]: return [] @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: """ Override to return the namespace of this stream, e.g. the Postgres schema which this stream will emit records for. :return: A string containing the name of the namespace. @@ -394,7 +394,7 @@ def exit_on_rate_limit(self, value: bool) -> None: @property @abstractmethod - def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + def primary_key(self) -> str | list[str] | list[list[str]] | None: """ :return: string if single primary key, list of strings if composite primary key, list of list of strings if composite primary key consisting of nested fields. If the stream has no primary keys, return None. @@ -403,10 +403,10 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: def stream_slices( self, *, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_state: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Iterable[Mapping[str, Any] | None]: """ Override to define the slices for this stream. See the stream slicing section of the docs for more information. @@ -418,7 +418,7 @@ def stream_slices( yield StreamSlice(partition={}, cursor_slice={}) @property - def state_checkpoint_interval(self) -> Optional[int]: + def state_checkpoint_interval(self) -> int | None: """ Decides how often to checkpoint state (i.e: emit a STATE message). E.g: if this returns a value of 100, then state is persisted after reading 100 records, then 200, 300, etc.. A good default value is 1000 although your mileage may vary depending on the underlying data source. @@ -438,7 +438,9 @@ def state_checkpoint_interval(self) -> Optional[int]: # "Please use explicit state property instead, see `IncrementalMixin` docs." # ) def get_updated_state( - self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any] + self, + current_stream_state: MutableMapping[str, Any], # noqa: ARG002 + latest_record: Mapping[str, Any], # noqa: ARG002 ) -> MutableMapping[str, Any]: """DEPRECATED. Please use explicit state property instead, see `IncrementalMixin` docs. @@ -455,7 +457,7 @@ def get_updated_state( """ return {} - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: """ A Cursor is an interface that a stream can implement to manage how its internal state is read and updated while reading records. Historically, Python connectors had no concept of a cursor to manage state. Python streams need @@ -465,14 +467,14 @@ def get_cursor(self) -> Optional[Cursor]: def _get_checkpoint_reader( self, - logger: logging.Logger, - cursor_field: Optional[List[str]], + logger: logging.Logger, # noqa: ARG002 + cursor_field: list[str] | None, sync_mode: SyncMode, stream_state: MutableMapping[str, Any], ) -> CheckpointReader: mappings_or_slices = self.stream_slices( cursor_field=cursor_field, - sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior + sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior stream_state=stream_state, ) @@ -505,35 +507,33 @@ def _get_checkpoint_reader( return LegacyCursorBasedCheckpointReader( stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=True ) - elif cursor: + if cursor: return CursorBasedCheckpointReader( stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH, ) - elif checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH: + if checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH: # Resumable full refresh readers rely on the stream state dynamically being updated during pagination and does # not iterate over a static set of slices. return ResumableFullRefreshCheckpointReader(stream_state=stream_state) - elif checkpoint_mode == CheckpointMode.INCREMENTAL: + if checkpoint_mode == CheckpointMode.INCREMENTAL: return IncrementalCheckpointReader( stream_slices=slices_iterable_copy, stream_state=stream_state ) - else: - return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy) + return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy) @property def _checkpoint_mode(self) -> CheckpointMode: if self.is_resumable and len(self._wrapped_cursor_field()) > 0: return CheckpointMode.INCREMENTAL - elif self.is_resumable: + if self.is_resumable: return CheckpointMode.RESUMABLE_FULL_REFRESH - else: - return CheckpointMode.FULL_REFRESH + return CheckpointMode.FULL_REFRESH @staticmethod def _classify_stream( - mappings_or_slices: Iterator[Optional[Union[Mapping[str, Any], StreamSlice]]], + mappings_or_slices: Iterator[Mapping[str, Any] | StreamSlice | None], ) -> StreamClassification: """ This is a bit of a crazy solution, but also the only way we can detect certain attributes about the stream since Python @@ -601,8 +601,8 @@ def log_stream_sync_configuration(self) -> None: @staticmethod def _wrapped_primary_key( - keys: Optional[Union[str, List[str], List[List[str]]]], - ) -> Optional[List[List[str]]]: + keys: str | list[str] | list[list[str]] | None, + ) -> list[list[str]] | None: """ :return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object. """ @@ -611,7 +611,7 @@ def _wrapped_primary_key( if isinstance(keys, str): return [[keys]] - elif isinstance(keys, list): + if isinstance(keys, list): wrapped_keys = [] for component in keys: if isinstance(component, str): @@ -619,13 +619,12 @@ def _wrapped_primary_key( elif isinstance(component, list): wrapped_keys.append(component) else: - raise ValueError(f"Element must be either list or str. Got: {type(component)}") + raise ValueError(f"Element must be either list or str. Got: {type(component)}") # noqa: TRY004 return wrapped_keys - else: - raise ValueError(f"Element must be either list or str. Got: {type(keys)}") + raise ValueError(f"Element must be either list or str. Got: {type(keys)}") def _observe_state( - self, checkpoint_reader: CheckpointReader, stream_state: Optional[Mapping[str, Any]] = None + self, checkpoint_reader: CheckpointReader, stream_state: Mapping[str, Any] | None = None ) -> None: """ Convenience method that attempts to read the Stream's state using the recommended way of connector's managing their @@ -652,15 +651,15 @@ def _observe_state( def _checkpoint_state( # type: ignore # ignoring typing for ConnectorStateManager because of circular dependencies self, stream_state: Mapping[str, Any], - state_manager, + state_manager, # noqa: ANN001 ) -> AirbyteMessage: - # todo: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want + # TODO: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want # to reduce changes right now and this would span concurrent as well state_manager.update_state_for_stream(self.name, self.namespace, stream_state) return state_manager.create_state_message(self.name, self.namespace) # type: ignore [no-any-return] @property - def configured_json_schema(self) -> Optional[Dict[str, Any]]: + def configured_json_schema(self) -> dict[str, Any] | None: """ This property is set from the read method. @@ -669,12 +668,12 @@ def configured_json_schema(self) -> Optional[Dict[str, Any]]: return self._configured_json_schema @configured_json_schema.setter - def configured_json_schema(self, json_schema: Dict[str, Any]) -> None: + def configured_json_schema(self, json_schema: dict[str, Any]) -> None: self._configured_json_schema = self._filter_schema_invalid_properties(json_schema) def _filter_schema_invalid_properties( - self, configured_catalog_json_schema: Dict[str, Any] - ) -> Dict[str, Any]: + self, configured_catalog_json_schema: dict[str, Any] + ) -> dict[str, Any]: """ Filters the properties in json_schema that are not present in the stream schema. Configured Schemas can have very old fields, so we need to housekeeping ourselves. diff --git a/airbyte_cdk/sources/streams/http/__init__.py b/airbyte_cdk/sources/streams/http/__init__.py index 74804614c..70342f5d0 100644 --- a/airbyte_cdk/sources/streams/http/__init__.py +++ b/airbyte_cdk/sources/streams/http/__init__.py @@ -7,4 +7,5 @@ from .http import HttpStream, HttpSubStream from .http_client import HttpClient + __all__ = ["HttpClient", "HttpStream", "HttpSubStream", "UserDefinedBackoffException"] diff --git a/airbyte_cdk/sources/streams/http/availability_strategy.py b/airbyte_cdk/sources/streams/http/availability_strategy.py index 494fcf151..8adcc89e1 100644 --- a/airbyte_cdk/sources/streams/http/availability_strategy.py +++ b/airbyte_cdk/sources/streams/http/availability_strategy.py @@ -4,20 +4,24 @@ import logging import typing -from typing import Optional, Tuple +from typing import Optional from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy from airbyte_cdk.utils.traced_exception import AirbyteTracedException + if typing.TYPE_CHECKING: from airbyte_cdk.sources import Source class HttpAvailabilityStrategy(AvailabilityStrategy): def check_availability( - self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None - ) -> Tuple[bool, Optional[str]]: + self, + stream: Stream, + logger: logging.Logger, + source: Optional["Source"] = None, # noqa: ARG002 + ) -> tuple[bool, str | None]: """ Check stream availability by attempting to read the first record of the stream. @@ -30,7 +34,7 @@ def check_availability( for some reason and the str should describe what went wrong and how to resolve the unavailability, if possible. """ - reason: Optional[str] + reason: str | None try: # Some streams need a stream slice to read records (e.g. if they have a SubstreamPartitionRouter) # Streams that don't need a stream slice will return `None` as their first stream slice. @@ -46,7 +50,7 @@ def check_availability( try: self.get_first_record_for_slice(stream, stream_slice) - return True, None + return True, None # noqa: TRY300 except StopIteration: logger.info(f"Successfully connected to stream {stream.name}, but got 0 records.") return True, None diff --git a/airbyte_cdk/sources/streams/http/error_handlers/__init__.py b/airbyte_cdk/sources/streams/http/error_handlers/__init__.py index 1f97d5cc7..9e65db1d6 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/__init__.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/__init__.py @@ -10,6 +10,7 @@ from .json_error_message_parser import JsonErrorMessageParser from .response_models import ErrorResolution, ResponseAction + __all__ = [ "BackoffStrategy", "DefaultBackoffStrategy", diff --git a/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py b/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py index 6ed821791..c30992627 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py @@ -3,7 +3,6 @@ # from abc import ABC, abstractmethod -from typing import Optional, Union import requests @@ -12,9 +11,9 @@ class BackoffStrategy(ABC): @abstractmethod def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], + response_or_exception: requests.Response | requests.RequestException | None, attempt_count: int, - ) -> Optional[float]: + ) -> float | None: """ Override this method to dynamically determine backoff time e.g: by reading the X-Retry-After header. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py b/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py index 2c3e10ad7..d8ae1f3ee 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py @@ -1,8 +1,6 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Optional, Union - import requests from .backoff_strategy import BackoffStrategy @@ -11,7 +9,7 @@ class DefaultBackoffStrategy(BackoffStrategy): def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], - attempt_count: int, - ) -> Optional[float]: + response_or_exception: requests.Response | requests.RequestException | None, # noqa: ARG002 + attempt_count: int, # noqa: ARG002 + ) -> float | None: return None diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py index 74840e2d2..7720f3eff 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py @@ -2,7 +2,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. # -from typing import Mapping, Type, Union +from collections.abc import Mapping from requests.exceptions import InvalidSchema, InvalidURL, RequestException @@ -12,7 +12,8 @@ ResponseAction, ) -DEFAULT_ERROR_MAPPING: Mapping[Union[int, str, Type[Exception]], ErrorResolution] = { + +DEFAULT_ERROR_MAPPING: Mapping[int | str | type[Exception], ErrorResolution] = { InvalidSchema: ErrorResolution( response_action=ResponseAction.FAIL, failure_type=FailureType.config_error, diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py index b231e72e0..7af046202 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py @@ -1,7 +1,6 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. from abc import ABC, abstractmethod -from typing import Optional, Union import requests @@ -15,7 +14,7 @@ class ErrorHandler(ABC): @property @abstractmethod - def max_retries(self) -> Optional[int]: + def max_retries(self) -> int | None: """ The maximum number of retries to attempt before giving up. """ @@ -23,16 +22,14 @@ def max_retries(self) -> Optional[int]: @property @abstractmethod - def max_time(self) -> Optional[int]: + def max_time(self) -> int | None: """ The maximum amount of time in seconds to retry before giving up. """ pass @abstractmethod - def interpret_response( - self, response: Optional[Union[requests.Response, Exception]] - ) -> ErrorResolution: + def interpret_response(self, response: requests.Response | Exception | None) -> ErrorResolution: """ Interpret the response or exception and return the corresponding response action, failure type, and error message. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py b/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py index 966fe93a1..d5d413f3d 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py @@ -3,14 +3,13 @@ # from abc import ABC, abstractmethod -from typing import Optional import requests class ErrorMessageParser(ABC): @abstractmethod - def parse_response_error_message(self, response: requests.Response) -> Optional[str]: + def parse_response_error_message(self, response: requests.Response) -> str | None: """ Parse error message from response. :param response: response received for the request diff --git a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py index 18daca3de..9bd47df66 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py @@ -3,8 +3,8 @@ # import logging +from collections.abc import Mapping from datetime import timedelta -from typing import Mapping, Optional, Union import requests @@ -23,7 +23,7 @@ class HttpStatusErrorHandler(ErrorHandler): def __init__( self, logger: logging.Logger, - error_mapping: Optional[Mapping[Union[int, str, type[Exception]], ErrorResolution]] = None, + error_mapping: Mapping[int | str | type[Exception], ErrorResolution] | None = None, max_retries: int = 5, max_time: timedelta = timedelta(seconds=600), ) -> None: @@ -38,15 +38,15 @@ def __init__( self._max_time = int(max_time.total_seconds()) @property - def max_retries(self) -> Optional[int]: + def max_retries(self) -> int | None: return self._max_retries @property - def max_time(self) -> Optional[int]: + def max_time(self) -> int | None: return self._max_time - def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] = None + def interpret_response( # noqa: PLR0911 + self, response_or_exception: requests.Response | Exception | None = None ) -> ErrorResolution: """ Interpret the response and return the corresponding response action, failure type, and error message. @@ -56,23 +56,20 @@ def interpret_response( """ if isinstance(response_or_exception, Exception): - mapped_error: Optional[ErrorResolution] = self._error_mapping.get( + mapped_error: ErrorResolution | None = self._error_mapping.get( response_or_exception.__class__ ) if mapped_error is not None: return mapped_error - else: - self._logger.error( - f"Unexpected exception in error handler: {response_or_exception}" - ) - return ErrorResolution( - response_action=ResponseAction.RETRY, - failure_type=FailureType.system_error, - error_message=f"Unexpected exception in error handler: {response_or_exception}", - ) + self._logger.error(f"Unexpected exception in error handler: {response_or_exception}") + return ErrorResolution( + response_action=ResponseAction.RETRY, + failure_type=FailureType.system_error, + error_message=f"Unexpected exception in error handler: {response_or_exception}", + ) - elif isinstance(response_or_exception, requests.Response): + if isinstance(response_or_exception, requests.Response): if response_or_exception.status_code is None: self._logger.error("Response does not include an HTTP status code.") return ErrorResolution( @@ -94,17 +91,15 @@ def interpret_response( if mapped_error is not None: return mapped_error - else: - self._logger.warning(f"Unexpected HTTP Status Code in error handler: '{error_key}'") - return ErrorResolution( - response_action=ResponseAction.RETRY, - failure_type=FailureType.system_error, - error_message=f"Unexpected HTTP Status Code in error handler: {error_key}", - ) - else: - self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + self._logger.warning(f"Unexpected HTTP Status Code in error handler: '{error_key}'") return ErrorResolution( - response_action=ResponseAction.FAIL, + response_action=ResponseAction.RETRY, failure_type=FailureType.system_error, - error_message=f"Received unexpected response type: {type(response_or_exception)}", + error_message=f"Unexpected HTTP Status Code in error handler: {error_key}", ) + self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + return ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message=f"Received unexpected response type: {type(response_or_exception)}", + ) diff --git a/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py b/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py index 7c58280c7..d41e3e70d 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Optional import requests @@ -11,13 +10,13 @@ class JsonErrorMessageParser(ErrorMessageParser): - def _try_get_error(self, value: Optional[JsonType]) -> Optional[str]: + def _try_get_error(self, value: JsonType | None) -> str | None: if isinstance(value, str): return value - elif isinstance(value, list): + if isinstance(value, list): errors_in_value = [self._try_get_error(v) for v in value] return ", ".join(v for v in errors_in_value if v is not None) - elif isinstance(value, dict): + if isinstance(value, dict): new_value = ( value.get("message") or value.get("messages") @@ -35,7 +34,7 @@ def _try_get_error(self, value: Optional[JsonType]) -> Optional[str]: return self._try_get_error(new_value) return None - def parse_response_error_message(self, response: requests.Response) -> Optional[str]: + def parse_response_error_message(self, response: requests.Response) -> str | None: """ Parses the raw response object from a failed request into a user-friendly error message. diff --git a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py index e882b89bd..8a06c230a 100644 --- a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py +++ b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional, Union import requests from requests import HTTPError @@ -21,13 +20,13 @@ class ResponseAction(Enum): @dataclass class ErrorResolution: - response_action: Optional[ResponseAction] = None - failure_type: Optional[FailureType] = None - error_message: Optional[str] = None + response_action: ResponseAction | None = None + failure_type: FailureType | None = None + error_message: str | None = None def _format_exception_error_message(exception: Exception) -> str: - return f"{type(exception).__name__}: {str(exception)}" + return f"{type(exception).__name__}: {exception!s}" def _format_response_error_message(response: requests.Response) -> str: @@ -35,7 +34,7 @@ def _format_response_error_message(response: requests.Response) -> str: response.raise_for_status() except HTTPError as exception: return filter_secrets( - f"Response was not ok: `{str(exception)}`. Response content is: {response.text}" + f"Response was not ok: `{exception!s}`. Response content is: {response.text}" ) # We purposefully do not add the response.content because the response is "ok" so there might be sensitive information in the payload. # Feel free the @@ -43,7 +42,7 @@ def _format_response_error_message(response: requests.Response) -> str: def create_fallback_error_resolution( - response_or_exception: Optional[Union[requests.Response, Exception]], + response_or_exception: requests.Response | Exception | None, ) -> ErrorResolution: if response_or_exception is None: # We do not expect this case to happen but if it does, it would be good to understand the cause and improve the error message diff --git a/airbyte_cdk/sources/streams/http/exceptions.py b/airbyte_cdk/sources/streams/http/exceptions.py index ee4687626..a3cf57689 100644 --- a/airbyte_cdk/sources/streams/http/exceptions.py +++ b/airbyte_cdk/sources/streams/http/exceptions.py @@ -3,16 +3,14 @@ # -from typing import Optional, Union - import requests class BaseBackoffException(requests.exceptions.HTTPError): - def __init__( + def __init__( # noqa: ANN204 self, request: requests.PreparedRequest, - response: Optional[Union[requests.Response, Exception]], + response: requests.Response | Exception | None, error_message: str = "", ): if isinstance(response, requests.Response): @@ -37,11 +35,11 @@ class UserDefinedBackoffException(BaseBackoffException): An exception that exposes how long it attempted to backoff """ - def __init__( + def __init__( # noqa: ANN204 self, - backoff: Union[int, float], + backoff: int | float, # noqa: PYI041 request: requests.PreparedRequest, - response: Optional[Union[requests.Response, Exception]], + response: requests.Response | Exception | None, error_message: str = "", ): """ diff --git a/airbyte_cdk/sources/streams/http/http.py b/airbyte_cdk/sources/streams/http/http.py index 40eab27a3..a3522e26f 100644 --- a/airbyte_cdk/sources/streams/http/http.py +++ b/airbyte_cdk/sources/streams/http/http.py @@ -1,11 +1,12 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import logging from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Mapping, MutableMapping from datetime import timedelta -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any from urllib.parse import urljoin import requests @@ -37,22 +38,23 @@ from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.sources.utils.types import JsonType + # list of all possible HTTP methods which can be used for sending of request bodies BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") -class HttpStream(Stream, CheckpointMixin, ABC): +class HttpStream(Stream, CheckpointMixin, ABC): # noqa: PLR0904 """ Base abstract class for an Airbyte Stream using the HTTP protocol. Basic building block for users building an Airbyte source for a HTTP API. """ source_defined_cursor = True # Most HTTP streams use a source defined cursor (i.e: the user can't configure it like on a SQL table) - page_size: Optional[int] = ( + page_size: int | None = ( None # Use this variable to define page size for API http requests with pagination support ) - def __init__( - self, authenticator: Optional[AuthBase] = None, api_budget: Optional[APIBudget] = None + def __init__( # noqa: ANN204 + self, authenticator: AuthBase | None = None, api_budget: APIBudget | None = None ): self._exit_on_rate_limit: bool = False self._http_client = HttpClient( @@ -135,7 +137,7 @@ def raise_on_http_errors(self) -> bool: "Deprecated as of CDK version 3.0.0. " "You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead." ) - def max_retries(self) -> Union[int, None]: + def max_retries(self) -> int | None: """ Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit. """ @@ -146,7 +148,7 @@ def max_retries(self) -> Union[int, None]: "Deprecated as of CDK version 3.0.0. " "You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead." ) - def max_time(self) -> Union[int, None]: + def max_time(self) -> int | None: """ Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit. """ @@ -164,7 +166,7 @@ def retry_factor(self) -> float: return 5 @abstractmethod - def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]: + def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None: """ Override this method to define a pagination strategy. @@ -177,9 +179,9 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, def path( self, *, - stream_state: Optional[Mapping[str, Any]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> str: """ Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity" @@ -187,9 +189,9 @@ def path( def request_params( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> MutableMapping[str, Any]: """ Override this method to define the query parameters that should be set on an outgoing HTTP request given the inputs. @@ -200,9 +202,9 @@ def request_params( def request_headers( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: """ Override to return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method. @@ -211,10 +213,10 @@ def request_headers( def request_body_data( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Union[Mapping[str, Any], str]]: + stream_state: Mapping[str, Any] | None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Mapping[str, Any] | str | None: """ Override when creating POST/PUT/PATCH requests to populate the body of the request with a non-JSON payload. @@ -228,10 +230,10 @@ def request_body_data( def request_body_json( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: + stream_state: Mapping[str, Any] | None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 + ) -> Mapping[str, Any] | None: """ Override when creating POST/PUT/PATCH requests to populate the body of the request with a JSON payload. @@ -241,9 +243,9 @@ def request_body_json( def request_kwargs( self, - stream_state: Optional[Mapping[str, Any]], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_state: Mapping[str, Any] | None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002 + next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002 ) -> Mapping[str, Any]: """ Override to return a mapping of keyword arguments to be used when creating the HTTP request. @@ -258,8 +260,8 @@ def parse_response( response: requests.Response, *, stream_state: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, ) -> Iterable[Mapping[str, Any]]: """ Parses the raw response object into a list of records. @@ -271,7 +273,7 @@ def parse_response( :return: An iterable containing the parsed response """ - def get_backoff_strategy(self) -> Optional[Union[BackoffStrategy, List[BackoffStrategy]]]: + def get_backoff_strategy(self) -> BackoffStrategy | list[BackoffStrategy] | None: """ Used to initialize Adapter to avoid breaking changes. If Stream has a `backoff_time` method implementation, we know this stream uses old (pre-HTTPClient) backoff handlers and thus an adapter is needed. @@ -281,10 +283,9 @@ def get_backoff_strategy(self) -> Optional[Union[BackoffStrategy, List[BackoffSt """ if hasattr(self, "backoff_time"): return HttpStreamAdapterBackoffStrategy(self) - else: - return None + return None - def get_error_handler(self) -> Optional[ErrorHandler]: + def get_error_handler(self) -> ErrorHandler | None: """ Used to initialize Adapter to avoid breaking changes. If Stream has a `should_retry` method implementation, we know this stream uses old (pre-HTTPClient) error handlers and thus an adapter is needed. @@ -300,15 +301,14 @@ def get_error_handler(self) -> Optional[ErrorHandler]: max_time=timedelta(seconds=self.max_time or 0), ) return error_handler - else: - return None + return None @classmethod def _join_url(cls, url_base: str, path: str) -> str: return urljoin(url_base, path) @classmethod - def parse_response_error_message(cls, response: requests.Response) -> Optional[str]: + def parse_response_error_message(cls, response: requests.Response) -> str | None: """ Parses the raw response object from a failed request into a user-friendly error message. By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently. @@ -318,13 +318,13 @@ def parse_response_error_message(cls, response: requests.Response) -> Optional[s """ # default logic to grab error from common fields - def _try_get_error(value: Optional[JsonType]) -> Optional[str]: + def _try_get_error(value: JsonType | None) -> str | None: if isinstance(value, str): return value - elif isinstance(value, list): + if isinstance(value, list): errors_in_value = [_try_get_error(v) for v in value] return ", ".join(v for v in errors_in_value if v is not None) - elif isinstance(value, dict): + if isinstance(value, dict): new_value = ( value.get("message") or value.get("messages") @@ -343,7 +343,7 @@ def _try_get_error(value: Optional[JsonType]) -> Optional[str]: except requests.exceptions.JSONDecodeError: return None - def get_error_display_message(self, exception: BaseException) -> Optional[str]: + def get_error_display_message(self, exception: BaseException) -> str | None: """ Retrieves the user-friendly display message that corresponds to an exception. This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage. @@ -360,15 +360,15 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: def read_records( self, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: # A cursor_field indicates this is an incremental stream which offers better checkpointing than RFR enabled via the cursor if self.cursor_field or not isinstance(self.get_cursor(), ResumableFullRefreshCursor): yield from self._read_pages( - lambda req, res, state, _slice: self.parse_response( + lambda req, res, state, _slice: self.parse_response( # noqa: ARG005 res, stream_slice=_slice, stream_state=state ), stream_slice, @@ -376,7 +376,7 @@ def read_records( ) else: yield from self._read_single_page( - lambda req, res, state, _slice: self.parse_response( + lambda req, res, state, _slice: self.parse_response( # noqa: ARG005 res, stream_slice=_slice, stream_state=state ), stream_slice, @@ -397,7 +397,7 @@ def state(self, value: MutableMapping[str, Any]) -> None: cursor.set_initial_state(value) self._state = value - def get_cursor(self) -> Optional[Cursor]: + def get_cursor(self) -> Cursor | None: # I don't love that this is semi-stateful but not sure what else to do. We don't know exactly what type of cursor to # instantiate when creating the class. We can make a few assumptions like if there is a cursor_field which implies # incremental, but we don't know until runtime if this is a substream. Ideally, a stream should explicitly define @@ -406,8 +406,7 @@ def get_cursor(self) -> Optional[Cursor]: if self.has_multiple_slices and isinstance(self.cursor, ResumableFullRefreshCursor): self.cursor = SubstreamResumableFullRefreshCursor() return self.cursor - else: - return self.cursor + return self.cursor def _read_pages( self, @@ -416,12 +415,12 @@ def _read_pages( requests.PreparedRequest, requests.Response, Mapping[str, Any], - Optional[Mapping[str, Any]], + Mapping[str, Any] | None, ], Iterable[StreamData], ], - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice) @@ -452,12 +451,12 @@ def _read_single_page( requests.PreparedRequest, requests.Response, Mapping[str, Any], - Optional[Mapping[str, Any]], + Mapping[str, Any] | None, ], Iterable[StreamData], ], - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, ) -> Iterable[StreamData]: partition, cursor_slice, remaining_slice = self._extract_slice_fields( stream_slice=stream_slice @@ -481,7 +480,7 @@ def _read_single_page( @staticmethod def _extract_slice_fields( - stream_slice: Optional[Mapping[str, Any]], + stream_slice: Mapping[str, Any] | None, ) -> tuple[Mapping[str, Any], Mapping[str, Any], Mapping[str, Any]]: if not stream_slice: return {}, {}, {} @@ -499,16 +498,16 @@ def _extract_slice_fields( remaining = { key: val for key, val in stream_slice.items() - if key != "partition" and key != "cursor_slice" + if key != "partition" and key != "cursor_slice" # noqa: PLR1714 } return partition, cursor_slice, remaining def _fetch_next_page( self, - stream_slice: Optional[Mapping[str, Any]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Tuple[requests.PreparedRequest, requests.Response]: + stream_slice: Mapping[str, Any] | None = None, + stream_state: Mapping[str, Any] | None = None, + next_page_token: Mapping[str, Any] | None = None, + ) -> tuple[requests.PreparedRequest, requests.Response]: request, response = self._http_client.send_request( http_method=self.http_method, url=self._join_url( @@ -551,7 +550,7 @@ def _fetch_next_page( return request, response - def get_log_formatter(self) -> Optional[Callable[[requests.Response], Any]]: + def get_log_formatter(self) -> Callable[[requests.Response], Any] | None: """ :return Optional[Callable[[requests.Response], Any]]: Function that will be used in logging inside HttpClient @@ -560,7 +559,7 @@ def get_log_formatter(self) -> Optional[Callable[[requests.Response], Any]]: class HttpSubStream(HttpStream, ABC): - def __init__(self, parent: HttpStream, **kwargs: Any): + def __init__(self, parent: HttpStream, **kwargs: Any): # noqa: ANN204, ANN401 """ :param parent: should be the instance of HttpStream class """ @@ -584,21 +583,21 @@ def __init__(self, parent: HttpStream, **kwargs: Any): def stream_slices( self, - sync_mode: SyncMode, - cursor_field: Optional[List[str]] = None, - stream_state: Optional[Mapping[str, Any]] = None, - ) -> Iterable[Optional[Mapping[str, Any]]]: + sync_mode: SyncMode, # noqa: ARG002 + cursor_field: list[str] | None = None, # noqa: ARG002 + stream_state: Mapping[str, Any] | None = None, + ) -> Iterable[Mapping[str, Any] | None]: # read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does # not support either substreams or RFR, but something that needs to be considered once we do for parent_record in self.parent.read_only_records(stream_state): # Skip non-records (eg AirbyteLogMessage) if isinstance(parent_record, AirbyteMessage): if parent_record.type == MessageType.RECORD: - parent_record = parent_record.record.data # type: ignore [assignment, union-attr] # Incorrect type for assignment + parent_record = parent_record.record.data # type: ignore [assignment, union-attr] # Incorrect type for assignment # noqa: PLW2901 else: continue elif isinstance(parent_record, Record): - parent_record = parent_record.data + parent_record = parent_record.data # noqa: PLW2901 yield {"parent": parent_record} @@ -607,15 +606,15 @@ def stream_slices( "You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead." ) class HttpStreamAdapterBackoffStrategy(BackoffStrategy): - def __init__(self, stream: HttpStream): + def __init__(self, stream: HttpStream): # noqa: ANN204 self.stream = stream def backoff_time( self, - response_or_exception: Optional[Union[requests.Response, requests.RequestException]], - attempt_count: int, - ) -> Optional[float]: - return self.stream.backoff_time(response_or_exception) # type: ignore # noqa # HttpStream.backoff_time has been deprecated + response_or_exception: requests.Response | requests.RequestException | None, + attempt_count: int, # noqa: ARG002 + ) -> float | None: + return self.stream.backoff_time(response_or_exception) # type: ignore # HttpStream.backoff_time has been deprecated @deprecated( @@ -623,19 +622,19 @@ def backoff_time( "You should set error_handler explicitly in HttpStream.get_error_handler() instead." ) class HttpStreamAdapterHttpStatusErrorHandler(HttpStatusErrorHandler): - def __init__(self, stream: HttpStream, **kwargs): # type: ignore # noqa + def __init__(self, stream: HttpStream, **kwargs) -> None: # type: ignore self.stream = stream super().__init__(**kwargs) - def interpret_response( - self, response_or_exception: Optional[Union[requests.Response, Exception]] = None + def interpret_response( # noqa: PLR0911 + self, response_or_exception: requests.Response | Exception | None = None ) -> ErrorResolution: if isinstance(response_or_exception, Exception): return super().interpret_response(response_or_exception) - elif isinstance(response_or_exception, requests.Response): - should_retry = self.stream.should_retry(response_or_exception) # type: ignore # noqa + if isinstance(response_or_exception, requests.Response): + should_retry = self.stream.should_retry(response_or_exception) # type: ignore if should_retry: - if response_or_exception.status_code == 429: + if response_or_exception.status_code == 429: # noqa: PLR2004 return ErrorResolution( response_action=ResponseAction.RATE_LIMITED, failure_type=FailureType.transient_error, @@ -646,29 +645,26 @@ def interpret_response( failure_type=FailureType.transient_error, error_message=f"Response status code: {response_or_exception.status_code}. Retrying...", ) - else: - if response_or_exception.ok: - return ErrorResolution( - response_action=ResponseAction.SUCCESS, - failure_type=None, - error_message=None, - ) - if self.stream.raise_on_http_errors: - return ErrorResolution( - response_action=ResponseAction.FAIL, - failure_type=FailureType.transient_error, - error_message=f"Response status code: {response_or_exception.status_code}. Unexpected error. Failed.", - ) - else: - return ErrorResolution( - response_action=ResponseAction.IGNORE, - failure_type=FailureType.transient_error, - error_message=f"Response status code: {response_or_exception.status_code}. Ignoring...", - ) - else: - self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + if response_or_exception.ok: + return ErrorResolution( + response_action=ResponseAction.SUCCESS, + failure_type=None, + error_message=None, + ) + if self.stream.raise_on_http_errors: + return ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.transient_error, + error_message=f"Response status code: {response_or_exception.status_code}. Unexpected error. Failed.", + ) return ErrorResolution( - response_action=ResponseAction.FAIL, - failure_type=FailureType.system_error, - error_message=f"Received unexpected response type: {type(response_or_exception)}", + response_action=ResponseAction.IGNORE, + failure_type=FailureType.transient_error, + error_message=f"Response status code: {response_or_exception.status_code}. Ignoring...", ) + self._logger.error(f"Received unexpected response type: {type(response_or_exception)}") + return ErrorResolution( + response_action=ResponseAction.FAIL, + failure_type=FailureType.system_error, + error_message=f"Received unexpected response type: {type(response_or_exception)}", + ) diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index c4fa86866..bfca42a71 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -5,8 +5,9 @@ import logging import os import urllib +from collections.abc import Callable, Mapping from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any import orjson import requests @@ -53,6 +54,7 @@ ) from airbyte_cdk.utils.traced_exception import AirbyteTracedException + BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH") @@ -68,7 +70,7 @@ class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException): def __str__(self) -> str: if self.message: return self.message - elif self.internal_message: + if self.internal_message: return self.internal_message return "" @@ -76,21 +78,21 @@ def __str__(self) -> str: class HttpClient: _DEFAULT_MAX_RETRY: int = 5 _DEFAULT_MAX_TIME: int = 60 * 10 - _ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED} + _ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED} # noqa: RUF012 - def __init__( + def __init__( # noqa: ANN204, PLR0913, PLR0917 self, name: str, logger: logging.Logger, - error_handler: Optional[ErrorHandler] = None, - api_budget: Optional[APIBudget] = None, - session: Optional[Union[requests.Session, requests_cache.CachedSession]] = None, - authenticator: Optional[AuthBase] = None, - use_cache: bool = False, - backoff_strategy: Optional[Union[BackoffStrategy, List[BackoffStrategy]]] = None, - error_message_parser: Optional[ErrorMessageParser] = None, - disable_retries: bool = False, - message_repository: Optional[MessageRepository] = None, + error_handler: ErrorHandler | None = None, + api_budget: APIBudget | None = None, + session: requests.Session | requests_cache.CachedSession | None = None, + authenticator: AuthBase | None = None, + use_cache: bool = False, # noqa: FBT001, FBT002 + backoff_strategy: BackoffStrategy | list[BackoffStrategy] | None = None, + error_message_parser: ErrorMessageParser | None = None, + disable_retries: bool = False, # noqa: FBT001, FBT002 + message_repository: MessageRepository | None = None, ): self._name = name self._api_budget: APIBudget = api_budget or APIBudget(policies=[]) @@ -117,7 +119,7 @@ def __init__( else: self._backoff_strategies = [DefaultBackoffStrategy()] self._error_message_parser = error_message_parser or JsonErrorMessageParser() - self._request_attempt_count: Dict[requests.PreparedRequest, int] = {} + self._request_attempt_count: dict[requests.PreparedRequest, int] = {} self._disable_retries = disable_retries self._message_repository = message_repository @@ -155,8 +157,7 @@ def _request_session(self) -> requests.Session: return CachedLimiterSession( sqlite_path, backend=backend, api_budget=self._api_budget, match_headers=True ) - else: - return LimiterSession(api_budget=self._api_budget) + return LimiterSession(api_budget=self._api_budget) def clear_cache(self) -> None: """ @@ -165,9 +166,7 @@ def clear_cache(self) -> None: if isinstance(self._session, requests_cache.CachedSession): self._session.cache.clear() # type: ignore # cache.clear is not typed - def _dedupe_query_params( - self, url: str, params: Optional[Mapping[str, str]] - ) -> Mapping[str, str]: + def _dedupe_query_params(self, url: str, params: Mapping[str, str] | None) -> Mapping[str, str]: """ Remove query parameters from params mapping if they are already encoded in the URL. :param url: URL with @@ -180,7 +179,7 @@ def _dedupe_query_params( query_dict = {k: v[0] for k, v in urllib.parse.parse_qs(query_string).items()} duplicate_keys_with_same_value = { - k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k]) + k for k in query_dict if str(params.get(k)) == str(query_dict[k]) } return {k: v for k, v in params.items() if k not in duplicate_keys_with_same_value} @@ -188,11 +187,11 @@ def _create_prepared_request( self, http_method: str, url: str, - dedupe_query_params: bool = False, - headers: Optional[Mapping[str, str]] = None, - params: Optional[Mapping[str, str]] = None, - json: Optional[Mapping[str, Any]] = None, - data: Optional[Union[str, Mapping[str, Any]]] = None, + dedupe_query_params: bool = False, # noqa: FBT001, FBT002 + headers: Mapping[str, str] | None = None, + params: Mapping[str, str] | None = None, + json: Mapping[str, Any] | None = None, + data: str | Mapping[str, Any] | None = None, ) -> requests.PreparedRequest: if dedupe_query_params: query_params = self._dedupe_query_params(url, params) @@ -204,7 +203,7 @@ def _create_prepared_request( raise RequestBodyException( "At the same time only one of the 'request_body_data' and 'request_body_json' functions can return data" ) - elif json: + if json: args["json"] = json elif data: args["data"] = data @@ -220,7 +219,7 @@ def _max_retries(self) -> int: Determines the max retries based on the provided error handler. """ max_retries = None - if self._disable_retries: + if self._disable_retries: # noqa: SIM108 max_retries = 0 else: max_retries = self._error_handler.max_retries @@ -241,8 +240,8 @@ def _send_with_retry( self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any], - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - exit_on_rate_limit: Optional[bool] = False, + log_formatter: Callable[[requests.Response], Any] | None = None, + exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002 ) -> requests.Response: """ Sends a request with retry logic. @@ -280,8 +279,8 @@ def _send( self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any], - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - exit_on_rate_limit: Optional[bool] = False, + log_formatter: Callable[[requests.Response], Any] | None = None, + exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002 ) -> requests.Response: if request not in self._request_attempt_count: self._request_attempt_count[request] = 1 @@ -295,8 +294,8 @@ def _send( extra={"headers": request.headers, "url": request.url, "request_body": request.body}, ) - response: Optional[requests.Response] = None - exc: Optional[requests.RequestException] = None + response: requests.Response | None = None + exc: requests.RequestException | None = None try: response = self._session.send(request, **request_kwargs) @@ -347,7 +346,7 @@ def _send( return response # type: ignore # will either return a valid response of type requests.Response or raise an exception - def _get_response_body(self, response: requests.Response) -> Optional[JsonType]: + def _get_response_body(self, response: requests.Response) -> JsonType | None: """ Extracts and returns the body of an HTTP response. @@ -383,11 +382,11 @@ def _evict_key(self, prepared_request: requests.PreparedRequest) -> None: def _handle_error_resolution( self, - response: Optional[requests.Response], - exc: Optional[requests.RequestException], + response: requests.Response | None, + exc: requests.RequestException | None, request: requests.PreparedRequest, error_resolution: ErrorResolution, - exit_on_rate_limit: Optional[bool] = False, + exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002 ) -> None: if error_resolution.response_action not in self._ACTIONS_TO_RETRY_ON: self._evict_key(request) @@ -413,7 +412,7 @@ def _handle_error_resolution( if error_resolution.response_action == ResponseAction.FAIL: if response is not None: filtered_response_message = filter_secrets( - f"Request (body): '{str(request.body)}'. Response (body): '{self._get_response_body(response)}'. Response (headers): '{response.headers}'." + f"Request (body): '{request.body!s}'. Response (body): '{self._get_response_body(response)}'. Response (headers): '{response.headers}'." ) error_message = f"'{request.method}' request to '{request.url}' failed with status code '{response.status_code}' and error message: '{self._error_message_parser.parse_response_error_message(response)}'. {filtered_response_message}" else: @@ -430,7 +429,7 @@ def _handle_error_resolution( failure_type=error_resolution.failure_type, ) - elif error_resolution.response_action == ResponseAction.IGNORE: + if error_resolution.response_action == ResponseAction.IGNORE: if response is not None: log_message = f"Ignoring response for '{request.method}' request to '{request.url}' with response code '{response.status_code}'" else: @@ -440,7 +439,7 @@ def _handle_error_resolution( # TODO: Consider dynamic retry count depending on subsequent error codes elif ( - error_resolution.response_action == ResponseAction.RETRY + error_resolution.response_action == ResponseAction.RETRY # noqa: PLR1714 or error_resolution.response_action == ResponseAction.RATE_LIMITED ): user_defined_backoff_time = None @@ -470,7 +469,7 @@ def _handle_error_resolution( error_message=error_message, ) - elif retry_endlessly: + if retry_endlessly: raise RateLimitBackoffException( request=request, response=(response if response is not None else exc), @@ -488,25 +487,25 @@ def _handle_error_resolution( response.raise_for_status() except requests.HTTPError as e: self._logger.error(response.text) - raise e + raise e # noqa: TRY201 @property def name(self) -> str: return self._name - def send_request( + def send_request( # noqa: PLR0913, PLR0917 self, http_method: str, url: str, request_kwargs: Mapping[str, Any], - headers: Optional[Mapping[str, str]] = None, - params: Optional[Mapping[str, str]] = None, - json: Optional[Mapping[str, Any]] = None, - data: Optional[Union[str, Mapping[str, Any]]] = None, - dedupe_query_params: bool = False, - log_formatter: Optional[Callable[[requests.Response], Any]] = None, - exit_on_rate_limit: Optional[bool] = False, - ) -> Tuple[requests.PreparedRequest, requests.Response]: + headers: Mapping[str, str] | None = None, + params: Mapping[str, str] | None = None, + json: Mapping[str, Any] | None = None, + data: str | Mapping[str, Any] | None = None, + dedupe_query_params: bool = False, # noqa: FBT001, FBT002 + log_formatter: Callable[[requests.Response], Any] | None = None, + exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002 + ) -> tuple[requests.PreparedRequest, requests.Response]: """ Prepares and sends request and return request and response objects. """ diff --git a/airbyte_cdk/sources/streams/http/rate_limiting.py b/airbyte_cdk/sources/streams/http/rate_limiting.py index 926a7ad56..33a55d9e1 100644 --- a/airbyte_cdk/sources/streams/http/rate_limiting.py +++ b/airbyte_cdk/sources/streams/http/rate_limiting.py @@ -5,7 +5,8 @@ import logging import sys import time -from typing import Any, Callable, Mapping, Optional +from collections.abc import Callable, Mapping +from typing import Any import backoff from requests import PreparedRequest, RequestException, Response, codes, exceptions @@ -16,6 +17,7 @@ UserDefinedBackoffException, ) + TRANSIENT_EXCEPTIONS = ( DefaultBackoffException, exceptions.ConnectTimeout, @@ -31,7 +33,10 @@ def default_backoff_handler( - max_tries: Optional[int], factor: float, max_time: Optional[int] = None, **kwargs: Any + max_tries: int | None, + factor: float, + max_time: int | None = None, + **kwargs: Any, # noqa: ANN401 ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() @@ -40,17 +45,17 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None: f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" ) logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) def should_give_up(exc: Exception) -> bool: # If a non-rate-limiting related 4XX error makes it this far, it means it was unexpected and probably consistent, so we shouldn't back off - if isinstance(exc, RequestException): + if isinstance(exc, RequestException): # noqa: SIM102 if exc.response is not None: give_up: bool = ( exc.response is not None and exc.response.status_code != codes.too_many_requests - and 400 <= exc.response.status_code < 500 + and 400 <= exc.response.status_code < 500 # noqa: PLR2004 ) if give_up: logger.info(f"Giving up for returned HTTP status: {exc.response.status_code!r}") @@ -72,7 +77,9 @@ def should_give_up(exc: Exception) -> bool: def http_client_default_backoff_handler( - max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any + max_tries: int | None, + max_time: int | None = None, + **kwargs: Any, # noqa: ANN401 ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() @@ -81,10 +88,10 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None: f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" ) logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) - def should_give_up(exc: Exception) -> bool: + def should_give_up(exc: Exception) -> bool: # noqa: ARG001 # If made it here, the ResponseAction was RETRY and therefore should not give up return False @@ -101,9 +108,11 @@ def should_give_up(exc: Exception) -> bool: def user_defined_backoff_handler( - max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any + max_tries: int | None, + max_time: int | None = None, + **kwargs: Any, # noqa: ANN401 ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: - def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: + def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: # noqa: ARG001 _, exc, _ = sys.exc_info() if isinstance(exc, UserDefinedBackoffException): if exc.response: @@ -137,7 +146,7 @@ def log_give_up(details: Mapping[str, Any]) -> None: def rate_limit_default_backoff_handler( - **kwargs: Any, + **kwargs: Any, # noqa: ANN401 ) -> Callable[[SendRequestCallableType], SendRequestCallableType]: def log_retry_attempt(details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() @@ -146,7 +155,7 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None: f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}" ) logger.info( - f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." + f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) return backoff.on_exception( # type: ignore # Decorator function returns a function with a different signature than the input function, so mypy can't infer the type of the returned function diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py b/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py index 307f91f40..e792ba0b4 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py @@ -5,6 +5,7 @@ from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator + __all__ = [ "Oauth2Authenticator", "SingleUseRefreshTokenOauth2Authenticator", diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 753d79269..44b86940c 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -4,27 +4,28 @@ import logging from abc import abstractmethod +from collections.abc import Mapping, MutableMapping from json import JSONDecodeError -from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any import backoff import pendulum import requests from requests.auth import AuthBase +from ..exceptions import DefaultBackoffException # noqa: TID252 from airbyte_cdk.models import FailureType, Level from airbyte_cdk.sources.http_logger import format_http_message from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets -from ..exceptions import DefaultBackoffException logger = logging.getLogger("airbyte") _NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() -class AbstractOauth2Authenticator(AuthBase): +class AbstractOauth2Authenticator(AuthBase): # noqa: PLR0904 """ Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by @@ -35,9 +36,9 @@ class AbstractOauth2Authenticator(AuthBase): def __init__( self, - refresh_token_error_status_codes: Tuple[int, ...] = (), + refresh_token_error_status_codes: tuple[int, ...] = (), refresh_token_error_key: str = "", - refresh_token_error_values: Tuple[str, ...] = (), + refresh_token_error_values: tuple[str, ...] = (), ) -> None: """ If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, @@ -104,7 +105,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: """ headers = self.get_refresh_request_headers() - return headers if headers else None + return headers if headers else None # noqa: FURB110 def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException @@ -130,7 +131,7 @@ def _wrap_refresh_token_exception( ), max_time=300, ) - def _get_refresh_access_token_response(self) -> Any: + def _get_refresh_access_token_response(self) -> Any: # noqa: ANN401 try: response = requests.request( method="POST", @@ -144,30 +145,29 @@ def _get_refresh_access_token_response(self) -> Any: # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... access_key = response_json.get(self.get_access_token_name()) if not access_key: - raise Exception( + raise Exception( # noqa: TRY002, TRY301 "Token refresh API response was missing access token {self.get_access_token_name()}" ) add_to_secrets(access_key) self._log_response(response) return response_json - else: - # log the response even if the request failed for troubleshooting purposes - self._log_response(response) - response.raise_for_status() + # log the response even if the request failed for troubleshooting purposes + self._log_response(response) + response.raise_for_status() except requests.exceptions.RequestException as e: - if e.response is not None: - if e.response.status_code == 429 or e.response.status_code >= 500: - raise DefaultBackoffException(request=e.response.request, response=e.response) + if e.response is not None: # noqa: SIM102 + if e.response.status_code == 429 or e.response.status_code >= 500: # noqa: PLR2004 + raise DefaultBackoffException(request=e.response.request, response=e.response) # noqa: B904 if self._wrap_refresh_token_exception(e): message = "Refresh token is invalid or expired. Please re-authenticate from Sources//Settings." - raise AirbyteTracedException( + raise AirbyteTracedException( # noqa: B904 internal_message=message, message=message, failure_type=FailureType.config_error ) raise except Exception as e: - raise Exception(f"Error while refreshing access token: {e}") from e + raise Exception(f"Error while refreshing access token: {e}") from e # noqa: TRY002 - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + def refresh_access_token(self) -> tuple[str, str | int]: """ Returns the refresh token and its expiration datetime @@ -179,7 +179,7 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: self.get_expires_in_name() ] - def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateTime: + def _parse_token_expiration_date(self, value: str | int) -> pendulum.DateTime: """ Return the expiration datetime of the refresh token @@ -192,8 +192,7 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateT f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." ) return pendulum.from_format(str(value), self.token_expiry_date_format) - else: - return pendulum.now().add(seconds=int(float(value))) + return pendulum.now().add(seconds=int(float(value))) @property def token_expiry_is_time_of_expiration(self) -> bool: @@ -204,7 +203,7 @@ def token_expiry_is_time_of_expiration(self) -> bool: return False @property - def token_expiry_date_format(self) -> Optional[str]: + def token_expiry_date_format(self) -> str | None: """ Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires """ @@ -212,7 +211,7 @@ def token_expiry_date_format(self) -> Optional[str]: return None @abstractmethod - def get_token_refresh_endpoint(self) -> Optional[str]: + def get_token_refresh_endpoint(self) -> str | None: """Returns the endpoint to refresh the access token""" @abstractmethod @@ -236,11 +235,11 @@ def get_refresh_token_name(self) -> str: """The refresh token name to authenticate""" @abstractmethod - def get_refresh_token(self) -> Optional[str]: + def get_refresh_token(self) -> str | None: """The token used to refresh the access token when it expires""" @abstractmethod - def get_scopes(self) -> List[str]: + def get_scopes(self) -> list[str]: """List of requested scopes""" @abstractmethod @@ -248,7 +247,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime: """Expiration date of the access token""" @abstractmethod - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: str | int) -> None: """Setter for access token expiration date""" @abstractmethod @@ -286,7 +285,7 @@ def access_token(self, value: str) -> str: """Setter for the access token""" @property - def _message_repository(self) -> Optional[MessageRepository]: + def _message_repository(self) -> MessageRepository | None: """ The implementation can define a message_repository if it wants debugging logs for HTTP requests """ diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py index ffcc8e851..9afab38b4 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py @@ -3,7 +3,8 @@ # from abc import abstractmethod -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import requests from requests.auth import AuthBase @@ -12,7 +13,7 @@ class AbstractHeaderAuthenticator(AuthBase): """Abstract class for an header-based authenticators that add a header to outgoing HTTP requests.""" - def __call__(self, request: requests.PreparedRequest) -> Any: + def __call__(self, request: requests.PreparedRequest) -> Any: # noqa: ANN401 """Attach the HTTP headers required to authenticate on the HTTP request""" request.headers.update(self.get_auth_header()) return request diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index f244e6508..8333c56e8 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -2,7 +2,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence +from typing import Any import dpath import pendulum @@ -24,7 +25,7 @@ class Oauth2Authenticator(AbstractOauth2Authenticator): If a connector_config is provided any mutation of it's value in the scope of this class will emit AirbyteControlConnectorConfigMessage. """ - def __init__( + def __init__( # noqa: ANN204, PLR0913, PLR0917 self, token_refresh_endpoint: str, client_id: str, @@ -33,7 +34,7 @@ def __init__( client_id_name: str = "client_id", client_secret_name: str = "client_secret", refresh_token_name: str = "refresh_token", - scopes: List[str] | None = None, + scopes: list[str] | None = None, token_expiry_date: pendulum.DateTime | None = None, token_expiry_date_format: str | None = None, access_token_name: str = "access_token", @@ -42,10 +43,10 @@ def __init__( refresh_request_headers: Mapping[str, Any] | None = None, grant_type_name: str = "grant_type", grant_type: str = "refresh_token", - token_expiry_is_time_of_expiration: bool = False, - refresh_token_error_status_codes: Tuple[int, ...] = (), + token_expiry_is_time_of_expiration: bool = False, # noqa: FBT001, FBT002 + refresh_token_error_status_codes: tuple[int, ...] = (), refresh_token_error_key: str = "", - refresh_token_error_values: Tuple[str, ...] = (), + refresh_token_error_values: tuple[str, ...] = (), ): self._token_refresh_endpoint = token_refresh_endpoint self._client_secret_name = client_secret_name @@ -115,7 +116,7 @@ def get_grant_type(self) -> str: def get_token_expiry_date(self) -> pendulum.DateTime: return self._token_expiry_date - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: str | int) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) @property @@ -123,7 +124,7 @@ def token_expiry_is_time_of_expiration(self) -> bool: return self._token_expiry_is_time_of_expiration @property - def token_expiry_date_format(self) -> Optional[str]: + def token_expiry_date_format(self) -> str | None: return self._token_expiry_date_format @property @@ -145,11 +146,11 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator): client_secret_config_path, refresh_token_config_path constructor arguments. """ - def __init__( + def __init__( # noqa: ANN204, PLR0913, PLR0917 self, connector_config: Mapping[str, Any], token_refresh_endpoint: str, - scopes: List[str] | None = None, + scopes: list[str] | None = None, access_token_name: str = "access_token", expires_in_name: str = "expires_in", refresh_token_name: str = "refresh_token", @@ -158,18 +159,18 @@ def __init__( grant_type_name: str = "grant_type", grant_type: str = "refresh_token", client_id_name: str = "client_id", - client_id: Optional[str] = None, + client_id: str | None = None, client_secret_name: str = "client_secret", - client_secret: Optional[str] = None, + client_secret: str | None = None, access_token_config_path: Sequence[str] = ("credentials", "access_token"), refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"), token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"), - token_expiry_date_format: Optional[str] = None, - message_repository: MessageRepository = NoopMessageRepository(), - token_expiry_is_time_of_expiration: bool = False, - refresh_token_error_status_codes: Tuple[int, ...] = (), + token_expiry_date_format: str | None = None, + message_repository: MessageRepository = NoopMessageRepository(), # noqa: B008 + token_expiry_is_time_of_expiration: bool = False, # noqa: FBT001, FBT002 + refresh_token_error_status_codes: tuple[int, ...] = (), refresh_token_error_key: str = "", - refresh_token_error_values: Tuple[str, ...] = (), + refresh_token_error_values: tuple[str, ...] = (), ): """ Args: @@ -282,7 +283,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime: self._token_expiry_date_config_path, default="", ) - return pendulum.now().subtract(days=1) if expiry_date == "" else pendulum.parse(expiry_date) # type: ignore [arg-type, return-value, no-untyped-call] + return pendulum.now().subtract(days=1) if expiry_date == "" else pendulum.parse(expiry_date) # type: ignore [arg-type, return-value, no-untyped-call] # noqa: PLC1901 def set_token_expiry_date( # type: ignore[override] self, @@ -305,8 +306,7 @@ def get_new_token_expiry_date( ) -> pendulum.DateTime: if token_expiry_date_format: return pendulum.from_format(access_token_expires_in, token_expiry_date_format) - else: - return pendulum.now("UTC").add(seconds=int(access_token_expires_in)) + return pendulum.now("UTC").add(seconds=int(access_token_expires_in)) def get_access_token(self) -> str: """Retrieve new access and refresh token if the access token has expired. @@ -324,7 +324,7 @@ def get_access_token(self) -> str: self.access_token = new_access_token self.set_refresh_token(new_refresh_token) self.set_token_expiry_date(new_token_expiry_date) - # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message + # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message # noqa: FIX001, TD001, TD004 # Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the # message directly in the console, this is needed if not isinstance(self._message_repository, NoopMessageRepository): @@ -337,7 +337,7 @@ def get_access_token(self) -> str: def refresh_access_token( # type: ignore[override] # Signature doesn't match base class self, - ) -> Tuple[str, str, str]: + ) -> tuple[str, str, str]: response_json = self._get_refresh_access_token_response() return ( response_json[self.get_access_token_name()], diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py index eec7fd0c5..bb7a79e60 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py @@ -1,10 +1,9 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import base64 from itertools import cycle -from typing import List from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import ( AbstractHeaderAuthenticator, @@ -26,8 +25,8 @@ def auth_header(self) -> str: def token(self) -> str: return f"{self._auth_method} {next(self._tokens_iter)}" - def __init__( - self, tokens: List[str], auth_method: str = "Bearer", auth_header: str = "Authorization" + def __init__( # noqa: ANN204 + self, tokens: list[str], auth_method: str = "Bearer", auth_header: str = "Authorization" ): self._auth_method = auth_method self._auth_header = auth_header @@ -49,7 +48,7 @@ def auth_header(self) -> str: def token(self) -> str: return f"{self._auth_method} {self._token}" - def __init__(self, token: str, auth_method: str = "Bearer", auth_header: str = "Authorization"): + def __init__(self, token: str, auth_method: str = "Bearer", auth_header: str = "Authorization"): # noqa: ANN204 self._auth_header = auth_header self._auth_method = auth_method self._token = token @@ -69,14 +68,14 @@ def auth_header(self) -> str: def token(self) -> str: return f"{self._auth_method} {self._token}" - def __init__( + def __init__( # noqa: ANN204 self, username: str, password: str = "", auth_method: str = "Basic", auth_header: str = "Authorization", ): - auth_string = f"{username}:{password}".encode("utf8") + auth_string = f"{username}:{password}".encode() b64_encoded = base64.b64encode(auth_string).decode("utf8") self._auth_header = auth_header self._auth_method = auth_method diff --git a/airbyte_cdk/sources/types.py b/airbyte_cdk/sources/types.py index 3c466ccd8..21bf39b17 100644 --- a/airbyte_cdk/sources/types.py +++ b/airbyte_cdk/sources/types.py @@ -1,28 +1,30 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # from __future__ import annotations -from typing import Any, ItemsView, Iterator, KeysView, List, Mapping, Optional, ValuesView +from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView +from typing import Any import orjson + # A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2": # "hello"}] returns "hello" -FieldPointer = List[str] +FieldPointer = list[str] Config = Mapping[str, Any] ConnectionDefinition = Mapping[str, Any] StreamState = Mapping[str, Any] -class Record(Mapping[str, Any]): - def __init__( +class Record(Mapping[str, Any]): # noqa: PLW1641 + def __init__( # noqa: ANN204 self, data: Mapping[str, Any], stream_name: str, - associated_slice: Optional[StreamSlice] = None, - is_file_transfer_message: bool = False, + associated_slice: StreamSlice | None = None, + is_file_transfer_message: bool = False, # noqa: FBT001, FBT002 ): self._data = data self._associated_slice = associated_slice @@ -34,19 +36,19 @@ def data(self) -> Mapping[str, Any]: return self._data @property - def associated_slice(self) -> Optional[StreamSlice]: + def associated_slice(self) -> StreamSlice | None: return self._associated_slice def __repr__(self) -> str: return repr(self._data) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> Any: # noqa: ANN401 return self._data[key] def __len__(self) -> int: return len(self._data) - def __iter__(self) -> Any: + def __iter__(self) -> Any: # noqa: ANN401 return iter(self._data) def __contains__(self, item: object) -> bool: @@ -68,7 +70,7 @@ def __init__( *, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any], - extra_fields: Optional[Mapping[str, Any]] = None, + extra_fields: Mapping[str, Any] | None = None, ) -> None: """ :param partition: The partition keys representing a unique partition in the stream. @@ -109,10 +111,10 @@ def extra_fields(self) -> Mapping[str, Any]: def __repr__(self) -> str: return repr(self._stream_slice) - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key: str, value: Any) -> None: # noqa: ANN401 raise ValueError("StreamSlice is immutable") - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> Any: # noqa: ANN401 return self._stream_slice[key] def __len__(self) -> int: @@ -121,7 +123,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: return iter(self._stream_slice) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: Any) -> bool: # noqa: ANN401 return item in self._stream_slice def keys(self) -> KeysView[str]: @@ -133,10 +135,10 @@ def items(self) -> ItemsView[str, Any]: def values(self) -> ValuesView[Any]: return self._stream_slice.values() - def get(self, key: str, default: Any = None) -> Optional[Any]: + def get(self, key: str, default: Any = None) -> Any | None: # noqa: ANN401 return self._stream_slice.get(key, default) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, dict): return self._stream_slice == other if isinstance(other, StreamSlice): @@ -144,10 +146,10 @@ def __eq__(self, other: Any) -> bool: return self._partition == other._partition and self._cursor_slice == other._cursor_slice return False - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __json_serializable__(self) -> Any: + def __json_serializable__(self) -> Any: # noqa: ANN401, PLW3201 return self._stream_slice def __hash__(self) -> int: diff --git a/airbyte_cdk/sources/utils/__init__.py b/airbyte_cdk/sources/utils/__init__.py index b609a6c7a..c941b3045 100644 --- a/airbyte_cdk/sources/utils/__init__.py +++ b/airbyte_cdk/sources/utils/__init__.py @@ -1,7 +1,3 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - -# Initialize Utils Package - -__all__ = ["record_helper"] diff --git a/airbyte_cdk/sources/utils/casing.py b/airbyte_cdk/sources/utils/casing.py index 806e077ae..c78d5e91e 100644 --- a/airbyte_cdk/sources/utils/casing.py +++ b/airbyte_cdk/sources/utils/casing.py @@ -8,5 +8,5 @@ # https://stackoverflow.com/a/1176023 def camel_to_snake(s: str) -> str: - s = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s).lower() + s = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s) # noqa: RUF039 + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s).lower() # noqa: RUF039 diff --git a/airbyte_cdk/sources/utils/record_helper.py b/airbyte_cdk/sources/utils/record_helper.py index 3d2cbcecf..1a56a18cd 100644 --- a/airbyte_cdk/sources/utils/record_helper.py +++ b/airbyte_cdk/sources/utils/record_helper.py @@ -2,8 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import time +from collections.abc import Mapping from collections.abc import Mapping as ABCMapping -from typing import Any, Mapping, Optional +from typing import Any from airbyte_cdk.models import ( AirbyteLogMessage, @@ -20,9 +21,9 @@ def stream_data_to_airbyte_message( stream_name: str, data_or_message: StreamData, - transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform), - schema: Optional[Mapping[str, Any]] = None, - is_file_transfer_message: bool = False, + transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform), # noqa: B008 + schema: Mapping[str, Any] | None = None, + is_file_transfer_message: bool = False, # noqa: FBT001, FBT002 ) -> AirbyteMessage: if schema is None: schema = {} diff --git a/airbyte_cdk/sources/utils/schema_helpers.py b/airbyte_cdk/sources/utils/schema_helpers.py index f15578238..241e008d6 100644 --- a/airbyte_cdk/sources/utils/schema_helpers.py +++ b/airbyte_cdk/sources/utils/schema_helpers.py @@ -7,7 +7,8 @@ import json import os import pkgutil -from typing import Any, ClassVar, Dict, List, Mapping, MutableMapping, Optional, Tuple +from collections.abc import Mapping, MutableMapping +from typing import Any, ClassVar import jsonref from jsonschema import RefResolver, validate @@ -25,21 +26,20 @@ class JsonFileLoader: pointing to shared_schema.json file instead of shared/shared_schema.json """ - def __init__(self, uri_base: str, shared: str): + def __init__(self, uri_base: str, shared: str): # noqa: ANN204 self.shared = shared self.uri_base = uri_base - def __call__(self, uri: str) -> Dict[str, Any]: + def __call__(self, uri: str) -> dict[str, Any]: uri = uri.replace(self.uri_base, f"{self.uri_base}/{self.shared}/") - with open(uri) as f: + with open(uri) as f: # noqa: PLW1514, PTH123 data = json.load(f) if isinstance(data, dict): return data - else: - raise ValueError(f"Expected to read a dictionary from {uri}. Got: {data}") + raise ValueError(f"Expected to read a dictionary from {uri}. Got: {data}") -def resolve_ref_links(obj: Any) -> Any: +def resolve_ref_links(obj: Any) -> Any: # noqa: ANN401 """ Scan resolved schema and convert jsonref.JsonRef object to JSON serializable dict. @@ -53,17 +53,15 @@ def resolve_ref_links(obj: Any) -> Any: if isinstance(obj, dict): obj.pop("definitions", None) return obj - else: - raise ValueError(f"Expected obj to be a dict. Got {obj}") - elif isinstance(obj, dict): + raise ValueError(f"Expected obj to be a dict. Got {obj}") + if isinstance(obj, dict): return {k: resolve_ref_links(v) for k, v in obj.items()} - elif isinstance(obj, list): + if isinstance(obj, list): return [resolve_ref_links(item) for item in obj] - else: - return obj + return obj -def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> None: +def _expand_refs(schema: Any, ref_resolver: RefResolver | None = None) -> None: # noqa: ANN401 """Internal function to iterate over schema and replace all occurrences of $ref with their definitions. Recursive. :param schema: schema that will be patched @@ -80,14 +78,14 @@ def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> Non ) # expand refs in definitions as well schema.update(definition) else: - for key, value in schema.items(): + for key, value in schema.items(): # noqa: B007, PERF102 _expand_refs(value, ref_resolver=ref_resolver) - elif isinstance(schema, List): + elif isinstance(schema, list): for value in schema: _expand_refs(value, ref_resolver=ref_resolver) -def expand_refs(schema: Any) -> None: +def expand_refs(schema: Any) -> None: # noqa: ANN401 """Iterate over schema and replace all occurrences of $ref with their definitions. :param schema: schema that will be patched @@ -96,7 +94,7 @@ def expand_refs(schema: Any) -> None: schema.pop("definitions", None) # remove definitions created by $ref -def rename_key(schema: Any, old_key: str, new_key: str) -> None: +def rename_key(schema: Any, old_key: str, new_key: str) -> None: # noqa: ANN401 """Iterate over nested dictionary and replace one key with another. Used to replace anyOf with oneOf. Recursive." :param schema: schema that will be patched @@ -106,7 +104,7 @@ def rename_key(schema: Any, old_key: str, new_key: str) -> None: if not isinstance(schema, MutableMapping): return - for key, value in schema.items(): + for key, value in schema.items(): # noqa: B007, PERF102 rename_key(value, old_key, new_key) if old_key in schema: schema[new_key] = schema.pop(old_key) @@ -115,7 +113,7 @@ def rename_key(schema: Any, old_key: str, new_key: str) -> None: class ResourceSchemaLoader: """JSONSchema loader from package resources""" - def __init__(self, package_name: str): + def __init__(self, package_name: str): # noqa: ANN204 self.package_name = package_name def get_schema(self, name: str) -> dict[str, Any]: @@ -134,7 +132,7 @@ def get_schema(self, name: str) -> dict[str, Any]: schema_filename = f"schemas/{name}.json" raw_file = pkgutil.get_data(self.package_name, schema_filename) if not raw_file: - raise IOError(f"Cannot find file {schema_filename}") + raise OSError(f"Cannot find file {schema_filename}") try: raw_schema = json.loads(raw_file) except ValueError as err: @@ -152,7 +150,7 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An package = importlib.import_module(self.package_name) if package.__file__: - base = os.path.dirname(package.__file__) + "/" + base = os.path.dirname(package.__file__) + "/" # noqa: PTH120 else: raise ValueError(f"Package {package} does not have a valid __file__ field") resolved = jsonref.JsonRef.replace_refs( @@ -161,8 +159,7 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An resolved = resolve_ref_links(resolved) if isinstance(resolved, dict): return resolved - else: - raise ValueError(f"Expected resolved to be a dict. Got {resolved}") + raise ValueError(f"Expected resolved to be a dict. Got {resolved}") def check_config_against_spec_or_exit( @@ -191,7 +188,7 @@ class InternalConfig(BaseModel): limit: int = Field(None, alias="_limit") page_size: int = Field(None, alias="_page_size") - def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401 kwargs["by_alias"] = True kwargs["exclude_unset"] = True return super().dict(*args, **kwargs) @@ -202,13 +199,13 @@ def is_limit_reached(self, records_counter: int) -> bool: :param records_counter - number of records already red :return True if limit reached, False otherwise """ - if self.limit: + if self.limit: # noqa: SIM102 if records_counter >= self.limit: return True return False -def split_config(config: Mapping[str, Any]) -> Tuple[dict[str, Any], InternalConfig]: +def split_config(config: Mapping[str, Any]) -> tuple[dict[str, Any], InternalConfig]: """ Break config map object into 2 instances: first is a dict with user defined configuration and second is internal config that contains private keys for diff --git a/airbyte_cdk/sources/utils/slice_logger.py b/airbyte_cdk/sources/utils/slice_logger.py index ee802a7a6..f394ef72f 100644 --- a/airbyte_cdk/sources/utils/slice_logger.py +++ b/airbyte_cdk/sources/utils/slice_logger.py @@ -5,7 +5,8 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level from airbyte_cdk.models import Type as MessageType @@ -19,7 +20,7 @@ class SliceLogger(ABC): SLICE_LOG_PREFIX = "slice:" - def create_slice_log_message(self, _slice: Optional[Mapping[str, Any]]) -> AirbyteMessage: + def create_slice_log_message(self, _slice: Mapping[str, Any] | None) -> AirbyteMessage: """ Mapping is an interface that can be implemented in various ways. However, json.dumps will just do a `str()` if the slice is a class implementing Mapping. Therefore, we want to cast this as a dict before passing this to json.dump @@ -53,5 +54,5 @@ def should_log_slice_message(self, logger: logging.Logger) -> bool: class AlwaysLogSliceLogger(SliceLogger): - def should_log_slice_message(self, logger: logging.Logger) -> bool: + def should_log_slice_message(self, logger: logging.Logger) -> bool: # noqa: ARG002 return True diff --git a/airbyte_cdk/sources/utils/transform.py b/airbyte_cdk/sources/utils/transform.py index 05c299560..66898751e 100644 --- a/airbyte_cdk/sources/utils/transform.py +++ b/airbyte_cdk/sources/utils/transform.py @@ -3,12 +3,14 @@ # import logging +from collections.abc import Callable, Generator, Mapping from distutils.util import strtobool from enum import Flag, auto -from typing import Any, Callable, Dict, Generator, Mapping, Optional, cast +from typing import Any, cast from jsonschema import Draft7Validator, RefResolver, ValidationError, Validator, validators + MAX_NESTING_DEPTH = 3 json_to_python_simple = { "string": str, @@ -17,7 +19,7 @@ "boolean": bool, "null": type(None), } -json_to_python = {**json_to_python_simple, **{"object": dict, "array": list}} +json_to_python = {**json_to_python_simple, "object": dict, "array": list} python_to_json = {v: k for k, v in json_to_python.items()} logger = logging.getLogger("airbyte") @@ -47,15 +49,15 @@ class TypeTransformer: Class for transforming object before output. """ - _custom_normalizer: Optional[Callable[[Any, Dict[str, Any]], Any]] = None + _custom_normalizer: Callable[[Any, dict[str, Any]], Any] | None = None - def __init__(self, config: TransformConfig): + def __init__(self, config: TransformConfig): # noqa: ANN204 """ Initialize TypeTransformer instance. :param config Transform config that would be applied to object """ if TransformConfig.NoTransform in config and config != TransformConfig.NoTransform: - raise Exception("NoTransform option cannot be combined with other flags.") + raise Exception("NoTransform option cannot be combined with other flags.") # noqa: TRY002 self._config = config all_validators = { key: self.__get_normalizer(key, orig_validator) @@ -67,7 +69,7 @@ def __init__(self, config: TransformConfig): meta_schema=Draft7Validator.META_SCHEMA, validators=all_validators ) - def registerCustomTransform( + def registerCustomTransform( # noqa: N802 self, normalization_callback: Callable[[Any, dict[str, Any]], Any] ) -> Callable[[Any, dict[str, Any]], Any]: """ @@ -79,13 +81,13 @@ def registerCustomTransform( :return Same callback, this is useful for using registerCustomTransform function as decorator. """ if TransformConfig.CustomSchemaNormalization not in self._config: - raise Exception( + raise Exception( # noqa: TRY002 "Please set TransformConfig.CustomSchemaNormalization config before registering custom normalizer" ) self._custom_normalizer = normalization_callback return normalization_callback - def __normalize(self, original_item: Any, subschema: Dict[str, Any]) -> Any: + def __normalize(self, original_item: Any, subschema: dict[str, Any]) -> Any: # noqa: ANN401 """ Applies different transform function to object's field according to config. :param original_item original value of field. @@ -100,7 +102,7 @@ def __normalize(self, original_item: Any, subschema: Dict[str, Any]) -> Any: return original_item @staticmethod - def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any: + def default_convert(original_item: Any, subschema: dict[str, Any]) -> Any: # noqa: ANN401, PLR0911 """ Default transform function that is used when TransformConfig.DefaultSchemaNormalization flag set. :param original_item original value of field. @@ -123,15 +125,15 @@ def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any: try: if target_type == "string": return str(original_item) - elif target_type == "number": + if target_type == "number": return float(original_item) - elif target_type == "integer": + if target_type == "integer": return int(original_item) - elif target_type == "boolean": + if target_type == "boolean": if isinstance(original_item, str): return strtobool(original_item) == 1 return bool(original_item) - elif target_type == "array": + if target_type == "array": item_types = set(subschema.get("items", {}).get("type", set())) if ( item_types.issubset(json_to_python_simple) @@ -155,9 +157,9 @@ def __get_normalizer( def normalizator( validator_instance: Validator, - property_value: Any, - instance: Any, - schema: Dict[str, Any], + property_value: Any, # noqa: ANN401 + instance: Any, # noqa: ANN401 + schema: dict[str, Any], ) -> Generator[Any, Any, None]: """ Jsonschema validator callable it uses for validating instance. We @@ -174,10 +176,10 @@ def normalizator( def resolve(subschema: dict[str, Any]) -> dict[str, Any]: if "$ref" in subschema: _, resolved = cast( - RefResolver, + RefResolver, # noqa: TC006 validator_instance.resolver, ).resolve(subschema["$ref"]) - return cast(dict[str, Any], resolved) + return cast(dict[str, Any], resolved) # noqa: TC006 return subschema # Transform object and array values before running json schema type checking for each element. @@ -186,7 +188,7 @@ def resolve(subschema: dict[str, Any]) -> dict[str, Any]: if schema_key == "properties" and isinstance(instance, dict): for k, subschema in property_value.items(): if k in instance: - subschema = resolve(subschema) + subschema = resolve(subschema) # noqa: PLW2901 instance[k] = self.__normalize(instance[k], subschema) # Recursively normalize every item of the "instance" sub-array, # if "instance" is an incorrect type - skip recursive normalization of "instance" @@ -207,7 +209,7 @@ def resolve(subschema: dict[str, Any]) -> dict[str, Any]: def transform( self, - record: Dict[str, Any], + record: dict[str, Any], schema: Mapping[str, Any], ) -> None: """ @@ -234,7 +236,7 @@ def get_error_message(self, e: ValidationError) -> str: return f"Failed to transform value from type '{type_structure}' to type '{e.validator_value}' at path: '{field_path}'" - def _get_type_structure(self, input_data: Any, current_depth: int = 0) -> Any: + def _get_type_structure(self, input_data: Any, current_depth: int = 0) -> Any: # noqa: ANN401 """ Get the structure of a given input data for use in error message construction. """ @@ -252,5 +254,4 @@ def _get_type_structure(self, input_data: Any, current_depth: int = 0) -> Any: for key, field_value in input_data.items() } - else: - return python_to_json[type(input_data)] + return python_to_json[type(input_data)] diff --git a/airbyte_cdk/sources/utils/types.py b/airbyte_cdk/sources/utils/types.py index 9dc5e253b..5a6f9b2a6 100644 --- a/airbyte_cdk/sources/utils/types.py +++ b/airbyte_cdk/sources/utils/types.py @@ -1,7 +1,8 @@ -# +# # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # from typing import Union -JsonType = Union[dict[str, "JsonType"], list["JsonType"], str, int, float, bool, None] + +JsonType = Union[dict[str, "JsonType"], list["JsonType"], str, int, float, bool, None] # noqa: UP007 diff --git a/airbyte_cdk/sql/_util/hashing.py b/airbyte_cdk/sql/_util/hashing.py index 781305c48..a88798c7c 100644 --- a/airbyte_cdk/sql/_util/hashing.py +++ b/airbyte_cdk/sql/_util/hashing.py @@ -6,6 +6,7 @@ import hashlib from collections.abc import Mapping + HASH_SEED = "Airbyte:" """Additional seed for randomizing one-way hashed strings.""" diff --git a/airbyte_cdk/sql/_util/name_normalizers.py b/airbyte_cdk/sql/_util/name_normalizers.py index 9311432d7..a479128bd 100644 --- a/airbyte_cdk/sql/_util/name_normalizers.py +++ b/airbyte_cdk/sql/_util/name_normalizers.py @@ -10,6 +10,7 @@ from airbyte_cdk.sql import exceptions as exc + if TYPE_CHECKING: from collections.abc import Iterable @@ -69,7 +70,7 @@ def normalize(name: str) -> str: result = name # Replace all non-alphanumeric characters with underscores. - result = re.sub("[^A-Za-z0-9]", "_", result.lower()) + result = re.sub("[^A-Za-z0-9]", "_", result.lower()) # noqa: RUF039 # Check if name starts with a number and prepend "_" if it does. if result and result[0].isdigit(): diff --git a/airbyte_cdk/sql/constants.py b/airbyte_cdk/sql/constants.py index 2f7de7817..b499d31f6 100644 --- a/airbyte_cdk/sql/constants.py +++ b/airbyte_cdk/sql/constants.py @@ -3,6 +3,7 @@ from __future__ import annotations + DEBUG_MODE = False # Set to True to enable additional debug logging. AB_EXTRACTED_AT_COLUMN = "_airbyte_extracted_at" diff --git a/airbyte_cdk/sql/exceptions.py b/airbyte_cdk/sql/exceptions.py index 963dc4696..797486052 100644 --- a/airbyte_cdk/sql/exceptions.py +++ b/airbyte_cdk/sql/exceptions.py @@ -44,6 +44,7 @@ from textwrap import indent from typing import Any + NEW_ISSUE_URL = "https://github.com/airbytehq/airbyte/issues/new/choose" DOCS_URL_BASE = "https://https://docs.airbyte.com/" DOCS_URL = f"{DOCS_URL_BASE}/airbyte.html" diff --git a/airbyte_cdk/sql/secrets.py b/airbyte_cdk/sql/secrets.py index c2508682d..766400b36 100644 --- a/airbyte_cdk/sql/secrets.py +++ b/airbyte_cdk/sql/secrets.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. # noqa: A005 """Base classes and methods for working with secrets in Airbyte.""" from __future__ import annotations @@ -10,6 +10,7 @@ from airbyte_cdk.sql import exceptions as exc + if TYPE_CHECKING: from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationInfo from pydantic.json_schema import JsonSchemaValue @@ -65,7 +66,7 @@ def __bool__(self) -> bool: """ return True - def parse_json(self) -> Any: + def parse_json(self) -> Any: # noqa: ANN401 """Parse the secret string as JSON.""" try: return json.loads(self) @@ -96,7 +97,7 @@ def validate( return cls(v) @classmethod - def __get_pydantic_core_schema__( # noqa: PLW3201 # Pydantic dunder + def __get_pydantic_core_schema__( # Pydantic dunder # noqa: PLW3201 cls, source_type: Any, # noqa: ANN401 # Must allow `Any` to match Pydantic signature handler: GetCoreSchemaHandler, @@ -107,15 +108,17 @@ def __get_pydantic_core_schema__( # noqa: PLW3201 # Pydantic dunder ) @classmethod - def __get_pydantic_json_schema__( # noqa: PLW3201 # Pydantic dunder method - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + def __get_pydantic_json_schema__( # Pydantic dunder method # noqa: PLW3201 + cls, + core_schema_: core_schema.CoreSchema, + handler: GetJsonSchemaHandler, ) -> JsonSchemaValue: """Return a modified JSON schema for the secret string. - `writeOnly=True` is the official way to prevent secrets from being exposed inadvertently. - `Format=password` is a popular and readable convention to indicate the field is sensitive. """ - _ = _core_schema, handler # Unused + _ = core_schema_, handler # Unused return { "type": "string", "format": "password", diff --git a/airbyte_cdk/sql/shared/__init__.py b/airbyte_cdk/sql/shared/__init__.py index d9156b9d0..bbc4260d7 100644 --- a/airbyte_cdk/sql/shared/__init__.py +++ b/airbyte_cdk/sql/shared/__init__.py @@ -10,6 +10,7 @@ from airbyte_cdk.sql.shared.sql_processor import SqlProcessorBase + __all__ = [ "SqlProcessorBase", ] diff --git a/airbyte_cdk/sql/shared/catalog_providers.py b/airbyte_cdk/sql/shared/catalog_providers.py index 80713a35a..93021162c 100644 --- a/airbyte_cdk/sql/shared/catalog_providers.py +++ b/airbyte_cdk/sql/shared/catalog_providers.py @@ -10,10 +10,11 @@ from typing import TYPE_CHECKING, Any, cast, final -from airbyte_cdk.models import ConfiguredAirbyteCatalog +from airbyte_cdk.models import ConfiguredAirbyteCatalog # noqa: TC001 from airbyte_cdk.sql import exceptions as exc from airbyte_cdk.sql._util.name_normalizers import LowerCaseNormalizer + if TYPE_CHECKING: from airbyte_cdk.models import ConfiguredAirbyteStream @@ -40,7 +41,7 @@ def __init__( self._catalog: ConfiguredAirbyteCatalog = self.validate_catalog(configured_catalog) @staticmethod - def validate_catalog(catalog: ConfiguredAirbyteCatalog) -> Any: + def validate_catalog(catalog: ConfiguredAirbyteCatalog) -> Any: # noqa: ANN401 """Validate the catalog to ensure it is valid. This requires ensuring that `generationId` and `minGenerationId` are both set. If @@ -107,14 +108,14 @@ def get_stream_json_schema( stream_name: str, ) -> dict[str, Any]: """Return the column definitions for the given stream.""" - return cast(dict[str, Any], self.get_configured_stream_info(stream_name).stream.json_schema) + return cast(dict[str, Any], self.get_configured_stream_info(stream_name).stream.json_schema) # noqa: TC006 def get_stream_properties( self, stream_name: str, ) -> dict[str, dict[str, Any]]: """Return the names of the top-level properties for the given stream.""" - return cast(dict[str, Any], self.get_stream_json_schema(stream_name)["properties"]) + return cast(dict[str, Any], self.get_stream_json_schema(stream_name)["properties"]) # noqa: TC006 def get_primary_keys( self, diff --git a/airbyte_cdk/sql/shared/sql_processor.py b/airbyte_cdk/sql/shared/sql_processor.py index 5fd0a5e46..f85d1f1be 100644 --- a/airbyte_cdk/sql/shared/sql_processor.py +++ b/airbyte_cdk/sql/shared/sql_processor.py @@ -13,12 +13,13 @@ import pandas as pd import sqlalchemy import ulid -from airbyte_protocol_dataclasses.models import AirbyteStateMessage from pandas import Index from pydantic import BaseModel, Field from sqlalchemy import Column, Table, and_, create_engine, insert, null, select, text, update from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from airbyte_protocol_dataclasses.models import AirbyteStateMessage # noqa: TC001 + from airbyte_cdk.sql import exceptions as exc from airbyte_cdk.sql._util.hashing import one_way_hash from airbyte_cdk.sql._util.name_normalizers import LowerCaseNormalizer @@ -31,6 +32,7 @@ from airbyte_cdk.sql.secrets import SecretString from airbyte_cdk.sql.types import SQLTypeConverter + if TYPE_CHECKING: from collections.abc import Generator @@ -79,13 +81,11 @@ def config_hash(self) -> str | None: """ return one_way_hash( SecretString( - ":".join( - [ - str(self.get_sql_alchemy_url()), - self.schema_name or "", - self.table_prefix or "", - ] - ) + ":".join([ + str(self.get_sql_alchemy_url()), + self.schema_name or "", + self.table_prefix or "", + ]) ) ) @@ -112,7 +112,7 @@ def get_vendor_client(self) -> object: ) -class SqlProcessorBase(abc.ABC): +class SqlProcessorBase(abc.ABC): # noqa: B024 """A base class to be used for SQL Caches.""" type_converter_class: type[SQLTypeConverter] = SQLTypeConverter @@ -326,9 +326,9 @@ def _ensure_schema_exists( if DEBUG_MODE: found_schemas = schemas_list - assert ( - schema_name in found_schemas - ), f"Schema {schema_name} was not created. Found: {found_schemas}" + assert schema_name in found_schemas, ( + f"Schema {schema_name} was not created. Found: {found_schemas}" + ) def _quote_identifier(self, identifier: str) -> str: """Return the given identifier, quoted.""" @@ -617,10 +617,10 @@ def _append_temp_table_to_final_table( self._execute_sql( f""" INSERT INTO {self._fully_qualified(final_table_name)} ( - {f',{nl} '.join(columns)} + {f",{nl} ".join(columns)} ) SELECT - {f',{nl} '.join(columns)} + {f",{nl} ".join(columns)} FROM {self._fully_qualified(temp_table_name)} """, ) @@ -643,15 +643,11 @@ def _swap_temp_table_with_final_table( _ = stream_name deletion_name = f"{final_table_name}_deleteme" - commands = "\n".join( - [ - f"ALTER TABLE {self._fully_qualified(final_table_name)} RENAME " - f"TO {deletion_name};", - f"ALTER TABLE {self._fully_qualified(temp_table_name)} RENAME " - f"TO {final_table_name};", - f"DROP TABLE {self._fully_qualified(deletion_name)};", - ] - ) + commands = "\n".join([ + f"ALTER TABLE {self._fully_qualified(final_table_name)} RENAME TO {deletion_name};", + f"ALTER TABLE {self._fully_qualified(temp_table_name)} RENAME TO {final_table_name};", + f"DROP TABLE {self._fully_qualified(deletion_name)};", + ]) self._execute_sql(commands) def _merge_temp_table_to_final_table( @@ -686,10 +682,10 @@ def _merge_temp_table_to_final_table( {set_clause} WHEN NOT MATCHED THEN INSERT ( - {f',{nl} '.join(columns)} + {f",{nl} ".join(columns)} ) VALUES ( - tmp.{f',{nl} tmp.'.join(columns)} + tmp.{f",{nl} tmp.".join(columns)} ); """, ) diff --git a/airbyte_cdk/sql/types.py b/airbyte_cdk/sql/types.py index bb6fa1cb7..dd60487d1 100644 --- a/airbyte_cdk/sql/types.py +++ b/airbyte_cdk/sql/types.py @@ -1,4 +1,4 @@ -# noqa: A005 # Allow shadowing the built-in 'types' module +# Allow shadowing the built-in 'types' module # noqa: A005 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. """Type conversion methods for SQL Caches.""" @@ -9,6 +9,7 @@ import sqlalchemy + # Compare to documentation here: https://docs.airbyte.com/understanding-airbyte/supported-data-types CONVERSION_MAP = { "string": sqlalchemy.types.VARCHAR, @@ -39,7 +40,7 @@ def _get_airbyte_type( # noqa: PLR0911 # Too many return statements Subtype is only used for array types. Otherwise, subtype will return None. """ - airbyte_type = cast(str, json_schema_property_def.get("airbyte_type", None)) + airbyte_type = cast(str, json_schema_property_def.get("airbyte_type", None)) # noqa: TC006 if airbyte_type: return airbyte_type, None @@ -122,7 +123,7 @@ def get_json_type(cls) -> sqlalchemy.types.TypeEngine[Any]: def to_sql_type( # noqa: PLR0911 # Too many return statements self, json_schema_property_def: dict[str, str | dict[str, Any] | list[Any]], - ) -> Any: + ) -> Any: # noqa: ANN401 """Convert a value to a SQL type.""" try: airbyte_type, _ = _get_airbyte_type(json_schema_property_def) diff --git a/airbyte_cdk/test/catalog_builder.py b/airbyte_cdk/test/catalog_builder.py index b1bf4341c..0cd844b07 100644 --- a/airbyte_cdk/test/catalog_builder.py +++ b/airbyte_cdk/test/catalog_builder.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -from typing import Any, Dict, List, Union, overload +from typing import Any, overload from airbyte_cdk.models import ( ConfiguredAirbyteCatalog, @@ -12,7 +12,7 @@ class ConfiguredAirbyteStreamBuilder: def __init__(self) -> None: - self._stream: Dict[str, Any] = { + self._stream: dict[str, Any] = { "stream": { "name": "any name", "json_schema": {}, @@ -32,12 +32,12 @@ def with_sync_mode(self, sync_mode: SyncMode) -> "ConfiguredAirbyteStreamBuilder self._stream["sync_mode"] = sync_mode.name return self - def with_primary_key(self, pk: List[List[str]]) -> "ConfiguredAirbyteStreamBuilder": + def with_primary_key(self, pk: list[list[str]]) -> "ConfiguredAirbyteStreamBuilder": self._stream["primary_key"] = pk self._stream["stream"]["source_defined_primary_key"] = pk # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any] return self - def with_json_schema(self, json_schema: Dict[str, Any]) -> "ConfiguredAirbyteStreamBuilder": + def with_json_schema(self, json_schema: dict[str, Any]) -> "ConfiguredAirbyteStreamBuilder": self._stream["stream"]["json_schema"] = json_schema return self @@ -47,7 +47,7 @@ def build(self) -> ConfiguredAirbyteStream: class CatalogBuilder: def __init__(self) -> None: - self._streams: List[ConfiguredAirbyteStreamBuilder] = [] + self._streams: list[ConfiguredAirbyteStreamBuilder] = [] @overload def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder": ... @@ -57,8 +57,8 @@ def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder": ... def with_stream( self, - name: Union[str, ConfiguredAirbyteStreamBuilder], - sync_mode: Union[SyncMode, None] = None, + name: str | ConfiguredAirbyteStreamBuilder, + sync_mode: SyncMode | None = None, ) -> "CatalogBuilder": # As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface # with_stream(str, SyncMode) @@ -77,5 +77,5 @@ def with_stream( def build(self) -> ConfiguredAirbyteCatalog: return ConfiguredAirbyteCatalog( - streams=list(map(lambda builder: builder.build(), self._streams)) + streams=list(map(lambda builder: builder.build(), self._streams)) # noqa: C417 ) diff --git a/airbyte_cdk/test/entrypoint_wrapper.py b/airbyte_cdk/test/entrypoint_wrapper.py index f8e85bfb0..c7fd15e5a 100644 --- a/airbyte_cdk/test/entrypoint_wrapper.py +++ b/airbyte_cdk/test/entrypoint_wrapper.py @@ -19,9 +19,10 @@ import re import tempfile import traceback +from collections.abc import Mapping from io import StringIO from pathlib import Path -from typing import Any, List, Mapping, Optional, Union +from typing import Any import orjson from pydantic import ValidationError as V2ValidationError @@ -47,7 +48,7 @@ class EntrypointOutput: - def __init__(self, messages: List[str], uncaught_exception: Optional[BaseException] = None): + def __init__(self, messages: list[str], uncaught_exception: BaseException | None = None): # noqa: ANN204 try: self._messages = [self._parse_message(message) for message in messages] except V2ValidationError as exception: @@ -71,38 +72,38 @@ def _parse_message(message: str) -> AirbyteMessage: ) @property - def records_and_state_messages(self) -> List[AirbyteMessage]: + def records_and_state_messages(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.RECORD, Type.STATE]) @property - def records(self) -> List[AirbyteMessage]: + def records(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.RECORD]) @property - def state_messages(self) -> List[AirbyteMessage]: + def state_messages(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.STATE]) @property - def most_recent_state(self) -> Any: + def most_recent_state(self) -> Any: # noqa: ANN401 state_messages = self._get_message_by_types([Type.STATE]) if not state_messages: raise ValueError("Can't provide most recent state as there are no state messages") return state_messages[-1].state.stream # type: ignore[union-attr] # state has `stream` @property - def logs(self) -> List[AirbyteMessage]: + def logs(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.LOG]) @property - def trace_messages(self) -> List[AirbyteMessage]: + def trace_messages(self) -> list[AirbyteMessage]: return self._get_message_by_types([Type.TRACE]) @property - def analytics_messages(self) -> List[AirbyteMessage]: + def analytics_messages(self) -> list[AirbyteMessage]: return self._get_trace_message_by_trace_type(TraceType.ANALYTICS) @property - def errors(self) -> List[AirbyteMessage]: + def errors(self) -> list[AirbyteMessage]: return self._get_trace_message_by_trace_type(TraceType.ERROR) @property @@ -112,8 +113,8 @@ def catalog(self) -> AirbyteMessage: raise ValueError(f"Expected exactly one catalog but got {len(catalog)}") return catalog[0] - def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]: - status_messages = map( + def get_stream_statuses(self, stream_name: str) -> list[AirbyteStreamStatus]: + status_messages = map( # noqa: C417 lambda message: message.trace.stream_status.status, # type: ignore filter( lambda message: message.trace.stream_status.stream_descriptor.name == stream_name, # type: ignore # callable; trace has `stream_status` @@ -122,10 +123,10 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]: ) return list(status_messages) - def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessage]: + def _get_message_by_types(self, message_types: list[Type]) -> list[AirbyteMessage]: return [message for message in self._messages if message.type in message_types] - def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]: + def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> list[AirbyteMessage]: return [ message for message in self._get_message_by_types([Type.TRACE]) @@ -149,7 +150,9 @@ def is_not_in_logs(self, pattern: str) -> bool: def _run_command( - source: Source, args: List[str], expecting_exception: bool = False + source: Source, + args: list[str], + expecting_exception: bool = False, # noqa: FBT001, FBT002 ) -> EntrypointOutput: log_capture_buffer = StringIO() stream_handler = logging.StreamHandler(log_capture_buffer) @@ -165,7 +168,7 @@ def _run_command( uncaught_exception = None try: for message in source_entrypoint.run(parsed_args): - messages.append(message) + messages.append(message) # noqa: PERF402 except Exception as exception: if not expecting_exception: print("Printing unexpected error from entrypoint_wrapper") @@ -182,7 +185,7 @@ def _run_command( def discover( source: Source, config: Mapping[str, Any], - expecting_exception: bool = False, + expecting_exception: bool = False, # noqa: FBT001, FBT002 ) -> EntrypointOutput: """ config must be json serializable @@ -203,8 +206,8 @@ def read( source: Source, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, - state: Optional[List[AirbyteStateMessage]] = None, - expecting_exception: bool = False, + state: list[AirbyteStateMessage] | None = None, + expecting_exception: bool = False, # noqa: FBT001, FBT002 ) -> EntrypointOutput: """ config and state must be json serializable @@ -227,21 +230,19 @@ def read( catalog_file, ] if state is not None: - args.extend( - [ - "--state", - make_file( - tmp_directory_path / "state.json", - f"[{','.join([orjson.dumps(AirbyteStateMessageSerializer.dump(stream_state)).decode() for stream_state in state])}]", - ), - ] - ) + args.extend([ + "--state", + make_file( + tmp_directory_path / "state.json", + f"[{','.join([orjson.dumps(AirbyteStateMessageSerializer.dump(stream_state)).decode() for stream_state in state])}]", + ), + ]) return _run_command(source, args, expecting_exception) def make_file( - path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]] + path: Path, file_contents: str | Mapping[str, Any] | list[Mapping[str, Any]] | None ) -> str: if isinstance(file_contents, str): path.write_text(file_contents) diff --git a/airbyte_cdk/test/mock_http/__init__.py b/airbyte_cdk/test/mock_http/__init__.py index fdd454d2a..dd8887251 100644 --- a/airbyte_cdk/test/mock_http/__init__.py +++ b/airbyte_cdk/test/mock_http/__init__.py @@ -3,4 +3,5 @@ from airbyte_cdk.test.mock_http.request import HttpRequest from airbyte_cdk.test.mock_http.response import HttpResponse + __all__ = ["HttpMocker", "HttpRequest", "HttpRequestMatcher", "HttpResponse"] diff --git a/airbyte_cdk/test/mock_http/matcher.py b/airbyte_cdk/test/mock_http/matcher.py index d07cec3ec..df600526f 100644 --- a/airbyte_cdk/test/mock_http/matcher.py +++ b/airbyte_cdk/test/mock_http/matcher.py @@ -1,11 +1,10 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -from typing import Any from airbyte_cdk.test.mock_http.request import HttpRequest -class HttpRequestMatcher: - def __init__(self, request: HttpRequest, minimum_number_of_expected_match: int): +class HttpRequestMatcher: # noqa: PLW1641 + def __init__(self, request: HttpRequest, minimum_number_of_expected_match: int): # noqa: ANN204 self._request_to_match = request self._minimum_number_of_expected_match = minimum_number_of_expected_match self._actual_number_of_matches = 0 @@ -35,7 +34,7 @@ def __str__(self) -> str: f"actual_number_of_matches={self._actual_number_of_matches})" ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, HttpRequestMatcher): return self._request_to_match == other._request_to_match return False diff --git a/airbyte_cdk/test/mock_http/mocker.py b/airbyte_cdk/test/mock_http/mocker.py index cd1b1f9a7..b14a55472 100644 --- a/airbyte_cdk/test/mock_http/mocker.py +++ b/airbyte_cdk/test/mock_http/mocker.py @@ -2,9 +2,9 @@ import contextlib import functools +from collections.abc import Callable from enum import Enum from types import TracebackType -from typing import Callable, List, Optional, Union import requests_mock @@ -39,7 +39,7 @@ class HttpMocker(contextlib.ContextDecorator): def __init__(self) -> None: self._mocker = requests_mock.Mocker() - self._matchers: List[HttpRequestMatcher] = [] + self._matchers: list[HttpRequestMatcher] = [] def __enter__(self) -> "HttpMocker": self._mocker.__enter__() @@ -47,9 +47,9 @@ def __enter__(self) -> "HttpMocker": def __exit__( self, - exc_type: Optional[BaseException], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: BaseException | None, # noqa: PYI036 + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self._mocker.__exit__(exc_type, exc_val, exc_tb) @@ -62,7 +62,7 @@ def _mock_request_method( self, method: SupportedHttpMethods, request: HttpRequest, - responses: Union[HttpResponse, List[HttpResponse]], + responses: HttpResponse | list[HttpResponse], ) -> None: if isinstance(responses, HttpResponse): responses = [responses] @@ -85,22 +85,16 @@ def _mock_request_method( ], ) - def get(self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]]) -> None: + def get(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.GET, request, responses) - def patch( - self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] - ) -> None: + def patch(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.PATCH, request, responses) - def post( - self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] - ) -> None: + def post(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.POST, request, responses) - def delete( - self, request: HttpRequest, responses: Union[HttpResponse, List[HttpResponse]] - ) -> None: + def delete(self, request: HttpRequest, responses: HttpResponse | list[HttpResponse]) -> None: self._mock_request_method(SupportedHttpMethods.DELETE, request, responses) @staticmethod @@ -131,9 +125,9 @@ def assert_number_of_calls(self, request: HttpRequest, number_of_calls: int) -> assert corresponding_matchers[0].actual_number_of_matches == number_of_calls # trying to type that using callables provides the error `incompatible with return type "_F" in supertype "ContextDecorator"` - def __call__(self, f): # type: ignore + def __call__(self, f): # type: ignore # noqa: ANN001, ANN204 @functools.wraps(f) - def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper that does not need to be typed + def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper that does not need to be typed # noqa: ANN002, ANN202 with self: assertion_error = None @@ -142,7 +136,7 @@ def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper result = f(*args, **kwargs) except requests_mock.NoMockAddress as no_mock_exception: matchers_as_string = "\n\t".join( - map(lambda matcher: str(matcher.request), self._matchers) + map(lambda matcher: str(matcher.request), self._matchers) # noqa: C417 ) raise ValueError( f"No matcher matches {no_mock_exception.args[0]} with headers `{no_mock_exception.request.headers}` " diff --git a/airbyte_cdk/test/mock_http/request.py b/airbyte_cdk/test/mock_http/request.py index 7209513d8..dbcee4487 100644 --- a/airbyte_cdk/test/mock_http/request.py +++ b/airbyte_cdk/test/mock_http/request.py @@ -1,9 +1,11 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. import json -from typing import Any, List, Mapping, Optional, Union +from collections.abc import Mapping +from typing import Any from urllib.parse import parse_qs, urlencode, urlparse + ANY_QUERY_PARAMS = "any query_parameters" @@ -11,13 +13,13 @@ def _is_subdict(small: Mapping[str, str], big: Mapping[str, str]) -> bool: return dict(big, **small) == big -class HttpRequest: +class HttpRequest: # noqa: PLW1641 def __init__( self, url: str, - query_params: Optional[Union[str, Mapping[str, Union[str, List[str]]]]] = None, - headers: Optional[Mapping[str, str]] = None, - body: Optional[Union[str, bytes, Mapping[str, Any]]] = None, + query_params: str | Mapping[str, str | list[str]] | None = None, + headers: Mapping[str, str] | None = None, + body: str | bytes | Mapping[str, Any] | None = None, ) -> None: self._parsed_url = urlparse(url) self._query_params = query_params @@ -32,12 +34,12 @@ def __init__( self._body = body @staticmethod - def _encode_qs(query_params: Union[str, Mapping[str, Union[str, List[str]]]]) -> str: + def _encode_qs(query_params: str | Mapping[str, str | list[str]]) -> str: if isinstance(query_params, str): return query_params return urlencode(query_params, doseq=True) - def matches(self, other: Any) -> bool: + def matches(self, other: Any) -> bool: # noqa: ANN401 """ If the body of any request is a Mapping, we compare as Mappings which means that the order is not important. If the body is a string, encoding ISO-8859-1 will be assumed @@ -45,41 +47,41 @@ def matches(self, other: Any) -> bool: """ if isinstance(other, HttpRequest): # if `other` is a mapping, we match as an object and formatting is not considers - if isinstance(self._body, Mapping) or isinstance(other._body, Mapping): - body_match = self._to_mapping(self._body) == self._to_mapping(other._body) + if isinstance(self._body, Mapping) or isinstance(other._body, Mapping): # noqa: SLF001 + body_match = self._to_mapping(self._body) == self._to_mapping(other._body) # noqa: SLF001 else: - body_match = self._to_bytes(self._body) == self._to_bytes(other._body) + body_match = self._to_bytes(self._body) == self._to_bytes(other._body) # noqa: SLF001 return ( - self._parsed_url.scheme == other._parsed_url.scheme - and self._parsed_url.hostname == other._parsed_url.hostname - and self._parsed_url.path == other._parsed_url.path + self._parsed_url.scheme == other._parsed_url.scheme # noqa: SLF001 + and self._parsed_url.hostname == other._parsed_url.hostname # noqa: SLF001 + and self._parsed_url.path == other._parsed_url.path # noqa: SLF001 and ( - ANY_QUERY_PARAMS in (self._query_params, other._query_params) - or parse_qs(self._parsed_url.query) == parse_qs(other._parsed_url.query) + ANY_QUERY_PARAMS in (self._query_params, other._query_params) # noqa: SLF001 + or parse_qs(self._parsed_url.query) == parse_qs(other._parsed_url.query) # noqa: SLF001 ) - and _is_subdict(other._headers, self._headers) + and _is_subdict(other._headers, self._headers) # noqa: SLF001 and body_match ) return False @staticmethod def _to_mapping( - body: Optional[Union[str, bytes, Mapping[str, Any]]], - ) -> Optional[Mapping[str, Any]]: + body: str | bytes | Mapping[str, Any] | None, + ) -> Mapping[str, Any] | None: if isinstance(body, Mapping): return body - elif isinstance(body, bytes): + if isinstance(body, bytes): return json.loads(body.decode()) # type: ignore # assumes return type of Mapping[str, Any] - elif isinstance(body, str): + if isinstance(body, str): return json.loads(body) # type: ignore # assumes return type of Mapping[str, Any] return None @staticmethod - def _to_bytes(body: Optional[Union[str, bytes]]) -> bytes: + def _to_bytes(body: str | bytes | None) -> bytes: if isinstance(body, bytes): return body - elif isinstance(body, str): + if isinstance(body, str): # `ISO-8859-1` is the default encoding used by requests return body.encode("ISO-8859-1") return b"" @@ -92,7 +94,7 @@ def __repr__(self) -> str: f"HttpRequest(request={self._parsed_url}, headers={self._headers}, body={self._body!r})" ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, HttpRequest): return ( self._parsed_url == other._parsed_url diff --git a/airbyte_cdk/test/mock_http/response.py b/airbyte_cdk/test/mock_http/response.py index 848be55a0..8a216e0cc 100644 --- a/airbyte_cdk/test/mock_http/response.py +++ b/airbyte_cdk/test/mock_http/response.py @@ -1,11 +1,11 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from collections.abc import Mapping from types import MappingProxyType -from typing import Mapping class HttpResponse: - def __init__( + def __init__( # noqa: ANN204 self, body: str, status_code: int = 200, headers: Mapping[str, str] = MappingProxyType({}) ): self._body = body diff --git a/airbyte_cdk/test/mock_http/response_builder.py b/airbyte_cdk/test/mock_http/response_builder.py index 41766af1b..016649c13 100644 --- a/airbyte_cdk/test/mock_http/response_builder.py +++ b/airbyte_cdk/test/mock_http/response_builder.py @@ -4,24 +4,24 @@ import json from abc import ABC, abstractmethod from pathlib import Path as FilePath -from typing import Any, Dict, List, Optional, Union +from typing import Any from airbyte_cdk.test.mock_http.response import HttpResponse from airbyte_cdk.test.utils.data import get_unit_test_folder -def _extract(path: List[str], response_template: Dict[str, Any]) -> Any: +def _extract(path: list[str], response_template: dict[str, Any]) -> Any: # noqa: ANN401 return functools.reduce(lambda a, b: a[b], path, response_template) -def _replace_value(dictionary: Dict[str, Any], path: List[str], value: Any) -> None: +def _replace_value(dictionary: dict[str, Any], path: list[str], value: Any) -> None: # noqa: ANN401 current = dictionary for key in path[:-1]: current = current[key] current[path[-1]] = value -def _write(dictionary: Dict[str, Any], path: List[str], value: Any) -> None: +def _write(dictionary: dict[str, Any], path: list[str], value: Any) -> None: # noqa: ANN401 current = dictionary for key in path[:-1]: current = current.setdefault(key, {}) @@ -30,28 +30,28 @@ def _write(dictionary: Dict[str, Any], path: List[str], value: Any) -> None: class Path(ABC): @abstractmethod - def write(self, template: Dict[str, Any], value: Any) -> None: + def write(self, template: dict[str, Any], value: Any) -> None: # noqa: ANN401 pass @abstractmethod - def update(self, template: Dict[str, Any], value: Any) -> None: + def update(self, template: dict[str, Any], value: Any) -> None: # noqa: ANN401 pass - def extract(self, template: Dict[str, Any]) -> Any: + def extract(self, template: dict[str, Any]) -> Any: # noqa: ANN401, B027 pass class FieldPath(Path): - def __init__(self, field: str): + def __init__(self, field: str): # noqa: ANN204 self._path = [field] - def write(self, template: Dict[str, Any], value: Any) -> None: + def write(self, template: dict[str, Any], value: Any) -> None: # noqa: ANN401 _write(template, self._path, value) - def update(self, template: Dict[str, Any], value: Any) -> None: + def update(self, template: dict[str, Any], value: Any) -> None: # noqa: ANN401 _replace_value(template, self._path, value) - def extract(self, template: Dict[str, Any]) -> Any: + def extract(self, template: dict[str, Any]) -> Any: # noqa: ANN401 return _extract(self._path, template) def __str__(self) -> str: @@ -59,16 +59,16 @@ def __str__(self) -> str: class NestedPath(Path): - def __init__(self, path: List[str]): + def __init__(self, path: list[str]): # noqa: ANN204 self._path = path - def write(self, template: Dict[str, Any], value: Any) -> None: + def write(self, template: dict[str, Any], value: Any) -> None: # noqa: ANN401 _write(template, self._path, value) - def update(self, template: Dict[str, Any], value: Any) -> None: + def update(self, template: dict[str, Any], value: Any) -> None: # noqa: ANN401 _replace_value(template, self._path, value) - def extract(self, template: Dict[str, Any]) -> Any: + def extract(self, template: dict[str, Any]) -> Any: # noqa: ANN401 return _extract(self._path, template) def __str__(self) -> str: @@ -77,25 +77,25 @@ def __str__(self) -> str: class PaginationStrategy(ABC): @abstractmethod - def update(self, response: Dict[str, Any]) -> None: + def update(self, response: dict[str, Any]) -> None: pass class FieldUpdatePaginationStrategy(PaginationStrategy): - def __init__(self, path: Path, value: Any): + def __init__(self, path: Path, value: Any): # noqa: ANN204, ANN401 self._path = path self._value = value - def update(self, response: Dict[str, Any]) -> None: + def update(self, response: dict[str, Any]) -> None: self._path.update(response, self._value) class RecordBuilder: - def __init__( + def __init__( # noqa: ANN204 self, - template: Dict[str, Any], - id_path: Optional[Path], - cursor_path: Optional[Union[FieldPath, NestedPath]], + template: dict[str, Any], + id_path: Path | None, + cursor_path: FieldPath | NestedPath | None, ): self._record = template self._id_path = id_path @@ -111,7 +111,7 @@ def _validate_template(self) -> None: for field_name, field_path in paths_to_validate: self._validate_field(field_name, field_path) - def _validate_field(self, field_name: str, path: Optional[Path]) -> None: + def _validate_field(self, field_name: str, path: Path | None) -> None: try: if path and not path.extract(self._record): raise ValueError( @@ -122,19 +122,19 @@ def _validate_field(self, field_name: str, path: Optional[Path]) -> None: f"{field_name} `{path}` was provided but it is not part of the template `{self._record}`" ) from exception - def with_id(self, identifier: Any) -> "RecordBuilder": + def with_id(self, identifier: Any) -> "RecordBuilder": # noqa: ANN401 self._set_field("id", self._id_path, identifier) return self - def with_cursor(self, cursor_value: Any) -> "RecordBuilder": + def with_cursor(self, cursor_value: Any) -> "RecordBuilder": # noqa: ANN401 self._set_field("cursor", self._cursor_path, cursor_value) return self - def with_field(self, path: Path, value: Any) -> "RecordBuilder": + def with_field(self, path: Path, value: Any) -> "RecordBuilder": # noqa: ANN401 path.write(self._record, value) return self - def _set_field(self, field_name: str, path: Optional[Path], value: Any) -> None: + def _set_field(self, field_name: str, path: Path | None, value: Any) -> None: # noqa: ANN401 if not path: raise ValueError( f"{field_name}_path was not provided and hence, the record {field_name} can't be modified. Please provide `id_field` while " @@ -142,19 +142,19 @@ def _set_field(self, field_name: str, path: Optional[Path], value: Any) -> None: ) path.update(self._record, value) - def build(self) -> Dict[str, Any]: + def build(self) -> dict[str, Any]: return self._record class HttpResponseBuilder: - def __init__( + def __init__( # noqa: ANN204 self, - template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - pagination_strategy: Optional[PaginationStrategy], + template: dict[str, Any], + records_path: FieldPath | NestedPath, + pagination_strategy: PaginationStrategy | None, ): self._response = template - self._records: List[RecordBuilder] = [] + self._records: list[RecordBuilder] = [] self._records_path = records_path self._pagination_strategy = pagination_strategy self._status_code = 200 @@ -182,11 +182,11 @@ def build(self) -> HttpResponse: def _get_unit_test_folder(execution_folder: str) -> FilePath: - # FIXME: This function should be removed after the next CDK release to avoid breaking amazon-seller-partner test code. + # FIXME: This function should be removed after the next CDK release to avoid breaking amazon-seller-partner test code. # noqa: FIX001, TD001 return get_unit_test_folder(execution_folder) -def find_template(resource: str, execution_folder: str) -> Dict[str, Any]: +def find_template(resource: str, execution_folder: str) -> dict[str, Any]: response_template_filepath = str( get_unit_test_folder(execution_folder) / "resource" @@ -194,15 +194,15 @@ def find_template(resource: str, execution_folder: str) -> Dict[str, Any]: / "response" / f"{resource}.json" ) - with open(response_template_filepath, "r") as template_file: + with open(response_template_filepath) as template_file: # noqa: PLW1514, PTH123 return json.load(template_file) # type: ignore # we assume the dev correctly set up the resource file def create_record_builder( - response_template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - record_id_path: Optional[Path] = None, - record_cursor_path: Optional[Union[FieldPath, NestedPath]] = None, + response_template: dict[str, Any], + records_path: FieldPath | NestedPath, + record_id_path: Path | None = None, + record_cursor_path: FieldPath | NestedPath | None = None, ) -> RecordBuilder: """ This will use the first record define at `records_path` as a template for the records. If more records are defined, they will be ignored @@ -216,14 +216,14 @@ def create_record_builder( ) return RecordBuilder(record_template, record_id_path, record_cursor_path) except (IndexError, KeyError): - raise ValueError( + raise ValueError( # noqa: B904 f"Error while extracting records at path `{records_path}` from response template `{response_template}`" ) def create_response_builder( - response_template: Dict[str, Any], - records_path: Union[FieldPath, NestedPath], - pagination_strategy: Optional[PaginationStrategy] = None, + response_template: dict[str, Any], + records_path: FieldPath | NestedPath, + pagination_strategy: PaginationStrategy | None = None, ) -> HttpResponseBuilder: return HttpResponseBuilder(response_template, records_path, pagination_strategy) diff --git a/airbyte_cdk/test/state_builder.py b/airbyte_cdk/test/state_builder.py index a1315cf4e..a262960ae 100644 --- a/airbyte_cdk/test/state_builder.py +++ b/airbyte_cdk/test/state_builder.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -from typing import Any, List +from typing import Any from airbyte_cdk.models import ( AirbyteStateBlob, @@ -13,9 +13,9 @@ class StateBuilder: def __init__(self) -> None: - self._state: List[AirbyteStateMessage] = [] + self._state: list[AirbyteStateMessage] = [] - def with_stream_state(self, stream_name: str, state: Any) -> "StateBuilder": + def with_stream_state(self, stream_name: str, state: Any) -> "StateBuilder": # noqa: ANN401 self._state.append( AirbyteStateMessage( type=AirbyteStateType.STREAM, @@ -23,11 +23,11 @@ def with_stream_state(self, stream_name: str, state: Any) -> "StateBuilder": stream_state=state if isinstance(state, AirbyteStateBlob) else AirbyteStateBlob(state), - stream_descriptor=StreamDescriptor(**{"name": stream_name}), + stream_descriptor=StreamDescriptor(name=stream_name), ), ) ) return self - def build(self) -> List[AirbyteStateMessage]: + def build(self) -> list[AirbyteStateMessage]: return self._state diff --git a/airbyte_cdk/test/utils/data.py b/airbyte_cdk/test/utils/data.py index 6aaeb8394..fb862f92c 100644 --- a/airbyte_cdk/test/utils/data.py +++ b/airbyte_cdk/test/utils/data.py @@ -6,7 +6,7 @@ def get_unit_test_folder(execution_folder: str) -> FilePath: path = FilePath(execution_folder) while path.name != "unit_tests": - if path.name == path.root or path.name == path.drive: + if path.name == path.root or path.name == path.drive: # noqa: PLR1714 raise ValueError( f"Could not find `unit_tests` folder as a parent of {execution_folder}" ) @@ -19,6 +19,6 @@ def read_resource_file_contents(resource: str, test_location: str) -> str: file_path = str( get_unit_test_folder(test_location) / "resource" / "http" / "response" / f"{resource}" ) - with open(file_path) as f: + with open(file_path) as f: # noqa: FURB101, PLW1514, PTH123 response = f.read() return response diff --git a/airbyte_cdk/test/utils/http_mocking.py b/airbyte_cdk/test/utils/http_mocking.py index 7fd1419fc..327803728 100644 --- a/airbyte_cdk/test/utils/http_mocking.py +++ b/airbyte_cdk/test/utils/http_mocking.py @@ -1,7 +1,8 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. import re -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from requests_mock import Mocker diff --git a/airbyte_cdk/test/utils/manifest_only_fixtures.py b/airbyte_cdk/test/utils/manifest_only_fixtures.py index 28015d05b..1bf063baf 100644 --- a/airbyte_cdk/test/utils/manifest_only_fixtures.py +++ b/airbyte_cdk/test/utils/manifest_only_fixtures.py @@ -7,6 +7,7 @@ import pytest + # The following fixtures are used to load a manifest-only connector's components module and manifest file. # They can be accessed from any test file in the connector's unit_tests directory by importing them as follows: diff --git a/airbyte_cdk/test/utils/reading.py b/airbyte_cdk/test/utils/reading.py index 2d89cb870..18d62a4b6 100644 --- a/airbyte_cdk/test/utils/reading.py +++ b/airbyte_cdk/test/utils/reading.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any from airbyte_cdk import AbstractSource from airbyte_cdk.models import AirbyteStateMessage, ConfiguredAirbyteCatalog, SyncMode @@ -18,9 +19,9 @@ def read_records( config: Mapping[str, Any], stream_name: str, sync_mode: SyncMode, - state: Optional[List[AirbyteStateMessage]] = None, - expecting_exception: bool = False, + state: list[AirbyteStateMessage] | None = None, + expecting_exception: bool = False, # noqa: FBT001, FBT002 ) -> EntrypointOutput: """Read records from a stream.""" - _catalog = catalog(stream_name, sync_mode) + _catalog = catalog(stream_name, sync_mode) # noqa: RUF052 return read(source, config, _catalog, state, expecting_exception) diff --git a/airbyte_cdk/utils/__init__.py b/airbyte_cdk/utils/__init__.py index dbfb641dd..f6d44525d 100644 --- a/airbyte_cdk/utils/__init__.py +++ b/airbyte_cdk/utils/__init__.py @@ -7,4 +7,5 @@ from .schema_inferrer import SchemaInferrer from .traced_exception import AirbyteTracedException + __all__ = ["AirbyteTracedException", "SchemaInferrer", "is_cloud_environment", "PrintBuffer"] diff --git a/airbyte_cdk/utils/airbyte_secrets_utils.py b/airbyte_cdk/utils/airbyte_secrets_utils.py index bb5a6be59..366407b86 100644 --- a/airbyte_cdk/utils/airbyte_secrets_utils.py +++ b/airbyte_cdk/utils/airbyte_secrets_utils.py @@ -2,15 +2,16 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, List, Mapping +from collections.abc import Mapping +from typing import Any import dpath -def get_secret_paths(spec: Mapping[str, Any]) -> List[List[str]]: +def get_secret_paths(spec: Mapping[str, Any]) -> list[list[str]]: paths = [] - def traverse_schema(schema_item: Any, path: List[str]) -> None: + def traverse_schema(schema_item: Any, path: list[str]) -> None: # noqa: ANN401 """ schema_item can be any property or value in the originally input jsonschema, depending on how far down the recursion stack we go path is the path to that schema item in the original input @@ -27,10 +28,9 @@ def traverse_schema(schema_item: Any, path: List[str]) -> None: elif isinstance(schema_item, list): for i in schema_item: traverse_schema(i, path) - else: - if path[-1] == "airbyte_secret" and schema_item is True: - filtered_path = [p for p in path[:-1] if p not in ["properties", "oneOf"]] - paths.append(filtered_path) + elif path[-1] == "airbyte_secret" and schema_item is True: + filtered_path = [p for p in path[:-1] if p not in ["properties", "oneOf"]] + paths.append(filtered_path) traverse_schema(spec, []) return paths @@ -38,7 +38,7 @@ def traverse_schema(schema_item: Any, path: List[str]) -> None: def get_secrets( connection_specification: Mapping[str, Any], config: Mapping[str, Any] -) -> List[Any]: +) -> list[Any]: """ Get a list of secret values from the source config based on the source specification :type connection_specification: the connection_specification field of an AirbyteSpecification i.e the JSONSchema definition @@ -46,7 +46,7 @@ def get_secrets( secret_paths = get_secret_paths(connection_specification.get("properties", {})) result = [] for path in secret_paths: - try: + try: # noqa: SIM105 result.append(dpath.get(config, path)) # type: ignore # dpath expect MutableMapping but doesn't need it except KeyError: # Since we try to get paths to all known secrets in the spec, in the case of oneOfs, some secret fields may not be present @@ -55,10 +55,10 @@ def get_secrets( return result -__SECRETS_FROM_CONFIG: List[str] = [] +__SECRETS_FROM_CONFIG: list[str] = [] -def update_secrets(secrets: List[str]) -> None: +def update_secrets(secrets: list[str]) -> None: """Update the list of secrets to be replaced""" global __SECRETS_FROM_CONFIG __SECRETS_FROM_CONFIG = secrets @@ -66,13 +66,13 @@ def update_secrets(secrets: List[str]) -> None: def add_to_secrets(secret: str) -> None: """Add to the list of secrets to be replaced""" - global __SECRETS_FROM_CONFIG + global __SECRETS_FROM_CONFIG # noqa: PLW0602 __SECRETS_FROM_CONFIG.append(secret) def filter_secrets(string: str) -> str: """Filter secrets from a string by replacing them with ****""" - # TODO this should perform a maximal match for each secret. if "x" and "xk" are both secret values, and this method is called twice on + # TODO this should perform a maximal match for each secret. if "x" and "xk" are both secret values, and this method is called twice on # noqa: TD004 # the input "xk", then depending on call order it might only obfuscate "*k". This is a bug. for secret in __SECRETS_FROM_CONFIG: if secret: diff --git a/airbyte_cdk/utils/analytics_message.py b/airbyte_cdk/utils/analytics_message.py index 82a074913..e5d7cfcd5 100644 --- a/airbyte_cdk/utils/analytics_message.py +++ b/airbyte_cdk/utils/analytics_message.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. import time -from typing import Any, Optional +from typing import Any from airbyte_cdk.models import ( AirbyteAnalyticsTraceMessage, @@ -12,7 +12,7 @@ ) -def create_analytics_message(type: str, value: Optional[Any]) -> AirbyteMessage: +def create_analytics_message(type: str, value: Any | None) -> AirbyteMessage: # noqa: ANN401, A002 return AirbyteMessage( type=Type.TRACE, trace=AirbyteTraceMessage( diff --git a/airbyte_cdk/utils/datetime_format_inferrer.py b/airbyte_cdk/utils/datetime_format_inferrer.py index 28eaefa31..71c36fa07 100644 --- a/airbyte_cdk/utils/datetime_format_inferrer.py +++ b/airbyte_cdk/utils/datetime_format_inferrer.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Dict, Optional +from typing import Any from airbyte_cdk.models import AirbyteRecordMessage from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser @@ -15,7 +15,7 @@ class DatetimeFormatInferrer: def __init__(self) -> None: self._parser = DatetimeParser() - self._datetime_candidates: Optional[Dict[str, str]] = None + self._datetime_candidates: dict[str, str] | None = None self._formats = [ "%Y-%m-%d", "%Y-%m-%d %H:%M:%S", @@ -34,12 +34,12 @@ def __init__(self) -> None: range(1_000_000_000_000, 2_000_000_000_000), ] - def _can_be_datetime(self, value: Any) -> bool: + def _can_be_datetime(self, value: Any) -> bool: # noqa: ANN401 """Checks if the value can be a datetime. This is the case if the value is a string or an integer between 1_000_000_000 and 2_000_000_000 for seconds or between 1_000_000_000_000 and 2_000_000_000_000 for milliseconds. This is separate from the format check for performance reasons""" - if isinstance(value, (str, int)): + if isinstance(value, (str, int)): # noqa: UP038 try: value_as_int = int(value) for timestamp_range in self._timestamp_heuristic_ranges: @@ -50,11 +50,11 @@ def _can_be_datetime(self, value: Any) -> bool: return True return False - def _matches_format(self, value: Any, format: str) -> bool: + def _matches_format(self, value: Any, format: str) -> bool: # noqa: ANN401, A002 """Checks if the value matches the format""" try: self._parser.parse(value, format) - return True + return True # noqa: TRY300 except ValueError: return False @@ -64,7 +64,7 @@ def _initialize(self, record: AirbyteRecordMessage) -> None: for field_name, field_value in record.data.items(): if not self._can_be_datetime(field_value): continue - for format in self._formats: + for format in self._formats: # noqa: A001 if self._matches_format(field_value, format): self._datetime_candidates[field_name] = format break @@ -86,7 +86,7 @@ def accumulate(self, record: AirbyteRecordMessage) -> None: """Analyzes the record and updates the internal state of candidate datetime fields""" self._initialize(record) if self._datetime_candidates is None else self._validate(record) - def get_inferred_datetime_formats(self) -> Dict[str, str]: + def get_inferred_datetime_formats(self) -> dict[str, str]: """ Returns the list of candidate datetime fields - the keys are the field names and the values are the inferred datetime formats. For these fields the format was consistent across all visited records. diff --git a/airbyte_cdk/utils/event_timing.py b/airbyte_cdk/utils/event_timing.py index 3f489c096..70c195cfb 100644 --- a/airbyte_cdk/utils/event_timing.py +++ b/airbyte_cdk/utils/event_timing.py @@ -5,9 +5,11 @@ import datetime import logging import time +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Generator, Literal, Optional +from typing import Any, Literal + logger = logging.getLogger("airbyte") @@ -60,7 +62,7 @@ def report(self, order_by: Literal["name", "duration"] = "name") -> str: class Event: name: str start: float = field(default_factory=time.perf_counter_ns) - end: Optional[float] = field(default=None) + end: float | None = field(default=None) @property def duration(self) -> float: diff --git a/airbyte_cdk/utils/is_cloud_environment.py b/airbyte_cdk/utils/is_cloud_environment.py index 25b1eee87..0c4e45a9b 100644 --- a/airbyte_cdk/utils/is_cloud_environment.py +++ b/airbyte_cdk/utils/is_cloud_environment.py @@ -4,6 +4,7 @@ import os + CLOUD_DEPLOYMENT_MODE = "cloud" diff --git a/airbyte_cdk/utils/mapping_helpers.py b/airbyte_cdk/utils/mapping_helpers.py index 469fb5e0a..e2676f377 100644 --- a/airbyte_cdk/utils/mapping_helpers.py +++ b/airbyte_cdk/utils/mapping_helpers.py @@ -3,12 +3,13 @@ # -from typing import Any, List, Mapping, Optional, Set, Union +from collections.abc import Mapping +from typing import Any def combine_mappings( - mappings: List[Optional[Union[Mapping[str, Any], str]]], -) -> Union[Mapping[str, Any], str]: + mappings: list[Mapping[str, Any] | str | None], +) -> Mapping[str, Any] | str: """ Combine multiple mappings into a single mapping. If any of the mappings are a string, return that string. Raise errors in the following cases: @@ -16,7 +17,7 @@ def combine_mappings( * If there are multiple string mappings * If there are multiple mappings containing keys and one of them is a string """ - all_keys: List[Set[str]] = [] + all_keys: list[set[str]] = [] for part in mappings: if part is None: continue diff --git a/airbyte_cdk/utils/oneof_option_config.py b/airbyte_cdk/utils/oneof_option_config.py index 17ebf0511..7ff931065 100644 --- a/airbyte_cdk/utils/oneof_option_config.py +++ b/airbyte_cdk/utils/oneof_option_config.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, Dict +from typing import Any class OneOfOptionConfig: @@ -26,7 +26,7 @@ class Config(OneOfOptionConfig): """ @staticmethod - def schema_extra(schema: Dict[str, Any], model: Any) -> None: + def schema_extra(schema: dict[str, Any], model: Any) -> None: # noqa: ANN401 if hasattr(model.Config, "description"): schema["description"] = model.Config.description if hasattr(model.Config, "discriminator"): diff --git a/airbyte_cdk/utils/print_buffer.py b/airbyte_cdk/utils/print_buffer.py index ae5a2020c..710297566 100644 --- a/airbyte_cdk/utils/print_buffer.py +++ b/airbyte_cdk/utils/print_buffer.py @@ -5,7 +5,6 @@ from io import StringIO from threading import RLock from types import TracebackType -from typing import Optional class PrintBuffer: @@ -37,7 +36,7 @@ class PrintBuffer: Exits the runtime context and restores the original stdout and stderr. """ - def __init__(self, flush_interval: float = 0.1): + def __init__(self, flush_interval: float = 0.1): # noqa: ANN204 self.buffer = StringIO() self.flush_interval = flush_interval self.last_flush_time = time.monotonic() @@ -67,9 +66,9 @@ def __enter__(self) -> "PrintBuffer": def __exit__( self, - exc_type: Optional[BaseException], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: BaseException | None, # noqa: PYI036 + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self.flush() sys.stdout, sys.stderr = self.old_stdout, self.old_stderr diff --git a/airbyte_cdk/utils/schema_inferrer.py b/airbyte_cdk/utils/schema_inferrer.py index f3c6b2fae..54e703cd1 100644 --- a/airbyte_cdk/utils/schema_inferrer.py +++ b/airbyte_cdk/utils/schema_inferrer.py @@ -3,7 +3,8 @@ # from collections import defaultdict -from typing import Any, Dict, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any from genson import SchemaBuilder, SchemaNode from genson.schema.strategies.object import Object @@ -11,6 +12,7 @@ from airbyte_cdk.models import AirbyteRecordMessage + # schema keywords _TYPE = "type" _NULL_TYPE = "null" @@ -31,7 +33,7 @@ class NoRequiredObj(Object): """ def to_schema(self) -> Mapping[str, Any]: - schema: Dict[str, Any] = super(NoRequiredObj, self).to_schema() + schema: dict[str, Any] = super(NoRequiredObj, self).to_schema() # noqa: UP008 schema.pop("required", None) return schema @@ -41,7 +43,7 @@ class IntegerToNumber(Number): This class has the regular Number behaviour, but it will never emit an integer type. """ - def __init__(self, node_class: SchemaNode): + def __init__(self, node_class: SchemaNode): # noqa: ANN204 super().__init__(node_class) self._type = "number" @@ -51,21 +53,21 @@ class NoRequiredSchemaBuilder(SchemaBuilder): # This type is inferred from the genson lib, but there is no alias provided for it - creating it here for type safety -InferredSchema = Dict[str, Any] +InferredSchema = dict[str, Any] class SchemaValidationException(Exception): @classmethod def merge_exceptions( - cls, exceptions: List["SchemaValidationException"] + cls, exceptions: list["SchemaValidationException"] ) -> "SchemaValidationException": # We assume the schema is the same for all SchemaValidationException return SchemaValidationException( exceptions[0].schema, - [x for exception in exceptions for x in exception._validation_errors], + [x for exception in exceptions for x in exception._validation_errors], # noqa: SLF001 ) - def __init__(self, schema: InferredSchema, validation_errors: List[Exception]): + def __init__(self, schema: InferredSchema, validation_errors: list[Exception]): # noqa: ANN204 self._schema = schema self._validation_errors = validation_errors @@ -74,8 +76,8 @@ def schema(self) -> InferredSchema: return self._schema @property - def validation_errors(self) -> List[str]: - return list(map(lambda error: str(error), self._validation_errors)) + def validation_errors(self) -> list[str]: + return list(map(lambda error: str(error), self._validation_errors)) # noqa: C417 class SchemaInferrer: @@ -88,10 +90,10 @@ class SchemaInferrer: """ - stream_to_builder: Dict[str, SchemaBuilder] + stream_to_builder: dict[str, SchemaBuilder] def __init__( - self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None + self, pk: list[list[str]] | None = None, cursor_field: list[list[str]] | None = None ) -> None: self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder) self._pk = [] if pk is None else pk @@ -104,15 +106,14 @@ def accumulate(self, record: AirbyteRecordMessage) -> None: def _null_type_in_any_of(self, node: InferredSchema) -> bool: if _ANY_OF in node: return {_TYPE: _NULL_TYPE} in node[_ANY_OF] - else: - return False + return False def _remove_type_from_any_of(self, node: InferredSchema) -> None: if _ANY_OF in node: node.pop(_TYPE, None) def _clean_any_of(self, node: InferredSchema) -> None: - if len(node[_ANY_OF]) == 2 and self._null_type_in_any_of(node): + if len(node[_ANY_OF]) == 2 and self._null_type_in_any_of(node): # noqa: PLR2004 real_type = ( node[_ANY_OF][1] if node[_ANY_OF][0][_TYPE] == _NULL_TYPE else node[_ANY_OF][0] ) @@ -120,7 +121,7 @@ def _clean_any_of(self, node: InferredSchema) -> None: node[_TYPE] = [node[_TYPE], _NULL_TYPE] node.pop(_ANY_OF) # populate `type` for `anyOf` if it's not present to pass all other checks - elif len(node[_ANY_OF]) == 2 and not self._null_type_in_any_of(node): + elif len(node[_ANY_OF]) == 2 and not self._null_type_in_any_of(node): # noqa: PLR2004 node[_TYPE] = [_NULL_TYPE] def _clean_properties(self, node: InferredSchema) -> None: @@ -184,11 +185,11 @@ def _add_required_properties(self, node: InferredSchema) -> InferredSchema: return node - def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List[str]]) -> None: + def _add_fields_as_required(self, node: InferredSchema, composite_key: list[list[str]]) -> None: """ Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required. """ - errors: List[Exception] = [] + errors: list[Exception] = [] for path in composite_key: try: @@ -200,7 +201,7 @@ def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List raise SchemaValidationException(node, errors) def _add_field_as_required( - self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None + self, node: InferredSchema, path: list[str], traveled_path: list[str] | None = None ) -> None: """ Take a nested key and travel the schema to mark every node as required. @@ -247,7 +248,7 @@ def _add_field_as_required( traveled_path.append(next_node) self._add_field_as_required(node[_PROPERTIES][next_node], path[1:], traveled_path) - def _is_leaf(self, path: List[str]) -> bool: + def _is_leaf(self, path: list[str]) -> bool: return len(path) == 0 def _remove_null_from_type(self, node: InferredSchema) -> None: @@ -257,7 +258,7 @@ def _remove_null_from_type(self, node: InferredSchema) -> None: if len(node[_TYPE]) == 1: node[_TYPE] = node[_TYPE][0] - def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]: + def get_stream_schema(self, stream_name: str) -> InferredSchema | None: """ Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name. """ diff --git a/airbyte_cdk/utils/slice_hasher.py b/airbyte_cdk/utils/slice_hasher.py index 7f46dd768..3e5e0a66d 100644 --- a/airbyte_cdk/utils/slice_hasher.py +++ b/airbyte_cdk/utils/slice_hasher.py @@ -1,10 +1,11 @@ import hashlib import json -from typing import Any, Final, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Final class SliceEncoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: + def default(self, obj: Any) -> Any: # noqa: ANN401 if hasattr(obj, "__json_serializable__"): return obj.__json_serializable__() @@ -16,13 +17,13 @@ class SliceHasher: _ENCODING: Final = "utf-8" @classmethod - def hash(cls, stream_name: str, stream_slice: Optional[Mapping[str, Any]] = None) -> int: + def hash(cls, stream_name: str, stream_slice: Mapping[str, Any] | None = None) -> int: if stream_slice: try: s = json.dumps(stream_slice, sort_keys=True, cls=SliceEncoder) hash_input = f"{stream_name}:{s}".encode(cls._ENCODING) except TypeError as e: - raise ValueError(f"Failed to serialize stream slice: {e}") + raise ValueError(f"Failed to serialize stream slice: {e}") # noqa: B904 else: hash_input = stream_name.encode(cls._ENCODING) diff --git a/airbyte_cdk/utils/stream_status_utils.py b/airbyte_cdk/utils/stream_status_utils.py index 49c07f49c..00f91a229 100644 --- a/airbyte_cdk/utils/stream_status_utils.py +++ b/airbyte_cdk/utils/stream_status_utils.py @@ -4,7 +4,6 @@ from datetime import datetime -from typing import List, Optional, Union from airbyte_cdk.models import ( AirbyteMessage, @@ -20,9 +19,9 @@ def as_airbyte_message( - stream: Union[AirbyteStream, StreamDescriptor], + stream: AirbyteStream | StreamDescriptor, current_status: AirbyteStreamStatus, - reasons: Optional[List[AirbyteStreamStatusReason]] = None, + reasons: list[AirbyteStreamStatusReason] | None = None, ) -> AirbyteMessage: """ Builds an AirbyteStreamStatusTraceMessage for the provided stream diff --git a/airbyte_cdk/utils/traced_exception.py b/airbyte_cdk/utils/traced_exception.py index 59dbab2a5..f2d028d11 100644 --- a/airbyte_cdk/utils/traced_exception.py +++ b/airbyte_cdk/utils/traced_exception.py @@ -3,7 +3,7 @@ # import time import traceback -from typing import Any, Optional +from typing import Any import orjson @@ -27,13 +27,13 @@ class AirbyteTracedException(Exception): An exception that should be emitted as an AirbyteTraceMessage """ - def __init__( + def __init__( # noqa: ANN204 self, - internal_message: Optional[str] = None, - message: Optional[str] = None, + internal_message: str | None = None, + message: str | None = None, failure_type: FailureType = FailureType.system_error, - exception: Optional[BaseException] = None, - stream_descriptor: Optional[StreamDescriptor] = None, + exception: BaseException | None = None, + stream_descriptor: StreamDescriptor | None = None, ): """ :param internal_message: the internal error that caused the failure @@ -50,7 +50,7 @@ def __init__( super().__init__(internal_message) def as_airbyte_message( - self, stream_descriptor: Optional[StreamDescriptor] = None + self, stream_descriptor: StreamDescriptor | None = None ) -> AirbyteMessage: """ Builds an AirbyteTraceMessage from the exception @@ -80,7 +80,7 @@ def as_airbyte_message( return AirbyteMessage(type=MessageType.TRACE, trace=trace_message) - def as_connection_status_message(self) -> Optional[AirbyteMessage]: + def as_connection_status_message(self) -> AirbyteMessage | None: if self.failure_type == FailureType.config_error: return AirbyteMessage( type=MessageType.CONNECTION_STATUS, @@ -103,9 +103,9 @@ def emit_message(self) -> None: def from_exception( cls, exc: BaseException, - stream_descriptor: Optional[StreamDescriptor] = None, - *args: Any, - **kwargs: Any, + stream_descriptor: StreamDescriptor | None = None, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 ) -> "AirbyteTracedException": """ Helper to create an AirbyteTracedException from an existing exception @@ -116,12 +116,12 @@ def from_exception( internal_message=str(exc), exception=exc, stream_descriptor=stream_descriptor, - *args, + *args, # noqa: B026 **kwargs, ) # type: ignore # ignoring because of args and kwargs def as_sanitized_airbyte_message( - self, stream_descriptor: Optional[StreamDescriptor] = None + self, stream_descriptor: StreamDescriptor | None = None ) -> AirbyteMessage: """ Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body diff --git a/bin/generate_component_manifest_files.py b/bin/generate_component_manifest_files.py index 43f9b568e..87cbe028d 100755 --- a/bin/generate_component_manifest_files.py +++ b/bin/generate_component_manifest_files.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. # noqa: EXE002 import sys from glob import glob @@ -7,6 +7,7 @@ import anyio import dagger + PYTHON_IMAGE = "python:3.10" LOCAL_YAML_DIR_PATH = "airbyte_cdk/sources/declarative" LOCAL_OUTPUT_DIR_PATH = "airbyte_cdk/sources/declarative/models" @@ -18,7 +19,7 @@ def get_all_yaml_files_without_ext() -> list[str]: - return [Path(f).stem for f in glob(f"{LOCAL_YAML_DIR_PATH}/*.yaml")] + return [Path(f).stem for f in glob(f"{LOCAL_YAML_DIR_PATH}/*.yaml")] # noqa: PTH207 def generate_init_module_content() -> str: @@ -28,7 +29,7 @@ def generate_init_module_content() -> str: return header -async def post_process_codegen(codegen_container: dagger.Container): +async def post_process_codegen(codegen_container: dagger.Container): # noqa: ANN201 codegen_container = codegen_container.with_exec( ["mkdir", "/generated_post_processed"], use_entrypoint=True ) @@ -47,7 +48,7 @@ async def post_process_codegen(codegen_container: dagger.Container): return codegen_container -async def main(): +async def main(): # noqa: ANN201 init_module_content = generate_init_module_content() async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as dagger_client: diff --git a/docs/generate.py b/docs/generate.py index 585897715..1340efd76 100644 --- a/docs/generate.py +++ b/docs/generate.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3 # noqa: EXE001 # Copyright (c) 2023 Airbyte, Inc., all rights reserved. """Generate docs for all public modules in the Airbyte CDK and save them to docs/generated. @@ -49,14 +49,14 @@ def run() -> None: for file_name in files: if not file_name.endswith(".py"): continue - if file_name in ["py.typed"]: + if file_name in ["py.typed"]: # noqa: FURB171 continue if file_name.startswith((".", "_")): continue - print(f"Found module file: {'|'.join([parent_dir, file_name])}") + print(f"Found module file: {'|'.join([parent_dir, file_name])}") # noqa: FLY002 module = ( - cast(str, ".".join([parent_dir, file_name])).replace("/", ".").removesuffix(".py") + cast(str, ".".join([parent_dir, file_name])).replace("/", ".").removesuffix(".py") # noqa: FLY002, TC006 ) public_modules.append(module)