diff --git a/airbyte_cdk/__init__.py b/airbyte_cdk/__init__.py
index 262d162cc..df6e4c873 100644
--- a/airbyte_cdk/__init__.py
+++ b/airbyte_cdk/__init__.py
@@ -176,7 +176,7 @@
InternalConfig,
ResourceSchemaLoader,
check_config_against_spec_or_exit,
- expand_refs,
+ expand_refs, # noqa: F401
split_config,
)
from .sources.utils.transform import TransformConfig, TypeTransformer
@@ -187,6 +187,7 @@
from .utils.spec_schema_transformations import resolve_refs
from .utils.stream_status_utils import as_airbyte_message
+
__all__ = [
# Availability strategy
"AvailabilityStrategy",
@@ -200,7 +201,6 @@
"ConcurrentSourceAdapter",
"Cursor",
"CursorField",
- "DEFAULT_CONCURRENCY",
"EpochValueConcurrentStreamStateConverter",
"FinalStateCursor",
"IsoMillisConcurrentStreamStateConverter",
@@ -258,7 +258,6 @@
"RequestOption",
"RequestOptionType",
"Requester",
- "ResponseStatus",
"SimpleRetriever",
"SinglePartitionRouter",
"StopConditionPaginationStrategyDecorator",
@@ -276,13 +275,11 @@
"DefaultBackoffException",
"default_backoff_handler",
"HttpAPIBudget",
- "HttpAuthenticator",
"HttpRequestMatcher",
"HttpStream",
"HttpSubStream",
"LimiterSession",
"MovingWindowCallRatePolicy",
- "MultipleTokenAuthenticator",
"Oauth2Authenticator",
"Rate",
"SingleUseRefreshTokenOauth2Authenticator",
@@ -317,7 +314,6 @@
# Stream
"IncrementalMixin",
"Stream",
- "StreamData",
"package_name_from_class",
# Utils
"AirbyteTracedException",
@@ -354,5 +350,5 @@
third_choice=_dunamai.Version.from_any_vcs,
fallback=_dunamai.Version("0.0.0+dev"),
).serialize()
-except:
+except: # noqa: E722
__version__ = "0.0.0+dev"
diff --git a/airbyte_cdk/cli/source_declarative_manifest/__init__.py b/airbyte_cdk/cli/source_declarative_manifest/__init__.py
index 0ea86fa7b..ea87aca1d 100644
--- a/airbyte_cdk/cli/source_declarative_manifest/__init__.py
+++ b/airbyte_cdk/cli/source_declarative_manifest/__init__.py
@@ -1,5 +1,7 @@
+# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from airbyte_cdk.cli.source_declarative_manifest._run import run
+
__all__ = [
"run",
]
diff --git a/airbyte_cdk/cli/source_declarative_manifest/_run.py b/airbyte_cdk/cli/source_declarative_manifest/_run.py
index 5def00602..44cf69016 100644
--- a/airbyte_cdk/cli/source_declarative_manifest/_run.py
+++ b/airbyte_cdk/cli/source_declarative_manifest/_run.py
@@ -43,7 +43,7 @@
ConcurrentDeclarativeSource,
)
from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource
-from airbyte_cdk.sources.source import TState
+from airbyte_cdk.sources.source import TState # noqa: TC001
class SourceLocalYaml(YamlDeclarativeSource):
@@ -56,7 +56,7 @@ def __init__(
catalog: ConfiguredAirbyteCatalog | None,
config: Mapping[str, Any] | None,
state: TState,
- **kwargs: Any,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""
HACK!
@@ -77,7 +77,7 @@ def __init__(
)
-def _is_local_manifest_command(args: list[str]) -> bool:
+def _is_local_manifest_command(args: list[str]) -> bool: # noqa: ARG001
# Check for a local manifest.yaml file
return Path("/airbyte/integration_code/source_declarative_manifest/manifest.yaml").exists()
@@ -111,7 +111,7 @@ def _get_local_yaml_source(args: list[str]) -> SourceLocalYaml:
)
).decode()
)
- raise error
+ raise error # noqa: TRY201
def handle_local_manifest_command(args: list[str]) -> None:
@@ -167,12 +167,12 @@ def create_declarative_source(
state: list[AirbyteStateMessage]
config, catalog, state = _parse_inputs_into_config_catalog_state(args)
if config is None or "__injected_declarative_manifest" not in config:
- raise ValueError(
+ raise ValueError( # noqa: TRY301
"Invalid config: `__injected_declarative_manifest` should be provided at the root "
f"of the config but config only has keys: {list(config.keys() if config else [])}"
)
if not isinstance(config["__injected_declarative_manifest"], dict):
- raise ValueError(
+ raise ValueError( # noqa: TRY004, TRY301
"Invalid config: `__injected_declarative_manifest` should be a dictionary, "
f"but got type: {type(config['__injected_declarative_manifest'])}"
)
@@ -181,7 +181,7 @@ def create_declarative_source(
config=config,
catalog=catalog,
state=state,
- source_config=cast(dict[str, Any], config["__injected_declarative_manifest"]),
+ source_config=cast(dict[str, Any], config["__injected_declarative_manifest"]), # noqa: TC006
)
except Exception as error:
print(
@@ -201,7 +201,7 @@ def create_declarative_source(
)
).decode()
)
- raise error
+ raise error # noqa: TRY201
def _parse_inputs_into_config_catalog_state(
diff --git a/airbyte_cdk/config_observation.py b/airbyte_cdk/config_observation.py
index ae85e8277..1233786a1 100644
--- a/airbyte_cdk/config_observation.py
+++ b/airbyte_cdk/config_observation.py
@@ -7,8 +7,9 @@
)
import time
+from collections.abc import MutableMapping
from copy import copy
-from typing import Any, List, MutableMapping
+from typing import Any
import orjson
@@ -27,7 +28,7 @@ def __init__(
self,
non_observed_mapping: MutableMapping[Any, Any],
observer: ConfigObserver,
- update_on_unchanged_value: bool = True,
+ update_on_unchanged_value: bool = True, # noqa: FBT001, FBT002
) -> None:
non_observed_mapping = copy(non_observed_mapping)
self.observer = observer
@@ -38,13 +39,13 @@ def __init__(
non_observed_mapping[item] = ObservedDict(value, observer)
# Observe nested list of dicts
- if isinstance(value, List):
+ if isinstance(value, list):
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, observer)
super().__init__(non_observed_mapping)
- def __setitem__(self, item: Any, value: Any) -> None:
+ def __setitem__(self, item: Any, value: Any) -> None: # noqa: ANN401
"""Override dict.__setitem__ by:
1. Observing the new value if it is a dict
2. Call observer update if the new value is different from the previous one
@@ -52,11 +53,11 @@ def __setitem__(self, item: Any, value: Any) -> None:
previous_value = self.get(item)
if isinstance(value, MutableMapping):
value = ObservedDict(value, self.observer)
- if isinstance(value, List):
+ if isinstance(value, list):
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, self.observer)
- super(ObservedDict, self).__setitem__(item, value)
+ super(ObservedDict, self).__setitem__(item, value) # noqa: UP008
if self.update_on_unchanged_value or value != previous_value:
self.observer.update()
@@ -77,7 +78,7 @@ def observe_connector_config(
non_observed_connector_config: MutableMapping[str, Any],
) -> ObservedDict:
if isinstance(non_observed_connector_config, ObservedDict):
- raise ValueError("This connector configuration is already observed")
+ raise ValueError("This connector configuration is already observed") # noqa: TRY004
connector_config_observer = ConfigObserver()
observed_connector_config = ObservedDict(
non_observed_connector_config, connector_config_observer
diff --git a/airbyte_cdk/connector.py b/airbyte_cdk/connector.py
index 342ecee2d..e534e5d49 100644
--- a/airbyte_cdk/connector.py
+++ b/airbyte_cdk/connector.py
@@ -8,7 +8,8 @@
import os
import pkgutil
from abc import ABC, abstractmethod
-from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar
+from collections.abc import Mapping
+from typing import Any, Generic, Protocol, TypeVar
import yaml
@@ -19,7 +20,7 @@
)
-def load_optional_package_file(package: str, filename: str) -> Optional[bytes]:
+def load_optional_package_file(package: str, filename: str) -> bytes | None:
"""Gets a resource from a package, returning None if it does not exist"""
try:
return pkgutil.get_data(package, filename)
@@ -45,29 +46,28 @@ def read_config(config_path: str) -> Mapping[str, Any]:
config = BaseConnector._read_json_file(config_path)
if isinstance(config, Mapping):
return config
- else:
- raise ValueError(
- f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config."
- )
+ raise ValueError(
+ f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config."
+ )
@staticmethod
- def _read_json_file(file_path: str) -> Any:
- with open(file_path, "r") as file:
+ def _read_json_file(file_path: str) -> Any: # noqa: ANN401
+ with open(file_path) as file: # noqa: FURB101, PLW1514, PTH123
contents = file.read()
try:
return json.loads(contents)
except json.JSONDecodeError as error:
- raise ValueError(
+ raise ValueError( # noqa: B904
f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON."
)
@staticmethod
def write_config(config: TConfig, config_path: str) -> None:
- with open(config_path, "w") as fh:
+ with open(config_path, "w") as fh: # noqa: FURB103, PLW1514, PTH123
fh.write(json.dumps(config))
- def spec(self, logger: logging.Logger) -> ConnectorSpecification:
+ def spec(self, logger: logging.Logger) -> ConnectorSpecification: # noqa: ARG002
"""
Returns the spec for this integration. The spec is a JSON-Schema object describing the required configurations (e.g: username and password)
required to run this integration. By default, this will be loaded from a "spec.yaml" or a "spec.json" in the package root.
@@ -89,7 +89,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification:
try:
spec_obj = json.loads(json_spec)
except json.JSONDecodeError as error:
- raise ValueError(
+ raise ValueError( # noqa: B904
f"Could not read json spec file: {error}. Please ensure that it is a valid JSON."
)
else:
@@ -115,7 +115,7 @@ class DefaultConnectorMixin:
def configure(
self: _WriteConfigProtocol, config: Mapping[str, Any], temp_dir: str
) -> Mapping[str, Any]:
- config_path = os.path.join(temp_dir, "config.json")
+ config_path = os.path.join(temp_dir, "config.json") # noqa: PTH118
self.write_config(config, config_path)
return config
diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py
index b2a728570..028c97a66 100644
--- a/airbyte_cdk/connector_builder/connector_builder_handler.py
+++ b/airbyte_cdk/connector_builder/connector_builder_handler.py
@@ -3,8 +3,9 @@
#
import dataclasses
+from collections.abc import Mapping
from datetime import datetime
-from typing import Any, List, Mapping
+from typing import Any
from airbyte_cdk.connector_builder.message_grouper import MessageGrouper
from airbyte_cdk.models import (
@@ -23,6 +24,7 @@
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
DEFAULT_MAXIMUM_RECORDS = 100
@@ -69,7 +71,7 @@ def read_stream(
source: DeclarativeSource,
config: Mapping[str, Any],
configured_catalog: ConfiguredAirbyteCatalog,
- state: List[AirbyteStateMessage],
+ state: list[AirbyteStateMessage],
limits: TestReadLimits,
) -> AirbyteMessage:
try:
@@ -90,7 +92,7 @@ def read_stream(
error = AirbyteTracedException.from_exception(
exc,
message=filter_secrets(
- f"Error reading stream with config={config} and catalog={configured_catalog}: {str(exc)}"
+ f"Error reading stream with config={config} and catalog={configured_catalog}: {exc!s}"
),
)
return error.as_airbyte_message()
@@ -108,7 +110,7 @@ def resolve_manifest(source: ManifestDeclarativeSource) -> AirbyteMessage:
)
except Exception as exc:
error = AirbyteTracedException.from_exception(
- exc, message=f"Error resolving manifest: {str(exc)}"
+ exc, message=f"Error resolving manifest: {exc!s}"
)
return error.as_airbyte_message()
diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py
index e122cee8c..453c6fffb 100644
--- a/airbyte_cdk/connector_builder/main.py
+++ b/airbyte_cdk/connector_builder/main.py
@@ -4,7 +4,8 @@
import sys
-from typing import Any, List, Mapping, Optional, Tuple
+from collections.abc import Mapping
+from typing import Any
import orjson
@@ -30,8 +31,8 @@
def get_config_and_catalog_from_args(
- args: List[str],
-) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]:
+ args: list[str],
+) -> tuple[str, Mapping[str, Any], ConfiguredAirbyteCatalog | None, Any]:
# TODO: Add functionality for the `debug` logger.
# Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`.
parsed_args = AirbyteEntrypoint.parse_args(args)
@@ -70,22 +71,21 @@ def handle_connector_builder_request(
source: ManifestDeclarativeSource,
command: str,
config: Mapping[str, Any],
- catalog: Optional[ConfiguredAirbyteCatalog],
- state: List[AirbyteStateMessage],
+ catalog: ConfiguredAirbyteCatalog | None,
+ state: list[AirbyteStateMessage],
limits: TestReadLimits,
) -> AirbyteMessage:
if command == "resolve_manifest":
return resolve_manifest(source)
- elif command == "test_read":
- assert (
- catalog is not None
- ), "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None."
+ if command == "test_read":
+ assert catalog is not None, (
+ "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None."
+ )
return read_stream(source, config, catalog, state, limits)
- else:
- raise ValueError(f"Unrecognized command {command}.")
+ raise ValueError(f"Unrecognized command {command}.")
-def handle_request(args: List[str]) -> str:
+def handle_request(args: list[str]) -> str:
command, config, catalog, state = get_config_and_catalog_from_args(args)
limits = get_limits(config)
source = create_source(config, limits)
@@ -101,7 +101,7 @@ def handle_request(args: List[str]) -> str:
print(handle_request(sys.argv[1:]))
except Exception as exc:
error = AirbyteTracedException.from_exception(
- exc, message=f"Error handling request: {str(exc)}"
+ exc, message=f"Error handling request: {exc!s}"
)
m = error.as_airbyte_message()
print(orjson.dumps(AirbyteMessageSerializer.dump(m)).decode())
diff --git a/airbyte_cdk/connector_builder/message_grouper.py b/airbyte_cdk/connector_builder/message_grouper.py
index ce43afab8..e2bf457ae 100644
--- a/airbyte_cdk/connector_builder/message_grouper.py
+++ b/airbyte_cdk/connector_builder/message_grouper.py
@@ -4,9 +4,10 @@
import json
import logging
+from collections.abc import Iterable, Iterator, Mapping
from copy import deepcopy
from json import JSONDecodeError
-from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.connector_builder.models import (
AuxiliaryRequest,
@@ -40,14 +41,14 @@
class MessageGrouper:
logger = logging.getLogger("airbyte.connector-builder")
- def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000):
+ def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000): # noqa: ANN204
self._max_pages_per_slice = max_pages_per_slice
self._max_slices = max_slices
self._max_record_limit = max_record_limit
def _pk_to_nested_and_composite_field(
- self, field: Optional[Union[str, List[str], List[List[str]]]]
- ) -> List[List[str]]:
+ self, field: str | list[str] | list[list[str]] | None
+ ) -> list[list[str]]:
if not field:
return [[]]
@@ -61,8 +62,8 @@ def _pk_to_nested_and_composite_field(
return field # type: ignore # the type of field is expected to be List[List[str]] here
def _cursor_field_to_nested_and_composite_field(
- self, field: Union[str, List[str]]
- ) -> List[List[str]]:
+ self, field: str | list[str]
+ ) -> list[list[str]]:
if not field:
return [[]]
@@ -80,8 +81,8 @@ def get_message_groups(
source: DeclarativeSource,
config: Mapping[str, Any],
configured_catalog: ConfiguredAirbyteCatalog,
- state: List[AirbyteStateMessage],
- record_limit: Optional[int] = None,
+ state: list[AirbyteStateMessage],
+ record_limit: int | None = None,
) -> StreamRead:
if record_limit is not None and not (1 <= record_limit <= self._max_record_limit):
raise ValueError(
@@ -113,20 +114,16 @@ def get_message_groups(
):
if isinstance(message_group, AirbyteLogMessage):
log_messages.append(
- LogMessage(
- **{"message": message_group.message, "level": message_group.level.value}
- )
+ LogMessage(message=message_group.message, level=message_group.level.value)
)
elif isinstance(message_group, AirbyteTraceMessage):
if message_group.type == TraceType.ERROR:
log_messages.append(
LogMessage(
- **{
- "message": message_group.error.message,
- "level": "ERROR",
- "internal_message": message_group.error.internal_message,
- "stacktrace": message_group.error.stack_trace,
- }
+ message=message_group.error.message,
+ level="ERROR",
+ internal_message=message_group.error.internal_message,
+ stacktrace=message_group.error.stack_trace,
)
)
elif isinstance(message_group, AirbyteControlMessage):
@@ -140,7 +137,7 @@ def get_message_groups(
elif isinstance(message_group, StreamReadSlices):
slices.append(message_group)
else:
- raise ValueError(f"Unknown message group type: {type(message_group)}")
+ raise ValueError(f"Unknown message group type: {type(message_group)}") # noqa: TRY004
try:
# The connector builder currently only supports reading from a single stream at a time
@@ -148,7 +145,7 @@ def get_message_groups(
schema = schema_inferrer.get_stream_schema(configured_stream.stream.name)
except SchemaValidationException as exception:
for validation_error in exception.validation_errors:
- log_messages.append(LogMessage(validation_error, "ERROR"))
+ log_messages.append(LogMessage(validation_error, "ERROR")) # noqa: PERF401
schema = exception.schema
return StreamRead(
@@ -163,20 +160,18 @@ def get_message_groups(
inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(),
)
- def _get_message_groups(
+ def _get_message_groups( # noqa: PLR0912, PLR0915
self,
messages: Iterator[AirbyteMessage],
schema_inferrer: SchemaInferrer,
datetime_format_inferrer: DatetimeFormatInferrer,
limit: int,
) -> Iterable[
- Union[
- StreamReadPages,
- AirbyteControlMessage,
- AirbyteLogMessage,
- AirbyteTraceMessage,
- AuxiliaryRequest,
- ]
+ StreamReadPages
+ | AirbyteControlMessage
+ | AirbyteLogMessage
+ | AirbyteTraceMessage
+ | AuxiliaryRequest
]:
"""
Message groups are partitioned according to when request log messages are received. Subsequent response log messages
@@ -195,12 +190,12 @@ def _get_message_groups(
"""
records_count = 0
at_least_one_page_in_group = False
- current_page_records: List[Mapping[str, Any]] = []
- current_slice_descriptor: Optional[Dict[str, Any]] = None
- current_slice_pages: List[StreamReadPages] = []
- current_page_request: Optional[HttpRequest] = None
- current_page_response: Optional[HttpResponse] = None
- latest_state_message: Optional[Dict[str, Any]] = None
+ current_page_records: list[Mapping[str, Any]] = []
+ current_slice_descriptor: dict[str, Any] | None = None
+ current_slice_pages: list[StreamReadPages] = []
+ current_page_request: HttpRequest | None = None
+ current_page_response: HttpResponse | None = None
+ latest_state_message: dict[str, Any] | None = None
while records_count < limit and (message := next(messages, None)):
json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None
@@ -208,7 +203,7 @@ def _get_message_groups(
raise ValueError(
f"Expected log message to be a dict, got {json_object} of type {type(json_object)}"
)
- json_message: Optional[Dict[str, JsonType]] = json_object
+ json_message: dict[str, JsonType] | None = json_object
if self._need_to_close_page(at_least_one_page_in_group, message, json_message):
self._close_page(
current_page_request,
@@ -285,25 +280,24 @@ def _get_message_groups(
yield message.control
elif message.type == MessageType.STATE:
latest_state_message = message.state # type: ignore[assignment]
- else:
- if current_page_request or current_page_response or current_page_records:
- self._close_page(
- current_page_request,
- current_page_response,
- current_slice_pages,
- current_page_records,
- )
- yield StreamReadSlices(
- pages=current_slice_pages,
- slice_descriptor=current_slice_descriptor,
- state=[latest_state_message] if latest_state_message else [],
- )
+ if current_page_request or current_page_response or current_page_records:
+ self._close_page(
+ current_page_request,
+ current_page_response,
+ current_slice_pages,
+ current_page_records,
+ )
+ yield StreamReadSlices(
+ pages=current_slice_pages,
+ slice_descriptor=current_slice_descriptor,
+ state=[latest_state_message] if latest_state_message else [],
+ )
@staticmethod
def _need_to_close_page(
- at_least_one_page_in_group: bool,
+ at_least_one_page_in_group: bool, # noqa: FBT001
message: AirbyteMessage,
- json_message: Optional[Dict[str, Any]],
+ json_message: dict[str, Any] | None,
) -> bool:
return (
at_least_one_page_in_group
@@ -315,20 +309,19 @@ def _need_to_close_page(
)
@staticmethod
- def _is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool:
+ def _is_page_http_request(json_message: dict[str, Any] | None) -> bool:
if not json_message:
return False
- else:
- return MessageGrouper._is_http_log(
- json_message
- ) and not MessageGrouper._is_auxiliary_http_request(json_message)
+ return MessageGrouper._is_http_log(
+ json_message
+ ) and not MessageGrouper._is_auxiliary_http_request(json_message)
@staticmethod
- def _is_http_log(message: Dict[str, JsonType]) -> bool:
- return bool(message.get("http", False))
+ def _is_http_log(message: dict[str, JsonType]) -> bool:
+ return bool(message.get("http"))
@staticmethod
- def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool:
+ def _is_auxiliary_http_request(message: dict[str, Any] | None) -> bool:
"""
A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried.
A couple of examples are:
@@ -343,10 +336,10 @@ def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool:
@staticmethod
def _close_page(
- current_page_request: Optional[HttpRequest],
- current_page_response: Optional[HttpResponse],
- current_slice_pages: List[StreamReadPages],
- current_page_records: List[Mapping[str, Any]],
+ current_page_request: HttpRequest | None,
+ current_page_response: HttpResponse | None,
+ current_slice_pages: list[StreamReadPages],
+ current_page_records: list[Mapping[str, Any]],
) -> None:
"""
Close a page when parsing message groups
@@ -365,7 +358,7 @@ def _read_stream(
source: DeclarativeSource,
config: Mapping[str, Any],
configured_catalog: ConfiguredAirbyteCatalog,
- state: List[AirbyteStateMessage],
+ state: list[AirbyteStateMessage],
) -> Iterator[AirbyteMessage]:
# the generator can raise an exception
# iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage
@@ -398,12 +391,12 @@ def _parse_json(log_message: AirbyteLogMessage) -> JsonType:
# protocol change is worked on.
try:
json_object: JsonType = json.loads(log_message.message)
- return json_object
+ return json_object # noqa: TRY300
except JSONDecodeError:
return None
@staticmethod
- def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpRequest:
+ def _create_request_from_log_message(json_http_message: dict[str, Any]) -> HttpRequest:
url = json_http_message.get("url", {}).get("full", "")
request = json_http_message.get("http", {}).get("request", {})
return HttpRequest(
@@ -414,14 +407,14 @@ def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpR
)
@staticmethod
- def _create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse:
+ def _create_response_from_log_message(json_http_message: dict[str, Any]) -> HttpResponse:
response = json_http_message.get("http", {}).get("response", {})
body = response.get("body", {}).get("content", "")
return HttpResponse(
status=response.get("status_code"), body=body, headers=response.get("headers")
)
- def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool:
+ def _has_reached_limit(self, slices: list[StreamReadSlices]) -> bool:
if len(slices) >= self._max_slices:
return True
@@ -436,13 +429,13 @@ def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool:
return True
return False
- def _parse_slice_description(self, log_message: str) -> Dict[str, Any]:
+ def _parse_slice_description(self, log_message: str) -> dict[str, Any]:
return json.loads(log_message.replace(SliceLogger.SLICE_LOG_PREFIX, "", 1)) # type: ignore
@staticmethod
- def _clean_config(config: Dict[str, Any]) -> Dict[str, Any]:
+ def _clean_config(config: dict[str, Any]) -> dict[str, Any]:
cleaned_config = deepcopy(config)
- for key in config.keys():
+ for key in config:
if key.startswith("__"):
del cleaned_config[key]
return cleaned_config
diff --git a/airbyte_cdk/connector_builder/models.py b/airbyte_cdk/connector_builder/models.py
index 50eb8eb95..f4bf28cdd 100644
--- a/airbyte_cdk/connector_builder/models.py
+++ b/airbyte_cdk/connector_builder/models.py
@@ -3,44 +3,44 @@
#
from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
+from typing import Any
@dataclass
class HttpResponse:
status: int
- body: Optional[str] = None
- headers: Optional[Dict[str, Any]] = None
+ body: str | None = None
+ headers: dict[str, Any] | None = None
@dataclass
class HttpRequest:
url: str
- headers: Optional[Dict[str, Any]]
+ headers: dict[str, Any] | None
http_method: str
- body: Optional[str] = None
+ body: str | None = None
@dataclass
class StreamReadPages:
- records: List[object]
- request: Optional[HttpRequest] = None
- response: Optional[HttpResponse] = None
+ records: list[object]
+ request: HttpRequest | None = None
+ response: HttpResponse | None = None
@dataclass
class StreamReadSlices:
- pages: List[StreamReadPages]
- slice_descriptor: Optional[Dict[str, Any]]
- state: Optional[List[Dict[str, Any]]] = None
+ pages: list[StreamReadPages]
+ slice_descriptor: dict[str, Any] | None
+ state: list[dict[str, Any]] | None = None
@dataclass
class LogMessage:
message: str
level: str
- internal_message: Optional[str] = None
- stacktrace: Optional[str] = None
+ internal_message: str | None = None
+ stacktrace: str | None = None
@dataclass
@@ -52,20 +52,20 @@ class AuxiliaryRequest:
@dataclass
-class StreamRead(object):
- logs: List[LogMessage]
- slices: List[StreamReadSlices]
+class StreamRead:
+ logs: list[LogMessage]
+ slices: list[StreamReadSlices]
test_read_limit_reached: bool
- auxiliary_requests: List[AuxiliaryRequest]
- inferred_schema: Optional[Dict[str, Any]]
- inferred_datetime_formats: Optional[Dict[str, str]]
- latest_config_update: Optional[Dict[str, Any]]
+ auxiliary_requests: list[AuxiliaryRequest]
+ inferred_schema: dict[str, Any] | None
+ inferred_datetime_formats: dict[str, str] | None
+ latest_config_update: dict[str, Any] | None
@dataclass
class StreamReadRequestBody:
- manifest: Dict[str, Any]
+ manifest: dict[str, Any]
stream: str
- config: Dict[str, Any]
- state: Optional[Dict[str, Any]]
- record_limit: Optional[int]
+ config: dict[str, Any]
+ state: dict[str, Any] | None
+ record_limit: int | None
diff --git a/airbyte_cdk/destinations/__init__.py b/airbyte_cdk/destinations/__init__.py
index 3a641025b..6b8bbe8f0 100644
--- a/airbyte_cdk/destinations/__init__.py
+++ b/airbyte_cdk/destinations/__init__.py
@@ -3,6 +3,7 @@
from .destination import Destination
+
__all__ = [
"Destination",
]
diff --git a/airbyte_cdk/destinations/destination.py b/airbyte_cdk/destinations/destination.py
index 547f96684..10113b3a0 100644
--- a/airbyte_cdk/destinations/destination.py
+++ b/airbyte_cdk/destinations/destination.py
@@ -7,7 +7,8 @@
import logging
import sys
from abc import ABC, abstractmethod
-from typing import Any, Iterable, List, Mapping
+from collections.abc import Iterable, Mapping
+from typing import Any
import orjson
@@ -23,11 +24,12 @@
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
logger = logging.getLogger("airbyte")
class Destination(Connector, ABC):
- VALID_CMDS = {"spec", "check", "write"}
+ VALID_CMDS = {"spec", "check", "write"} # noqa: RUF012
@abstractmethod
def write(
@@ -59,7 +61,7 @@ def _run_write(
input_stream: io.TextIOWrapper,
) -> Iterable[AirbyteMessage]:
catalog = ConfiguredAirbyteCatalogSerializer.load(
- orjson.loads(open(configured_catalog_path).read())
+ orjson.loads(open(configured_catalog_path).read()) # noqa: PLW1514, PTH123, SIM115
)
input_messages = self._parse_input_stream(input_stream)
logger.info("Begin writing to the destination...")
@@ -68,7 +70,7 @@ def _run_write(
)
logger.info("Writing complete.")
- def parse_args(self, args: List[str]) -> argparse.Namespace:
+ def parse_args(self, args: list[str]) -> argparse.Namespace:
"""
:param args: commandline arguments
:return:
@@ -107,18 +109,18 @@ def parse_args(self, args: List[str]) -> argparse.Namespace:
parsed_args = main_parser.parse_args(args)
cmd = parsed_args.command
if not cmd:
- raise Exception("No command entered. ")
- elif cmd not in ["spec", "check", "write"]:
+ raise Exception("No command entered. ") # noqa: TRY002
+ if cmd not in ["spec", "check", "write"]:
# This is technically dead code since parse_args() would fail if this was the case
# But it's non-obvious enough to warrant placing it here anyways
- raise Exception(f"Unknown command entered: {cmd}")
+ raise Exception(f"Unknown command entered: {cmd}") # noqa: TRY002
return parsed_args
def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
cmd = parsed_args.command
if cmd not in self.VALID_CMDS:
- raise Exception(f"Unrecognized command: {cmd}")
+ raise Exception(f"Unrecognized command: {cmd}") # noqa: TRY002
spec = self.spec(logger)
if cmd == "spec":
@@ -133,7 +135,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
if connection_status and cmd == "check":
yield connection_status
return
- raise traced_exc
+ raise traced_exc # noqa: TRY201
if cmd == "check":
yield self._run_check(config=config)
@@ -146,7 +148,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
input_stream=wrapped_stdin,
)
- def run(self, args: List[str]) -> None:
+ def run(self, args: list[str]) -> None:
init_uncaught_exception_handler(logger)
parsed_args = self.parse_args(args)
output_messages = self.run_cmd(parsed_args)
diff --git a/airbyte_cdk/destinations/vector_db_based/__init__.py b/airbyte_cdk/destinations/vector_db_based/__init__.py
index 86ae207f6..696bd2bf5 100644
--- a/airbyte_cdk/destinations/vector_db_based/__init__.py
+++ b/airbyte_cdk/destinations/vector_db_based/__init__.py
@@ -16,8 +16,8 @@
from .indexer import Indexer
from .writer import Writer
+
__all__ = [
- "AzureOpenAIEmbedder",
"AzureOpenAIEmbeddingConfigModel",
"Chunk",
"CohereEmbedder",
@@ -26,10 +26,8 @@
"Embedder",
"FakeEmbedder",
"FakeEmbeddingConfigModel",
- "FromFieldEmbedder",
"FromFieldEmbeddingConfigModel",
"Indexer",
- "OpenAICompatibleEmbedder",
"OpenAICompatibleEmbeddingConfigModel",
"OpenAIEmbedder",
"OpenAIEmbeddingConfigModel",
diff --git a/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte_cdk/destinations/vector_db_based/config.py
index c7c40ecaa..e2e412139 100644
--- a/airbyte_cdk/destinations/vector_db_based/config.py
+++ b/airbyte_cdk/destinations/vector_db_based/config.py
@@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, Dict, List, Literal, Optional, Union
+from typing import Any, Literal, Union
import dpath
from pydantic.v1 import BaseModel, Field
@@ -13,7 +13,7 @@
class SeparatorSplitterConfigModel(BaseModel):
mode: Literal["separator"] = Field("separator", const=True)
- separators: List[str] = Field(
+ separators: list[str] = Field(
default=['"\\n\\n"', '"\\n"', '" "', '""'],
title="Separators",
description='List of separator strings to split text fields by. The separator itself needs to be wrapped in double quotes, e.g. to split by the dot character, use ".". To split by a newline, use "\\n".',
@@ -77,7 +77,7 @@ class Config(OneOfOptionConfig):
discriminator = "mode"
-TextSplitterConfigModel = Union[
+TextSplitterConfigModel = Union[ # noqa: UP007
SeparatorSplitterConfigModel, MarkdownHeaderSplitterConfigModel, CodeSplitterConfigModel
]
@@ -102,14 +102,14 @@ class ProcessingConfigModel(BaseModel):
description="Size of overlap between chunks in tokens to store in vector store to better capture relevant context",
default=0,
)
- text_fields: Optional[List[str]] = Field(
+ text_fields: list[str] | None = Field(
default=[],
title="Text fields to embed",
description="List of fields in the record that should be used to calculate the embedding. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered text fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array.",
always_show=True,
examples=["text", "user.name", "users.*.name"],
)
- metadata_fields: Optional[List[str]] = Field(
+ metadata_fields: list[str] | None = Field(
default=[],
title="Fields to store as metadata",
description="List of fields in the record that should be stored as metadata. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered metadata fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array. When specifying nested paths, all matching values are flattened into an array set to a field named by the path.",
@@ -123,14 +123,14 @@ class ProcessingConfigModel(BaseModel):
type="object",
description="Split text fields into chunks based on the specified method.",
)
- field_name_mappings: Optional[List[FieldNameMappingConfigModel]] = Field(
+ field_name_mappings: list[FieldNameMappingConfigModel] | None = Field(
default=[],
title="Field name mappings",
description="List of fields to rename. Not applicable for nested fields, but can be used to rename fields already flattened via dot notation.",
)
class Config:
- schema_extra = {"group": "processing"}
+ schema_extra = {"group": "processing"} # noqa: RUF012
class OpenAIEmbeddingConfigModel(BaseModel):
@@ -251,13 +251,13 @@ class VectorDBConfigModel(BaseModel):
Processing, embedding and advanced configuration are provided by this base class, while the indexing configuration is provided by the destination connector in the sub class.
"""
- embedding: Union[
- OpenAIEmbeddingConfigModel,
- CohereEmbeddingConfigModel,
- FakeEmbeddingConfigModel,
- AzureOpenAIEmbeddingConfigModel,
- OpenAICompatibleEmbeddingConfigModel,
- ] = Field(
+ embedding: (
+ OpenAIEmbeddingConfigModel
+ | CohereEmbeddingConfigModel
+ | FakeEmbeddingConfigModel
+ | AzureOpenAIEmbeddingConfigModel
+ | OpenAICompatibleEmbeddingConfigModel
+ ) = Field(
...,
title="Embedding",
description="Embedding configuration",
@@ -275,7 +275,7 @@ class VectorDBConfigModel(BaseModel):
class Config:
title = "Destination Config"
- schema_extra = {
+ schema_extra = { # noqa: RUF012
"groups": [
{"id": "processing", "title": "Processing"},
{"id": "embedding", "title": "Embedding"},
@@ -285,14 +285,14 @@ class Config:
}
@staticmethod
- def remove_discriminator(schema: Dict[str, Any]) -> None:
+ def remove_discriminator(schema: dict[str, Any]) -> None:
"""pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references"""
dpath.delete(schema, "properties/**/discriminator")
@classmethod
- def schema(cls, by_alias: bool = True, ref_template: str = "") -> Dict[str, Any]:
+ def schema(cls, by_alias: bool = True, ref_template: str = "") -> dict[str, Any]: # noqa: FBT001, FBT002, ARG003
"""we're overriding the schema classmethod to enable some post-processing"""
- schema: Dict[str, Any] = super().schema()
+ schema: dict[str, Any] = super().schema()
schema = resolve_refs(schema)
cls.remove_discriminator(schema)
return schema
diff --git a/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte_cdk/destinations/vector_db_based/document_processor.py
index c007bf9e2..fdd8a2244 100644
--- a/airbyte_cdk/destinations/vector_db_based/document_processor.py
+++ b/airbyte_cdk/destinations/vector_db_based/document_processor.py
@@ -4,8 +4,9 @@
import json
import logging
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Any, Dict, List, Mapping, Optional, Tuple
+from typing import Any
import dpath
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
@@ -26,6 +27,7 @@
)
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType
+
METADATA_STREAM_FIELD = "_ab_stream"
METADATA_RECORD_ID_FIELD = "_ab_record_id"
@@ -34,10 +36,10 @@
@dataclass
class Chunk:
- page_content: Optional[str]
- metadata: Dict[str, Any]
+ page_content: str | None
+ metadata: dict[str, Any]
record: AirbyteRecordMessage
- embedding: Optional[List[float]] = None
+ embedding: list[float] | None = None
headers_to_split_on = [
@@ -69,7 +71,7 @@ class DocumentProcessor:
streams: Mapping[str, ConfiguredAirbyteStream]
@staticmethod
- def check_config(config: ProcessingConfigModel) -> Optional[str]:
+ def check_config(config: ProcessingConfigModel) -> str | None:
if config.text_splitter is not None and config.text_splitter.mode == "separator":
for s in config.text_splitter.separators:
try:
@@ -80,11 +82,11 @@ def check_config(config: ProcessingConfigModel) -> Optional[str]:
return f"Invalid separator: {s}. Separator needs to be a valid JSON string using double quotes."
return None
- def _get_text_splitter(
+ def _get_text_splitter( # noqa: RET503
self,
chunk_size: int,
chunk_overlap: int,
- splitter_config: Optional[TextSplitterConfigModel],
+ splitter_config: TextSplitterConfigModel | None,
) -> RecursiveCharacterTextSplitter:
if splitter_config is None:
splitter_config = SeparatorSplitterConfigModel(mode="separator")
@@ -105,7 +107,7 @@ def _get_text_splitter(
keep_separator=True,
disallowed_special=(),
)
- if splitter_config.mode == "code":
+ if splitter_config.mode == "code": # noqa: RET503, RUF100
return RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
@@ -115,7 +117,7 @@ def _get_text_splitter(
disallowed_special=(),
)
- def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog):
+ def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog): # noqa: ANN204
self.streams = {
create_stream_identifier(stream.stream): stream for stream in catalog.streams
}
@@ -128,13 +130,13 @@ def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCata
self.field_name_mappings = config.field_name_mappings
self.logger = logging.getLogger("airbyte.document_processor")
- def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[str]]:
+ def process(self, record: AirbyteRecordMessage) -> tuple[list[Chunk], str | None]:
"""
Generate documents from records.
:param records: List of AirbyteRecordMessages
:return: Tuple of (List of document chunks, record id to delete if a stream is in dedup mode to avoid stale documents in the vector store)
"""
- if CDC_DELETED_FIELD in record.data and record.data[CDC_DELETED_FIELD]:
+ if record.data.get(CDC_DELETED_FIELD):
return [], self._extract_primary_key(record)
doc = self._generate_document(record)
if doc is None:
@@ -153,13 +155,13 @@ def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[s
for chunk_document in self._split_document(doc)
]
id_to_delete = (
- doc.metadata[METADATA_RECORD_ID_FIELD]
+ doc.metadata[METADATA_RECORD_ID_FIELD] # noqa: SIM401
if METADATA_RECORD_ID_FIELD in doc.metadata
else None
)
return chunks, id_to_delete
- def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]:
+ def _generate_document(self, record: AirbyteRecordMessage) -> Document | None:
relevant_fields = self._extract_relevant_fields(record, self.text_fields)
if len(relevant_fields) == 0:
return None
@@ -168,8 +170,8 @@ def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]
return Document(page_content=text, metadata=metadata)
def _extract_relevant_fields(
- self, record: AirbyteRecordMessage, fields: Optional[List[str]]
- ) -> Dict[str, Any]:
+ self, record: AirbyteRecordMessage, fields: list[str] | None
+ ) -> dict[str, Any]:
relevant_fields = {}
if fields and len(fields) > 0:
for field in fields:
@@ -180,7 +182,7 @@ def _extract_relevant_fields(
relevant_fields = record.data
return self._remap_field_names(relevant_fields)
- def _extract_metadata(self, record: AirbyteRecordMessage) -> Dict[str, Any]:
+ def _extract_metadata(self, record: AirbyteRecordMessage) -> dict[str, Any]:
metadata = self._extract_relevant_fields(record, self.metadata_fields)
metadata[METADATA_STREAM_FIELD] = create_stream_identifier(record)
primary_key = self._extract_primary_key(record)
@@ -188,7 +190,7 @@ def _extract_metadata(self, record: AirbyteRecordMessage) -> Dict[str, Any]:
metadata[METADATA_RECORD_ID_FIELD] = primary_key
return metadata
- def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]:
+ def _extract_primary_key(self, record: AirbyteRecordMessage) -> str | None:
stream_identifier = create_stream_identifier(record)
current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier]
# if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones
@@ -207,11 +209,11 @@ def _extract_primary_key(self, record: AirbyteRecordMessage) -> Optional[str]:
stringified_primary_key = "_".join(primary_key)
return f"{stream_identifier}_{stringified_primary_key}"
- def _split_document(self, doc: Document) -> List[Document]:
- chunks: List[Document] = self.splitter.split_documents([doc])
+ def _split_document(self, doc: Document) -> list[Document]:
+ chunks: list[Document] = self.splitter.split_documents([doc])
return chunks
- def _remap_field_names(self, fields: Dict[str, Any]) -> Dict[str, Any]:
+ def _remap_field_names(self, fields: dict[str, Any]) -> dict[str, Any]:
if not self.field_name_mappings:
return fields
diff --git a/airbyte_cdk/destinations/vector_db_based/embedder.py b/airbyte_cdk/destinations/vector_db_based/embedder.py
index 6889c8e16..5d491d7d5 100644
--- a/airbyte_cdk/destinations/vector_db_based/embedder.py
+++ b/airbyte_cdk/destinations/vector_db_based/embedder.py
@@ -5,7 +5,7 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import List, Optional, Union, cast
+from typing import cast
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.fake import FakeEmbeddings
@@ -41,15 +41,15 @@ class Embedder(ABC):
The CDK defines basic embedders that should be supported in each destination. It is possible to implement custom embedders for special destinations if needed.
"""
- def __init__(self) -> None:
+ def __init__(self) -> None: # noqa: B027
pass
@abstractmethod
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
pass
@abstractmethod
- def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
+ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk.
@@ -68,19 +68,19 @@ def embedding_dimensions(self) -> int:
class BaseOpenAIEmbedder(Embedder):
- def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int):
+ def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int): # noqa: ANN204
super().__init__()
self.embeddings = embeddings
self.chunk_size = chunk_size
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
- def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
+ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
@@ -91,7 +91,7 @@ def embed_documents(self, documents: List[Document]) -> List[Optional[List[float
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request
embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size
batches = create_chunks(documents, batch_size=embedding_batch_size)
- embeddings: List[Optional[List[float]]] = []
+ embeddings: list[list[float] | None] = []
for batch in batches:
embeddings.extend(
self.embeddings.embed_documents([chunk.page_content for chunk in batch])
@@ -105,7 +105,7 @@ def embedding_dimensions(self) -> int:
class OpenAIEmbedder(BaseOpenAIEmbedder):
- def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
+ def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int): # noqa: ANN204
super().__init__(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key, max_retries=15, disallowed_special=()
@@ -115,7 +115,7 @@ def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
class AzureOpenAIEmbedder(BaseOpenAIEmbedder):
- def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
+ def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int): # noqa: ANN204
# Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request
super().__init__(
OpenAIEmbeddings( # type: ignore [call-arg]
@@ -136,23 +136,23 @@ def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
class CohereEmbedder(Embedder):
- def __init__(self, config: CohereEmbeddingConfigModel):
+ def __init__(self, config: CohereEmbeddingConfigModel): # noqa: ANN204
super().__init__()
# Client is set internally
self.embeddings = CohereEmbeddings(
cohere_api_key=config.cohere_key, model="embed-english-light-v2.0"
) # type: ignore
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
- def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
+ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]:
return cast(
- List[Optional[List[float]]],
+ list[list[float] | None], # noqa: TC006
self.embeddings.embed_documents([document.page_content for document in documents]),
)
@@ -163,20 +163,20 @@ def embedding_dimensions(self) -> int:
class FakeEmbedder(Embedder):
- def __init__(self, config: FakeEmbeddingConfigModel):
+ def __init__(self, config: FakeEmbeddingConfigModel): # noqa: ANN204, ARG002
super().__init__()
self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE)
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
- def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
+ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]:
return cast(
- List[Optional[List[float]]],
+ list[list[float] | None], # noqa: TC006
self.embeddings.embed_documents([document.page_content for document in documents]),
)
@@ -190,7 +190,7 @@ def embedding_dimensions(self) -> int:
class OpenAICompatibleEmbedder(Embedder):
- def __init__(self, config: OpenAICompatibleEmbeddingConfigModel):
+ def __init__(self, config: OpenAICompatibleEmbeddingConfigModel): # noqa: ANN204
super().__init__()
self.config = config
# Client is set internally
@@ -203,7 +203,7 @@ def __init__(self, config: OpenAICompatibleEmbeddingConfigModel):
disallowed_special=(),
) # type: ignore
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
if (
deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE
@@ -217,9 +217,9 @@ def check(self) -> Optional[str]:
return format_exception(e)
return None
- def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
+ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]:
return cast(
- List[Optional[List[float]]],
+ list[list[float] | None], # noqa: TC006
self.embeddings.embed_documents([document.page_content for document in documents]),
)
@@ -230,19 +230,19 @@ def embedding_dimensions(self) -> int:
class FromFieldEmbedder(Embedder):
- def __init__(self, config: FromFieldEmbeddingConfigModel):
+ def __init__(self, config: FromFieldEmbeddingConfigModel): # noqa: ANN204
super().__init__()
self.config = config
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
return None
- def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
+ def embed_documents(self, documents: list[Document]) -> list[list[float] | None]:
"""
From each chunk, pull the embedding from the field specified in the config.
Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem.
"""
- embeddings: List[Optional[List[float]]] = []
+ embeddings: list[list[float] | None] = []
for document in documents:
data = document.record.data
if self.config.field_name not in data:
@@ -252,7 +252,7 @@ def embed_documents(self, documents: List[Document]) -> List[Optional[List[float
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
)
field = data[self.config.field_name]
- if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field):
+ if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field): # noqa: UP038
raise AirbyteTracedException(
internal_message="Embedding vector field not a list of numbers",
failure_type=FailureType.config_error,
@@ -284,20 +284,17 @@ def embedding_dimensions(self) -> int:
def create_from_config(
- embedding_config: Union[
- AzureOpenAIEmbeddingConfigModel,
- CohereEmbeddingConfigModel,
- FakeEmbeddingConfigModel,
- FromFieldEmbeddingConfigModel,
- OpenAIEmbeddingConfigModel,
- OpenAICompatibleEmbeddingConfigModel,
- ],
+ embedding_config: AzureOpenAIEmbeddingConfigModel
+ | CohereEmbeddingConfigModel
+ | FakeEmbeddingConfigModel
+ | FromFieldEmbeddingConfigModel
+ | OpenAIEmbeddingConfigModel
+ | OpenAICompatibleEmbeddingConfigModel,
processing_config: ProcessingConfigModel,
) -> Embedder:
- if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai":
+ if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai": # noqa: PLR1714
return cast(
- Embedder,
+ Embedder, # noqa: TC006
embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size),
)
- else:
- return cast(Embedder, embedder_map[embedding_config.mode](embedding_config))
+ return cast(Embedder, embedder_map[embedding_config.mode](embedding_config)) # noqa: TC006
diff --git a/airbyte_cdk/destinations/vector_db_based/indexer.py b/airbyte_cdk/destinations/vector_db_based/indexer.py
index c49f576a6..74804b75c 100644
--- a/airbyte_cdk/destinations/vector_db_based/indexer.py
+++ b/airbyte_cdk/destinations/vector_db_based/indexer.py
@@ -4,7 +4,8 @@
import itertools
from abc import ABC, abstractmethod
-from typing import Any, Generator, Iterable, List, Optional, Tuple, TypeVar
+from collections.abc import Generator, Iterable
+from typing import Any, TypeVar
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog
@@ -18,11 +19,11 @@ class Indexer(ABC):
In a destination connector, implement a custom indexer by extending this class and implementing the abstract methods.
"""
- def __init__(self, config: Any):
+ def __init__(self, config: Any): # noqa: ANN204, ANN401
self.config = config
pass
- def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None:
+ def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None: # noqa: B027
"""
Run before the sync starts. This method should be used to make sure all records in the destination that belong to streams with a destination mode of overwrite are deleted.
@@ -31,14 +32,14 @@ def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None:
"""
pass
- def post_sync(self) -> List[AirbyteMessage]:
+ def post_sync(self) -> list[AirbyteMessage]:
"""
Run after the sync finishes. This method should be used to perform any cleanup operations and can return a list of AirbyteMessages to be logged.
"""
return []
@abstractmethod
- def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> None:
+ def index(self, document_chunks: list[Chunk], namespace: str, stream: str) -> None:
"""
Index a list of document chunks.
@@ -48,7 +49,7 @@ def index(self, document_chunks: List[Chunk], namespace: str, stream: str) -> No
pass
@abstractmethod
- def delete(self, delete_ids: List[str], namespace: str, stream: str) -> None:
+ def delete(self, delete_ids: list[str], namespace: str, stream: str) -> None:
"""
Delete document chunks belonging to certain record ids.
@@ -59,7 +60,7 @@ def delete(self, delete_ids: List[str], namespace: str, stream: str) -> None:
pass
@abstractmethod
- def check(self) -> Optional[str]:
+ def check(self) -> str | None:
"""
Check if the indexer is configured correctly. This method should be used to check if the indexer is configured correctly and return an error message if it is not.
"""
@@ -69,7 +70,7 @@ def check(self) -> Optional[str]:
T = TypeVar("T")
-def chunks(iterable: Iterable[T], batch_size: int) -> Generator[Tuple[T, ...], None, None]:
+def chunks(iterable: Iterable[T], batch_size: int) -> Generator[tuple[T, ...], None, None]:
"""A helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
chunk = tuple(itertools.islice(it, batch_size))
diff --git a/airbyte_cdk/destinations/vector_db_based/test_utils.py b/airbyte_cdk/destinations/vector_db_based/test_utils.py
index a2f3d3d83..8f29b386a 100644
--- a/airbyte_cdk/destinations/vector_db_based/test_utils.py
+++ b/airbyte_cdk/destinations/vector_db_based/test_utils.py
@@ -4,7 +4,7 @@
import json
import unittest
-from typing import Any, Dict
+from typing import Any
from airbyte_cdk.models import (
AirbyteMessage,
@@ -47,7 +47,7 @@ def _get_configured_catalog(
return ConfiguredAirbyteCatalog(streams=[overwrite_stream])
- def _state(self, data: Dict[str, Any]) -> AirbyteMessage:
+ def _state(self, data: dict[str, Any]) -> AirbyteMessage:
return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=data))
def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage:
@@ -59,5 +59,5 @@ def _record(self, stream: str, str_value: str, int_value: int) -> AirbyteMessage
)
def setUp(self) -> None:
- with open("secrets/config.json", "r") as f:
+ with open("secrets/config.json") as f: # noqa: FURB101, PLW1514, PTH123
self.config = json.loads(f.read())
diff --git a/airbyte_cdk/destinations/vector_db_based/utils.py b/airbyte_cdk/destinations/vector_db_based/utils.py
index dbb1f4714..0c77dc2df 100644
--- a/airbyte_cdk/destinations/vector_db_based/utils.py
+++ b/airbyte_cdk/destinations/vector_db_based/utils.py
@@ -4,7 +4,8 @@
import itertools
import traceback
-from typing import Any, Iterable, Iterator, Tuple, Union
+from collections.abc import Iterable, Iterator
+from typing import Any
from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStream
@@ -17,7 +18,7 @@ def format_exception(exception: Exception) -> str:
)
-def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[Any, ...]]:
+def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[tuple[Any, ...]]:
"""A helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
chunk = tuple(itertools.islice(it, batch_size))
@@ -26,10 +27,7 @@ def create_chunks(iterable: Iterable[Any], batch_size: int) -> Iterator[Tuple[An
chunk = tuple(itertools.islice(it, batch_size))
-def create_stream_identifier(stream: Union[AirbyteStream, AirbyteRecordMessage]) -> str:
+def create_stream_identifier(stream: AirbyteStream | AirbyteRecordMessage) -> str:
if isinstance(stream, AirbyteStream):
return str(stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}")
- else:
- return str(
- stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}"
- )
+ return str(stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}")
diff --git a/airbyte_cdk/destinations/vector_db_based/writer.py b/airbyte_cdk/destinations/vector_db_based/writer.py
index 45c7c7326..14f13cbdd 100644
--- a/airbyte_cdk/destinations/vector_db_based/writer.py
+++ b/airbyte_cdk/destinations/vector_db_based/writer.py
@@ -4,7 +4,7 @@
from collections import defaultdict
-from typing import Dict, Iterable, List, Tuple
+from collections.abc import Iterable
from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor
@@ -32,7 +32,7 @@ def __init__(
indexer: Indexer,
embedder: Embedder,
batch_size: int,
- omit_raw_text: bool,
+ omit_raw_text: bool, # noqa: FBT001
) -> None:
self.processing_config = processing_config
self.indexer = indexer
@@ -42,8 +42,8 @@ def __init__(
self._init_batch()
def _init_batch(self) -> None:
- self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
- self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list)
+ self.chunks: dict[tuple[str, str], list[Chunk]] = defaultdict(list)
+ self.ids_to_delete: dict[tuple[str, str], list[str]] = defaultdict(list)
self.number_of_chunks = 0
def _convert_to_document(self, chunk: Chunk) -> Document:
@@ -59,9 +59,9 @@ def _process_batch(self) -> None:
self.indexer.delete(ids, namespace, stream)
for (namespace, stream), chunks in self.chunks.items():
- embeddings = self.embedder.embed_documents(
- [self._convert_to_document(chunk) for chunk in chunks]
- )
+ embeddings = self.embedder.embed_documents([
+ self._convert_to_document(chunk) for chunk in chunks
+ ])
for i, document in enumerate(chunks):
document.embedding = embeddings[i]
if self.omit_raw_text:
@@ -84,14 +84,14 @@ def write(
elif message.type == Type.RECORD:
record_chunks, record_id_to_delete = self.processor.process(message.record)
self.chunks[
- ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]"
+ ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]" # noqa: RUF031
message.record.namespace, # type: ignore [union-attr] # record not None
message.record.stream, # type: ignore [union-attr] # record not None
)
].extend(record_chunks)
if record_id_to_delete is not None:
self.ids_to_delete[
- ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]"
+ ( # type: ignore [index] # expected "tuple[str, str]", got "tuple[str | Any | None, str | Any]" # noqa: RUF031
message.record.namespace, # type: ignore [union-attr] # record not None
message.record.stream, # type: ignore [union-attr] # record not None
)
diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py
index a5052a575..092025f7d 100644
--- a/airbyte_cdk/entrypoint.py
+++ b/airbyte_cdk/entrypoint.py
@@ -12,8 +12,9 @@
import sys
import tempfile
from collections import defaultdict
+from collections.abc import Iterable, Mapping
from functools import wraps
-from typing import Any, DefaultDict, Iterable, List, Mapping, Optional
+from typing import Any
from urllib.parse import urlparse
import orjson
@@ -43,6 +44,7 @@
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
logger = init_logger("airbyte")
VALID_URL_SCHEMES = ["https"]
@@ -50,8 +52,8 @@
_HAS_LOGGED_FOR_SERIALIZATION_ERROR = False
-class AirbyteEntrypoint(object):
- def __init__(self, source: Source):
+class AirbyteEntrypoint:
+ def __init__(self, source: Source): # noqa: ANN204
init_uncaught_exception_handler(logger)
# Deployment mode is read when instantiating the entrypoint because it is the common path shared by syncs and connector builder test requests
@@ -62,7 +64,7 @@ def __init__(self, source: Source):
self.logger = logging.getLogger(f"airbyte.{getattr(source, 'name', '')}")
@staticmethod
- def parse_args(args: List[str]) -> argparse.Namespace:
+ def parse_args(args: list[str]) -> argparse.Namespace:
# set up parent parsers
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument(
@@ -120,7 +122,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
def run(self, parsed_args: argparse.Namespace) -> Iterable[str]:
cmd = parsed_args.command
if not cmd:
- raise Exception("No command passed")
+ raise Exception("No command passed") # noqa: TRY002
if hasattr(parsed_args, "debug") and parsed_args.debug:
self.logger.setLevel(logging.DEBUG)
@@ -173,7 +175,7 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]:
self.read(source_spec, config, config_catalog, state),
)
else:
- raise Exception("Unexpected command " + cmd)
+ raise Exception("Unexpected command " + cmd) # noqa: TRY002
finally:
yield from [
self.airbyte_message_to_string(queued_message)
@@ -191,7 +193,7 @@ def check(
# The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error
# If the failure is not exceptional, we'll emit a failed connection status message and return
if traced_exc.failure_type != FailureType.config_error:
- raise traced_exc
+ raise traced_exc # noqa: TRY201
if connection_status:
yield from self._emit_queued_messages(self.source)
yield connection_status
@@ -204,7 +206,7 @@ def check(
# The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error
# If the failure is not exceptional, we'll emit a failed connection status message and return
if traced_exc.failure_type != FailureType.config_error:
- raise traced_exc
+ raise traced_exc # noqa: TRY201
else:
yield AirbyteMessage(
type=Type.CONNECTION_STATUS,
@@ -233,14 +235,18 @@ def discover(
yield AirbyteMessage(type=Type.CATALOG, catalog=catalog)
def read(
- self, source_spec: ConnectorSpecification, config: TConfig, catalog: Any, state: list[Any]
+ self,
+ source_spec: ConnectorSpecification,
+ config: TConfig,
+ catalog: Any, # noqa: ANN401
+ state: list[Any],
) -> Iterable[AirbyteMessage]:
self.set_up_secret_filter(config, source_spec.connectionSpecification)
if self.source.check_config_against_spec:
self.validate_connection(source_spec, config)
# The Airbyte protocol dictates that counts be expressed as float/double to better protect against integer overflows
- stream_message_counter: DefaultDict[HashableStreamDescriptor, float] = defaultdict(float)
+ stream_message_counter: defaultdict[HashableStreamDescriptor, float] = defaultdict(float)
for message in self.source.read(self.logger, config, catalog, state):
yield self.handle_record_counts(message, stream_message_counter)
for message in self._emit_queued_messages(self.source):
@@ -248,7 +254,7 @@ def read(
@staticmethod
def handle_record_counts(
- message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float]
+ message: AirbyteMessage, stream_message_count: defaultdict[HashableStreamDescriptor, float]
) -> AirbyteMessage:
match message.type:
case Type.RECORD:
@@ -306,21 +312,21 @@ def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
return json.dumps(serialized_message)
@classmethod
- def extract_state(cls, args: List[str]) -> Optional[Any]:
+ def extract_state(cls, args: list[str]) -> Any | None: # noqa: ANN401
parsed_args = cls.parse_args(args)
if hasattr(parsed_args, "state"):
return parsed_args.state
return None
@classmethod
- def extract_catalog(cls, args: List[str]) -> Optional[Any]:
+ def extract_catalog(cls, args: list[str]) -> Any | None: # noqa: ANN401
parsed_args = cls.parse_args(args)
if hasattr(parsed_args, "catalog"):
return parsed_args.catalog
return None
@classmethod
- def extract_config(cls, args: List[str]) -> Optional[Any]:
+ def extract_config(cls, args: list[str]) -> Any | None: # noqa: ANN401
parsed_args = cls.parse_args(args)
if hasattr(parsed_args, "config"):
return parsed_args.config
@@ -332,7 +338,7 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]:
return
-def launch(source: Source, args: List[str]) -> None:
+def launch(source: Source, args: list[str]) -> None:
source_entrypoint = AirbyteEntrypoint(source)
parsed_args = source_entrypoint.parse_args(args)
# temporarily removes the PrintBuffer because we're seeing weird print behavior for concurrent syncs
@@ -351,12 +357,12 @@ def _init_internal_request_filter() -> None:
wrapped_fn = Session.send
@wraps(wrapped_fn)
- def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Response:
+ def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Response: # noqa: ANN401
parsed_url = urlparse(request.url)
if parsed_url.scheme not in VALID_URL_SCHEMES:
raise requests.exceptions.InvalidSchema(
- "Invalid Protocol Scheme: The endpoint that data is being requested from is using an invalid or insecure "
+ "Invalid Protocol Scheme: The endpoint that data is being requested from is using an invalid or insecure " # noqa: ISC003
+ f"protocol {parsed_url.scheme!r}. Valid protocol schemes: {','.join(VALID_URL_SCHEMES)}"
)
@@ -377,7 +383,7 @@ def filtered_send(self: Any, request: PreparedRequest, **kwargs: Any) -> Respons
# This is a special case where the developer specifies an IP address string that is not formatted correctly like trailing
# whitespace which will fail the socket IP lookup. This only happens when using IP addresses and not text hostnames.
# Knowing that this is a request using the requests library, we will mock the exception without calling the lib
- raise requests.exceptions.InvalidURL(f"Invalid URL {parsed_url}: {exception}")
+ raise requests.exceptions.InvalidURL(f"Invalid URL {parsed_url}: {exception}") # noqa: B904
return wrapped_fn(self, request, **kwargs)
@@ -409,6 +415,6 @@ def main() -> None:
source = impl()
if not isinstance(source, Source):
- raise Exception("Source implementation provided does not implement Source class!")
+ raise Exception("Source implementation provided does not implement Source class!") # noqa: TRY002, TRY004
launch(source, sys.argv[1:])
diff --git a/airbyte_cdk/exception_handler.py b/airbyte_cdk/exception_handler.py
index 84aa39ba1..78170375b 100644
--- a/airbyte_cdk/exception_handler.py
+++ b/airbyte_cdk/exception_handler.py
@@ -4,8 +4,9 @@
import logging
import sys
+from collections.abc import Mapping
from types import TracebackType
-from typing import Any, List, Mapping, Optional
+from typing import Any
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
@@ -28,8 +29,8 @@ def init_uncaught_exception_handler(logger: logging.Logger) -> None:
def hook_fn(
exception_type: type[BaseException],
exception_value: BaseException,
- traceback_: Optional[TracebackType],
- ) -> Any:
+ traceback_: TracebackType | None,
+ ) -> Any: # noqa: ANN401
# For developer ergonomics, we want to see the stack trace in the logs when we do a ctrl-c
if issubclass(exception_type, KeyboardInterrupt):
sys.__excepthook__(exception_type, exception_value, traceback_)
@@ -45,12 +46,10 @@ def hook_fn(
sys.excepthook = hook_fn
-def generate_failed_streams_error_message(stream_failures: Mapping[str, List[Exception]]) -> str:
- failures = "\n".join(
- [
- f"{stream}: {filter_secrets(exception.__repr__())}"
- for stream, exceptions in stream_failures.items()
- for exception in exceptions
- ]
- )
+def generate_failed_streams_error_message(stream_failures: Mapping[str, list[Exception]]) -> str:
+ failures = "\n".join([
+ f"{stream}: {filter_secrets(exception.__repr__())}" # noqa: PLC2801
+ for stream, exceptions in stream_failures.items()
+ for exception in exceptions
+ ])
return f"During the sync, the following streams did not sync successfully: {failures}"
diff --git a/airbyte_cdk/logger.py b/airbyte_cdk/logger.py
index 78061b605..75f29c7ae 100644
--- a/airbyte_cdk/logger.py
+++ b/airbyte_cdk/logger.py
@@ -5,7 +5,8 @@
import json
import logging
import logging.config
-from typing import Any, Callable, Mapping, Optional, Tuple
+from collections.abc import Callable, Mapping
+from typing import Any
import orjson
@@ -18,6 +19,7 @@
)
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
+
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
@@ -37,7 +39,7 @@
}
-def init_logger(name: Optional[str] = None) -> logging.Logger:
+def init_logger(name: str | None = None) -> logging.Logger:
"""Initial set up of logger"""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
@@ -57,7 +59,7 @@ class AirbyteLogFormatter(logging.Formatter):
"""Output log records using AirbyteMessage"""
# Transforming Python log levels to Airbyte protocol log levels
- level_mapping = {
+ level_mapping = { # noqa: RUF012
logging.FATAL: Level.FATAL,
logging.ERROR: Level.ERROR,
logging.WARNING: Level.WARN,
@@ -72,13 +74,12 @@ def format(self, record: logging.LogRecord) -> str:
extras = self.extract_extra_args_from_record(record)
debug_dict = {"type": "DEBUG", "message": record.getMessage(), "data": extras}
return filter_secrets(json.dumps(debug_dict))
- else:
- message = super().format(record)
- message = filter_secrets(message)
- log_message = AirbyteMessage(
- type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message)
- )
- return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode()
+ message = super().format(record)
+ message = filter_secrets(message)
+ log_message = AirbyteMessage(
+ type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message)
+ )
+ return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode()
@staticmethod
def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]:
@@ -91,7 +92,7 @@ def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, An
return {k: str(getattr(record, k)) for k in extra_keys if hasattr(record, k)}
-def log_by_prefix(msg: str, default_level: str) -> Tuple[int, str]:
+def log_by_prefix(msg: str, default_level: str) -> tuple[int, str]:
"""Custom method, which takes log level from first word of message"""
valid_log_types = ["FATAL", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"]
split_line = msg.split()
diff --git a/airbyte_cdk/models/__init__.py b/airbyte_cdk/models/__init__.py
index 3fa24be49..1d4fc1f38 100644
--- a/airbyte_cdk/models/__init__.py
+++ b/airbyte_cdk/models/__init__.py
@@ -1,3 +1,5 @@
+# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
+#
# The earlier versions of airbyte-cdk (0.28.0<=) had the airbyte_protocol python classes
# declared inline in the airbyte-cdk code. However, somewhere around Feb 2023 the
# Airbyte Protocol moved to its own repo/PyPi package, called airbyte-protocol-models.
@@ -6,66 +8,66 @@
# to make the airbyte_protocol python classes available to the airbyte-cdk consumer as part
# of airbyte-cdk rather than a standalone package.
from .airbyte_protocol import (
- AdvancedAuth,
- AirbyteAnalyticsTraceMessage,
- AirbyteCatalog,
- AirbyteConnectionStatus,
- AirbyteControlConnectorConfigMessage,
- AirbyteControlMessage,
- AirbyteErrorTraceMessage,
- AirbyteEstimateTraceMessage,
- AirbyteGlobalState,
- AirbyteLogMessage,
- AirbyteMessage,
- AirbyteProtocol,
- AirbyteRecordMessage,
- AirbyteStateBlob,
- AirbyteStateMessage,
- AirbyteStateStats,
- AirbyteStateType,
- AirbyteStream,
- AirbyteStreamState,
- AirbyteStreamStatus,
- AirbyteStreamStatusReason,
- AirbyteStreamStatusReasonType,
- AirbyteStreamStatusTraceMessage,
- AirbyteTraceMessage,
- AuthFlowType,
- ConfiguredAirbyteCatalog,
- ConfiguredAirbyteStream,
- ConnectorSpecification,
- DestinationSyncMode,
- EstimateType,
- FailureType,
- Level,
- OAuthConfigSpecification,
- OauthConnectorInputSpecification,
- OrchestratorType,
- State,
- Status,
- StreamDescriptor,
- SyncMode,
- TraceType,
- Type,
+ AdvancedAuth, # noqa: F401
+ AirbyteAnalyticsTraceMessage, # noqa: F401
+ AirbyteCatalog, # noqa: F401
+ AirbyteConnectionStatus, # noqa: F401
+ AirbyteControlConnectorConfigMessage, # noqa: F401
+ AirbyteControlMessage, # noqa: F401
+ AirbyteErrorTraceMessage, # noqa: F401
+ AirbyteEstimateTraceMessage, # noqa: F401
+ AirbyteGlobalState, # noqa: F401
+ AirbyteLogMessage, # noqa: F401
+ AirbyteMessage, # noqa: F401
+ AirbyteProtocol, # noqa: F401
+ AirbyteRecordMessage, # noqa: F401
+ AirbyteStateBlob, # noqa: F401
+ AirbyteStateMessage, # noqa: F401
+ AirbyteStateStats, # noqa: F401
+ AirbyteStateType, # noqa: F401
+ AirbyteStream, # noqa: F401
+ AirbyteStreamState, # noqa: F401
+ AirbyteStreamStatus, # noqa: F401
+ AirbyteStreamStatusReason, # noqa: F401
+ AirbyteStreamStatusReasonType, # noqa: F401
+ AirbyteStreamStatusTraceMessage, # noqa: F401
+ AirbyteTraceMessage, # noqa: F401
+ AuthFlowType, # noqa: F401
+ ConfiguredAirbyteCatalog, # noqa: F401
+ ConfiguredAirbyteStream, # noqa: F401
+ ConnectorSpecification, # noqa: F401
+ DestinationSyncMode, # noqa: F401
+ EstimateType, # noqa: F401
+ FailureType, # noqa: F401
+ Level, # noqa: F401
+ OAuthConfigSpecification, # noqa: F401
+ OauthConnectorInputSpecification, # noqa: F401
+ OrchestratorType, # noqa: F401
+ State, # noqa: F401
+ Status, # noqa: F401
+ StreamDescriptor, # noqa: F401
+ SyncMode, # noqa: F401
+ TraceType, # noqa: F401
+ Type, # noqa: F401
)
from .airbyte_protocol_serializers import (
- AirbyteMessageSerializer,
- AirbyteStateMessageSerializer,
- AirbyteStreamStateSerializer,
- ConfiguredAirbyteCatalogSerializer,
- ConfiguredAirbyteStreamSerializer,
- ConnectorSpecificationSerializer,
+ AirbyteMessageSerializer, # noqa: F401
+ AirbyteStateMessageSerializer, # noqa: F401
+ AirbyteStreamStateSerializer, # noqa: F401
+ ConfiguredAirbyteCatalogSerializer, # noqa: F401
+ ConfiguredAirbyteStreamSerializer, # noqa: F401
+ ConnectorSpecificationSerializer, # noqa: F401
)
from .well_known_types import (
- BinaryData,
- Boolean,
- Date,
- Integer,
- Model,
- Number,
- String,
- TimestampWithoutTimezone,
- TimestampWithTimezone,
- TimeWithoutTimezone,
- TimeWithTimezone,
+ BinaryData, # noqa: F401
+ Boolean, # noqa: F401
+ Date, # noqa: F401
+ Integer, # noqa: F401
+ Model, # noqa: F401
+ Number, # noqa: F401
+ String, # noqa: F401
+ TimestampWithoutTimezone, # noqa: F401
+ TimestampWithTimezone, # noqa: F401
+ TimeWithoutTimezone, # noqa: F401
+ TimeWithTimezone, # noqa: F401
)
diff --git a/airbyte_cdk/models/airbyte_protocol.py b/airbyte_cdk/models/airbyte_protocol.py
index 2528f7d7e..592cd138b 100644
--- a/airbyte_cdk/models/airbyte_protocol.py
+++ b/airbyte_cdk/models/airbyte_protocol.py
@@ -2,19 +2,22 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Annotated, Any, Dict, List, Mapping, Optional, Union
+from typing import Annotated, Any
-from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*'
from serpyco_rs.metadata import Alias
+from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*'
+
from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage
+
# ruff: noqa: F405 # ignore fuzzy import issues with 'import *'
@dataclass
-class AirbyteStateBlob:
+class AirbyteStateBlob: # noqa: PLW1641
"""
A dataclass that dynamically sets attributes based on provided keyword arguments and positional arguments.
Used to "mimic" pydantic Basemodel with ConfigDict(extra='allow') option.
@@ -37,7 +40,7 @@ class AirbyteStateBlob:
kwargs: InitVar[Mapping[str, Any]]
- def __init__(self, *args: Any, **kwargs: Any) -> None:
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
# Set any attribute passed in through kwargs
for arg in args:
self.__dict__.update(arg)
@@ -56,35 +59,35 @@ def __eq__(self, other: object) -> bool:
@dataclass
class AirbyteStreamState:
stream_descriptor: StreamDescriptor # type: ignore [name-defined]
- stream_state: Optional[AirbyteStateBlob] = None
+ stream_state: AirbyteStateBlob | None = None
@dataclass
class AirbyteGlobalState:
- stream_states: List[AirbyteStreamState]
- shared_state: Optional[AirbyteStateBlob] = None
+ stream_states: list[AirbyteStreamState]
+ shared_state: AirbyteStateBlob | None = None
@dataclass
class AirbyteStateMessage:
- type: Optional[AirbyteStateType] = None # type: ignore [name-defined]
- stream: Optional[AirbyteStreamState] = None
+ type: AirbyteStateType | None = None # type: ignore [name-defined]
+ stream: AirbyteStreamState | None = None
global_: Annotated[AirbyteGlobalState | None, Alias("global")] = (
None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization
)
- data: Optional[Dict[str, Any]] = None
- sourceStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
- destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
+ data: dict[str, Any] | None = None
+ sourceStats: AirbyteStateStats | None = None # type: ignore [name-defined] # noqa: N815
+ destinationStats: AirbyteStateStats | None = None # type: ignore [name-defined] # noqa: N815
@dataclass
class AirbyteMessage:
type: Type # type: ignore [name-defined]
- log: Optional[AirbyteLogMessage] = None # type: ignore [name-defined]
- spec: Optional[ConnectorSpecification] = None # type: ignore [name-defined]
- connectionStatus: Optional[AirbyteConnectionStatus] = None # type: ignore [name-defined]
- catalog: Optional[AirbyteCatalog] = None # type: ignore [name-defined]
- record: Optional[Union[AirbyteFileTransferRecordMessage, AirbyteRecordMessage]] = None # type: ignore [name-defined]
- state: Optional[AirbyteStateMessage] = None
- trace: Optional[AirbyteTraceMessage] = None # type: ignore [name-defined]
- control: Optional[AirbyteControlMessage] = None # type: ignore [name-defined]
+ log: AirbyteLogMessage | None = None # type: ignore [name-defined]
+ spec: ConnectorSpecification | None = None # type: ignore [name-defined]
+ connectionStatus: AirbyteConnectionStatus | None = None # type: ignore [name-defined] # noqa: N815
+ catalog: AirbyteCatalog | None = None # type: ignore [name-defined]
+ record: AirbyteFileTransferRecordMessage | AirbyteRecordMessage | None = None # type: ignore [name-defined]
+ state: AirbyteStateMessage | None = None
+ trace: AirbyteTraceMessage | None = None # type: ignore [name-defined]
+ control: AirbyteControlMessage | None = None # type: ignore [name-defined]
diff --git a/airbyte_cdk/models/airbyte_protocol_serializers.py b/airbyte_cdk/models/airbyte_protocol_serializers.py
index 129556acc..6ce15d130 100644
--- a/airbyte_cdk/models/airbyte_protocol_serializers.py
+++ b/airbyte_cdk/models/airbyte_protocol_serializers.py
@@ -1,5 +1,5 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
-from typing import Any, Dict
+from typing import Any
from serpyco_rs import CustomType, Serializer
@@ -14,19 +14,19 @@
)
-class AirbyteStateBlobType(CustomType[AirbyteStateBlob, Dict[str, Any]]):
- def serialize(self, value: AirbyteStateBlob) -> Dict[str, Any]:
+class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]):
+ def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]:
# cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete"
return {k: v for k, v in value.__dict__.items()}
- def deserialize(self, value: Dict[str, Any]) -> AirbyteStateBlob:
+ def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob:
return AirbyteStateBlob(value)
- def get_json_schema(self) -> Dict[str, Any]:
+ def get_json_schema(self) -> dict[str, Any]:
return {"type": "object"}
-def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, Dict[str, Any]] | None:
+def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None:
return AirbyteStateBlobType() if t is AirbyteStateBlob else None
diff --git a/airbyte_cdk/models/file_transfer_record_message.py b/airbyte_cdk/models/file_transfer_record_message.py
index dcc1b7a92..8bde8b408 100644
--- a/airbyte_cdk/models/file_transfer_record_message.py
+++ b/airbyte_cdk/models/file_transfer_record_message.py
@@ -1,13 +1,13 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Any
@dataclass
class AirbyteFileTransferRecordMessage:
stream: str
- file: Dict[str, Any]
+ file: dict[str, Any]
emitted_at: int
- namespace: Optional[str] = None
- data: Optional[Dict[str, Any]] = None
+ namespace: str | None = None
+ data: dict[str, Any] | None = None
diff --git a/airbyte_cdk/sources/__init__.py b/airbyte_cdk/sources/__init__.py
index a6560a503..162ac01f2 100644
--- a/airbyte_cdk/sources/__init__.py
+++ b/airbyte_cdk/sources/__init__.py
@@ -8,6 +8,7 @@
from .config import BaseConfig
from .source import Source
+
# As part of the CDK sources, we do not control what the APIs return and it is possible that a key is empty.
# Reasons why we are doing this at the airbyte_cdk level:
# * As of today, all the use cases should allow for empty keys
diff --git a/airbyte_cdk/sources/abstract_source.py b/airbyte_cdk/sources/abstract_source.py
index ab9ee48b8..5eef75933 100644
--- a/airbyte_cdk/sources/abstract_source.py
+++ b/airbyte_cdk/sources/abstract_source.py
@@ -4,17 +4,9 @@
import logging
from abc import ABC, abstractmethod
+from collections.abc import Iterable, Iterator, Mapping, MutableMapping
from typing import (
Any,
- Dict,
- Iterable,
- Iterator,
- List,
- Mapping,
- MutableMapping,
- Optional,
- Tuple,
- Union,
)
from airbyte_cdk.exception_handler import generate_failed_streams_error_message
@@ -46,6 +38,7 @@
)
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
_default_message_repository = InMemoryMessageRepository()
@@ -58,7 +51,7 @@ class AbstractSource(Source, ABC):
@abstractmethod
def check_connection(
self, logger: logging.Logger, config: Mapping[str, Any]
- ) -> Tuple[bool, Optional[Any]]:
+ ) -> tuple[bool, Any | None]:
"""
:param logger: source logger
:param config: The user-provided configuration as specified by the source's spec.
@@ -71,7 +64,7 @@ def check_connection(
"""
@abstractmethod
- def streams(self, config: Mapping[str, Any]) -> List[Stream]:
+ def streams(self, config: Mapping[str, Any]) -> list[Stream]:
"""
:param config: The user-provided configuration as specified by the source's spec.
Any stream construction related operation should happen here.
@@ -79,10 +72,10 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
"""
# Stream name to instance map for applying output object transformation
- _stream_to_instance_map: Dict[str, Stream] = {}
+ _stream_to_instance_map: dict[str, Stream] = {} # noqa: RUF012
_slice_logger: SliceLogger = DebugSliceLogger()
- def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
+ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: # noqa: ARG002
"""Implements the Discover operation from the Airbyte Specification.
See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/#discover.
"""
@@ -98,17 +91,17 @@ def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCon
return AirbyteConnectionStatus(status=Status.FAILED, message=repr(error))
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
- def read(
+ def read( # noqa: PLR0915
self,
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
- state: Optional[List[AirbyteStateMessage]] = None,
+ state: list[AirbyteStateMessage] | None = None,
) -> Iterator[AirbyteMessage]:
"""Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.com/understanding-airbyte/airbyte-protocol/."""
logger.info(f"Starting syncing {self.name}")
config, internal_config = split_config(config)
- # TODO assert all streams exist in the connector
+ # TODO assert all streams exist in the connector # noqa: TD004
# get the streams once in case the connector needs to make any queries to generate them
stream_instances = {s.name: s for s in self.streams(config)}
state_manager = ConnectorStateManager(state=state)
@@ -137,7 +130,7 @@ def read(
# Use configured_stream as stream_instance to support references in error handling.
stream_instance = configured_stream.stream
- raise AirbyteTracedException(
+ raise AirbyteTracedException( # noqa: TRY301
message="A stream listed in your configuration was not found in the source. Please check the logs for more "
"details.",
internal_message=error_message,
@@ -198,9 +191,9 @@ def read(
logger.info(timer.report())
if len(stream_name_to_exception) > 0:
- error_message = generate_failed_streams_error_message(
- {key: [value] for key, value in stream_name_to_exception.items()}
- )
+ error_message = generate_failed_streams_error_message({
+ key: [value] for key, value in stream_name_to_exception.items()
+ })
logger.info(error_message)
# We still raise at least one exception when a stream raises an exception because the platform currently relies
# on a non-zero exit code to determine if a sync attempt has failed. We also raise the exception as a config_error
@@ -212,7 +205,7 @@ def read(
@staticmethod
def _serialize_exception(
- stream_descriptor: StreamDescriptor, e: Exception, stream_instance: Optional[Stream] = None
+ stream_descriptor: StreamDescriptor, e: Exception, stream_instance: Stream | None = None
) -> AirbyteTracedException:
display_message = stream_instance.get_error_display_message(e) if stream_instance else None
if display_message:
@@ -294,7 +287,7 @@ def _emit_queued_messages(self) -> Iterable[AirbyteMessage]:
return
def _get_message(
- self, record_data_or_message: Union[StreamData, AirbyteMessage], stream: Stream
+ self, record_data_or_message: StreamData | AirbyteMessage, stream: Stream
) -> AirbyteMessage:
"""
Converts the input to an AirbyteMessage if it is a StreamData. Returns the input as is if it is already an AirbyteMessage
@@ -311,7 +304,7 @@ def _get_message(
)
@property
- def message_repository(self) -> Union[None, MessageRepository]:
+ def message_repository(self) -> None | MessageRepository: # noqa: RUF036
return _default_message_repository
@property
diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py
index f57db7e14..688c6beda 100644
--- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py
+++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py
@@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
-from typing import Dict, Iterable, List, Optional, Set
+from collections.abc import Iterable
from airbyte_cdk.exception_handler import generate_failed_streams_error_message
from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus, FailureType, StreamDescriptor
@@ -28,9 +28,9 @@
class ConcurrentReadProcessor:
- def __init__(
+ def __init__( # noqa: ANN204
self,
- stream_instances_to_read_from: List[AbstractStream],
+ stream_instances_to_read_from: list[AbstractStream],
partition_enqueuer: PartitionEnqueuer,
thread_pool_manager: ThreadPoolManager,
logger: logging.Logger,
@@ -50,20 +50,20 @@ def __init__(
"""
self._stream_name_to_instance = {s.name: s for s in stream_instances_to_read_from}
self._record_counter = {}
- self._streams_to_running_partitions: Dict[str, Set[Partition]] = {}
+ self._streams_to_running_partitions: dict[str, set[Partition]] = {}
for stream in stream_instances_to_read_from:
self._streams_to_running_partitions[stream.name] = set()
self._record_counter[stream.name] = 0
self._thread_pool_manager = thread_pool_manager
self._partition_enqueuer = partition_enqueuer
self._stream_instances_to_start_partition_generation = stream_instances_to_read_from
- self._streams_currently_generating_partitions: List[str] = []
+ self._streams_currently_generating_partitions: list[str] = []
self._logger = logger
self._slice_logger = slice_logger
self._message_repository = message_repository
self._partition_reader = partition_reader
- self._streams_done: Set[str] = set()
- self._exceptions_per_stream_name: dict[str, List[Exception]] = {}
+ self._streams_done: set[str] = set()
+ self._exceptions_per_stream_name: dict[str, list[Exception]] = {}
def on_partition_generation_completed(
self, sentinel: PartitionGenerationCompletedSentinel
@@ -186,7 +186,7 @@ def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMess
def _flag_exception(self, stream_name: str, exception: Exception) -> None:
self._exceptions_per_stream_name.setdefault(stream_name, []).append(exception)
- def start_next_partition_generator(self) -> Optional[AirbyteMessage]:
+ def start_next_partition_generator(self) -> AirbyteMessage | None:
"""
Start the next partition generator.
1. Pop the next stream to read from
@@ -204,8 +204,7 @@ def start_next_partition_generator(self) -> Optional[AirbyteMessage]:
stream.as_airbyte_stream(),
AirbyteStreamStatus.STARTED,
)
- else:
- return None
+ return None
def is_done(self) -> bool:
"""
@@ -215,12 +214,10 @@ def is_done(self) -> bool:
2. There are no more streams to read from
3. All partitions for all streams are closed
"""
- is_done = all(
- [
- self._is_stream_done(stream_name)
- for stream_name in self._stream_name_to_instance.keys()
- ]
- )
+ is_done = all([ # noqa: C419
+ self._is_stream_done(stream_name)
+ for stream_name in self._stream_name_to_instance.keys() # noqa: SIM118
+ ])
if is_done and self._exceptions_per_stream_name:
error_message = generate_failed_streams_error_message(self._exceptions_per_stream_name)
self._logger.info(error_message)
diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py
index bc7d97cdd..59f847cb9 100644
--- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py
+++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py
@@ -3,8 +3,8 @@
#
import concurrent
import logging
+from collections.abc import Iterable, Iterator
from queue import Queue
-from typing import Iterable, Iterator, List
from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
@@ -49,9 +49,9 @@ def create(
too_many_generator = (
not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers
)
- assert (
- not too_many_generator
- ), "It is required to have more workers than threads generating partitions"
+ assert not too_many_generator, (
+ "It is required to have more workers than threads generating partitions"
+ )
threadpool = ThreadPoolManager(
concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers, thread_name_prefix="workerpool"
@@ -71,8 +71,8 @@ def __init__(
self,
threadpool: ThreadPoolManager,
logger: logging.Logger,
- slice_logger: SliceLogger = DebugSliceLogger(),
- message_repository: MessageRepository = InMemoryMessageRepository(),
+ slice_logger: SliceLogger = DebugSliceLogger(), # noqa: B008
+ message_repository: MessageRepository = InMemoryMessageRepository(), # noqa: B008
initial_number_partitions_to_generate: int = 1,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> None:
@@ -93,7 +93,7 @@ def __init__(
def read(
self,
- streams: List[AbstractStream],
+ streams: list[AbstractStream],
) -> Iterator[AirbyteMessage]:
self._logger.info("Starting syncing")
@@ -162,4 +162,4 @@ def _handle_item(
elif isinstance(queue_item, Record):
yield from concurrent_stream_processor.on_record(queue_item)
else:
- raise ValueError(f"Unknown queue item type: {type(queue_item)}")
+ raise ValueError(f"Unknown queue item type: {type(queue_item)}") # noqa: TRY004
diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py
index c150dc956..91b7eb58d 100644
--- a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py
+++ b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py
@@ -4,8 +4,9 @@
import logging
from abc import ABC
+from collections.abc import Callable, Iterator, Mapping, MutableMapping
from datetime import timedelta
-from typing import Any, Callable, Iterator, List, Mapping, MutableMapping, Optional, Tuple
+from typing import Any
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.sources import AbstractSource
@@ -27,11 +28,12 @@
AbstractStreamStateConverter,
)
+
DEFAULT_LOOKBACK_SECONDS = 0
class ConcurrentSourceAdapter(AbstractSource, ABC):
- def __init__(self, concurrent_source: ConcurrentSource, **kwargs: Any) -> None:
+ def __init__(self, concurrent_source: ConcurrentSource, **kwargs: Any) -> None: # noqa: ANN401
"""
ConcurrentSourceAdapter is a Source that wraps a concurrent source and exposes it as a regular source.
@@ -47,7 +49,7 @@ def read(
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
- state: Optional[List[AirbyteStateMessage]] = None,
+ state: list[AirbyteStateMessage] | None = None,
) -> Iterator[AirbyteMessage]:
abstract_streams = self._select_abstract_streams(config, catalog)
concurrent_stream_names = {stream.name for stream in abstract_streams}
@@ -65,13 +67,13 @@ def read(
def _select_abstract_streams(
self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog
- ) -> List[AbstractStream]:
+ ) -> list[AbstractStream]:
"""
Selects streams that can be processed concurrently and returns their abstract representations.
"""
all_streams = self.streams(config)
stream_name_to_instance: Mapping[str, Stream] = {s.name: s for s in all_streams}
- abstract_streams: List[AbstractStream] = []
+ abstract_streams: list[AbstractStream] = []
for configured_stream in configured_catalog.streams:
stream_instance = stream_name_to_instance.get(configured_stream.stream.name)
if not stream_instance:
@@ -86,7 +88,7 @@ def convert_to_concurrent_stream(
logger: logging.Logger,
stream: Stream,
state_manager: ConnectorStateManager,
- cursor: Optional[Cursor] = None,
+ cursor: Cursor | None = None,
) -> Stream:
"""
Prepares a stream for concurrent processing by initializing or assigning a cursor,
@@ -113,12 +115,12 @@ def initialize_cursor(
stream: Stream,
state_manager: ConnectorStateManager,
converter: AbstractStreamStateConverter,
- slice_boundary_fields: Optional[Tuple[str, str]],
- start: Optional[CursorValueType],
+ slice_boundary_fields: tuple[str, str] | None,
+ start: CursorValueType | None,
end_provider: Callable[[], CursorValueType],
- lookback_window: Optional[GapType] = None,
- slice_range: Optional[GapType] = None,
- ) -> Optional[ConcurrentCursor]:
+ lookback_window: GapType | None = None,
+ slice_range: GapType | None = None,
+ ) -> ConcurrentCursor | None:
lookback_window = lookback_window or timedelta(seconds=DEFAULT_LOOKBACK_SECONDS)
cursor_field_name = stream.cursor_field
diff --git a/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py b/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py
index b6643042b..4cdbd221f 100644
--- a/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py
+++ b/airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py
@@ -1,24 +1,23 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
-class PartitionGenerationCompletedSentinel:
+class PartitionGenerationCompletedSentinel: # noqa: PLW1641
"""
A sentinel object indicating all partitions for a stream were produced.
Includes a pointer to the stream that was processed.
"""
- def __init__(self, stream: AbstractStream):
+ def __init__(self, stream: AbstractStream): # noqa: ANN204
"""
:param stream: The stream that was processed
"""
self.stream = stream
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
if isinstance(other, PartitionGenerationCompletedSentinel):
return self.stream == other.stream
return False
diff --git a/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py b/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py
index c865bef59..da034575d 100644
--- a/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py
+++ b/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py
@@ -1,10 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
-from typing import Any
-
-class StreamThreadException(Exception):
- def __init__(self, exception: Exception, stream_name: str):
+class StreamThreadException(Exception): # noqa: PLW1641
+ def __init__(self, exception: Exception, stream_name: str): # noqa: ANN204
self._exception = exception
self._stream_name = stream_name
@@ -19,7 +17,7 @@ def exception(self) -> Exception:
def __str__(self) -> str:
return f"Exception while syncing stream {self._stream_name}: {self._exception}"
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
if isinstance(other, StreamThreadException):
return self._exception == other._exception and self._stream_name == other._stream_name
return False
diff --git a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py
index 59f8a1f0b..556d0117a 100644
--- a/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py
+++ b/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py
@@ -3,8 +3,9 @@
#
import logging
import threading
+from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
-from typing import Any, Callable, List, Optional
+from typing import Any
class ThreadPoolManager:
@@ -14,7 +15,7 @@ class ThreadPoolManager:
DEFAULT_MAX_QUEUE_SIZE = 10_000
- def __init__(
+ def __init__( # noqa: ANN204
self,
threadpool: ThreadPoolExecutor,
logger: logging.Logger,
@@ -28,9 +29,9 @@ def __init__(
self._threadpool = threadpool
self._logger = logger
self._max_concurrent_tasks = max_concurrent_tasks
- self._futures: List[Future[Any]] = []
+ self._futures: list[Future[Any]] = []
self._lock = threading.Lock()
- self._most_recently_seen_exception: Optional[Exception] = None
+ self._most_recently_seen_exception: Exception | None = None
self._logging_threshold = max_concurrent_tasks * 2
@@ -42,10 +43,10 @@ def prune_to_validate_has_reached_futures_limit(self) -> bool:
)
return len(self._futures) >= self._max_concurrent_tasks
- def submit(self, function: Callable[..., Any], *args: Any) -> None:
+ def submit(self, function: Callable[..., Any], *args: Any) -> None: # noqa: ANN401
self._futures.append(self._threadpool.submit(function, *args))
- def _prune_futures(self, futures: List[Future[Any]]) -> None:
+ def _prune_futures(self, futures: list[Future[Any]]) -> None:
"""
Take a list in input and remove the futures that are completed. If a future has an exception, it'll raise and kill the stream
operation.
@@ -79,7 +80,7 @@ def _shutdown(self) -> None:
self._threadpool.shutdown(wait=False, cancel_futures=True)
def is_done(self) -> bool:
- return all([f.done() for f in self._futures])
+ return all([f.done() for f in self._futures]) # noqa: C419
def check_for_errors_and_shutdown(self) -> None:
"""
diff --git a/airbyte_cdk/sources/config.py b/airbyte_cdk/sources/config.py
index ea91b17f3..d0ce9c896 100644
--- a/airbyte_cdk/sources/config.py
+++ b/airbyte_cdk/sources/config.py
@@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, Dict
+from typing import Any
from pydantic.v1 import BaseModel
@@ -18,7 +18,7 @@ class BaseConfig(BaseModel):
"""
@classmethod
- def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
+ def schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
"""We're overriding the schema classmethod to enable some post-processing"""
schema = super().schema(*args, **kwargs)
rename_key(schema, old_key="anyOf", new_key="oneOf") # UI supports only oneOf
diff --git a/airbyte_cdk/sources/connector_state_manager.py b/airbyte_cdk/sources/connector_state_manager.py
index 914374a55..50fb16542 100644
--- a/airbyte_cdk/sources/connector_state_manager.py
+++ b/airbyte_cdk/sources/connector_state_manager.py
@@ -3,8 +3,9 @@
#
import copy
+from collections.abc import Mapping, MutableMapping
from dataclasses import dataclass
-from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union, cast
+from typing import Any, cast
from airbyte_cdk.models import (
AirbyteMessage,
@@ -15,7 +16,7 @@
StreamDescriptor,
)
from airbyte_cdk.models import Type as MessageType
-from airbyte_cdk.models.airbyte_protocol import AirbyteGlobalState, AirbyteStateBlob
+from airbyte_cdk.models.airbyte_protocol import AirbyteGlobalState, AirbyteStateBlob # noqa: F811
@dataclass(frozen=True)
@@ -26,7 +27,7 @@ class HashableStreamDescriptor:
"""
name: str
- namespace: Optional[str] = None
+ namespace: str | None = None
class ConnectorStateManager:
@@ -35,7 +36,7 @@ class ConnectorStateManager:
interface. It also provides methods to extract and update state
"""
- def __init__(self, state: Optional[List[AirbyteStateMessage]] = None):
+ def __init__(self, state: list[AirbyteStateMessage] | None = None): # noqa: ANN204
shared_state, per_stream_states = self._extract_from_state_message(state)
# We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
@@ -50,9 +51,7 @@ def __init__(self, state: Optional[List[AirbyteStateMessage]] = None):
)
self.per_stream_states = per_stream_states
- def get_stream_state(
- self, stream_name: str, namespace: Optional[str]
- ) -> MutableMapping[str, Any]:
+ def get_stream_state(self, stream_name: str, namespace: str | None) -> MutableMapping[str, Any]:
"""
Retrieves the state of a given stream based on its descriptor (name + namespace).
:param stream_name: Name of the stream being fetched
@@ -67,7 +66,7 @@ def get_stream_state(
return {}
def update_state_for_stream(
- self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]
+ self, stream_name: str, namespace: str | None, value: Mapping[str, Any]
) -> None:
"""
Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
@@ -78,7 +77,7 @@ def update_state_for_stream(
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value)
- def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage:
+ def create_state_message(self, stream_name: str, namespace: str | None) -> AirbyteMessage:
"""
Generates an AirbyteMessage using the current per-stream state of a specified stream
:param stream_name: The name of the stream for the message that is being created
@@ -102,10 +101,10 @@ def create_state_message(self, stream_name: str, namespace: Optional[str]) -> Ai
@classmethod
def _extract_from_state_message(
cls,
- state: Optional[List[AirbyteStateMessage]],
- ) -> Tuple[
- Optional[AirbyteStateBlob],
- MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]],
+ state: list[AirbyteStateMessage] | None,
+ ) -> tuple[
+ AirbyteStateBlob | None,
+ MutableMapping[HashableStreamDescriptor, AirbyteStateBlob | None],
]:
"""
Takes an incoming list of state messages or a global state message and extracts state attributes according to
@@ -120,10 +119,11 @@ def _extract_from_state_message(
if is_global:
# We already validate that this is a global state message, not None:
- global_state = cast(AirbyteGlobalState, state[0].global_)
+ global_state = cast(AirbyteGlobalState, state[0].global_) # noqa: TC006
# global_state has shared_state, also not None:
shared_state: AirbyteStateBlob = cast(
- AirbyteStateBlob, copy.deepcopy(global_state.shared_state, {})
+ "AirbyteStateBlob",
+ copy.deepcopy(global_state.shared_state, {}),
)
streams = {
HashableStreamDescriptor(
@@ -133,22 +133,21 @@ def _extract_from_state_message(
for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state
}
return shared_state, streams
- else:
- streams = {
- HashableStreamDescriptor(
- name=per_stream_state.stream.stream_descriptor.name, # type: ignore[union-attr] # stream has stream_descriptor
- namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor
- ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state
- for per_stream_state in state
- if per_stream_state.type == AirbyteStateType.STREAM
- and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True
- }
- return None, streams
+ streams = {
+ HashableStreamDescriptor(
+ name=per_stream_state.stream.stream_descriptor.name, # type: ignore[union-attr] # stream has stream_descriptor
+ namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor
+ ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state
+ for per_stream_state in state
+ if per_stream_state.type == AirbyteStateType.STREAM
+ and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True
+ }
+ return None, streams
@staticmethod
- def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
+ def _is_global_state(state: list[AirbyteStateMessage] | MutableMapping[str, Any]) -> bool:
return (
- isinstance(state, List)
+ isinstance(state, list)
and len(state) == 1
and isinstance(state[0], AirbyteStateMessage)
and state[0].type == AirbyteStateType.GLOBAL
@@ -156,6 +155,6 @@ def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str,
@staticmethod
def _is_per_stream_state(
- state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]],
+ state: list[AirbyteStateMessage] | MutableMapping[str, Any],
) -> bool:
- return isinstance(state, List)
+ return isinstance(state, list)
diff --git a/airbyte_cdk/sources/declarative/async_job/job.py b/airbyte_cdk/sources/declarative/async_job/job.py
index b075b61e2..7283a34e4 100644
--- a/airbyte_cdk/sources/declarative/async_job/job.py
+++ b/airbyte_cdk/sources/declarative/async_job/job.py
@@ -2,13 +2,11 @@
from datetime import timedelta
-from typing import Optional
+from .status import AsyncJobStatus
from airbyte_cdk.sources.declarative.async_job.timer import Timer
from airbyte_cdk.sources.types import StreamSlice
-from .status import AsyncJobStatus
-
class AsyncJob:
"""
@@ -19,13 +17,13 @@ class AsyncJob:
"""
def __init__(
- self, api_job_id: str, job_parameters: StreamSlice, timeout: Optional[timedelta] = None
+ self, api_job_id: str, job_parameters: StreamSlice, timeout: timedelta | None = None
) -> None:
self._api_job_id = api_job_id
self._job_parameters = job_parameters
self._status = AsyncJobStatus.RUNNING
- timeout = timeout if timeout else timedelta(minutes=60)
+ timeout = timeout if timeout else timedelta(minutes=60) # noqa: FURB110
self._timer = Timer(timeout)
self._timer.start()
diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py
index 3938b8c07..bb9232eff 100644
--- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py
+++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py
@@ -5,18 +5,11 @@
import time
import traceback
import uuid
+from collections.abc import Generator, Iterable, Mapping
from datetime import timedelta
from typing import (
Any,
- Generator,
Generic,
- Iterable,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Type,
TypeVar,
)
@@ -34,6 +27,7 @@
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
LOGGER = logging.getLogger("airbyte")
_NO_TIMEOUT = timedelta.max
_API_SIDE_RUNNING_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT}
@@ -46,30 +40,30 @@ class AsyncPartition:
_MAX_NUMBER_OF_ATTEMPTS = 3
- def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None:
- self._attempts_per_job = {job: 1 for job in jobs}
+ def __init__(self, jobs: list[AsyncJob], stream_slice: StreamSlice) -> None:
+ self._attempts_per_job = {job: 1 for job in jobs} # noqa: C420
self._stream_slice = stream_slice
def has_reached_max_attempt(self) -> bool:
return any(
- map(
+ map( # noqa: C417
lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS,
self._attempts_per_job.values(),
)
)
- def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> None:
+ def replace_job(self, job_to_replace: AsyncJob, new_jobs: list[AsyncJob]) -> None:
current_attempt_count = self._attempts_per_job.pop(job_to_replace, None)
if current_attempt_count is None:
raise ValueError("Could not find job to replace")
- elif current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS:
+ if current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS:
raise ValueError(f"Max attempt reached for job in partition {self._stream_slice}")
new_attempt_count = current_attempt_count + 1
for job in new_jobs:
self._attempts_per_job[job] = new_attempt_count
- def should_split(self, job: AsyncJob) -> bool:
+ def should_split(self, job: AsyncJob) -> bool: # noqa: ARG002
"""
Not used right now but once we support job split, we should split based on the number of attempts
"""
@@ -88,20 +82,19 @@ def status(self) -> AsyncJobStatus:
"""
Given different job statuses, the priority is: FAILED, TIMED_OUT, RUNNING. Else, it means everything is completed.
"""
- statuses = set(map(lambda job: job.status(), self.jobs))
+ statuses = set(map(lambda job: job.status(), self.jobs)) # noqa: C417
if statuses == {AsyncJobStatus.COMPLETED}:
return AsyncJobStatus.COMPLETED
- elif AsyncJobStatus.FAILED in statuses:
+ if AsyncJobStatus.FAILED in statuses:
return AsyncJobStatus.FAILED
- elif AsyncJobStatus.TIMED_OUT in statuses:
+ if AsyncJobStatus.TIMED_OUT in statuses:
return AsyncJobStatus.TIMED_OUT
- else:
- return AsyncJobStatus.RUNNING
+ return AsyncJobStatus.RUNNING
def __repr__(self) -> str:
return f"AsyncPartition(stream_slice={self._stream_slice}, attempt_per_job={self._attempts_per_job})"
- def __json_serializable__(self) -> Any:
+ def __json_serializable__(self) -> Any: # noqa: ANN401, PLW3201
return self._stream_slice
@@ -111,7 +104,7 @@ def __json_serializable__(self) -> Any:
class LookaheadIterator(Generic[T]):
def __init__(self, iterable: Iterable[T]) -> None:
self._iterator = iter(iterable)
- self._buffer: List[T] = []
+ self._buffer: list[T] = []
def __iter__(self) -> "LookaheadIterator[T]":
return self
@@ -119,8 +112,7 @@ def __iter__(self) -> "LookaheadIterator[T]":
def __next__(self) -> T:
if self._buffer:
return self._buffer.pop()
- else:
- return next(self._iterator)
+ return next(self._iterator)
def has_next(self) -> bool:
if self._buffer:
@@ -134,18 +126,18 @@ def has_next(self) -> bool:
return True
def add_at_the_beginning(self, item: T) -> None:
- self._buffer = [item] + self._buffer
+ self._buffer = [item] + self._buffer # noqa: RUF005
class AsyncJobOrchestrator:
_WAIT_TIME_BETWEEN_STATUS_UPDATE_IN_SECONDS = 5
- _KNOWN_JOB_STATUSES = {
+ _KNOWN_JOB_STATUSES = { # noqa: RUF012
AsyncJobStatus.COMPLETED,
AsyncJobStatus.FAILED,
AsyncJobStatus.RUNNING,
AsyncJobStatus.TIMED_OUT,
}
- _RUNNING_ON_API_SIDE_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT}
+ _RUNNING_ON_API_SIDE_STATUS = {AsyncJobStatus.RUNNING, AsyncJobStatus.TIMED_OUT} # noqa: RUF012
def __init__(
self,
@@ -153,8 +145,8 @@ def __init__(
slices: Iterable[StreamSlice],
job_tracker: JobTracker,
message_repository: MessageRepository,
- exceptions_to_break_on: Iterable[Type[Exception]] = tuple(),
- has_bulk_parent: bool = False,
+ exceptions_to_break_on: Iterable[type[Exception]] = tuple(), # noqa: C408
+ has_bulk_parent: bool = False, # noqa: FBT001, FBT002
) -> None:
"""
If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent`
@@ -170,13 +162,13 @@ def __init__(
self._job_repository: AsyncJobRepository = job_repository
self._slice_iterator = LookaheadIterator(slices)
- self._running_partitions: List[AsyncPartition] = []
+ self._running_partitions: list[AsyncPartition] = []
self._job_tracker = job_tracker
self._message_repository = message_repository
- self._exceptions_to_break_on: Tuple[Type[Exception], ...] = tuple(exceptions_to_break_on)
+ self._exceptions_to_break_on: tuple[type[Exception], ...] = tuple(exceptions_to_break_on)
self._has_bulk_parent = has_bulk_parent
- self._non_breaking_exceptions: List[Exception] = []
+ self._non_breaking_exceptions: list[Exception] = []
def _replace_failed_jobs(self, partition: AsyncPartition) -> None:
failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT)
@@ -225,7 +217,7 @@ def _start_jobs(self) -> None:
"Waiting before creating more jobs as the limit of concurrent jobs has been reached. Will try again later..."
)
- def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None) -> AsyncJob:
+ def _start_job(self, _slice: StreamSlice, previous_job_id: str | None = None) -> AsyncJob:
if previous_job_id:
id_to_replace = previous_job_id
lazy_log(LOGGER, logging.DEBUG, lambda: f"Attempting to replace job {id_to_replace}...")
@@ -235,16 +227,19 @@ def _start_job(self, _slice: StreamSlice, previous_job_id: Optional[str] = None)
try:
job = self._job_repository.start(_slice)
self._job_tracker.add_job(id_to_replace, job.api_job_id())
- return job
+ return job # noqa: TRY300
except Exception as exception:
LOGGER.warning(f"Exception has occurred during job creation: {exception}")
if self._is_breaking_exception(exception):
self._job_tracker.remove_job(id_to_replace)
- raise exception
+ raise exception # noqa: TRY201
return self._keep_api_budget_with_failed_job(_slice, exception, id_to_replace)
def _keep_api_budget_with_failed_job(
- self, _slice: StreamSlice, exception: Exception, intent: str
+ self,
+ slice_: StreamSlice,
+ exception: Exception,
+ intent: str,
) -> AsyncJob:
"""
We have a mechanism to retry job. It is used when a job status is FAILED or TIMED_OUT. The easiest way to retry is to have this job
@@ -252,7 +247,7 @@ def _keep_api_budget_with_failed_job(
retrying jobs that couldn't be started.
"""
LOGGER.warning(
- f"Could not start job for slice {_slice}. Job will be flagged as failed and retried if max number of attempts not reached: {exception}"
+ f"Could not start job for slice {slice_}. Job will be flagged as failed and retried if max number of attempts not reached: {exception}"
)
traced_exception = (
exception
@@ -262,7 +257,7 @@ def _keep_api_budget_with_failed_job(
# Even though we're not sure this will break the stream, we will emit here for simplicity's sake. If we wanted to be more accurate,
# we would keep the exceptions in-memory until we know that we have reached the max attempt.
self._message_repository.emit_message(traced_exception.as_airbyte_message())
- job = self._create_failed_job(_slice)
+ job = self._create_failed_job(slice_)
self._job_tracker.add_job(intent, job.api_job_id())
return job
@@ -271,7 +266,7 @@ def _create_failed_job(self, stream_slice: StreamSlice) -> AsyncJob:
job.update_status(AsyncJobStatus.FAILED)
return job
- def _get_running_jobs(self) -> Set[AsyncJob]:
+ def _get_running_jobs(self) -> set[AsyncJob]:
"""
Returns a set of running AsyncJob objects.
@@ -324,7 +319,7 @@ def _process_completed_partition(self, partition: AsyncPartition) -> None:
Args:
partition (AsyncPartition): The completed partition to process.
"""
- job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs}))
+ job_ids = list(map(lambda job: job.api_job_id(), {job for job in partition.jobs})) # noqa: C417
LOGGER.info(
f"The following jobs for stream slice {partition.stream_slice} have been completed: {job_ids}."
)
@@ -346,7 +341,7 @@ def _process_running_partitions_and_yield_completed_ones(
Raises:
Any: Any exception raised during processing.
"""
- current_running_partitions: List[AsyncPartition] = []
+ current_running_partitions: list[AsyncPartition] = []
for partition in self._running_partitions:
match partition.status:
case AsyncJobStatus.COMPLETED:
@@ -384,7 +379,7 @@ def _stop_timed_out_jobs(self, partition: AsyncPartition) -> None:
# we don't free allocation here because it is expected to retry the job
self._abort_job(job, free_job_allocation=False)
- def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
+ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None: # noqa: FBT001, FBT002
try:
self._job_repository.abort(job)
if free_job_allocation:
@@ -442,7 +437,7 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]:
f"Caught exception that stops the processing of the jobs: {exception}"
)
self._abort_all_running_jobs()
- raise exception
+ raise exception # noqa: TRY201
self._non_breaking_exceptions.append(exception)
@@ -454,12 +449,10 @@ def create_and_get_completed_partitions(self) -> Iterable[AsyncPartition]:
# call of `create_and_get_completed_partitions` knows that there was an issue with some partitions and the sync is incomplete.
raise AirbyteTracedException(
message="",
- internal_message="\n".join(
- [
- filter_secrets(exception.__repr__())
- for exception in self._non_breaking_exceptions
- ]
- ),
+ internal_message="\n".join([
+ filter_secrets(exception.__repr__()) # noqa: PLC2801
+ for exception in self._non_breaking_exceptions
+ ]),
failure_type=FailureType.config_error,
)
diff --git a/airbyte_cdk/sources/declarative/async_job/job_tracker.py b/airbyte_cdk/sources/declarative/async_job/job_tracker.py
index b47fc4cad..25c3c570e 100644
--- a/airbyte_cdk/sources/declarative/async_job/job_tracker.py
+++ b/airbyte_cdk/sources/declarative/async_job/job_tracker.py
@@ -3,10 +3,10 @@
import logging
import threading
import uuid
-from typing import Set
from airbyte_cdk.logger import lazy_log
+
LOGGER = logging.getLogger("airbyte")
@@ -15,8 +15,8 @@ class ConcurrentJobLimitReached(Exception):
class JobTracker:
- def __init__(self, limit: int):
- self._jobs: Set[str] = set()
+ def __init__(self, limit: int): # noqa: ANN204
+ self._jobs: set[str] = set()
self._limit = limit
self._lock = threading.Lock()
@@ -31,7 +31,7 @@ def try_to_get_intent(self) -> str:
raise ConcurrentJobLimitReached(
"Can't allocate more jobs right now: limit already reached"
)
- intent = f"intent_{str(uuid.uuid4())}"
+ intent = f"intent_{uuid.uuid4()!s}"
lazy_log(
LOGGER,
logging.DEBUG,
diff --git a/airbyte_cdk/sources/declarative/async_job/repository.py b/airbyte_cdk/sources/declarative/async_job/repository.py
index 21581ec4f..df46748b4 100644
--- a/airbyte_cdk/sources/declarative/async_job/repository.py
+++ b/airbyte_cdk/sources/declarative/async_job/repository.py
@@ -1,7 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from abc import abstractmethod
-from typing import Any, Iterable, Mapping, Set
+from collections.abc import Iterable, Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
from airbyte_cdk.sources.types import StreamSlice
@@ -13,7 +14,7 @@ def start(self, stream_slice: StreamSlice) -> AsyncJob:
pass
@abstractmethod
- def update_jobs_status(self, jobs: Set[AsyncJob]) -> None:
+ def update_jobs_status(self, jobs: set[AsyncJob]) -> None:
pass
@abstractmethod
diff --git a/airbyte_cdk/sources/declarative/async_job/status.py b/airbyte_cdk/sources/declarative/async_job/status.py
index 586e79889..abd7bb4b9 100644
--- a/airbyte_cdk/sources/declarative/async_job/status.py
+++ b/airbyte_cdk/sources/declarative/async_job/status.py
@@ -3,6 +3,7 @@
from enum import Enum
+
_TERMINAL = True
@@ -12,7 +13,7 @@ class AsyncJobStatus(Enum):
FAILED = ("FAILED", _TERMINAL)
TIMED_OUT = ("TIMED_OUT", _TERMINAL)
- def __init__(self, value: str, is_terminal: bool) -> None:
+ def __init__(self, value: str, is_terminal: bool) -> None: # noqa: FBT001
self._value = value
self._is_terminal = is_terminal
diff --git a/airbyte_cdk/sources/declarative/async_job/timer.py b/airbyte_cdk/sources/declarative/async_job/timer.py
index c4e5a9a1d..a1686bc04 100644
--- a/airbyte_cdk/sources/declarative/async_job/timer.py
+++ b/airbyte_cdk/sources/declarative/async_job/timer.py
@@ -1,12 +1,11 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from datetime import datetime, timedelta, timezone
-from typing import Optional
class Timer:
def __init__(self, timeout: timedelta) -> None:
- self._start_datetime: Optional[datetime] = None
- self._end_datetime: Optional[datetime] = None
+ self._start_datetime: datetime | None = None
+ self._end_datetime: datetime | None = None
self._timeout = timeout
def start(self) -> None:
@@ -21,7 +20,7 @@ def is_started(self) -> bool:
return self._start_datetime is not None
@property
- def elapsed_time(self) -> Optional[timedelta]:
+ def elapsed_time(self) -> timedelta | None:
if not self._start_datetime:
return None
diff --git a/airbyte_cdk/sources/declarative/auth/__init__.py b/airbyte_cdk/sources/declarative/auth/__init__.py
index 810437810..a0b61072d 100644
--- a/airbyte_cdk/sources/declarative/auth/__init__.py
+++ b/airbyte_cdk/sources/declarative/auth/__init__.py
@@ -5,4 +5,5 @@
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator
+
__all__ = ["DeclarativeOauth2Authenticator", "JwtAuthenticator"]
diff --git a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py
index b749718fa..5a7949eb2 100644
--- a/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py
+++ b/airbyte_cdk/sources/declarative/auth/declarative_authenticator.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Union
+from typing import Any
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import (
AbstractHeaderAuthenticator,
@@ -20,7 +21,7 @@ def get_request_params(self) -> Mapping[str, Any]:
"""HTTP request parameter to add to the requests"""
return {}
- def get_request_body_data(self) -> Union[Mapping[str, Any], str]:
+ def get_request_body_data(self) -> Mapping[str, Any] | str:
"""Form-encoded body data to set on the requests"""
return {}
diff --git a/airbyte_cdk/sources/declarative/auth/jwt.py b/airbyte_cdk/sources/declarative/auth/jwt.py
index d7dd59282..005c336a5 100644
--- a/airbyte_cdk/sources/declarative/auth/jwt.py
+++ b/airbyte_cdk/sources/declarative/auth/jwt.py
@@ -3,9 +3,10 @@
#
import base64
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
from datetime import datetime
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import jwt
@@ -15,7 +16,7 @@
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
-class JwtAlgorithm(str):
+class JwtAlgorithm(str): # noqa: SLOT000
"""
Enum for supported JWT algorithms
"""
@@ -60,19 +61,19 @@ class JwtAuthenticator(DeclarativeAuthenticator):
config: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
- secret_key: Union[InterpolatedString, str]
- algorithm: Union[str, JwtAlgorithm]
- token_duration: Optional[int]
- base64_encode_secret_key: Optional[Union[InterpolatedBoolean, str, bool]] = False
- header_prefix: Optional[Union[InterpolatedString, str]] = None
- kid: Optional[Union[InterpolatedString, str]] = None
- typ: Optional[Union[InterpolatedString, str]] = None
- cty: Optional[Union[InterpolatedString, str]] = None
- iss: Optional[Union[InterpolatedString, str]] = None
- sub: Optional[Union[InterpolatedString, str]] = None
- aud: Optional[Union[InterpolatedString, str]] = None
- additional_jwt_headers: Optional[Mapping[str, Any]] = None
- additional_jwt_payload: Optional[Mapping[str, Any]] = None
+ secret_key: InterpolatedString | str
+ algorithm: str | JwtAlgorithm
+ token_duration: int | None
+ base64_encode_secret_key: InterpolatedBoolean | str | bool | None = False
+ header_prefix: InterpolatedString | str | None = None
+ kid: InterpolatedString | str | None = None
+ typ: InterpolatedString | str | None = None
+ cty: InterpolatedString | str | None = None
+ iss: InterpolatedString | str | None = None
+ sub: InterpolatedString | str | None = None
+ aud: InterpolatedString | str | None = None
+ additional_jwt_headers: Mapping[str, Any] | None = None
+ additional_jwt_payload: Mapping[str, Any] | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters)
@@ -158,7 +159,7 @@ def _get_secret_key(self) -> str:
else secret_key
)
- def _get_signed_token(self) -> Union[str, Any]:
+ def _get_signed_token(self) -> str | Any: # noqa: ANN401
"""
Signed the JWT using the provided secret key and algorithm and the generated headers and payload. For additional information on PyJWT see: https://pyjwt.readthedocs.io/en/stable/
"""
@@ -170,9 +171,9 @@ def _get_signed_token(self) -> Union[str, Any]:
headers=self._get_jwt_headers(),
)
except Exception as e:
- raise ValueError(f"Failed to sign token: {e}")
+ raise ValueError(f"Failed to sign token: {e}") # noqa: B904
- def _get_header_prefix(self) -> Union[str, None]:
+ def _get_header_prefix(self) -> str | None:
"""
Returns the header prefix to be used when attaching the token to the request.
"""
diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py
index 36508fd7e..a7dda3921 100644
--- a/airbyte_cdk/sources/declarative/auth/oauth.py
+++ b/airbyte_cdk/sources/declarative/auth/oauth.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, List, Mapping, Optional, Union
+from typing import Any
import pendulum
@@ -44,33 +45,33 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""
- client_id: Union[InterpolatedString, str]
- client_secret: Union[InterpolatedString, str]
+ client_id: InterpolatedString | str
+ client_secret: InterpolatedString | str
config: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
- token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
- refresh_token: Optional[Union[InterpolatedString, str]] = None
- scopes: Optional[List[str]] = None
- token_expiry_date: Optional[Union[InterpolatedString, str]] = None
- _token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None)
- token_expiry_date_format: Optional[str] = None
+ token_refresh_endpoint: InterpolatedString | str | None = None
+ refresh_token: InterpolatedString | str | None = None
+ scopes: list[str] | None = None
+ token_expiry_date: InterpolatedString | str | None = None
+ _token_expiry_date: pendulum.DateTime | None = field(init=False, repr=False, default=None)
+ token_expiry_date_format: str | None = None
token_expiry_is_time_of_expiration: bool = False
- access_token_name: Union[InterpolatedString, str] = "access_token"
- access_token_value: Optional[Union[InterpolatedString, str]] = None
- client_id_name: Union[InterpolatedString, str] = "client_id"
- client_secret_name: Union[InterpolatedString, str] = "client_secret"
- expires_in_name: Union[InterpolatedString, str] = "expires_in"
- refresh_token_name: Union[InterpolatedString, str] = "refresh_token"
- refresh_request_body: Optional[Mapping[str, Any]] = None
- refresh_request_headers: Optional[Mapping[str, Any]] = None
- grant_type_name: Union[InterpolatedString, str] = "grant_type"
- grant_type: Union[InterpolatedString, str] = "refresh_token"
- message_repository: MessageRepository = NoopMessageRepository()
+ access_token_name: InterpolatedString | str = "access_token"
+ access_token_value: InterpolatedString | str | None = None
+ client_id_name: InterpolatedString | str = "client_id"
+ client_secret_name: InterpolatedString | str = "client_secret"
+ expires_in_name: InterpolatedString | str = "expires_in"
+ refresh_token_name: InterpolatedString | str = "refresh_token"
+ refresh_request_body: Mapping[str, Any] | None = None
+ refresh_request_headers: Mapping[str, Any] | None = None
+ grant_type_name: InterpolatedString | str = "grant_type"
+ grant_type: InterpolatedString | str = "refresh_token"
+ message_repository: MessageRepository = NoopMessageRepository() # noqa: RUF009
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
if self.token_refresh_endpoint is not None:
- self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create(
+ self._token_refresh_endpoint: InterpolatedString | None = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
)
else:
@@ -85,7 +86,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.refresh_token_name, parameters=parameters
)
if self.refresh_token is not None:
- self._refresh_token: Optional[InterpolatedString] = InterpolatedString.create(
+ self._refresh_token: InterpolatedString | None = InterpolatedString.create(
self.refresh_token, parameters=parameters
)
else:
@@ -122,7 +123,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else:
self._access_token_value = None
- self._access_token: Optional[str] = (
+ self._access_token: str | None = (
self._access_token_value if self.access_token_value else None
)
@@ -131,7 +132,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
"OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`"
)
- def get_token_refresh_endpoint(self) -> Optional[str]:
+ def get_token_refresh_endpoint(self) -> str | None:
if self._token_refresh_endpoint is not None:
refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token_endpoint:
@@ -162,10 +163,10 @@ def get_client_secret(self) -> str:
def get_refresh_token_name(self) -> str:
return self._refresh_token_name.eval(self.config) # type: ignore # eval returns a string in this context
- def get_refresh_token(self) -> Optional[str]:
+ def get_refresh_token(self) -> str | None:
return None if self._refresh_token is None else str(self._refresh_token.eval(self.config))
- def get_scopes(self) -> List[str]:
+ def get_scopes(self) -> list[str]:
return self.scopes or []
def get_access_token_name(self) -> str:
@@ -189,7 +190,7 @@ def get_refresh_request_headers(self) -> Mapping[str, Any]:
def get_token_expiry_date(self) -> pendulum.DateTime:
return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks
- def set_token_expiry_date(self, value: Union[str, int]) -> None:
+ def set_token_expiry_date(self, value: str | int) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)
@property
@@ -218,5 +219,5 @@ class DeclarativeSingleUseRefreshTokenOauth2Authenticator(
Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors.
"""
- def __init__(self, *args: Any, **kwargs: Any) -> None:
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
diff --git a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py
index 3a84150bf..fdd0b4deb 100644
--- a/airbyte_cdk/sources/declarative/auth/selective_authenticator.py
+++ b/airbyte_cdk/sources/declarative/auth/selective_authenticator.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Any, List, Mapping
+from typing import Any
import dpath
@@ -16,16 +17,16 @@ class SelectiveAuthenticator(DeclarativeAuthenticator):
config: Mapping[str, Any]
authenticators: Mapping[str, DeclarativeAuthenticator]
- authenticator_selection_path: List[str]
+ authenticator_selection_path: list[str]
# returns "DeclarativeAuthenticator", but must return a subtype of "SelectiveAuthenticator"
def __new__( # type: ignore[misc]
cls,
config: Mapping[str, Any],
authenticators: Mapping[str, DeclarativeAuthenticator],
- authenticator_selection_path: List[str],
- *arg: Any,
- **kwargs: Any,
+ authenticator_selection_path: list[str],
+ *arg: Any, # noqa: ANN401, ARG003
+ **kwargs: Any, # noqa: ANN401, ARG003
) -> DeclarativeAuthenticator:
try:
selected_key = str(
diff --git a/airbyte_cdk/sources/declarative/auth/token.py b/airbyte_cdk/sources/declarative/auth/token.py
index 12fb899b9..0dbf0cf2a 100644
--- a/airbyte_cdk/sources/declarative/auth/token.py
+++ b/airbyte_cdk/sources/declarative/auth/token.py
@@ -1,11 +1,12 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import base64
import logging
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Union
+from typing import Any
import requests
from cachetools import TTLCache, cached
@@ -68,7 +69,7 @@ def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, A
def get_request_params(self) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter)
- def get_request_body_data(self) -> Union[Mapping[str, Any], str]:
+ def get_request_body_data(self) -> Mapping[str, Any] | str:
return self._get_request_options(RequestOptionType.body_data)
def get_request_body_json(self) -> Mapping[str, Any]:
@@ -118,10 +119,10 @@ class BasicHttpAuthenticator(DeclarativeAuthenticator):
parameters (Mapping[str, Any]): Additional runtime parameters to be used for string interpolation
"""
- username: Union[InterpolatedString, str]
+ username: InterpolatedString | str
config: Config
parameters: InitVar[Mapping[str, Any]]
- password: Union[InterpolatedString, str] = ""
+ password: InterpolatedString | str = ""
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._username = InterpolatedString.create(self.username, parameters=parameters)
@@ -134,7 +135,7 @@ def auth_header(self) -> str:
@property
def token(self) -> str:
auth_string = (
- f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode("utf8")
+ f"{self._username.eval(self.config)}:{self._password.eval(self.config)}".encode()
)
b64_encoded = base64.b64encode(auth_string).decode("utf8")
return f"Basic {b64_encoded}"
@@ -148,7 +149,7 @@ def token(self) -> str:
i.e. by adding another item the cache would exceed its maximum size, the cache must choose which item(s) to discard
ttl=86400 means that cached token will live for 86400 seconds (one day)
"""
-cacheSessionTokenAuthenticator: TTLCache[str, str] = TTLCache(maxsize=1000, ttl=86400)
+cacheSessionTokenAuthenticator: TTLCache[str, str] = TTLCache(maxsize=1000, ttl=86400) # noqa: N816
@cached(cacheSessionTokenAuthenticator)
@@ -201,16 +202,16 @@ class LegacySessionTokenAuthenticator(DeclarativeAuthenticator):
validate_session_url (Union[InterpolatedString, str]): Url to validate passed session token
"""
- api_url: Union[InterpolatedString, str]
- header: Union[InterpolatedString, str]
- session_token: Union[InterpolatedString, str]
- session_token_response_key: Union[InterpolatedString, str]
- username: Union[InterpolatedString, str]
+ api_url: InterpolatedString | str
+ header: InterpolatedString | str
+ session_token: InterpolatedString | str
+ session_token_response_key: InterpolatedString | str
+ username: InterpolatedString | str
config: Config
parameters: InitVar[Mapping[str, Any]]
- login_url: Union[InterpolatedString, str]
- validate_session_url: Union[InterpolatedString, str]
- password: Union[InterpolatedString, str] = ""
+ login_url: InterpolatedString | str
+ validate_session_url: InterpolatedString | str
+ password: InterpolatedString | str = ""
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._username = InterpolatedString.create(self.username, parameters=parameters)
@@ -234,7 +235,7 @@ def auth_header(self) -> str:
@property
def token(self) -> str:
- if self._session_token.eval(self.config):
+ if self._session_token.eval(self.config): # noqa: SIM102
if self.is_valid_session_token():
return str(self._session_token.eval(self.config))
if self._password.eval(self.config) and self._username.eval(self.config):
@@ -259,14 +260,12 @@ def is_valid_session_token(self) -> bool:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code == requests.codes["unauthorized"]:
- self.logger.info(f"Unable to connect by session token from config due to {str(e)}")
+ self.logger.info(f"Unable to connect by session token from config due to {e!s}")
return False
- else:
- raise ConnectionError(f"Error while validating session token: {e}")
+ raise ConnectionError(f"Error while validating session token: {e}") # noqa: B904
if response.ok:
self.logger.info("Connection check for source is successful.")
return True
- else:
- raise ConnectionError(
- f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}"
- )
+ raise ConnectionError(
+ f"Failed to retrieve new session token, response code {response.status_code} because {response.reason}"
+ )
diff --git a/airbyte_cdk/sources/declarative/auth/token_provider.py b/airbyte_cdk/sources/declarative/auth/token_provider.py
index ed933bc59..1b1a9eabc 100644
--- a/airbyte_cdk/sources/declarative/auth/token_provider.py
+++ b/airbyte_cdk/sources/declarative/auth/token_provider.py
@@ -5,8 +5,9 @@
import datetime
from abc import abstractmethod
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, List, Mapping, Optional, Union
+from typing import Any
import dpath
import pendulum
@@ -32,14 +33,14 @@ def get_token(self) -> str:
@dataclass
class SessionTokenProvider(TokenProvider):
login_requester: Requester
- session_token_path: List[str]
- expiration_duration: Optional[Union[datetime.timedelta, Duration]]
+ session_token_path: list[str]
+ expiration_duration: datetime.timedelta | Duration | None
parameters: InitVar[Mapping[str, Any]]
- message_repository: MessageRepository = NoopMessageRepository()
+ message_repository: MessageRepository = NoopMessageRepository() # noqa: RUF009
decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={}))
- _next_expiration_time: Optional[DateTime] = None
- _token: Optional[str] = None
+ _next_expiration_time: DateTime | None = None
+ _token: str | None = None
def get_token(self) -> str:
self._refresh_if_necessary()
@@ -72,7 +73,7 @@ def _refresh(self) -> None:
@dataclass
class InterpolatedStringTokenProvider(TokenProvider):
config: Config
- api_token: Union[InterpolatedString, str]
+ api_token: InterpolatedString | str
parameters: Mapping[str, Any]
def __post_init__(self) -> None:
diff --git a/airbyte_cdk/sources/declarative/checks/__init__.py b/airbyte_cdk/sources/declarative/checks/__init__.py
index 6362e0613..41f6d832a 100644
--- a/airbyte_cdk/sources/declarative/checks/__init__.py
+++ b/airbyte_cdk/sources/declarative/checks/__init__.py
@@ -2,7 +2,7 @@
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
#
-from typing import Mapping
+from collections.abc import Mapping
from pydantic.v1 import BaseModel
@@ -16,6 +16,7 @@
CheckStream as CheckStreamModel,
)
+
COMPONENTS_CHECKER_TYPE_MAPPING: Mapping[str, type[BaseModel]] = {
"CheckStream": CheckStreamModel,
"CheckDynamicStream": CheckDynamicStreamModel,
diff --git a/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py b/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py
index 75807c400..0abaa3b15 100644
--- a/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py
+++ b/airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py
@@ -4,8 +4,9 @@
import logging
import traceback
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, List, Mapping, Tuple
+from typing import Any
from airbyte_cdk import AbstractSource
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
@@ -29,7 +30,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def check_connection(
self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any]
- ) -> Tuple[bool, Any]:
+ ) -> tuple[bool, Any]:
streams = source.streams(config=config)
if len(streams) == 0:
return False, f"No streams to connect to from source {source}"
diff --git a/airbyte_cdk/sources/declarative/checks/check_stream.py b/airbyte_cdk/sources/declarative/checks/check_stream.py
index c45159ec9..2f65675f8 100644
--- a/airbyte_cdk/sources/declarative/checks/check_stream.py
+++ b/airbyte_cdk/sources/declarative/checks/check_stream.py
@@ -4,8 +4,9 @@
import logging
import traceback
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, List, Mapping, Tuple
+from typing import Any
from airbyte_cdk import AbstractSource
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
@@ -21,7 +22,7 @@ class CheckStream(ConnectionChecker):
stream_name (List[str]): names of streams to check
"""
- stream_names: List[str]
+ stream_names: list[str]
parameters: InitVar[Mapping[str, Any]]
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@@ -29,13 +30,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def check_connection(
self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any]
- ) -> Tuple[bool, Any]:
+ ) -> tuple[bool, Any]:
streams = source.streams(config=config)
stream_name_to_stream = {s.name: s for s in streams}
if len(streams) == 0:
return False, f"No streams to connect to from source {source}"
for stream_name in self.stream_names:
- if stream_name not in stream_name_to_stream.keys():
+ if stream_name not in stream_name_to_stream:
raise ValueError(
f"{stream_name} is not part of the catalog. Expected one of {stream_name_to_stream.keys()}."
)
diff --git a/airbyte_cdk/sources/declarative/checks/connection_checker.py b/airbyte_cdk/sources/declarative/checks/connection_checker.py
index fd1d1bba2..ee1b10783 100644
--- a/airbyte_cdk/sources/declarative/checks/connection_checker.py
+++ b/airbyte_cdk/sources/declarative/checks/connection_checker.py
@@ -4,7 +4,8 @@
import logging
from abc import ABC, abstractmethod
-from typing import Any, Mapping, Tuple
+from collections.abc import Mapping
+from typing import Any
from airbyte_cdk import AbstractSource
@@ -17,7 +18,7 @@ class ConnectionChecker(ABC):
@abstractmethod
def check_connection(
self, source: AbstractSource, logger: logging.Logger, config: Mapping[str, Any]
- ) -> Tuple[bool, Any]:
+ ) -> tuple[bool, Any]:
"""
Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect
to the Stripe API.
diff --git a/airbyte_cdk/sources/declarative/concurrency_level/__init__.py b/airbyte_cdk/sources/declarative/concurrency_level/__init__.py
index 6c55c15c9..669500b8c 100644
--- a/airbyte_cdk/sources/declarative/concurrency_level/__init__.py
+++ b/airbyte_cdk/sources/declarative/concurrency_level/__init__.py
@@ -4,4 +4,5 @@
from airbyte_cdk.sources.declarative.concurrency_level.concurrency_level import ConcurrencyLevel
+
__all__ = ["ConcurrencyLevel"]
diff --git a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py
index f5cd24f00..0f6f36af2 100644
--- a/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py
+++ b/airbyte_cdk/sources/declarative/concurrency_level/concurrency_level.py
@@ -2,8 +2,9 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.types import Config
@@ -19,14 +20,14 @@ class ConcurrencyLevel:
max_concurrency (Optional[int]): The maximum number of worker threads to use when the default_concurrency is exceeded
"""
- default_concurrency: Union[int, str]
- max_concurrency: Optional[int]
+ default_concurrency: int | str
+ max_concurrency: int | None
config: Config
parameters: InitVar[Mapping[str, Any]]
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if isinstance(self.default_concurrency, int):
- self._default_concurrency: Union[int, InterpolatedString] = self.default_concurrency
+ self._default_concurrency: int | InterpolatedString = self.default_concurrency
elif "config" in self.default_concurrency and not self.max_concurrency:
raise ValueError(
"ConcurrencyLevel requires that max_concurrency be defined if the default_concurrency can be used-specified"
@@ -40,11 +41,10 @@ def get_concurrency_level(self) -> int:
if isinstance(self._default_concurrency, InterpolatedString):
evaluated_default_concurrency = self._default_concurrency.eval(config=self.config)
if not isinstance(evaluated_default_concurrency, int):
- raise ValueError("default_concurrency did not evaluate to an integer")
+ raise ValueError("default_concurrency did not evaluate to an integer") # noqa: TRY004
return (
min(evaluated_default_concurrency, self.max_concurrency)
if self.max_concurrency
else evaluated_default_concurrency
)
- else:
- return self._default_concurrency
+ return self._default_concurrency
diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py
index 5db0b0909..ab1bead26 100644
--- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py
+++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py
@@ -3,7 +3,8 @@
#
import logging
-from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple
+from collections.abc import Iterator, Mapping
+from typing import Any, Generic
from airbyte_cdk.models import (
AirbyteCatalog,
@@ -57,14 +58,14 @@ class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]):
def __init__(
self,
- catalog: Optional[ConfiguredAirbyteCatalog],
- config: Optional[Mapping[str, Any]],
+ catalog: ConfiguredAirbyteCatalog | None, # noqa: ARG002
+ config: Mapping[str, Any] | None,
state: TState,
source_config: ConnectionDefinition,
- debug: bool = False,
- emit_connector_builder_messages: bool = False,
- component_factory: Optional[ModelToComponentFactory] = None,
- **kwargs: Any,
+ debug: bool = False, # noqa: FBT001, FBT002
+ emit_connector_builder_messages: bool = False, # noqa: FBT001, FBT002
+ component_factory: ModelToComponentFactory | None = None,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
# To reduce the complexity of the concurrent framework, we are not enabling RFR with synthetic
# cursors. We do this by no longer automatically instantiating RFR cursors when converting
@@ -83,7 +84,7 @@ def __init__(
component_factory=component_factory,
)
- # todo: We could remove state from initialization. Now that streams are grouped during the read(), a source
+ # TODO: We could remove state from initialization. Now that streams are grouped during the read(), a source
# no longer needs to store the original incoming state. But maybe there's an edge case?
self._state = state
@@ -120,7 +121,7 @@ def read(
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
- state: Optional[List[AirbyteStateMessage]] = None,
+ state: list[AirbyteStateMessage] | None = None,
) -> Iterator[AirbyteMessage]:
concurrent_streams, _ = self._group_streams(config=config)
@@ -128,7 +129,7 @@ def read(
# the concurrent streams must be saved so that they can be removed from the catalog before starting
# synchronous streams
if len(concurrent_streams) > 0:
- concurrent_stream_names = set(
+ concurrent_stream_names = set( # noqa: C403
[concurrent_stream.name for concurrent_stream in concurrent_streams]
)
@@ -151,7 +152,7 @@ def read(
yield from super().read(logger, config, filtered_catalog, state)
- def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
+ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog: # noqa: ARG002
concurrent_streams, synchronous_streams = self._group_streams(config=config)
return AirbyteCatalog(
streams=[
@@ -159,7 +160,7 @@ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> Airbyte
]
)
- def streams(self, config: Mapping[str, Any]) -> List[Stream]:
+ def streams(self, config: Mapping[str, Any]) -> list[Stream]:
"""
The `streams` method is used as part of the AbstractSource in the following cases:
* ConcurrentDeclarativeSource.check -> ManifestDeclarativeSource.check -> AbstractSource.check -> DeclarativeSource.check_connection -> CheckStream.check_connection -> streams
@@ -172,9 +173,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
def _group_streams(
self, config: Mapping[str, Any]
- ) -> Tuple[List[AbstractStream], List[Stream]]:
- concurrent_streams: List[AbstractStream] = []
- synchronous_streams: List[Stream] = []
+ ) -> tuple[list[AbstractStream], list[Stream]]:
+ concurrent_streams: list[AbstractStream] = []
+ synchronous_streams: list[Stream] = []
state_manager = ConnectorStateManager(state=self._state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later
@@ -268,7 +269,7 @@ def _group_streams(
if hasattr(cursor, "cursor_field")
and hasattr(
cursor.cursor_field, "cursor_field_key"
- ) # FIXME this will need to be updated once we do the per partition
+ ) # FIXME this will need to be updated once we do the per partition # noqa: FIX001, TD001, TD004
else None,
logger=self.logger,
cursor=cursor,
@@ -347,13 +348,13 @@ def _stream_supports_concurrent_partition_processing(
declarative_stream.retriever.requester, HttpRequester
):
http_requester = declarative_stream.retriever.requester
- if "stream_state" in http_requester._path.string:
+ if "stream_state" in http_requester._path.string: # noqa: SLF001
self.logger.warning(
f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the HttpRequester which is not thread-safe. Defaulting to synchronous processing"
)
return False
- request_options_provider = http_requester._request_options_provider
+ request_options_provider = http_requester._request_options_provider # noqa: SLF001
if request_options_provider.request_options_contain_stream_state():
self.logger.warning(
f"Low-code stream '{declarative_stream.name}' uses interpolation of stream_state in the HttpRequester which is not thread-safe. Defaulting to synchronous processing"
@@ -397,10 +398,10 @@ def _stream_supports_concurrent_partition_processing(
@staticmethod
def _select_streams(
- streams: List[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog
- ) -> List[AbstractStream]:
+ streams: list[AbstractStream], configured_catalog: ConfiguredAirbyteCatalog
+ ) -> list[AbstractStream]:
stream_name_to_instance: Mapping[str, AbstractStream] = {s.name: s for s in streams}
- abstract_streams: List[AbstractStream] = []
+ abstract_streams: list[AbstractStream] = []
for configured_stream in configured_catalog.streams:
stream_instance = stream_name_to_instance.get(configured_stream.stream.name)
if stream_instance:
diff --git a/airbyte_cdk/sources/declarative/datetime/__init__.py b/airbyte_cdk/sources/declarative/datetime/__init__.py
index bf1f13e1e..dbf4fb5b8 100644
--- a/airbyte_cdk/sources/declarative/datetime/__init__.py
+++ b/airbyte_cdk/sources/declarative/datetime/__init__.py
@@ -4,4 +4,5 @@
from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime
+
__all__ = ["MinMaxDatetime"]
diff --git a/airbyte_cdk/sources/declarative/datetime/datetime_parser.py b/airbyte_cdk/sources/declarative/datetime/datetime_parser.py
index 93122e29c..bafd3d8bb 100644
--- a/airbyte_cdk/sources/declarative/datetime/datetime_parser.py
+++ b/airbyte_cdk/sources/declarative/datetime/datetime_parser.py
@@ -3,7 +3,6 @@
#
import datetime
-from typing import Union
class DatetimeParser:
@@ -18,7 +17,7 @@ class DatetimeParser:
_UNIX_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
- def parse(self, date: Union[str, int], format: str) -> datetime.datetime:
+ def parse(self, date: str | int, format: str) -> datetime.datetime: # noqa: A002
# "%s" is a valid (but unreliable) directive for formatting, but not for parsing
# It is defined as
# The number of seconds since the Epoch, 1970-01-01 00:00:00+0000 (UTC). https://man7.org/linux/man-pages/man3/strptime.3.html
@@ -27,9 +26,9 @@ def parse(self, date: Union[str, int], format: str) -> datetime.datetime:
# See https://stackoverflow.com/a/4974930
if format == "%s":
return datetime.datetime.fromtimestamp(int(date), tz=datetime.timezone.utc)
- elif format == "%s_as_float":
+ if format == "%s_as_float":
return datetime.datetime.fromtimestamp(float(date), tz=datetime.timezone.utc)
- elif format == "%ms":
+ if format == "%ms":
return self._UNIX_EPOCH + datetime.timedelta(milliseconds=int(date))
parsed_datetime = datetime.datetime.strptime(str(date), format)
@@ -37,7 +36,7 @@ def parse(self, date: Union[str, int], format: str) -> datetime.datetime:
return parsed_datetime.replace(tzinfo=datetime.timezone.utc)
return parsed_datetime
- def format(self, dt: datetime.datetime, format: str) -> str:
+ def format(self, dt: datetime.datetime, format: str) -> str: # noqa: A002
# strftime("%s") is unreliable because it ignores the time zone information and assumes the time zone of the system it's running on
# It's safer to use the timestamp() method than the %s directive
# See https://stackoverflow.com/a/4974930
@@ -48,8 +47,7 @@ def format(self, dt: datetime.datetime, format: str) -> str:
if format == "%ms":
# timstamp() returns a float representing the number of seconds since the unix epoch
return str(int(dt.timestamp() * 1000))
- else:
- return dt.strftime(format)
+ return dt.strftime(format)
def _is_naive(self, dt: datetime.datetime) -> bool:
return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None
diff --git a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py
index eb407db44..cb1c65140 100644
--- a/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py
+++ b/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py
@@ -3,8 +3,9 @@
#
import datetime as dt
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, Optional, Union
+from typing import Any, Union
from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
@@ -28,14 +29,14 @@ class MinMaxDatetime:
max_datetime (Union[InterpolatedString, str]): Represents the maximum allowed datetime value.
"""
- datetime: Union[InterpolatedString, str]
+ datetime: InterpolatedString | str
parameters: InitVar[Mapping[str, Any]]
# datetime_format is a unique case where we inherit it from the parent if it is not specified before using the default value
# which is why we need dedicated getter/setter methods and private dataclass field
datetime_format: str
_datetime_format: str = field(init=False, repr=False, default="")
- min_datetime: Union[InterpolatedString, str] = ""
- max_datetime: Union[InterpolatedString, str] = ""
+ min_datetime: InterpolatedString | str = ""
+ max_datetime: InterpolatedString | str = ""
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {})
@@ -104,15 +105,14 @@ def datetime_format(self, value: str) -> None:
def create(
cls,
interpolated_string_or_min_max_datetime: Union[InterpolatedString, str, "MinMaxDatetime"],
- parameters: Optional[Mapping[str, Any]] = None,
+ parameters: Mapping[str, Any] | None = None,
) -> "MinMaxDatetime":
if parameters is None:
parameters = {}
- if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance(
+ if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance( # noqa: SIM101
interpolated_string_or_min_max_datetime, str
):
return MinMaxDatetime( # type: ignore [call-arg]
datetime=interpolated_string_or_min_max_datetime, parameters=parameters
)
- else:
- return interpolated_string_or_min_max_datetime
+ return interpolated_string_or_min_max_datetime
diff --git a/airbyte_cdk/sources/declarative/declarative_source.py b/airbyte_cdk/sources/declarative/declarative_source.py
index 77bf427a1..769407d3a 100644
--- a/airbyte_cdk/sources/declarative/declarative_source.py
+++ b/airbyte_cdk/sources/declarative/declarative_source.py
@@ -4,7 +4,8 @@
import logging
from abc import abstractmethod
-from typing import Any, Mapping, Tuple
+from collections.abc import Mapping
+from typing import Any
from airbyte_cdk.sources.abstract_source import AbstractSource
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
@@ -22,7 +23,7 @@ def connection_checker(self) -> ConnectionChecker:
def check_connection(
self, logger: logging.Logger, config: Mapping[str, Any]
- ) -> Tuple[bool, Any]:
+ ) -> tuple[bool, Any]:
"""
:param logger: The source logger
:param config: The user-provided configuration as specified by the source's spec.
diff --git a/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte_cdk/sources/declarative/declarative_stream.py
index 12cdd3337..ac11ea870 100644
--- a/airbyte_cdk/sources/declarative/declarative_stream.py
+++ b/airbyte_cdk/sources/declarative/declarative_stream.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
+from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
+from typing import Any
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.incremental import (
@@ -46,12 +47,12 @@ class DeclarativeStream(Stream):
config: Config
parameters: InitVar[Mapping[str, Any]]
name: str
- primary_key: Optional[Union[str, List[str], List[List[str]]]]
- state_migrations: List[StateMigration] = field(repr=True, default_factory=list)
- schema_loader: Optional[SchemaLoader] = None
+ primary_key: str | list[str] | list[list[str]] | None
+ state_migrations: list[StateMigration] = field(repr=True, default_factory=list)
+ schema_loader: SchemaLoader | None = None
_name: str = field(init=False, repr=False, default="")
_primary_key: str = field(init=False, repr=False, default="")
- stream_cursor_field: Optional[Union[InterpolatedString, str]] = None
+ stream_cursor_field: InterpolatedString | str | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._stream_cursor_field = (
@@ -60,13 +61,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else self.stream_cursor_field
)
self._schema_loader = (
- self.schema_loader
+ self.schema_loader # noqa: FURB110
if self.schema_loader
else DefaultSchemaLoader(config=self.config, parameters=parameters)
)
@property # type: ignore
- def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
+ def primary_key(self) -> str | list[str] | list[list[str]] | None:
return self._primary_key
@primary_key.setter
@@ -109,18 +110,20 @@ def state(self, value: MutableMapping[str, Any]) -> None:
self.retriever.state = state
def get_updated_state(
- self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]
+ self,
+ current_stream_state: MutableMapping[str, Any], # noqa: ARG002
+ latest_record: Mapping[str, Any], # noqa: ARG002
) -> MutableMapping[str, Any]:
return self.state
@property
- def cursor_field(self) -> Union[str, List[str]]:
+ def cursor_field(self) -> str | list[str]:
"""
Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field.
:return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor.
"""
cursor = self._stream_cursor_field.eval(self.config) # type: ignore # _stream_cursor_field is always cast to interpolated string
- return cursor if cursor else []
+ return cursor if cursor else [] # noqa: FURB110
@property
def is_resumable(self) -> bool:
@@ -130,10 +133,10 @@ def is_resumable(self) -> bool:
def read_records(
self,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None,
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Iterable[Mapping[str, Any]]:
"""
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
@@ -146,7 +149,7 @@ def read_records(
# empty slice which seems to make sense.
stream_slice = StreamSlice(partition={}, cursor_slice={})
if not isinstance(stream_slice, StreamSlice):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}"
)
yield from self.retriever.read_records(self.get_json_schema(), stream_slice) # type: ignore # records are of the correct type
@@ -163,10 +166,10 @@ def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
def stream_slices(
self,
*,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
- ) -> Iterable[Optional[StreamSlice]]:
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Iterable[StreamSlice | None]:
"""
Override to define the slices for this stream. See the stream slicing section of the docs for more information.
@@ -178,7 +181,7 @@ def stream_slices(
return self.retriever.stream_slices()
@property
- def state_checkpoint_interval(self) -> Optional[int]:
+ def state_checkpoint_interval(self) -> int | None:
"""
We explicitly disable checkpointing here. There are a couple reasons for that and not all are documented here but:
* In the case where records are not ordered, the granularity of what is ordered is the slice. Therefore, we will only update the
@@ -188,7 +191,7 @@ def state_checkpoint_interval(self) -> Optional[int]:
"""
return None
- def get_cursor(self) -> Optional[Cursor]:
+ def get_cursor(self) -> Cursor | None:
if self.retriever and isinstance(self.retriever, SimpleRetriever):
return self.retriever.cursor
return None
@@ -196,7 +199,7 @@ def get_cursor(self) -> Optional[Cursor]:
def _get_checkpoint_reader(
self,
logger: logging.Logger,
- cursor_field: Optional[List[str]],
+ cursor_field: list[str] | None,
sync_mode: SyncMode,
stream_state: MutableMapping[str, Any],
) -> CheckpointReader:
@@ -212,14 +215,14 @@ def _get_checkpoint_reader(
"""
mappings_or_slices = self.stream_slices(
cursor_field=cursor_field,
- sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior
+ sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior
stream_state=stream_state,
)
cursor = self.get_cursor()
checkpoint_mode = self._checkpoint_mode
- if isinstance(
+ if isinstance( # noqa: UP038
cursor, (GlobalSubstreamCursor, PerPartitionCursor, PerPartitionWithGlobalCursor)
):
self.has_multiple_slices = True
diff --git a/airbyte_cdk/sources/declarative/decoders/__init__.py b/airbyte_cdk/sources/declarative/decoders/__init__.py
index 45eaf5599..e81195349 100644
--- a/airbyte_cdk/sources/declarative/decoders/__init__.py
+++ b/airbyte_cdk/sources/declarative/decoders/__init__.py
@@ -4,9 +4,9 @@
from airbyte_cdk.sources.declarative.decoders.composite_raw_decoder import (
CompositeRawDecoder,
- GzipParser,
+ GzipParser, # noqa: F401
JsonParser,
- Parser,
+ Parser, # noqa: F401
)
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
from airbyte_cdk.sources.declarative.decoders.json_decoder import (
@@ -22,6 +22,7 @@
from airbyte_cdk.sources.declarative.decoders.xml_decoder import XmlDecoder
from airbyte_cdk.sources.declarative.decoders.zipfile_decoder import ZipfileDecoder
+
__all__ = [
"Decoder",
"CompositeRawDecoder",
diff --git a/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py b/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py
index 4d670db11..340b46848 100644
--- a/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py
+++ b/airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py
@@ -1,11 +1,13 @@
+# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
import csv
import gzip
import json
import logging
from abc import ABC, abstractmethod
+from collections.abc import Generator, MutableMapping
from dataclasses import dataclass
from io import BufferedIOBase, TextIOWrapper
-from typing import Any, Generator, MutableMapping, Optional
+from typing import Any
import orjson
import requests
@@ -14,6 +16,7 @@
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
from airbyte_cdk.utils import AirbyteTracedException
+
logger = logging.getLogger("airbyte")
@@ -59,7 +62,7 @@ def parse(self, data: BufferedIOBase) -> Generator[MutableMapping[str, Any], Non
if body_json is None:
raise AirbyteTracedException(
message="Response JSON data failed to be parsed. See logs for more information.",
- internal_message=f"Response JSON data failed to be parsed.",
+ internal_message="Response JSON data failed to be parsed.",
failure_type=FailureType.system_error,
)
@@ -68,7 +71,7 @@ def parse(self, data: BufferedIOBase) -> Generator[MutableMapping[str, Any], Non
else:
yield from [body_json]
- def _parse_orjson(self, raw_data: bytes) -> Optional[Any]:
+ def _parse_orjson(self, raw_data: bytes) -> Any | None: # noqa: ANN401
try:
return orjson.loads(raw_data.decode(self.encoding))
except Exception as exc:
@@ -77,7 +80,7 @@ def _parse_orjson(self, raw_data: bytes) -> Optional[Any]:
)
return None
- def _parse_json(self, raw_data: bytes) -> Optional[Any]:
+ def _parse_json(self, raw_data: bytes) -> Any | None: # noqa: ANN401
try:
return json.loads(raw_data.decode(self.encoding))
except Exception as exc:
@@ -87,7 +90,7 @@ def _parse_json(self, raw_data: bytes) -> Optional[Any]:
@dataclass
class JsonLineParser(Parser):
- encoding: Optional[str] = "utf-8"
+ encoding: str | None = "utf-8"
def parse(
self,
@@ -103,8 +106,8 @@ def parse(
@dataclass
class CsvParser(Parser):
# TODO: migrate implementation to re-use file-base classes
- encoding: Optional[str] = "utf-8"
- delimiter: Optional[str] = ","
+ encoding: str | None = "utf-8"
+ delimiter: str | None = ","
def parse(
self,
diff --git a/airbyte_cdk/sources/declarative/decoders/decoder.py b/airbyte_cdk/sources/declarative/decoders/decoder.py
index 5fa9dc8f6..e6b6aa9bd 100644
--- a/airbyte_cdk/sources/declarative/decoders/decoder.py
+++ b/airbyte_cdk/sources/declarative/decoders/decoder.py
@@ -3,8 +3,9 @@
#
from abc import abstractmethod
+from collections.abc import Generator, MutableMapping
from dataclasses import dataclass
-from typing import Any, Generator, MutableMapping
+from typing import Any
import requests
diff --git a/airbyte_cdk/sources/declarative/decoders/json_decoder.py b/airbyte_cdk/sources/declarative/decoders/json_decoder.py
index cab572ef4..e47400758 100644
--- a/airbyte_cdk/sources/declarative/decoders/json_decoder.py
+++ b/airbyte_cdk/sources/declarative/decoders/json_decoder.py
@@ -3,15 +3,17 @@
#
import codecs
import logging
+from collections.abc import Generator, Mapping, MutableMapping
from dataclasses import InitVar, dataclass
from gzip import decompress
-from typing import Any, Generator, List, Mapping, MutableMapping, Optional
+from typing import Any
import orjson
import requests
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
+
logger = logging.getLogger("airbyte")
@@ -43,7 +45,7 @@ def decode(
@staticmethod
def parse_body_json(
- body_json: MutableMapping[str, Any] | List[MutableMapping[str, Any]],
+ body_json: MutableMapping[str, Any] | list[MutableMapping[str, Any]],
) -> Generator[MutableMapping[str, Any], None, None]:
if not isinstance(body_json, list):
body_json = [body_json]
@@ -85,7 +87,7 @@ def is_stream_response(self) -> bool:
def decode(
self, response: requests.Response
) -> Generator[MutableMapping[str, Any], None, None]:
- # TODO???: set delimiter? usually it is `\n` but maybe it would be useful to set optional?
+ # TODO???: set delimiter? usually it is `\n` but maybe it would be useful to set optional? # noqa: TD004
# https://github.com/airbytehq/airbyte-internal-issues/issues/8436
for record in response.iter_lines():
yield orjson.loads(record)
@@ -93,14 +95,14 @@ def decode(
@dataclass
class GzipJsonDecoder(JsonDecoder):
- encoding: Optional[str]
+ encoding: str | None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.encoding:
try:
codecs.lookup(self.encoding)
except LookupError:
- raise ValueError(
+ raise ValueError( # noqa: B904
f"Invalid encoding '{self.encoding}'. Please check provided encoding"
)
diff --git a/airbyte_cdk/sources/declarative/decoders/noop_decoder.py b/airbyte_cdk/sources/declarative/decoders/noop_decoder.py
index cf0bc56eb..145303539 100644
--- a/airbyte_cdk/sources/declarative/decoders/noop_decoder.py
+++ b/airbyte_cdk/sources/declarative/decoders/noop_decoder.py
@@ -1,12 +1,14 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
import logging
-from typing import Any, Generator, Mapping
+from collections.abc import Generator, Mapping
+from typing import Any
import requests
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
+
logger = logging.getLogger("airbyte")
@@ -16,6 +18,6 @@ def is_stream_response(self) -> bool:
def decode( # type: ignore[override] # Signature doesn't match base class
self,
- response: requests.Response,
+ response: requests.Response, # noqa: ARG002
) -> Generator[Mapping[str, Any], None, None]:
yield from [{}]
diff --git a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py
index e5a152711..6928e0d53 100644
--- a/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py
+++ b/airbyte_cdk/sources/declarative/decoders/pagination_decoder_decorator.py
@@ -3,13 +3,15 @@
#
import logging
+from collections.abc import Generator, MutableMapping
from dataclasses import dataclass
-from typing import Any, Generator, MutableMapping
+from typing import Any
import requests
from airbyte_cdk.sources.declarative.decoders import Decoder
+
logger = logging.getLogger("airbyte")
@@ -19,7 +21,7 @@ class PaginationDecoderDecorator(Decoder):
Decoder to wrap other decoders when instantiating a DefaultPaginator in order to bypass decoding if the response is streamed.
"""
- def __init__(self, decoder: Decoder):
+ def __init__(self, decoder: Decoder): # noqa: ANN204
self._decoder = decoder
@property
diff --git a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py
index 0786c3202..471ac0abf 100644
--- a/airbyte_cdk/sources/declarative/decoders/xml_decoder.py
+++ b/airbyte_cdk/sources/declarative/decoders/xml_decoder.py
@@ -3,8 +3,9 @@
#
import logging
+from collections.abc import Generator, Mapping, MutableMapping
from dataclasses import InitVar, dataclass
-from typing import Any, Generator, Mapping, MutableMapping
+from typing import Any
from xml.parsers.expat import ExpatError
import requests
@@ -12,6 +13,7 @@
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
+
logger = logging.getLogger("airbyte")
diff --git a/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py b/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py
index a937a1e4d..9510134cb 100644
--- a/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py
+++ b/airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py
@@ -4,11 +4,11 @@
import logging
import zipfile
+from collections.abc import Generator, MutableMapping
from dataclasses import dataclass
from io import BytesIO
-from typing import Any, Generator, MutableMapping
+from typing import Any
-import orjson
import requests
from airbyte_cdk.models import FailureType
@@ -18,6 +18,7 @@
)
from airbyte_cdk.utils import AirbyteTracedException
+
logger = logging.getLogger("airbyte")
diff --git a/airbyte_cdk/sources/declarative/extractors/__init__.py b/airbyte_cdk/sources/declarative/extractors/__init__.py
index 8f1d18d12..9cef0d121 100644
--- a/airbyte_cdk/sources/declarative/extractors/__init__.py
+++ b/airbyte_cdk/sources/declarative/extractors/__init__.py
@@ -11,6 +11,7 @@
)
from airbyte_cdk.sources.declarative.extractors.type_transformer import TypeTransformer
+
__all__ = [
"TypeTransformer",
"HttpSelector",
diff --git a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py
index 9c97773e3..d1575e4b6 100644
--- a/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py
+++ b/airbyte_cdk/sources/declarative/extractors/dpath_extractor.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Iterable, List, Mapping, MutableMapping, Union
+from typing import Any
import dpath
import requests
@@ -53,7 +54,7 @@ class DpathExtractor(RecordExtractor):
decoder (Decoder): The decoder responsible to transfom the response in a Mapping
"""
- field_path: List[Union[InterpolatedString, str]]
+ field_path: list[InterpolatedString | str]
config: Config
parameters: InitVar[Mapping[str, Any]]
decoder: Decoder = field(default_factory=lambda: JsonDecoder(parameters={}))
diff --git a/airbyte_cdk/sources/declarative/extractors/http_selector.py b/airbyte_cdk/sources/declarative/extractors/http_selector.py
index 846071125..8db96c796 100644
--- a/airbyte_cdk/sources/declarative/extractors/http_selector.py
+++ b/airbyte_cdk/sources/declarative/extractors/http_selector.py
@@ -3,7 +3,8 @@
#
from abc import abstractmethod
-from typing import Any, Iterable, Mapping, Optional
+from collections.abc import Iterable, Mapping
+from typing import Any
import requests
@@ -22,8 +23,8 @@ def select_records(
response: requests.Response,
stream_state: StreamState,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Record]:
"""
Selects records from the response
diff --git a/airbyte_cdk/sources/declarative/extractors/record_extractor.py b/airbyte_cdk/sources/declarative/extractors/record_extractor.py
index 5de6a84a7..6890950af 100644
--- a/airbyte_cdk/sources/declarative/extractors/record_extractor.py
+++ b/airbyte_cdk/sources/declarative/extractors/record_extractor.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from abc import abstractmethod
+from collections.abc import Iterable, Mapping
from dataclasses import dataclass
-from typing import Any, Iterable, Mapping
+from typing import Any
import requests
diff --git a/airbyte_cdk/sources/declarative/extractors/record_filter.py b/airbyte_cdk/sources/declarative/extractors/record_filter.py
index b744c9796..5cdca8618 100644
--- a/airbyte_cdk/sources/declarative/extractors/record_filter.py
+++ b/airbyte_cdk/sources/declarative/extractors/record_filter.py
@@ -1,8 +1,9 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Iterable, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.declarative.incremental import (
DatetimeBasedCursor,
@@ -35,8 +36,8 @@ def filter_records(
self,
records: Iterable[Mapping[str, Any]],
stream_state: StreamState,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Mapping[str, Any]]:
kwargs = {
"stream_state": stream_state,
@@ -57,11 +58,11 @@ class ClientSideIncrementalRecordFilterDecorator(RecordFilter):
:param PerPartitionCursor per_partition_cursor: Optional Cursor used for mapping cursor value in nested stream_state
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
date_time_based_cursor: DatetimeBasedCursor,
- substream_cursor: Optional[Union[PerPartitionWithGlobalCursor, GlobalSubstreamCursor]],
- **kwargs: Any,
+ substream_cursor: PerPartitionWithGlobalCursor | GlobalSubstreamCursor | None,
+ **kwargs: Any, # noqa: ANN401
):
super().__init__(**kwargs)
self._date_time_based_cursor = date_time_based_cursor
@@ -71,8 +72,8 @@ def filter_records(
self,
records: Iterable[Mapping[str, Any]],
stream_state: StreamState,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Mapping[str, Any]]:
records = (
record
diff --git a/airbyte_cdk/sources/declarative/extractors/record_selector.py b/airbyte_cdk/sources/declarative/extractors/record_selector.py
index f29b8a75a..92905fa6c 100644
--- a/airbyte_cdk/sources/declarative/extractors/record_selector.py
+++ b/airbyte_cdk/sources/declarative/extractors/record_selector.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Iterable, List, Mapping, Optional, Union
+from typing import Any
import requests
@@ -14,7 +15,6 @@
TypeTransformer as DeclarativeTypeTransformer,
)
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
-from airbyte_cdk.sources.declarative.models import SchemaNormalization
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.sources.utils.transform import TypeTransformer
@@ -36,11 +36,11 @@ class RecordSelector(HttpSelector):
extractor: RecordExtractor
config: Config
parameters: InitVar[Mapping[str, Any]]
- schema_normalization: Union[TypeTransformer, DeclarativeTypeTransformer]
+ schema_normalization: TypeTransformer | DeclarativeTypeTransformer
name: str
- _name: Union[InterpolatedString, str] = field(init=False, repr=False, default="")
- record_filter: Optional[RecordFilter] = None
- transformations: List[RecordTransformation] = field(default_factory=lambda: [])
+ _name: InterpolatedString | str = field(init=False, repr=False, default="")
+ record_filter: RecordFilter | None = None
+ transformations: list[RecordTransformation] = field(default_factory=list)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._parameters = parameters
@@ -71,8 +71,8 @@ def select_records(
response: requests.Response,
stream_state: StreamState,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Record]:
"""
Selects records from the response
@@ -93,8 +93,8 @@ def filter_and_transform(
all_data: Iterable[Mapping[str, Any]],
stream_state: StreamState,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Record]:
"""
There is an issue with the selector as of 2024-08-30: it does technology-agnostic processing like filtering, transformation and
@@ -111,7 +111,7 @@ def filter_and_transform(
yield Record(data=data, stream_name=self.name, associated_slice=stream_slice)
def _normalize_by_schema(
- self, records: Iterable[Mapping[str, Any]], schema: Optional[Mapping[str, Any]]
+ self, records: Iterable[Mapping[str, Any]], schema: Mapping[str, Any] | None
) -> Iterable[Mapping[str, Any]]:
if schema:
# record has type Mapping[str, Any], but dict[str, Any] expected
@@ -126,8 +126,8 @@ def _filter(
self,
records: Iterable[Mapping[str, Any]],
stream_state: StreamState,
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
) -> Iterable[Mapping[str, Any]]:
if self.record_filter:
yield from self.record_filter.filter_records(
@@ -143,7 +143,7 @@ def _transform(
self,
records: Iterable[Mapping[str, Any]],
stream_state: StreamState,
- stream_slice: Optional[StreamSlice] = None,
+ stream_slice: StreamSlice | None = None,
) -> Iterable[Mapping[str, Any]]:
for record in records:
for transformation in self.transformations:
diff --git a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py
index 0215ddb45..0ee34d540 100644
--- a/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py
+++ b/airbyte_cdk/sources/declarative/extractors/response_to_file_extractor.py
@@ -5,9 +5,10 @@
import os
import uuid
import zlib
+from collections.abc import Iterable, Mapping
from contextlib import closing
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, Iterable, Mapping, Optional, Tuple
+from typing import Any
import pandas as pd
import requests
@@ -15,6 +16,7 @@
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
+
EMPTY_STR: str = ""
DEFAULT_ENCODING: str = "utf-8"
DOWNLOAD_CHUNK_SIZE: int = 1024 * 10
@@ -35,7 +37,7 @@ class ResponseToFileExtractor(RecordExtractor):
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.logger = logging.getLogger("airbyte")
- def _get_response_encoding(self, headers: Dict[str, Any]) -> str:
+ def _get_response_encoding(self, headers: dict[str, Any]) -> str:
"""
Get the encoding of the response based on the provided headers. This method is heavily inspired by the requests library
implementation.
@@ -78,7 +80,7 @@ def _filter_null_bytes(self, b: bytes) -> bytes:
)
return res
- def _save_to_file(self, response: requests.Response) -> Tuple[str, str]:
+ def _save_to_file(self, response: requests.Response) -> tuple[str, str]:
"""
Saves the binary data from the given response to a temporary file and returns the filepath and response encoding.
@@ -96,7 +98,7 @@ def _save_to_file(self, response: requests.Response) -> Tuple[str, str]:
needs_decompression = True # we will assume at first that the response is compressed and change the flag if not
tmp_file = str(uuid.uuid4())
- with closing(response) as response, open(tmp_file, "wb") as data_file:
+ with closing(response) as response, open(tmp_file, "wb") as data_file: # noqa: PTH123, PLR1704
response_encoding = self._get_response_encoding(dict(response.headers or {}))
for chunk in response.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
try:
@@ -110,12 +112,11 @@ def _save_to_file(self, response: requests.Response) -> Tuple[str, str]:
needs_decompression = False
# check the file exists
- if os.path.isfile(tmp_file):
+ if os.path.isfile(tmp_file): # noqa: PTH113
return tmp_file, response_encoding
- else:
- raise ValueError(
- f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist."
- )
+ raise ValueError(
+ f"The IO/Error occured while verifying binary data. Tmp file {tmp_file} doesn't exist."
+ )
def _read_with_chunks(
self, path: str, file_encoding: str, chunk_size: int = 100
@@ -136,25 +137,25 @@ def _read_with_chunks(
"""
try:
- with open(path, "r", encoding=file_encoding) as data:
+ with open(path, encoding=file_encoding) as data: # noqa: PTH123
chunks = pd.read_csv(
data, chunksize=chunk_size, iterator=True, dialect="unix", dtype=object
)
for chunk in chunks:
- chunk = chunk.replace({nan: None}).to_dict(orient="records")
- for row in chunk:
+ chunk = chunk.replace({nan: None}).to_dict(orient="records") # noqa: PLW2901
+ for row in chunk: # noqa: UP028
yield row
except pd.errors.EmptyDataError as e:
self.logger.info(f"Empty data received. {e}")
yield from []
- except IOError as ioe:
- raise ValueError(f"The IO/Error occured while reading tmp data. Called: {path}", ioe)
+ except OSError as ioe:
+ raise ValueError(f"The IO/Error occured while reading tmp data. Called: {path}", ioe) # noqa: B904
finally:
# remove binary tmp file, after data is read
- os.remove(path)
+ os.remove(path) # noqa: PTH107
def extract_records(
- self, response: Optional[requests.Response] = None
+ self, response: requests.Response | None = None
) -> Iterable[Mapping[str, Any]]:
"""
Extracts records from the given response by:
diff --git a/airbyte_cdk/sources/declarative/extractors/type_transformer.py b/airbyte_cdk/sources/declarative/extractors/type_transformer.py
index fe307684f..9060bed35 100644
--- a/airbyte_cdk/sources/declarative/extractors/type_transformer.py
+++ b/airbyte_cdk/sources/declarative/extractors/type_transformer.py
@@ -3,8 +3,9 @@
#
from abc import ABC, abstractmethod
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Any, Dict, Mapping
+from typing import Any
@dataclass
@@ -35,7 +36,7 @@ class TypeTransformer(ABC):
@abstractmethod
def transform(
self,
- record: Dict[str, Any],
+ record: dict[str, Any],
schema: Mapping[str, Any],
) -> None:
"""
diff --git a/airbyte_cdk/sources/declarative/incremental/__init__.py b/airbyte_cdk/sources/declarative/incremental/__init__.py
index 7ce54a07a..66f48438a 100644
--- a/airbyte_cdk/sources/declarative/incremental/__init__.py
+++ b/airbyte_cdk/sources/declarative/incremental/__init__.py
@@ -19,6 +19,7 @@
ResumableFullRefreshCursor,
)
+
__all__ = [
"CursorFactory",
"DatetimeBasedCursor",
diff --git a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py
index d6d329aec..a38be1613 100644
--- a/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py
+++ b/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py
@@ -3,9 +3,10 @@
#
import datetime
+from collections.abc import Callable, Iterable, Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
from datetime import timedelta
-from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union
+from typing import Any
from isodate import Duration, duration_isoformat, parse_duration
@@ -52,28 +53,28 @@ class DatetimeBasedCursor(DeclarativeCursor):
lookback_window (Optional[InterpolatedString]): how many days before start_datetime to read data for (ISO8601 duration)
"""
- start_datetime: Union[MinMaxDatetime, str]
- cursor_field: Union[InterpolatedString, str]
+ start_datetime: MinMaxDatetime | str
+ cursor_field: InterpolatedString | str
datetime_format: str
config: Config
parameters: InitVar[Mapping[str, Any]]
- _highest_observed_cursor_field_value: Optional[str] = field(
+ _highest_observed_cursor_field_value: str | None = field(
repr=False, default=None
) # tracks the latest observed datetime, which may not be safe to emit in the case of out-of-order records
- _cursor: Optional[str] = field(
+ _cursor: str | None = field(
repr=False, default=None
) # tracks the latest observed datetime that is appropriate to emit as stream state
- end_datetime: Optional[Union[MinMaxDatetime, str]] = None
- step: Optional[Union[InterpolatedString, str]] = None
- cursor_granularity: Optional[str] = None
- start_time_option: Optional[RequestOption] = None
- end_time_option: Optional[RequestOption] = None
- partition_field_start: Optional[str] = None
- partition_field_end: Optional[str] = None
- lookback_window: Optional[Union[InterpolatedString, str]] = None
- message_repository: Optional[MessageRepository] = None
- is_compare_strictly: Optional[bool] = False
- cursor_datetime_formats: List[str] = field(default_factory=lambda: [])
+ end_datetime: MinMaxDatetime | str | None = None
+ step: InterpolatedString | str | None = None
+ cursor_granularity: str | None = None
+ start_time_option: RequestOption | None = None
+ end_time_option: RequestOption | None = None
+ partition_field_start: str | None = None
+ partition_field_end: str | None = None
+ lookback_window: InterpolatedString | str | None = None
+ message_repository: MessageRepository | None = None
+ is_compare_strictly: bool | None = False
+ cursor_datetime_formats: list[str] = field(default_factory=list)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if (self.step and not self.cursor_granularity) or (
@@ -166,12 +167,12 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
):
self._highest_observed_cursor_field_value = record_cursor_value
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002
if stream_slice.partition:
raise ValueError(
f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}."
)
- cursor_value_str_by_cursor_value_datetime = dict(
+ cursor_value_str_by_cursor_value_datetime = dict( # noqa: C417
map(
# we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like
# 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z'
@@ -202,7 +203,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
start_datetime = self._calculate_earliest_possible_value(self.select_best_end_datetime())
return self._partition_daterange(start_datetime, end_datetime, self._step)
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002
# Datetime based cursors operate over slices made up of datetime ranges. Stream state is based on the progress
# through each slice and does not belong to a specific slice. We just return stream state as it is.
return self.get_stream_state()
@@ -254,8 +255,8 @@ def _partition_daterange(
self,
start: datetime.datetime,
end: datetime.datetime,
- step: Union[datetime.timedelta, Duration],
- ) -> List[StreamSlice]:
+ step: datetime.timedelta | Duration,
+ ) -> list[StreamSlice]:
start_field = self._partition_field_start.eval(self.config)
end_field = self._partition_field_end.eval(self.config)
dates = []
@@ -303,7 +304,7 @@ def _get_date(
return comparator(cursor_date, default_date)
def parse_date(self, date: str) -> datetime.datetime:
- for datetime_format in self.cursor_datetime_formats + [self.datetime_format]:
+ for datetime_format in self.cursor_datetime_formats + [self.datetime_format]: # noqa: RUF005
try:
return self._parser.parse(date, datetime_format)
except ValueError:
@@ -311,7 +312,7 @@ def parse_date(self, date: str) -> datetime.datetime:
raise ValueError(f"No format in {self.cursor_datetime_formats} matching {date}")
@classmethod
- def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, Duration]:
+ def _parse_timedelta(cls, time_str: str | None) -> datetime.timedelta | Duration:
"""
:return Parses an ISO 8601 durations into datetime.timedelta or Duration objects.
"""
@@ -322,36 +323,36 @@ def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta,
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter, stream_slice)
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.header, stream_slice)
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_data, stream_slice)
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_json, stream_slice)
@@ -360,7 +361,7 @@ def request_kwargs(self) -> Mapping[str, Any]:
return {}
def _get_request_options(
- self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]
+ self, option_type: RequestOptionType, stream_slice: StreamSlice | None
) -> Mapping[str, Any]:
options: MutableMapping[str, Any] = {}
if not stream_slice:
@@ -395,8 +396,8 @@ def should_be_synced(self, record: Record) -> bool:
def _is_within_daterange_boundaries(
self,
record: Record,
- start_datetime_boundary: Union[datetime.datetime, str],
- end_datetime_boundary: Union[datetime.datetime, str],
+ start_datetime_boundary: datetime.datetime | str,
+ end_datetime_boundary: datetime.datetime | str,
) -> bool:
cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
record_cursor_value = record.get(cursor_field)
@@ -429,10 +430,9 @@ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
second_cursor_value = second.get(cursor_field)
if first_cursor_value and second_cursor_value:
return self.parse_date(first_cursor_value) >= self.parse_date(second_cursor_value)
- elif first_cursor_value:
+ if first_cursor_value: # noqa: SIM103
return True
- else:
- return False
+ return False
def set_runtime_lookback_window(self, lookback_window_in_seconds: int) -> None:
"""
diff --git a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py
index 3b3636236..a27436d0c 100644
--- a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py
+++ b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py
@@ -4,18 +4,20 @@
import threading
import time
-from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union
+from collections.abc import Callable, Iterable, Mapping
+from typing import Any, TypeVar
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
+
T = TypeVar("T")
def iterate_with_last_flag_and_state(
- generator: Iterable[T], get_stream_state_func: Callable[[], Optional[Mapping[str, StreamState]]]
+ generator: Iterable[T], get_stream_state_func: Callable[[], Mapping[str, StreamState] | None]
) -> Iterable[tuple[T, bool, Any]]:
"""
Iterates over the given generator, yielding tuples containing the element, a flag
@@ -53,16 +55,15 @@ class Timer:
"""
def __init__(self) -> None:
- self._start: Optional[int] = None
+ self._start: int | None = None
def start(self) -> None:
self._start = time.perf_counter_ns()
def finish(self) -> int:
if self._start:
- return ((time.perf_counter_ns() - self._start) / 1e9).__ceil__()
- else:
- raise RuntimeError("Global substream cursor timer not started")
+ return ((time.perf_counter_ns() - self._start) / 1e9).__ceil__() # noqa: PLC2801
+ raise RuntimeError("Global substream cursor timer not started")
class GlobalSubstreamCursor(DeclarativeCursor):
@@ -79,7 +80,7 @@ class GlobalSubstreamCursor(DeclarativeCursor):
- When using the `incremental_dependency` option, the sync will progress through parent records, preventing the sync from getting infinitely stuck. However, it is crucial to understand the requirements for both the `global_substream_cursor` and `incremental_dependency` options to avoid data loss.
"""
- def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: PartitionRouter):
+ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: PartitionRouter): # noqa: ANN204
self._stream_cursor = stream_cursor
self._partition_router = partition_router
self._timer = Timer()
@@ -88,10 +89,10 @@ def __init__(self, stream_cursor: DatetimeBasedCursor, partition_router: Partiti
0
) # Start with 0, indicating no slices being tracked
self._all_slices_yielded = False
- self._lookback_window: Optional[int] = None
- self._current_partition: Optional[Mapping[str, Any]] = None
+ self._lookback_window: int | None = None
+ self._current_partition: Mapping[str, Any] | None = None
self._last_slice: bool = False
- self._parent_state: Optional[Mapping[str, Any]] = None
+ self._parent_state: Mapping[str, Any] | None = None
def start_slices_generation(self) -> None:
self._timer.start()
@@ -118,7 +119,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
)
self.start_slices_generation()
- for slice, last, state in iterate_with_last_flag_and_state(
+ for slice, last, state in iterate_with_last_flag_and_state( # noqa: A001
slice_generator, self._partition_router.get_stream_state
):
self._parent_state = state
@@ -134,7 +135,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str
yield from slice_generator
- def register_slice(self, last: bool) -> None:
+ def register_slice(self, last: bool) -> None: # noqa: FBT001
"""
Tracks the processing of a stream slice.
@@ -213,7 +214,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record
)
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401
"""
Close the current stream slice.
@@ -227,7 +228,7 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
"""
with self._lock:
self._slice_semaphore.acquire()
- if self._all_slices_yielded and self._slice_semaphore._value == 0:
+ if self._all_slices_yielded and self._slice_semaphore._value == 0: # noqa: SLF001
self._lookback_window = self._timer.finish()
self._stream_cursor.close_slice(
StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args
@@ -244,16 +245,16 @@ def get_stream_state(self) -> StreamState:
return state
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002
# stream_slice is ignored as cursor is global
return self._stream_cursor.get_stream_state()
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
if stream_slice:
return self._partition_router.get_request_params( # type: ignore # this always returns a mapping
@@ -265,15 +266,14 @@ def get_request_params(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request params")
+ raise ValueError("A partition needs to be provided in order to get request params")
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
if stream_slice:
return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping
@@ -285,16 +285,15 @@ def get_request_headers(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request headers")
+ raise ValueError("A partition needs to be provided in order to get request headers")
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
if stream_slice:
return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping
stream_state=stream_state,
@@ -305,15 +304,14 @@ def get_request_body_data(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request body data")
+ raise ValueError("A partition needs to be provided in order to get request body data")
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
if stream_slice:
return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping
@@ -325,8 +323,7 @@ def get_request_body_json(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request body json")
+ raise ValueError("A partition needs to be provided in order to get request body json")
def should_be_synced(self, record: Record) -> bool:
return self._stream_cursor.should_be_synced(self._convert_record_to_cursor_record(record))
diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py
index 8241f7761..9d65ba5fb 100644
--- a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py
+++ b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py
@@ -4,7 +4,8 @@
import logging
from collections import OrderedDict
-from typing import Any, Callable, Iterable, Mapping, Optional, Union
+from collections.abc import Callable, Iterable, Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
@@ -13,11 +14,12 @@
)
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
+
logger = logging.getLogger("airbyte")
class CursorFactory:
- def __init__(self, create_function: Callable[[], DeclarativeCursor]):
+ def __init__(self, create_function: Callable[[], DeclarativeCursor]): # noqa: ANN204
self._create_function = create_function
def create(self) -> DeclarativeCursor:
@@ -49,7 +51,7 @@ class PerPartitionCursor(DeclarativeCursor):
_VALUE = 1
_state_to_migrate_from: Mapping[str, Any] = {}
- def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter):
+ def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter): # noqa: ANN204
self._cursor_factory = cursor_factory
self._partition_router = partition_router
# The dict is ordered to ensure that once the maximum number of partitions is reached,
@@ -70,7 +72,7 @@ def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[Str
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
partition_state = (
- self._state_to_migrate_from
+ self._state_to_migrate_from # noqa: FURB110
if self._state_to_migrate_from
else self._NO_CURSOR_STATE
)
@@ -154,14 +156,14 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record
)
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401
try:
self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].close_slice(
StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args
)
except KeyError as exception:
- raise ValueError(
- f"Partition {str(exception)} could not be found in current state based on the record. This is unexpected because "
+ raise ValueError( # noqa: B904
+ f"Partition {exception!s} could not be found in current state based on the record. This is unexpected because "
f"we should only update state for partitions that were emitted during `stream_slices`"
)
@@ -170,12 +172,10 @@ def get_stream_state(self) -> StreamState:
for partition_tuple, cursor in self._cursor_per_partition.items():
cursor_state = cursor.get_stream_state()
if cursor_state:
- states.append(
- {
- "partition": self._to_dict(partition_tuple),
- "cursor": cursor_state,
- }
- )
+ states.append({
+ "partition": self._to_dict(partition_tuple),
+ "cursor": cursor_state,
+ })
state: dict[str, Any] = {"states": states}
parent_state = self._partition_router.get_stream_state()
@@ -183,7 +183,7 @@ def get_stream_state(self) -> StreamState:
state["parent_state"] = parent_state
return state
- def _get_state_for_partition(self, partition: Mapping[str, Any]) -> Optional[StreamState]:
+ def _get_state_for_partition(self, partition: Mapping[str, Any]) -> StreamState | None:
cursor = self._cursor_per_partition.get(self._to_partition_key(partition))
if cursor:
return cursor.get_stream_state()
@@ -200,7 +200,7 @@ def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
def _to_dict(self, partition_key: str) -> Mapping[str, Any]:
return self._partition_serializer.to_partition(partition_key)
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None:
if not stream_slice:
raise ValueError("A partition needs to be provided in order to extract a state")
@@ -209,7 +209,7 @@ def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[S
return self._get_state_for_partition(stream_slice.partition)
- def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor:
+ def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor: # noqa: ANN401
cursor = self._cursor_factory.create()
cursor.set_initial_state(cursor_state)
return cursor
@@ -217,9 +217,9 @@ def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
if stream_slice:
return self._partition_router.get_request_params( # type: ignore # this always returns a mapping
@@ -233,15 +233,14 @@ def get_request_params(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request params")
+ raise ValueError("A partition needs to be provided in order to get request params")
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
if stream_slice:
return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping
@@ -255,16 +254,15 @@ def get_request_headers(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request headers")
+ raise ValueError("A partition needs to be provided in order to get request headers")
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
if stream_slice:
return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping
stream_state=stream_state,
@@ -277,15 +275,14 @@ def get_request_body_data(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request body data")
+ raise ValueError("A partition needs to be provided in order to get request body data")
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
if stream_slice:
return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping
@@ -299,8 +296,7 @@ def get_request_body_json(
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
- else:
- raise ValueError("A partition needs to be provided in order to get request body json")
+ raise ValueError("A partition needs to be provided in order to get request body json")
def should_be_synced(self, record: Record) -> bool:
return self._get_cursor(record).should_be_synced(
diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py
index defa2d897..7a86a7b9b 100644
--- a/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py
+++ b/airbyte_cdk/sources/declarative/incremental/per_partition_with_global.py
@@ -1,7 +1,8 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, Iterable, Mapping, MutableMapping, Optional, Union
+from collections.abc import Iterable, Mapping, MutableMapping
+from typing import Any
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
@@ -66,7 +67,7 @@ class PerPartitionWithGlobalCursor(DeclarativeCursor):
Suitable for streams where the number of partitions may vary significantly, requiring dynamic switching between per-partition and global state management to ensure data consistency and efficient synchronization.
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
cursor_factory: CursorFactory,
partition_router: PartitionRouter,
@@ -76,11 +77,11 @@ def __init__(
self._per_partition_cursor = PerPartitionCursor(cursor_factory, partition_router)
self._global_cursor = GlobalSubstreamCursor(stream_cursor, partition_router)
self._use_global_cursor = False
- self._current_partition: Optional[Mapping[str, Any]] = None
+ self._current_partition: Mapping[str, Any] | None = None
self._last_slice: bool = False
- self._parent_state: Optional[Mapping[str, Any]] = None
+ self._parent_state: Mapping[str, Any] | None = None
- def _get_active_cursor(self) -> Union[PerPartitionCursor, GlobalSubstreamCursor]:
+ def _get_active_cursor(self) -> PerPartitionCursor | GlobalSubstreamCursor:
return self._global_cursor if self._use_global_cursor else self._per_partition_cursor
def stream_slices(self) -> Iterable[StreamSlice]:
@@ -92,7 +93,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
):
# Generate slices for the current cursor and handle the last slice using the flag
self._parent_state = parent_state
- for slice, is_last_slice, _ in iterate_with_last_flag_and_state(
+ for slice, is_last_slice, _ in iterate_with_last_flag_and_state( # noqa: A001
self._get_active_cursor().generate_slices_from_partition(partition=partition),
lambda: None,
):
@@ -120,7 +121,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
self._per_partition_cursor.observe(stream_slice, record)
self._global_cursor.observe(stream_slice, record)
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401
if not self._use_global_cursor:
self._per_partition_cursor.close_slice(stream_slice, *args)
self._global_cursor.close_slice(stream_slice, *args)
@@ -138,15 +139,15 @@ def get_stream_state(self) -> StreamState:
return final_state
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None:
return self._get_active_cursor().select_state(stream_slice)
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._get_active_cursor().get_request_params(
stream_state=stream_state,
@@ -157,9 +158,9 @@ def get_request_params(
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._get_active_cursor().get_request_headers(
stream_state=stream_state,
@@ -170,10 +171,10 @@ def get_request_headers(
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
return self._get_active_cursor().get_request_body_data(
stream_state=stream_state,
stream_slice=stream_slice,
@@ -183,9 +184,9 @@ def get_request_body_data(
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._get_active_cursor().get_request_body_json(
stream_state=stream_state,
diff --git a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py
index a0b4665f1..82d1d70d4 100644
--- a/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py
+++ b/airbyte_cdk/sources/declarative/incremental/resumable_full_refresh_cursor.py
@@ -1,7 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Iterable, Mapping, Optional
+from typing import Any
from airbyte_cdk.sources.declarative.incremental import DeclarativeCursor
from airbyte_cdk.sources.declarative.types import Record, StreamSlice, StreamState
@@ -27,7 +28,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
"""
pass
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002
# The ResumableFullRefreshCursor doesn't support nested streams yet so receiving a partition is unexpected
if stream_slice.partition:
raise ValueError(
@@ -35,20 +36,20 @@ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
)
self._cursor = stream_slice.cursor_slice
- def should_be_synced(self, record: Record) -> bool:
+ def should_be_synced(self, record: Record) -> bool: # noqa: ARG002
"""
Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages
that don't have filterable bounds. We should always return them.
"""
return True
- def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
+ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: # noqa: ARG002
"""
RFR record don't have ordering to be compared between one another.
"""
return False
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002
# A top-level RFR cursor only manages the state of a single partition
return self._cursor
@@ -65,36 +66,36 @@ def stream_slices(self) -> Iterable[StreamSlice]:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
@@ -109,7 +110,7 @@ class ChildPartitionResumableFullRefreshCursor(ResumableFullRefreshCursor):
Check the `close_slice` method overide for more info about the actual behaviour of this cursor.
"""
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002
"""
Once the current slice has finished syncing:
- paginator returns None
diff --git a/airbyte_cdk/sources/declarative/interpolation/__init__.py b/airbyte_cdk/sources/declarative/interpolation/__init__.py
index d721b99f1..082e7ffb2 100644
--- a/airbyte_cdk/sources/declarative/interpolation/__init__.py
+++ b/airbyte_cdk/sources/declarative/interpolation/__init__.py
@@ -6,4 +6,5 @@
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
+
__all__ = ["InterpolatedBoolean", "InterpolatedMapping", "InterpolatedString"]
diff --git a/airbyte_cdk/sources/declarative/interpolation/filters.py b/airbyte_cdk/sources/declarative/interpolation/filters.py
index 52d76cab6..aaf7cc4a6 100644
--- a/airbyte_cdk/sources/declarative/interpolation/filters.py
+++ b/airbyte_cdk/sources/declarative/interpolation/filters.py
@@ -5,10 +5,10 @@
import hashlib
import json
import re
-from typing import Any, Optional
+from typing import Any
-def hash(value: Any, hash_type: str = "md5", salt: Optional[str] = None) -> str:
+def hash(value: Any, hash_type: str = "md5", salt: str | None = None) -> str: # noqa: ANN401, A001
"""
Implementation of a custom Jinja2 hash filter
Hash type defaults to 'md5' if one is not specified.
@@ -49,7 +49,7 @@ def hash(value: Any, hash_type: str = "md5", salt: Optional[str] = None) -> str:
hash_obj.update(str(salt).encode("utf-8"))
computed_hash: str = hash_obj.hexdigest()
else:
- raise AttributeError("No hashing function named {hname}".format(hname=hash_type))
+ raise AttributeError(f"No hashing function named {hash_type}")
return computed_hash
@@ -92,7 +92,7 @@ def base64decode(value: str) -> str:
return base64.b64decode(value.encode("utf-8")).decode()
-def string(value: Any) -> str:
+def string(value: Any) -> str: # noqa: ANN401
"""
Converts the input value to a string.
If the value is already a string, it is returned as is.
diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py
index 78569b350..8ebb12d3b 100644
--- a/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py
+++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_boolean.py
@@ -2,13 +2,15 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Final, List, Mapping
+from typing import Any, Final
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.types import Config
-FALSE_VALUES: Final[List[Any]] = [
+
+FALSE_VALUES: Final[list[Any]] = [
"False",
"false",
"{}",
@@ -33,7 +35,7 @@ class InterpolatedBoolean:
Attributes:
condition (str): The string representing the condition to evaluate to a boolean
- """
+ """ # noqa: B021
condition: str
parameters: InitVar[Mapping[str, Any]]
@@ -42,7 +44,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._interpolation = JinjaInterpolation()
self._parameters = parameters
- def eval(self, config: Config, **additional_parameters: Any) -> bool:
+ def eval(self, config: Config, **additional_parameters: Any) -> bool: # noqa: ANN401
"""
Interpolates the predicate condition string using the config and other optional arguments passed as parameter.
@@ -52,15 +54,14 @@ def eval(self, config: Config, **additional_parameters: Any) -> bool:
"""
if isinstance(self.condition, bool):
return self.condition
- else:
- evaluated = self._interpolation.eval(
- self.condition,
- config,
- self._default,
- parameters=self._parameters,
- **additional_parameters,
- )
- if evaluated in FALSE_VALUES:
- return False
- # The presence of a value is generally regarded as truthy, so we treat it as such
- return True
+ evaluated = self._interpolation.eval(
+ self.condition,
+ config,
+ self._default,
+ parameters=self._parameters,
+ **additional_parameters,
+ )
+ if evaluated in FALSE_VALUES: # noqa: SIM103
+ return False
+ # The presence of a value is generally regarded as truthy, so we treat it as such
+ return True
diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py
index 11b2dac97..fa8919ec5 100644
--- a/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py
+++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_mapping.py
@@ -3,8 +3,9 @@
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, Mapping, Optional
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.types import Config
@@ -22,11 +23,11 @@ class InterpolatedMapping:
mapping: Mapping[str, str]
parameters: InitVar[Mapping[str, Any]]
- def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None:
+ def __post_init__(self, parameters: Mapping[str, Any] | None) -> None:
self._interpolation = JinjaInterpolation()
self._parameters = parameters
- def eval(self, config: Config, **additional_parameters: Any) -> Dict[str, Any]:
+ def eval(self, config: Config, **additional_parameters: Any) -> dict[str, Any]: # noqa: ANN401
"""
Wrapper around a Mapping[str, str] that allows for both keys and values to be interpolated.
@@ -47,10 +48,9 @@ def eval(self, config: Config, **additional_parameters: Any) -> Dict[str, Any]:
for name, value in self.mapping.items()
}
- def _eval(self, value: str, config: Config, **kwargs: Any) -> Any:
+ def _eval(self, value: str, config: Config, **kwargs: Any) -> Any: # noqa: ANN401
# The values in self._mapping can be of Any type
# We only want to interpolate them if they are strings
if isinstance(value, str):
return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs)
- else:
- return value
+ return value
diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py
index 82454919e..3c656d501 100644
--- a/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py
+++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_nested_mapping.py
@@ -3,16 +3,18 @@
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any, Union
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.types import Config
-NestedMappingEntry = Union[
+
+NestedMappingEntry = Union[ # noqa: UP007
dict[str, "NestedMapping"], list["NestedMapping"], str, int, float, bool, None
]
-NestedMapping = Union[dict[str, NestedMappingEntry], str]
+NestedMapping = Union[dict[str, NestedMappingEntry], str] # noqa: UP007
@dataclass
@@ -27,26 +29,28 @@ class InterpolatedNestedMapping:
mapping: NestedMapping
parameters: InitVar[Mapping[str, Any]]
- def __post_init__(self, parameters: Optional[Mapping[str, Any]]) -> None:
+ def __post_init__(self, parameters: Mapping[str, Any] | None) -> None:
self._interpolation = JinjaInterpolation()
self._parameters = parameters
- def eval(self, config: Config, **additional_parameters: Any) -> Any:
+ def eval(self, config: Config, **additional_parameters: Any) -> Any: # noqa: ANN401
return self._eval(self.mapping, config, **additional_parameters)
def _eval(
- self, value: Union[NestedMapping, NestedMappingEntry], config: Config, **kwargs: Any
- ) -> Any:
+ self,
+ value: NestedMapping | NestedMappingEntry,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401
+ ) -> Any: # noqa: ANN401
# Recursively interpolate dictionaries and lists
if isinstance(value, str):
return self._interpolation.eval(value, config, parameters=self._parameters, **kwargs)
- elif isinstance(value, dict):
+ if isinstance(value, dict):
interpolated_dict = {
self._eval(k, config, **kwargs): self._eval(v, config, **kwargs)
for k, v in value.items()
}
return {k: v for k, v in interpolated_dict.items() if v is not None}
- elif isinstance(value, list):
+ if isinstance(value, list):
return [self._eval(v, config, **kwargs) for v in value]
- else:
- return value
+ return value
diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py
index 542fa8068..55eb72fdf 100644
--- a/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py
+++ b/airbyte_cdk/sources/declarative/interpolation/interpolated_string.py
@@ -2,15 +2,16 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any, Union
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.types import Config
@dataclass
-class InterpolatedString:
+class InterpolatedString: # noqa: PLW1641
"""
Wrapper around a raw string to be interpolated with the Jinja2 templating engine
@@ -22,7 +23,7 @@ class InterpolatedString:
string: str
parameters: InitVar[Mapping[str, Any]]
- default: Optional[str] = None
+ default: str | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.default = self.default or self.string
@@ -32,7 +33,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
# This allows for optimization, but we do not know it yet at this stage
self._is_plain_string = None
- def eval(self, config: Config, **kwargs: Any) -> Any:
+ def eval(self, config: Config, **kwargs: Any) -> Any: # noqa: ANN401
"""
Interpolates the input string using the config and other optional arguments passed as parameter.
@@ -54,7 +55,7 @@ def eval(self, config: Config, **kwargs: Any) -> Any:
self.string, config, self.default, parameters=self._parameters, **kwargs
)
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
if not isinstance(other, InterpolatedString):
return False
return self.string == other.string and self.default == other.default
@@ -75,5 +76,4 @@ def create(
"""
if isinstance(string_or_interpolated, str):
return InterpolatedString(string=string_or_interpolated, parameters=parameters)
- else:
- return string_or_interpolated
+ return string_or_interpolated
diff --git a/airbyte_cdk/sources/declarative/interpolation/interpolation.py b/airbyte_cdk/sources/declarative/interpolation/interpolation.py
index 5af61905e..2158996f9 100644
--- a/airbyte_cdk/sources/declarative/interpolation/interpolation.py
+++ b/airbyte_cdk/sources/declarative/interpolation/interpolation.py
@@ -3,7 +3,7 @@
#
from abc import ABC, abstractmethod
-from typing import Any, Optional
+from typing import Any
from airbyte_cdk.sources.types import Config
@@ -18,9 +18,9 @@ def eval(
self,
input_str: str,
config: Config,
- default: Optional[str] = None,
- **additional_options: Any,
- ) -> Any:
+ default: str | None = None,
+ **additional_options: Any, # noqa: ANN401
+ ) -> Any: # noqa: ANN401
"""
Interpolates the input string using the config, and additional options passed as parameter.
diff --git a/airbyte_cdk/sources/declarative/interpolation/jinja.py b/airbyte_cdk/sources/declarative/interpolation/jinja.py
index 3bb9b0c20..96aeb6005 100644
--- a/airbyte_cdk/sources/declarative/interpolation/jinja.py
+++ b/airbyte_cdk/sources/declarative/interpolation/jinja.py
@@ -3,8 +3,9 @@
#
import ast
+from collections.abc import Mapping
from functools import cache
-from typing import Any, Mapping, Optional, Set, Tuple, Type
+from typing import Any
from jinja2 import meta
from jinja2.environment import Template
@@ -24,8 +25,8 @@ class StreamPartitionAccessEnvironment(SandboxedEnvironment):
parameter
"""
- def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool:
- if attr in ["_partition"]:
+ def is_safe_attribute(self, obj: Any, attr: str, value: Any) -> bool: # noqa: ANN401
+ if attr in ["_partition"]: # noqa: FURB171
return True
return super().is_safe_attribute(obj, attr, value) # type: ignore # for some reason, mypy says 'Returning Any from function declared to return "bool"'
@@ -80,10 +81,10 @@ def eval(
self,
input_str: str,
config: Config,
- default: Optional[str] = None,
- valid_types: Optional[Tuple[Type[Any]]] = None,
- **additional_parameters: Any,
- ) -> Any:
+ default: str | None = None,
+ valid_types: tuple[type[Any]] | None = None,
+ **additional_parameters: Any, # noqa: ANN401
+ ) -> Any: # noqa: ANN401
context = {"config": config, **additional_parameters}
for alias, equivalent in _ALIASES.items():
@@ -92,7 +93,7 @@ def eval(
raise ValueError(
f"Found reserved keyword {alias} in interpolation context. This is unexpected and indicative of a bug in the CDK."
)
- elif equivalent in context:
+ if equivalent in context:
context[alias] = context[equivalent]
try:
@@ -102,14 +103,14 @@ def eval(
return self._literal_eval(result, valid_types)
else:
# If input is not a string, return it as is
- raise Exception(f"Expected a string, got {input_str}")
+ raise Exception(f"Expected a string, got {input_str}") # noqa: TRY002, TRY004
except UndefinedError:
pass
# If result is empty or resulted in an undefined error, evaluate and return the default string
return self._literal_eval(self._eval(default, context), valid_types)
- def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[Any]]]) -> Any:
+ def _literal_eval(self, result: str | None, valid_types: tuple[type[Any]] | None) -> Any: # noqa: ANN401
try:
evaluated = ast.literal_eval(result) # type: ignore # literal_eval is able to handle None
except (ValueError, SyntaxError):
@@ -118,7 +119,7 @@ def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[
return evaluated
return result
- def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]:
+ def _eval(self, s: str | None, context: Mapping[str, Any]) -> str | None:
try:
undeclared = self._find_undeclared_variables(s)
undeclared_not_in_context = {var for var in undeclared if var not in context}
@@ -132,15 +133,15 @@ def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]:
# It can be returned as is
return s
- @cache
- def _find_undeclared_variables(self, s: Optional[str]) -> Set[str]:
+ @cache # noqa: B019
+ def _find_undeclared_variables(self, s: str | None) -> set[str]:
"""
Find undeclared variables and cache them
"""
ast = _ENVIRONMENT.parse(s) # type: ignore # parse is able to handle None
return meta.find_undeclared_variables(ast)
- @cache
+ @cache # noqa: B019
def _compile(self, s: str) -> Template:
"""
We must cache the Jinja Template ourselves because we're using `from_string` instead of a template loader
diff --git a/airbyte_cdk/sources/declarative/interpolation/macros.py b/airbyte_cdk/sources/declarative/interpolation/macros.py
index 1ca5b31f0..a208a1438 100644
--- a/airbyte_cdk/sources/declarative/interpolation/macros.py
+++ b/airbyte_cdk/sources/declarative/interpolation/macros.py
@@ -5,13 +5,13 @@
import builtins
import datetime
import typing
-from typing import Optional, Union
import isodate
import pytz
from dateutil import parser
from isodate import parse_duration
+
"""
This file contains macros that can be evaluated by a `JinjaInterpolation` object
"""
@@ -47,7 +47,7 @@ def today_with_timezone(timezone: str) -> datetime.date:
return datetime.datetime.now(tz=pytz.timezone(timezone)).date()
-def timestamp(dt: Union[float, str]) -> Union[int, float]:
+def timestamp(dt: float | str) -> int | float:
"""
Converts a number or a string to a timestamp
@@ -60,10 +60,9 @@ def timestamp(dt: Union[float, str]) -> Union[int, float]:
:param dt: datetime to convert to timestamp
:return: unix timestamp
"""
- if isinstance(dt, (int, float)):
+ if isinstance(dt, (int, float)): # noqa: UP038
return int(dt)
- else:
- return _str_to_datetime(dt).astimezone(pytz.utc).timestamp()
+ return _str_to_datetime(dt).astimezone(pytz.utc).timestamp()
def _str_to_datetime(s: str) -> datetime.datetime:
@@ -74,7 +73,7 @@ def _str_to_datetime(s: str) -> datetime.datetime:
return parsed_date.astimezone(pytz.utc)
-def max(*args: typing.Any) -> typing.Any:
+def max(*args: typing.Any) -> typing.Any: # noqa: ANN401, A001
"""
Returns biggest object of an iterable, or two or more arguments.
@@ -94,7 +93,7 @@ def max(*args: typing.Any) -> typing.Any:
return builtins.max(*args)
-def min(*args: typing.Any) -> typing.Any:
+def min(*args: typing.Any) -> typing.Any: # noqa: ANN401, A001
"""
Returns smallest object of an iterable, or two or more arguments.
@@ -114,7 +113,7 @@ def min(*args: typing.Any) -> typing.Any:
return builtins.min(*args)
-def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str:
+def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str: # noqa: A002
"""
Returns datetime of now() + num_days
@@ -129,7 +128,7 @@ def day_delta(num_days: int, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> str:
).strftime(format)
-def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]:
+def duration(datestring: str) -> datetime.timedelta | isodate.Duration:
"""
Converts ISO8601 duration to datetime.timedelta
@@ -140,7 +139,9 @@ def duration(datestring: str) -> Union[datetime.timedelta, isodate.Duration]:
def format_datetime(
- dt: Union[str, datetime.datetime], format: str, input_format: Optional[str] = None
+ dt: str | datetime.datetime,
+ format: str, # noqa: A002
+ input_format: str | None = None,
) -> str:
"""
Converts datetime to another format
diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py
index deef5a3be..a58ffed72 100644
--- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py
+++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py
@@ -5,10 +5,11 @@
import json
import logging
import pkgutil
+from collections.abc import Iterator, Mapping
from copy import deepcopy
from importlib import metadata
from types import ModuleType
-from typing import Any, Dict, Iterator, List, Mapping, Optional, Set
+from typing import Any
import yaml
from jsonschema.exceptions import ValidationError
@@ -26,9 +27,6 @@
from airbyte_cdk.sources.declarative.checks import COMPONENTS_CHECKER_TYPE_MAPPING
from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
-from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
- CheckStream as CheckStreamModel,
-)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
DeclarativeStream as DeclarativeStreamModel,
)
@@ -60,14 +58,14 @@
class ManifestDeclarativeSource(DeclarativeSource):
"""Declarative source defined by a manifest of low-code components that define source connector behavior"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
source_config: ConnectionDefinition,
*,
config: Mapping[str, Any] | None = None,
debug: bool = False,
emit_connector_builder_messages: bool = False,
- component_factory: Optional[ModelToComponentFactory] = None,
+ component_factory: ModelToComponentFactory | None = None,
):
"""
Args:
@@ -94,7 +92,7 @@ def __init__(
self._debug = debug
self._emit_connector_builder_messages = emit_connector_builder_messages
self._constructor = (
- component_factory
+ component_factory # noqa: FURB110
if component_factory
else ModelToComponentFactory(emit_connector_builder_messages)
)
@@ -121,17 +119,16 @@ def connection_checker(self) -> ConnectionChecker:
check_stream = self._constructor.create_component(
COMPONENTS_CHECKER_TYPE_MAPPING[check["type"]],
check,
- dict(),
+ dict(), # noqa: C408
emit_connector_builder_messages=self._emit_connector_builder_messages,
)
if isinstance(check_stream, ConnectionChecker):
return check_stream
- else:
- raise ValueError(
- f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}"
- )
+ raise ValueError(
+ f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}"
+ )
- def streams(self, config: Mapping[str, Any]) -> List[Stream]:
+ def streams(self, config: Mapping[str, Any]) -> list[Stream]:
self._emit_manifest_debug_message(
extra_args={"source_name": self.name, "parsed_config": json.dumps(self._source_config)}
)
@@ -154,8 +151,8 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
@staticmethod
def _initialize_cache_for_parent_streams(
- stream_configs: List[Dict[str, Any]],
- ) -> List[Dict[str, Any]]:
+ stream_configs: list[dict[str, Any]],
+ ) -> list[dict[str, Any]]:
parent_streams = set()
def update_with_cache_parent_configs(parent_configs: list[dict[str, Any]]) -> None:
@@ -204,10 +201,9 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification:
if spec:
if "type" not in spec:
spec["type"] = "Spec"
- spec_component = self._constructor.create_component(SpecModel, spec, dict())
+ spec_component = self._constructor.create_component(SpecModel, spec, dict()) # noqa: C408
return spec_component.generate_spec()
- else:
- return super().spec(logger)
+ return super().spec(logger)
def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
self._configure_logger_level(logger)
@@ -218,7 +214,7 @@ def read(
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
- state: Optional[List[AirbyteStateMessage]] = None,
+ state: list[AirbyteStateMessage] | None = None,
) -> Iterator[AirbyteMessage]:
self._configure_logger_level(logger)
yield from super().read(logger, config, catalog, state)
@@ -247,7 +243,7 @@ def _validate_source(self) -> None:
"Failed to read manifest component json schema required for validation"
)
except FileNotFoundError as e:
- raise FileNotFoundError(
+ raise FileNotFoundError( # noqa: B904
f"Failed to read manifest component json schema required for validation: {e}"
)
@@ -314,9 +310,9 @@ def _parse_version(
# No exception
return parsed_version
- def _stream_configs(self, manifest: Mapping[str, Any]) -> List[Dict[str, Any]]:
+ def _stream_configs(self, manifest: Mapping[str, Any]) -> list[dict[str, Any]]:
# This has a warning flag for static, but after we finish part 4 we'll replace manifest with self._source_config
- stream_configs: List[Dict[str, Any]] = manifest.get("streams", [])
+ stream_configs: list[dict[str, Any]] = manifest.get("streams", [])
for s in stream_configs:
if "type" not in s:
s["type"] = "DeclarativeStream"
@@ -324,10 +320,10 @@ def _stream_configs(self, manifest: Mapping[str, Any]) -> List[Dict[str, Any]]:
def _dynamic_stream_configs(
self, manifest: Mapping[str, Any], config: Mapping[str, Any]
- ) -> List[Dict[str, Any]]:
- dynamic_stream_definitions: List[Dict[str, Any]] = manifest.get("dynamic_streams", [])
- dynamic_stream_configs: List[Dict[str, Any]] = []
- seen_dynamic_streams: Set[str] = set()
+ ) -> list[dict[str, Any]]:
+ dynamic_stream_definitions: list[dict[str, Any]] = manifest.get("dynamic_streams", [])
+ dynamic_stream_configs: list[dict[str, Any]] = []
+ seen_dynamic_streams: set[str] = set()
for dynamic_definition in dynamic_stream_definitions:
components_resolver_config = dynamic_definition["components_resolver"]
diff --git a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py
index 830646fe9..03c3ed249 100644
--- a/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py
+++ b/airbyte_cdk/sources/declarative/migrations/legacy_to_per_partition_state_migration.py
@@ -1,6 +1,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
-from typing import Any, Mapping
+from collections.abc import Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
@@ -33,7 +34,7 @@ class LegacyToPerPartitionStateMigration(StateMigration):
}
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
partition_router: SubstreamPartitionRouter,
cursor: CustomIncrementalSync | DatetimeBasedCursor,
@@ -79,7 +80,7 @@ def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
}
"""
if stream_state:
- for key, value in stream_state.items():
+ for key, value in stream_state.items(): # noqa: B007, PERF102
if isinstance(value, dict):
keys = list(value.keys())
if len(keys) != 1:
diff --git a/airbyte_cdk/sources/declarative/migrations/state_migration.py b/airbyte_cdk/sources/declarative/migrations/state_migration.py
index 9cf7f3cfe..49bf5aa57 100644
--- a/airbyte_cdk/sources/declarative/migrations/state_migration.py
+++ b/airbyte_cdk/sources/declarative/migrations/state_migration.py
@@ -1,7 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from abc import abstractmethod
-from typing import Any, Mapping
+from collections.abc import Mapping
+from typing import Any
class StateMigration:
diff --git a/airbyte_cdk/sources/declarative/models/__init__.py b/airbyte_cdk/sources/declarative/models/__init__.py
index 81f2e2f33..7fcd706df 100644
--- a/airbyte_cdk/sources/declarative/models/__init__.py
+++ b/airbyte_cdk/sources/declarative/models/__init__.py
@@ -1,2 +1,2 @@
# generated by bin/generate_component_manifest_files.py
-from .declarative_component_schema import *
+from .declarative_component_schema import * # noqa: F403
diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py
index fa4a00d18..38a272a69 100644
--- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py
+++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py
@@ -4,7 +4,7 @@
from __future__ import annotations
from enum import Enum
-from typing import Any, Dict, List, Literal, Optional, Union
+from typing import Any, Literal
from pydantic.v1 import BaseModel, Extra, Field
@@ -22,13 +22,13 @@ class BasicHttpAuthenticator(BaseModel):
examples=["{{ config['username'] }}", "{{ config['api_key'] }}"],
title="Username",
)
- password: Optional[str] = Field(
+ password: str | None = Field(
"",
description="The password that will be combined with the username, base64 encoded and used to make requests. Fill it in the user inputs.",
examples=["{{ config['password'] }}", ""],
title="Password",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class BearerAuthenticator(BaseModel):
@@ -39,12 +39,12 @@ class BearerAuthenticator(BaseModel):
examples=["{{ config['api_key'] }}", "{{ config['token'] }}"],
title="Bearer Token",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CheckStream(BaseModel):
type: Literal["CheckStream"]
- stream_names: List[str] = Field(
+ stream_names: list[str] = Field(
...,
description="Names of the streams to try reading from when running a check operation.",
examples=[["users"], ["users", "contacts"]],
@@ -62,31 +62,31 @@ class CheckDynamicStream(BaseModel):
class ConcurrencyLevel(BaseModel):
- type: Optional[Literal["ConcurrencyLevel"]] = None
- default_concurrency: Union[int, str] = Field(
+ type: Literal["ConcurrencyLevel"] | None = None
+ default_concurrency: int | str = Field(
...,
description="The amount of concurrency that will applied during a sync. This value can be hardcoded or user-defined in the config if different users have varying volume thresholds in the target API.",
examples=[10, "{{ config['num_workers'] or 10 }}"],
title="Default Concurrency",
)
- max_concurrency: Optional[int] = Field(
+ max_concurrency: int | None = Field(
None,
description="The maximum level of concurrency that will be used during a sync. This becomes a required field when the default_concurrency derives from the config, because it serves as a safeguard against a user-defined threshold that is too high.",
examples=[20, 100],
title="Max Concurrency",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ConstantBackoffStrategy(BaseModel):
type: Literal["ConstantBackoffStrategy"]
- backoff_time_in_seconds: Union[float, str] = Field(
+ backoff_time_in_seconds: float | str = Field(
...,
description="Backoff time in seconds.",
examples=[30, 30.5, "{{ config['backoff_time'] }}"],
title="Backoff Time",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CursorPagination(BaseModel):
@@ -101,13 +101,13 @@ class CursorPagination(BaseModel):
],
title="Cursor Value",
)
- page_size: Optional[int] = Field(
+ page_size: int | None = Field(
None,
description="The number of records to include in each pages.",
examples=[100],
title="Page Size",
)
- stop_condition: Optional[str] = Field(
+ stop_condition: str | None = Field(
None,
description="Template string evaluating when to stop paginating.",
examples=[
@@ -116,7 +116,7 @@ class CursorPagination(BaseModel):
],
title="Stop Condition",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomAuthenticator(BaseModel):
@@ -130,7 +130,7 @@ class Config:
examples=["source_railz.components.ShortLivedTokenAuthenticator"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomBackoffStrategy(BaseModel):
@@ -144,7 +144,7 @@ class Config:
examples=["source_railz.components.MyCustomBackoffStrategy"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomErrorHandler(BaseModel):
@@ -158,7 +158,7 @@ class Config:
examples=["source_railz.components.MyCustomErrorHandler"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomIncrementalSync(BaseModel):
@@ -176,7 +176,7 @@ class Config:
...,
description="The location of the value on a record that will be used as a bookmark during sync.",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomPaginationStrategy(BaseModel):
@@ -190,7 +190,7 @@ class Config:
examples=["source_railz.components.MyCustomPaginationStrategy"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomRecordExtractor(BaseModel):
@@ -204,7 +204,7 @@ class Config:
examples=["source_railz.components.MyCustomRecordExtractor"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomRecordFilter(BaseModel):
@@ -218,7 +218,7 @@ class Config:
examples=["source_railz.components.MyCustomCustomRecordFilter"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomRequester(BaseModel):
@@ -232,7 +232,7 @@ class Config:
examples=["source_railz.components.MyCustomRecordExtractor"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomRetriever(BaseModel):
@@ -246,7 +246,7 @@ class Config:
examples=["source_railz.components.MyCustomRetriever"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomPartitionRouter(BaseModel):
@@ -260,7 +260,7 @@ class Config:
examples=["source_railz.components.MyCustomPartitionRouter"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomSchemaLoader(BaseModel):
@@ -274,7 +274,7 @@ class Config:
examples=["source_railz.components.MyCustomSchemaLoader"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomSchemaNormalization(BaseModel):
@@ -290,7 +290,7 @@ class Config:
],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomStateMigration(BaseModel):
@@ -304,7 +304,7 @@ class Config:
examples=["source_railz.components.MyCustomStateMigration"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class CustomTransformation(BaseModel):
@@ -318,14 +318,14 @@ class Config:
examples=["source_railz.components.MyCustomTransformation"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class LegacyToPerPartitionStateMigration(BaseModel):
class Config:
extra = Extra.allow
- type: Optional[Literal["LegacyToPerPartitionStateMigration"]] = None
+ type: Literal["LegacyToPerPartitionStateMigration"] | None = None
class Algorithm(Enum):
@@ -349,19 +349,19 @@ class JwtHeaders(BaseModel):
class Config:
extra = Extra.forbid
- kid: Optional[str] = Field(
+ kid: str | None = Field(
None,
description="Private key ID for user account.",
examples=["{{ config['kid'] }}"],
title="Key Identifier",
)
- typ: Optional[str] = Field(
+ typ: str | None = Field(
"JWT",
description="The media type of the complete JWT.",
examples=["JWT"],
title="Type",
)
- cty: Optional[str] = Field(
+ cty: str | None = Field(
None,
description="Content type of JWT header.",
examples=["JWT"],
@@ -373,18 +373,18 @@ class JwtPayload(BaseModel):
class Config:
extra = Extra.forbid
- iss: Optional[str] = Field(
+ iss: str | None = Field(
None,
description="The user/principal that issued the JWT. Commonly a value unique to the user.",
examples=["{{ config['iss'] }}"],
title="Issuer",
)
- sub: Optional[str] = Field(
+ sub: str | None = Field(
None,
description="The subject of the JWT. Commonly defined by the API.",
title="Subject",
)
- aud: Optional[str] = Field(
+ aud: str | None = Field(
None,
description="The recipient that the JWT is intended for. Commonly defined by the API.",
examples=["appstoreconnect-v1"],
@@ -399,8 +399,8 @@ class JwtAuthenticator(BaseModel):
description="Secret used to sign the JSON web token.",
examples=["{{ config['secret_key'] }}"],
)
- base64_encode_secret_key: Optional[bool] = Field(
- False,
+ base64_encode_secret_key: bool | None = Field(
+ False, # noqa: FBT003
description='When set to true, the secret key will be base64 encoded prior to being encoded as part of the JWT. Only set to "true" when required by the API.',
)
algorithm: Algorithm = Field(
@@ -408,79 +408,79 @@ class JwtAuthenticator(BaseModel):
description="Algorithm used to sign the JSON web token.",
examples=["ES256", "HS256", "RS256", "{{ config['algorithm'] }}"],
)
- token_duration: Optional[int] = Field(
+ token_duration: int | None = Field(
1200,
description="The amount of time in seconds a JWT token can be valid after being issued.",
examples=[1200, 3600],
title="Token Duration",
)
- header_prefix: Optional[str] = Field(
+ header_prefix: str | None = Field(
None,
description="The prefix to be used within the Authentication header.",
examples=["Bearer", "Basic"],
title="Header Prefix",
)
- jwt_headers: Optional[JwtHeaders] = Field(
+ jwt_headers: JwtHeaders | None = Field(
None,
description="JWT headers used when signing JSON web token.",
title="JWT Headers",
)
- additional_jwt_headers: Optional[Dict[str, Any]] = Field(
+ additional_jwt_headers: dict[str, Any] | None = Field(
None,
description="Additional headers to be included with the JWT headers object.",
title="Additional JWT Headers",
)
- jwt_payload: Optional[JwtPayload] = Field(
+ jwt_payload: JwtPayload | None = Field(
None,
description="JWT Payload used when signing JSON web token.",
title="JWT Payload",
)
- additional_jwt_payload: Optional[Dict[str, Any]] = Field(
+ additional_jwt_payload: dict[str, Any] | None = Field(
None,
description="Additional properties to be added to the JWT payload.",
title="Additional JWT Payload Properties",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class RefreshTokenUpdater(BaseModel):
- refresh_token_name: Optional[str] = Field(
+ refresh_token_name: str | None = Field(
"refresh_token",
description="The name of the property which contains the updated refresh token in the response from the token refresh endpoint.",
examples=["refresh_token"],
title="Refresh Token Property Name",
)
- access_token_config_path: Optional[List[str]] = Field(
+ access_token_config_path: list[str] | None = Field(
["credentials", "access_token"],
description="Config path to the access token. Make sure the field actually exists in the config.",
examples=[["credentials", "access_token"], ["access_token"]],
title="Config Path To Access Token",
)
- refresh_token_config_path: Optional[List[str]] = Field(
+ refresh_token_config_path: list[str] | None = Field(
["credentials", "refresh_token"],
description="Config path to the access token. Make sure the field actually exists in the config.",
examples=[["credentials", "refresh_token"], ["refresh_token"]],
title="Config Path To Refresh Token",
)
- token_expiry_date_config_path: Optional[List[str]] = Field(
+ token_expiry_date_config_path: list[str] | None = Field(
["credentials", "token_expiry_date"],
description="Config path to the expiry date. Make sure actually exists in the config.",
examples=[["credentials", "token_expiry_date"]],
title="Config Path To Expiry Date",
)
- refresh_token_error_status_codes: Optional[List[int]] = Field(
+ refresh_token_error_status_codes: list[int] | None = Field(
[],
description="Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error",
examples=[[400, 500]],
title="Refresh Token Error Status Codes",
)
- refresh_token_error_key: Optional[str] = Field(
+ refresh_token_error_key: str | None = Field(
"",
description="Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).",
examples=["error"],
title="Refresh Token Error Key",
)
- refresh_token_error_values: Optional[List[str]] = Field(
+ refresh_token_error_values: list[str] | None = Field(
[],
description='List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error',
examples=[["invalid_grant", "invalid_permissions"]],
@@ -490,7 +490,7 @@ class RefreshTokenUpdater(BaseModel):
class OAuthAuthenticator(BaseModel):
type: Literal["OAuthAuthenticator"]
- client_id_name: Optional[str] = Field(
+ client_id_name: str | None = Field(
"client_id",
description="The name of the property to use to refresh the `access_token`.",
examples=["custom_app_id"],
@@ -502,7 +502,7 @@ class OAuthAuthenticator(BaseModel):
examples=["{{ config['client_id }}", "{{ config['credentials']['client_id }}"],
title="Client ID",
)
- client_secret_name: Optional[str] = Field(
+ client_secret_name: str | None = Field(
"client_secret",
description="The name of the property to use to refresh the `access_token`.",
examples=["custom_app_secret"],
@@ -517,13 +517,13 @@ class OAuthAuthenticator(BaseModel):
],
title="Client Secret",
)
- refresh_token_name: Optional[str] = Field(
+ refresh_token_name: str | None = Field(
"refresh_token",
description="The name of the property to use to refresh the `access_token`.",
examples=["custom_app_refresh_value"],
title="Refresh Token Property Name",
)
- refresh_token: Optional[str] = Field(
+ refresh_token: str | None = Field(
None,
description="Credential artifact used to get a new access token.",
examples=[
@@ -532,43 +532,43 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Token",
)
- token_refresh_endpoint: Optional[str] = Field(
+ token_refresh_endpoint: str | None = Field(
None,
description="The full URL to call to obtain a new access token.",
examples=["https://connect.squareup.com/oauth2/token"],
title="Token Refresh Endpoint",
)
- access_token_name: Optional[str] = Field(
+ access_token_name: str | None = Field(
"access_token",
description="The name of the property which contains the access token in the response from the token refresh endpoint.",
examples=["access_token"],
title="Access Token Property Name",
)
- access_token_value: Optional[str] = Field(
+ access_token_value: str | None = Field(
None,
description="The value of the access_token to bypass the token refreshing using `refresh_token`.",
examples=["secret_access_token_value"],
title="Access Token Value",
)
- expires_in_name: Optional[str] = Field(
+ expires_in_name: str | None = Field(
"expires_in",
description="The name of the property which contains the expiry date in the response from the token refresh endpoint.",
examples=["expires_in"],
title="Token Expiry Property Name",
)
- grant_type_name: Optional[str] = Field(
+ grant_type_name: str | None = Field(
"grant_type",
description="The name of the property to use to refresh the `access_token`.",
examples=["custom_grant_type"],
title="Grant Type Property Name",
)
- grant_type: Optional[str] = Field(
+ grant_type: str | None = Field(
"refresh_token",
description="Specifies the OAuth2 grant type. If set to refresh_token, the refresh_token needs to be provided as well. For client_credentials, only client id and secret are required. Other grant types are not officially supported.",
examples=["refresh_token", "client_credentials"],
title="Grant Type",
)
- refresh_request_body: Optional[Dict[str, Any]] = Field(
+ refresh_request_body: dict[str, Any] | None = Field(
None,
description="Body of the request sent to get a new access token.",
examples=[
@@ -580,7 +580,7 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Request Body",
)
- refresh_request_headers: Optional[Dict[str, Any]] = Field(
+ refresh_request_headers: dict[str, Any] | None = Field(
None,
description="Headers of the request sent to get a new access token.",
examples=[
@@ -591,35 +591,35 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Request Headers",
)
- scopes: Optional[List[str]] = Field(
+ scopes: list[str] | None = Field(
None,
description="List of scopes that should be granted to the access token.",
examples=[["crm.list.read", "crm.objects.contacts.read", "crm.schema.contacts.read"]],
title="Scopes",
)
- token_expiry_date: Optional[str] = Field(
+ token_expiry_date: str | None = Field(
None,
description="The access token expiry date.",
examples=["2023-04-06T07:12:10.421833+00:00", 1680842386],
title="Token Expiry Date",
)
- token_expiry_date_format: Optional[str] = Field(
+ token_expiry_date_format: str | None = Field(
None,
description="The format of the time to expiration datetime. Provide it if the time is returned as a date-time string instead of seconds.",
examples=["%Y-%m-%d %H:%M:%S.%f+00:00"],
title="Token Expiry Date Format",
)
- refresh_token_updater: Optional[RefreshTokenUpdater] = Field(
+ refresh_token_updater: RefreshTokenUpdater | None = Field(
None,
description="When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.",
title="Token Updater",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DpathExtractor(BaseModel):
type: Literal["DpathExtractor"]
- field_path: List[str] = Field(
+ field_path: list[str] = Field(
...,
description='List of potentially nested fields describing the full path of the field to extract. Use "*" to extract all values from an array. See more info in the [docs](https://docs.airbyte.com/connector-development/config-based/understanding-the-yaml-file/record-selector).',
examples=[
@@ -630,23 +630,23 @@ class DpathExtractor(BaseModel):
],
title="Field Path",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ResponseToFileExtractor(BaseModel):
type: Literal["ResponseToFileExtractor"]
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ExponentialBackoffStrategy(BaseModel):
type: Literal["ExponentialBackoffStrategy"]
- factor: Optional[Union[float, str]] = Field(
+ factor: float | str | None = Field(
5,
description="Multiplicative constant applied on each retry.",
examples=[5, 5.5, "10"],
title="Factor",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class SessionTokenRequestBearerAuthenticator(BaseModel):
@@ -674,37 +674,37 @@ class FailureType(Enum):
class HttpResponseFilter(BaseModel):
type: Literal["HttpResponseFilter"]
- action: Optional[Action] = Field(
+ action: Action | None = Field(
None,
description="Action to execute if a response matches the filter.",
examples=["SUCCESS", "FAIL", "RETRY", "IGNORE", "RATE_LIMITED"],
title="Action",
)
- failure_type: Optional[FailureType] = Field(
+ failure_type: FailureType | None = Field(
None,
description="Failure type of traced exception if a response matches the filter.",
examples=["system_error", "config_error", "transient_error"],
title="Failure Type",
)
- error_message: Optional[str] = Field(
+ error_message: str | None = Field(
None,
description="Error Message to display if the response matches the filter.",
title="Error Message",
)
- error_message_contains: Optional[str] = Field(
+ error_message_contains: str | None = Field(
None,
description="Match the response if its error message contains the substring.",
example=["This API operation is not enabled for this site"],
title="Error Message Substring",
)
- http_codes: Optional[List[int]] = Field(
+ http_codes: list[int] | None = Field(
None,
description="Match the response if its HTTP code is included in this list.",
examples=[[420, 429], [500]],
title="HTTP Codes",
unique_items=True,
)
- predicate: Optional[str] = Field(
+ predicate: str | None = Field(
None,
description="Match the response if the predicate evaluates to true.",
examples=[
@@ -713,39 +713,39 @@ class HttpResponseFilter(BaseModel):
],
title="Predicate",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class TypesMap(BaseModel):
- target_type: Union[str, List[str]]
- current_type: Union[str, List[str]]
- condition: Optional[str]
+ target_type: str | list[str]
+ current_type: str | list[str]
+ condition: str | None
class SchemaTypeIdentifier(BaseModel):
- type: Optional[Literal["SchemaTypeIdentifier"]] = None
- schema_pointer: Optional[List[str]] = Field(
+ type: Literal["SchemaTypeIdentifier"] | None = None
+ schema_pointer: list[str] | None = Field(
[],
description="List of nested fields defining the schema field path to extract. Defaults to [].",
title="Schema Path",
)
- key_pointer: List[str] = Field(
+ key_pointer: list[str] = Field(
...,
description="List of potentially nested fields describing the full path of the field key to extract.",
title="Key Path",
)
- type_pointer: Optional[List[str]] = Field(
+ type_pointer: list[str] | None = Field(
None,
description="List of potentially nested fields describing the full path of the field type to extract.",
title="Type Path",
)
- types_mapping: Optional[List[TypesMap]] = None
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ types_mapping: list[TypesMap] | None = None
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class InlineSchemaLoader(BaseModel):
type: Literal["InlineSchemaLoader"]
- schema_: Optional[Dict[str, Any]] = Field(
+ schema_: dict[str, Any] | None = Field(
None,
alias="schema",
description='Describes a streams\' schema. Refer to the Data Types documentation for more details on which types are valid.',
@@ -755,13 +755,13 @@ class InlineSchemaLoader(BaseModel):
class JsonFileSchemaLoader(BaseModel):
type: Literal["JsonFileSchemaLoader"]
- file_path: Optional[str] = Field(
+ file_path: str | None = Field(
None,
description="Path to the JSON file defining the schema. The path is relative to the connector module's root.",
example=["./schemas/users.json"],
title="File Path",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class JsonDecoder(BaseModel):
@@ -774,27 +774,27 @@ class JsonlDecoder(BaseModel):
class KeysToLower(BaseModel):
type: Literal["KeysToLower"]
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class KeysToSnakeCase(BaseModel):
type: Literal["KeysToSnakeCase"]
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class FlattenFields(BaseModel):
type: Literal["FlattenFields"]
- flatten_lists: Optional[bool] = Field(
- True,
+ flatten_lists: bool | None = Field(
+ True, # noqa: FBT003
description="Whether to flatten lists or leave it as is. Default is True.",
title="Flatten Lists",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DpathFlattenFields(BaseModel):
type: Literal["DpathFlattenFields"]
- field_path: List[str] = Field(
+ field_path: list[str] = Field(
...,
description="A path to field that needs to be flattened.",
examples=[
@@ -803,12 +803,12 @@ class DpathFlattenFields(BaseModel):
],
title="Field Path",
)
- delete_origin_value: Optional[bool] = Field(
- False,
+ delete_origin_value: bool | None = Field(
+ False, # noqa: FBT003
description="Whether to delete the origin value or keep it. Default is False.",
title="Delete Origin Value",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class KeysReplace(BaseModel):
@@ -835,7 +835,7 @@ class KeysReplace(BaseModel):
],
title="New value",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class IterableDecoder(BaseModel):
@@ -857,7 +857,7 @@ class Config:
examples=["source_amazon_ads.components.GzipJsonlDecoder"],
title="Class Name",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class GzipJsonDecoder(BaseModel):
@@ -865,8 +865,8 @@ class Config:
extra = Extra.allow
type: Literal["GzipJsonDecoder"]
- encoding: Optional[str] = "utf-8"
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ encoding: str | None = "utf-8"
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class MinMaxDatetime(BaseModel):
@@ -877,30 +877,30 @@ class MinMaxDatetime(BaseModel):
examples=["2021-01-01", "2021-01-01T00:00:00Z", "{{ config['start_time'] }}"],
title="Datetime",
)
- datetime_format: Optional[str] = Field(
+ datetime_format: str | None = Field(
"",
description='Format of the datetime value. Defaults to "%Y-%m-%dT%H:%M:%S.%f%z" if left empty. Use placeholders starting with "%" to describe the format the API is using. The following placeholders are available:\n * **%s**: Epoch unix timestamp - `1686218963`\n * **%s_as_float**: Epoch unix timestamp in seconds as float with microsecond precision - `1686218963.123456`\n * **%ms**: Epoch unix timestamp - `1686218963123`\n * **%a**: Weekday (abbreviated) - `Sun`\n * **%A**: Weekday (full) - `Sunday`\n * **%w**: Weekday (decimal) - `0` (Sunday), `6` (Saturday)\n * **%d**: Day of the month (zero-padded) - `01`, `02`, ..., `31`\n * **%b**: Month (abbreviated) - `Jan`\n * **%B**: Month (full) - `January`\n * **%m**: Month (zero-padded) - `01`, `02`, ..., `12`\n * **%y**: Year (without century, zero-padded) - `00`, `01`, ..., `99`\n * **%Y**: Year (with century) - `0001`, `0002`, ..., `9999`\n * **%H**: Hour (24-hour, zero-padded) - `00`, `01`, ..., `23`\n * **%I**: Hour (12-hour, zero-padded) - `01`, `02`, ..., `12`\n * **%p**: AM/PM indicator\n * **%M**: Minute (zero-padded) - `00`, `01`, ..., `59`\n * **%S**: Second (zero-padded) - `00`, `01`, ..., `59`\n * **%f**: Microsecond (zero-padded to 6 digits) - `000000`, `000001`, ..., `999999`\n * **%z**: UTC offset - `(empty)`, `+0000`, `-04:00`\n * **%Z**: Time zone name - `(empty)`, `UTC`, `GMT`\n * **%j**: Day of the year (zero-padded) - `001`, `002`, ..., `366`\n * **%U**: Week number of the year (Sunday as first day) - `00`, `01`, ..., `53`\n * **%W**: Week number of the year (Monday as first day) - `00`, `01`, ..., `53`\n * **%c**: Date and time representation - `Tue Aug 16 21:30:00 1988`\n * **%x**: Date representation - `08/16/1988`\n * **%X**: Time representation - `21:30:00`\n * **%%**: Literal \'%\' character\n\n Some placeholders depend on the locale of the underlying system - in most cases this locale is configured as en/US. For more information see the [Python documentation](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes).\n',
examples=["%Y-%m-%dT%H:%M:%S.%f%z", "%Y-%m-%d", "%s"],
title="Datetime Format",
)
- max_datetime: Optional[str] = Field(
+ max_datetime: str | None = Field(
None,
description="Ceiling applied on the datetime value. Must be formatted with the datetime_format field.",
examples=["2021-01-01T00:00:00Z", "2021-01-01"],
title="Max Datetime",
)
- min_datetime: Optional[str] = Field(
+ min_datetime: str | None = Field(
None,
description="Floor applied on the datetime value. Must be formatted with the datetime_format field.",
examples=["2010-01-01T00:00:00Z", "2010-01-01"],
title="Min Datetime",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class NoAuth(BaseModel):
type: Literal["NoAuth"]
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class NoPagination(BaseModel):
@@ -928,7 +928,7 @@ class Config:
],
title="Consent URL",
)
- scope: Optional[str] = Field(
+ scope: str | None = Field(
None,
description="The DeclarativeOAuth Specific string of the scopes needed to be grant for authenticated user.",
examples=["user:read user:read_orders workspaces:read"],
@@ -942,7 +942,7 @@ class Config:
],
title="Access Token URL",
)
- access_token_headers: Optional[Dict[str, Any]] = Field(
+ access_token_headers: dict[str, Any] | None = Field(
None,
description="The DeclarativeOAuth Specific optional headers to inject while exchanging the `auth_code` to `access_token` during `completeOAuthFlow` step.",
examples=[
@@ -952,7 +952,7 @@ class Config:
],
title="Access Token Headers",
)
- access_token_params: Optional[Dict[str, Any]] = Field(
+ access_token_params: dict[str, Any] | None = Field(
None,
description="The DeclarativeOAuth Specific optional query parameters to inject while exchanging the `auth_code` to `access_token` during `completeOAuthFlow` step.\nWhen this property is provided, the query params will be encoded as `Json` and included in the outgoing API request.",
examples=[
@@ -964,49 +964,49 @@ class Config:
],
title="Access Token Query Params (Json Encoded)",
)
- extract_output: Optional[List[str]] = Field(
+ extract_output: list[str] | None = Field(
None,
description="The DeclarativeOAuth Specific list of strings to indicate which keys should be extracted and returned back to the input config.",
examples=[["access_token", "refresh_token", "other_field"]],
title="Extract Output",
)
- state: Optional[State] = Field(
+ state: State | None = Field(
None,
description="The DeclarativeOAuth Specific object to provide the criteria of how the `state` query param should be constructed,\nincluding length and complexity.",
examples=[{"min": 7, "max": 128}],
title="Configurable State Query Param",
)
- client_id_key: Optional[str] = Field(
+ client_id_key: str | None = Field(
None,
description="The DeclarativeOAuth Specific optional override to provide the custom `client_id` key name, if required by data-provider.",
examples=["my_custom_client_id_key_name"],
title="Client ID Key Override",
)
- client_secret_key: Optional[str] = Field(
+ client_secret_key: str | None = Field(
None,
description="The DeclarativeOAuth Specific optional override to provide the custom `client_secret` key name, if required by data-provider.",
examples=["my_custom_client_secret_key_name"],
title="Client Secret Key Override",
)
- scope_key: Optional[str] = Field(
+ scope_key: str | None = Field(
None,
description="The DeclarativeOAuth Specific optional override to provide the custom `scope` key name, if required by data-provider.",
examples=["my_custom_scope_key_key_name"],
title="Scopes Key Override",
)
- state_key: Optional[str] = Field(
+ state_key: str | None = Field(
None,
description="The DeclarativeOAuth Specific optional override to provide the custom `state` key name, if required by data-provider.",
examples=["my_custom_state_key_key_name"],
title="State Key Override",
)
- auth_code_key: Optional[str] = Field(
+ auth_code_key: str | None = Field(
None,
description="The DeclarativeOAuth Specific optional override to provide the custom `code` key name to something like `auth_code` or `custom_auth_code`, if required by data-provider.",
examples=["my_custom_auth_code_key_name"],
title="Auth Code Key Override",
)
- redirect_uri_key: Optional[str] = Field(
+ redirect_uri_key: str | None = Field(
None,
description="The DeclarativeOAuth Specific optional override to provide the custom `redirect_uri` key name to something like `callback_uri`, if required by data-provider.",
examples=["my_custom_redirect_uri_key_name"],
@@ -1018,7 +1018,7 @@ class OAuthConfigSpecification(BaseModel):
class Config:
extra = Extra.allow
- oauth_user_input_from_connector_config_specification: Optional[Dict[str, Any]] = Field(
+ oauth_user_input_from_connector_config_specification: dict[str, Any] | None = Field(
None,
description="OAuth specific blob. This is a Json Schema used to validate Json configurations used as input to OAuth.\nMust be a valid non-nested JSON that refers to properties from ConnectorSpecification.connectionSpecification\nusing special annotation 'path_in_connector_config'.\nThese are input values the user is entering through the UI to authenticate to the connector, that might also shared\nas inputs for syncing data via the connector.\nExamples:\nif no connector values is shared during oauth flow, oauth_user_input_from_connector_config_specification=[]\nif connector values such as 'app_id' inside the top level are used to generate the API url for the oauth flow,\n oauth_user_input_from_connector_config_specification={\n app_id: {\n type: string\n path_in_connector_config: ['app_id']\n }\n }\nif connector values such as 'info.app_id' nested inside another object are used to generate the API url for the oauth flow,\n oauth_user_input_from_connector_config_specification={\n app_id: {\n type: string\n path_in_connector_config: ['info', 'app_id']\n }\n }",
examples=[
@@ -1032,12 +1032,12 @@ class Config:
],
title="OAuth user input",
)
- oauth_connector_input_specification: Optional[OauthConnectorInputSpecification] = Field(
+ oauth_connector_input_specification: OauthConnectorInputSpecification | None = Field(
None,
description='The DeclarativeOAuth specific blob.\nPertains to the fields defined by the connector relating to the OAuth flow.\n\nInterpolation capabilities:\n- The variables placeholders are declared as `{{my_var}}`.\n- The nested resolution variables like `{{ {{my_nested_var}} }}` is allowed as well.\n\n- The allowed interpolation context is:\n + base64Encoder - encode to `base64`, {{ {{my_var_a}}:{{my_var_b}} | base64Encoder }}\n + base64Decorer - decode from `base64` encoded string, {{ {{my_string_variable_or_string_value}} | base64Decoder }}\n + urlEncoder - encode the input string to URL-like format, {{ https://test.host.com/endpoint | urlEncoder}}\n + urlDecorer - decode the input url-encoded string into text format, {{ urlDecoder:https%3A%2F%2Fairbyte.io | urlDecoder}}\n + codeChallengeS256 - get the `codeChallenge` encoded value to provide additional data-provider specific authorisation values, {{ {{state_value}} | codeChallengeS256 }}\n\nExamples:\n - The TikTok Marketing DeclarativeOAuth spec:\n {\n "oauth_connector_input_specification": {\n "type": "object",\n "additionalProperties": false,\n "properties": {\n "consent_url": "https://ads.tiktok.com/marketing_api/auth?{{client_id_key}}={{client_id_value}}&{{redirect_uri_key}}={{ {{redirect_uri_value}} | urlEncoder}}&{{state_key}}={{state_value}}",\n "access_token_url": "https://business-api.tiktok.com/open_api/v1.3/oauth2/access_token/",\n "access_token_params": {\n "{{ auth_code_key }}": "{{ auth_code_value }}",\n "{{ client_id_key }}": "{{ client_id_value }}",\n "{{ client_secret_key }}": "{{ client_secret_value }}"\n },\n "access_token_headers": {\n "Content-Type": "application/json",\n "Accept": "application/json"\n },\n "extract_output": ["data.access_token"],\n "client_id_key": "app_id",\n "client_secret_key": "secret",\n "auth_code_key": "auth_code"\n }\n }\n }',
title="DeclarativeOAuth Connector Specification",
)
- complete_oauth_output_specification: Optional[Dict[str, Any]] = Field(
+ complete_oauth_output_specification: dict[str, Any] | None = Field(
None,
description="OAuth specific blob. This is a Json Schema used to validate Json configurations produced by the OAuth flows as they are\nreturned by the distant OAuth APIs.\nMust be a valid JSON describing the fields to merge back to `ConnectorSpecification.connectionSpecification`.\nFor each field, a special annotation `path_in_connector_config` can be specified to determine where to merge it,\nExamples:\n complete_oauth_output_specification={\n refresh_token: {\n type: string,\n path_in_connector_config: ['credentials', 'refresh_token']\n }\n }",
examples=[
@@ -1050,13 +1050,13 @@ class Config:
],
title="OAuth output specification",
)
- complete_oauth_server_input_specification: Optional[Dict[str, Any]] = Field(
+ complete_oauth_server_input_specification: dict[str, Any] | None = Field(
None,
description="OAuth specific blob. This is a Json Schema used to validate Json configurations persisted as Airbyte Server configurations.\nMust be a valid non-nested JSON describing additional fields configured by the Airbyte Instance or Workspace Admins to be used by the\nserver when completing an OAuth flow (typically exchanging an auth code for refresh token).\nExamples:\n complete_oauth_server_input_specification={\n client_id: {\n type: string\n },\n client_secret: {\n type: string\n }\n }",
examples=[{"client_id": {"type": "string"}, "client_secret": {"type": "string"}}],
title="OAuth input specification",
)
- complete_oauth_server_output_specification: Optional[Dict[str, Any]] = Field(
+ complete_oauth_server_output_specification: dict[str, Any] | None = Field(
None,
description="OAuth specific blob. This is a Json Schema used to validate Json configurations persisted as Airbyte Server configurations that\nalso need to be merged back into the connector configuration at runtime.\nThis is a subset configuration of `complete_oauth_server_input_specification` that filters fields out to retain only the ones that\nare necessary for the connector to function with OAuth. (some fields could be used during oauth flows but not needed afterwards, therefore\nthey would be listed in the `complete_oauth_server_input_specification` but not `complete_oauth_server_output_specification`)\nMust be a valid non-nested JSON describing additional fields configured by the Airbyte Instance or Workspace Admins to be used by the\nconnector when using OAuth flow APIs.\nThese fields are to be merged back to `ConnectorSpecification.connectionSpecification`.\nFor each field, a special annotation `path_in_connector_config` can be specified to determine where to merge it,\nExamples:\n complete_oauth_server_output_specification={\n client_id: {\n type: string,\n path_in_connector_config: ['credentials', 'client_id']\n },\n client_secret: {\n type: string,\n path_in_connector_config: ['credentials', 'client_secret']\n }\n }",
examples=[
@@ -1077,44 +1077,44 @@ class Config:
class OffsetIncrement(BaseModel):
type: Literal["OffsetIncrement"]
- page_size: Optional[Union[int, str]] = Field(
+ page_size: int | str | None = Field(
None,
description="The number of records to include in each pages.",
examples=[100, "{{ config['page_size'] }}"],
title="Limit",
)
- inject_on_first_request: Optional[bool] = Field(
- False,
+ inject_on_first_request: bool | None = Field(
+ False, # noqa: FBT003
description="Using the `offset` with value `0` during the first request",
title="Inject Offset",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class PageIncrement(BaseModel):
type: Literal["PageIncrement"]
- page_size: Optional[Union[int, str]] = Field(
+ page_size: int | str | None = Field(
None,
description="The number of records to include in each pages.",
examples=[100, "100", "{{ config['page_size'] }}"],
title="Page Size",
)
- start_from_page: Optional[int] = Field(
+ start_from_page: int | None = Field(
0,
description="Index of the first page to request.",
examples=[0, 1],
title="Start From Page",
)
- inject_on_first_request: Optional[bool] = Field(
- False,
+ inject_on_first_request: bool | None = Field(
+ False, # noqa: FBT003
description="Using the `page number` with value defined by `start_from_page` during the first request",
title="Inject Page Number",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class PrimaryKey(BaseModel):
- __root__: Union[str, List[str], List[List[str]]] = Field(
+ __root__: str | list[str] | list[list[str]] = Field(
...,
description="The stream field to be used to distinguish unique records. Can either be a single field, an array of fields representing a composite key, or an array of arrays representing a composite key where the fields are nested fields.",
examples=["id", ["code", "type"]],
@@ -1124,7 +1124,7 @@ class PrimaryKey(BaseModel):
class RecordFilter(BaseModel):
type: Literal["RecordFilter"]
- condition: Optional[str] = Field(
+ condition: str | None = Field(
"",
description="The predicate to filter a record. Records will be removed if evaluated to False.",
examples=[
@@ -1132,7 +1132,7 @@ class RecordFilter(BaseModel):
"{{ record.status in ['active', 'expired'] }}",
],
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class SchemaNormalization(Enum):
@@ -1142,7 +1142,7 @@ class SchemaNormalization(Enum):
class RemoveFields(BaseModel):
type: Literal["RemoveFields"]
- condition: Optional[str] = Field(
+ condition: str | None = Field(
"",
description="The predicate to filter a property by a property value. Property will be removed if it is empty OR expression is evaluated to True.,",
examples=[
@@ -1152,7 +1152,7 @@ class RemoveFields(BaseModel):
"{{ property == 'some_string_to_match' }}",
],
)
- field_pointers: List[List[str]] = Field(
+ field_pointers: list[list[str]] = Field(
...,
description="Array of paths defining the field to remove. Each item is an array whose field describe the path of a field to remove.",
examples=[["tags"], [["content", "html"], ["content", "plain_text"]]],
@@ -1208,7 +1208,7 @@ class LegacySessionTokenAuthenticator(BaseModel):
examples=["session"],
title="Login Path",
)
- session_token: Optional[str] = Field(
+ session_token: str | None = Field(
None,
description="Session token to use if using a pre-defined token. Not needed if authenticating with username + password pair",
example=["{{ config['session_token'] }}"],
@@ -1220,13 +1220,13 @@ class LegacySessionTokenAuthenticator(BaseModel):
examples=["id"],
title="Response Token Response Key",
)
- username: Optional[str] = Field(
+ username: str | None = Field(
None,
description="Username used to authenticate and obtain a session token",
examples=[" {{ config['username'] }}"],
title="Username",
)
- password: Optional[str] = Field(
+ password: str | None = Field(
"",
description="Password used to authenticate and obtain a session token",
examples=["{{ config['password'] }}", ""],
@@ -1238,31 +1238,31 @@ class LegacySessionTokenAuthenticator(BaseModel):
examples=["user/current"],
title="Validate Session Path",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class JsonParser(BaseModel):
type: Literal["JsonParser"]
- encoding: Optional[str] = "utf-8"
+ encoding: str | None = "utf-8"
class JsonLineParser(BaseModel):
type: Literal["JsonLineParser"]
- encoding: Optional[str] = "utf-8"
+ encoding: str | None = "utf-8"
class CsvParser(BaseModel):
type: Literal["CsvParser"]
- encoding: Optional[str] = "utf-8"
- delimiter: Optional[str] = ","
+ encoding: str | None = "utf-8"
+ delimiter: str | None = ","
class AsyncJobStatusMap(BaseModel):
- type: Optional[Literal["AsyncJobStatusMap"]] = None
- running: List[str]
- completed: List[str]
- failed: List[str]
- timeout: List[str]
+ type: Literal["AsyncJobStatusMap"] | None = None
+ running: list[str]
+ completed: list[str]
+ failed: list[str]
+ timeout: list[str]
class ValueType(Enum):
@@ -1280,19 +1280,19 @@ class WaitTimeFromHeader(BaseModel):
examples=["Retry-After"],
title="Response Header Name",
)
- regex: Optional[str] = Field(
+ regex: str | None = Field(
None,
description="Optional regex to apply on the header to extract its value. The regex should define a capture group defining the wait time.",
examples=["([-+]?\\d+)"],
title="Extraction Regex",
)
- max_waiting_time_in_seconds: Optional[float] = Field(
+ max_waiting_time_in_seconds: float | None = Field(
None,
description="Given the value extracted from the header is greater than this value, stop the stream.",
examples=[3600],
title="Max Waiting Time in Seconds",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class WaitUntilTimeFromHeader(BaseModel):
@@ -1303,24 +1303,24 @@ class WaitUntilTimeFromHeader(BaseModel):
examples=["wait_time"],
title="Response Header",
)
- min_wait: Optional[Union[float, str]] = Field(
+ min_wait: float | str | None = Field(
None,
description="Minimum time to wait before retrying.",
examples=[10, "60"],
title="Minimum Wait Time",
)
- regex: Optional[str] = Field(
+ regex: str | None = Field(
None,
description="Optional regex to apply on the header to extract its value. The regex should define a capture group defining the wait time.",
examples=["([-+]?\\d+)"],
title="Extraction Regex",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ComponentMappingDefinition(BaseModel):
type: Literal["ComponentMappingDefinition"]
- field_path: List[str] = Field(
+ field_path: list[str] = Field(
...,
description="A list of potentially nested fields indicating the full path where value will be added or updated.",
examples=[
@@ -1345,35 +1345,35 @@ class ComponentMappingDefinition(BaseModel):
],
title="Value",
)
- value_type: Optional[ValueType] = Field(
+ value_type: ValueType | None = Field(
None,
description="The expected data type of the value. If omitted, the type will be inferred from the value provided.",
title="Value Type",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class StreamConfig(BaseModel):
type: Literal["StreamConfig"]
- configs_pointer: List[str] = Field(
+ configs_pointer: list[str] = Field(
...,
description="A list of potentially nested fields indicating the full path in source config file where streams configs located.",
examples=[["data"], ["data", "streams"], ["data", "{{ parameters.name }}"]],
title="Configs Pointer",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ConfigComponentsResolver(BaseModel):
type: Literal["ConfigComponentsResolver"]
stream_config: StreamConfig
- components_mapping: List[ComponentMappingDefinition]
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ components_mapping: list[ComponentMappingDefinition]
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class AddedFieldDefinition(BaseModel):
type: Literal["AddedFieldDefinition"]
- path: List[str] = Field(
+ path: list[str] = Field(
...,
description="List of strings defining the path where to add the value on the record.",
examples=[["segment_id"], ["metadata", "segment_id"]],
@@ -1389,39 +1389,39 @@ class AddedFieldDefinition(BaseModel):
],
title="Value",
)
- value_type: Optional[ValueType] = Field(
+ value_type: ValueType | None = Field(
None,
description="Type of the value. If not specified, the type will be inferred from the value.",
title="Value Type",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class AddFields(BaseModel):
type: Literal["AddFields"]
- fields: List[AddedFieldDefinition] = Field(
+ fields: list[AddedFieldDefinition] = Field(
...,
description="List of transformations (path and corresponding value) that will be added to the record.",
title="Fields",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ApiKeyAuthenticator(BaseModel):
type: Literal["ApiKeyAuthenticator"]
- api_token: Optional[str] = Field(
+ api_token: str | None = Field(
None,
description="The API key to inject in the request. Fill it in the user inputs.",
examples=["{{ config['api_key'] }}", "Token token={{ config['api_key'] }}"],
title="API Key",
)
- header: Optional[str] = Field(
+ header: str | None = Field(
None,
description="The name of the HTTP header that will be set to the API key. This setting is deprecated, use inject_into instead. Header and inject_into can not be defined at the same time.",
examples=["Authorization", "Api-Token", "X-Auth-Token"],
title="Header Name",
)
- inject_into: Optional[RequestOption] = Field(
+ inject_into: RequestOption | None = Field(
None,
description="Configure how the API Key will be sent in requests to the source API. Either inject_into or header has to be defined.",
examples=[
@@ -1430,26 +1430,26 @@ class ApiKeyAuthenticator(BaseModel):
],
title="Inject API Key Into Outgoing HTTP Request",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class AuthFlow(BaseModel):
- auth_flow_type: Optional[AuthFlowType] = Field(
+ auth_flow_type: AuthFlowType | None = Field(
None, description="The type of auth to use", title="Auth flow type"
)
- predicate_key: Optional[List[str]] = Field(
+ predicate_key: list[str] | None = Field(
None,
description="JSON path to a field in the connectorSpecification that should exist for the advanced auth to be applicable.",
examples=[["credentials", "auth_type"]],
title="Predicate key",
)
- predicate_value: Optional[str] = Field(
+ predicate_value: str | None = Field(
None,
description="Value of the predicate_key fields for the advanced auth to be applicable.",
examples=["Oauth"],
title="Predicate value",
)
- oauth_config_specification: Optional[OAuthConfigSpecification] = None
+ oauth_config_specification: OAuthConfigSpecification | None = None
class DatetimeBasedCursor(BaseModel):
@@ -1466,129 +1466,128 @@ class DatetimeBasedCursor(BaseModel):
examples=["%Y-%m-%dT%H:%M:%S.%f%z", "%Y-%m-%d", "%s", "%ms", "%s_as_float"],
title="Outgoing Datetime Format",
)
- start_datetime: Union[str, MinMaxDatetime] = Field(
+ start_datetime: str | MinMaxDatetime = Field(
...,
description="The datetime that determines the earliest record that should be synced.",
examples=["2020-01-1T00:00:00Z", "{{ config['start_time'] }}"],
title="Start Datetime",
)
- cursor_datetime_formats: Optional[List[str]] = Field(
+ cursor_datetime_formats: list[str] | None = Field(
None,
description="The possible formats for the cursor field, in order of preference. The first format that matches the cursor field value will be used to parse it. If not provided, the `datetime_format` will be used.",
title="Cursor Datetime Formats",
)
- cursor_granularity: Optional[str] = Field(
+ cursor_granularity: str | None = Field(
None,
description="Smallest increment the datetime_format has (ISO 8601 duration) that is used to ensure the start of a slice does not overlap with the end of the previous one, e.g. for %Y-%m-%d the granularity should be P1D, for %Y-%m-%dT%H:%M:%SZ the granularity should be PT1S. Given this field is provided, `step` needs to be provided as well.",
examples=["PT1S"],
title="Cursor Granularity",
)
- end_datetime: Optional[Union[str, MinMaxDatetime]] = Field(
+ end_datetime: str | MinMaxDatetime | None = Field(
None,
description="The datetime that determines the last record that should be synced. If not provided, `{{ now_utc() }}` will be used.",
examples=["2021-01-1T00:00:00Z", "{{ now_utc() }}", "{{ day_delta(-1) }}"],
title="End Datetime",
)
- end_time_option: Optional[RequestOption] = Field(
+ end_time_option: RequestOption | None = Field(
None,
description="Optionally configures how the end datetime will be sent in requests to the source API.",
title="Inject End Time Into Outgoing HTTP Request",
)
- is_data_feed: Optional[bool] = Field(
+ is_data_feed: bool | None = Field(
None,
description="A data feed API is an API that does not allow filtering and paginates the content from the most recent to the least recent. Given this, the CDK needs to know when to stop paginating and this field will generate a stop condition for pagination.",
title="Whether the target API is formatted as a data feed",
)
- is_client_side_incremental: Optional[bool] = Field(
+ is_client_side_incremental: bool | None = Field(
None,
description="If the target API endpoint does not take cursor values to filter records and returns all records anyway, the connector with this cursor will filter out records locally, and only emit new records from the last sync, hence incremental. This means that all records would be read from the API, but only new records will be emitted to the destination.",
title="Whether the target API does not support filtering and returns all data (the cursor filters records in the client instead of the API side)",
)
- is_compare_strictly: Optional[bool] = Field(
- False,
+ is_compare_strictly: bool | None = Field(
+ False, # noqa: FBT003
description="Set to True if the target API does not accept queries where the start time equal the end time.",
title="Whether to skip requests if the start time equals the end time",
)
- global_substream_cursor: Optional[bool] = Field(
- False,
+ global_substream_cursor: bool | None = Field(
+ False, # noqa: FBT003
description="This setting optimizes performance when the parent stream has thousands of partitions by storing the cursor as a single value rather than per partition. Notably, the substream state is updated only at the end of the sync, which helps prevent data loss in case of a sync failure. See more info in the [docs](https://docs.airbyte.com/connector-development/config-based/understanding-the-yaml-file/incremental-syncs).",
title="Whether to store cursor as one value instead of per partition",
)
- lookback_window: Optional[str] = Field(
+ lookback_window: str | None = Field(
None,
description="Time interval before the start_datetime to read data for, e.g. P1M for looking back one month.",
examples=["P1D", "P{{ config['lookback_days'] }}D"],
title="Lookback Window",
)
- partition_field_end: Optional[str] = Field(
+ partition_field_end: str | None = Field(
None,
description="Name of the partition start time field.",
examples=["ending_time"],
title="Partition Field End",
)
- partition_field_start: Optional[str] = Field(
+ partition_field_start: str | None = Field(
None,
description="Name of the partition end time field.",
examples=["starting_time"],
title="Partition Field Start",
)
- start_time_option: Optional[RequestOption] = Field(
+ start_time_option: RequestOption | None = Field(
None,
description="Optionally configures how the start datetime will be sent in requests to the source API.",
title="Inject Start Time Into Outgoing HTTP Request",
)
- step: Optional[str] = Field(
+ step: str | None = Field(
None,
description="The size of the time window (ISO8601 duration). Given this field is provided, `cursor_granularity` needs to be provided as well.",
examples=["P1W", "{{ config['step_increment'] }}"],
title="Step",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DefaultErrorHandler(BaseModel):
type: Literal["DefaultErrorHandler"]
- backoff_strategies: Optional[
- List[
- Union[
- ConstantBackoffStrategy,
- CustomBackoffStrategy,
- ExponentialBackoffStrategy,
- WaitTimeFromHeader,
- WaitUntilTimeFromHeader,
- ]
+ backoff_strategies: (
+ list[
+ ConstantBackoffStrategy
+ | CustomBackoffStrategy
+ | ExponentialBackoffStrategy
+ | WaitTimeFromHeader
+ | WaitUntilTimeFromHeader
]
- ] = Field(
+ | None
+ ) = Field(
None,
description="List of backoff strategies to use to determine how long to wait before retrying a retryable request.",
title="Backoff Strategies",
)
- max_retries: Optional[int] = Field(
+ max_retries: int | None = Field(
5,
description="The maximum number of time to retry a retryable request before giving up and failing.",
examples=[5, 0, 10],
title="Max Retry Count",
)
- response_filters: Optional[List[HttpResponseFilter]] = Field(
+ response_filters: list[HttpResponseFilter] | None = Field(
None,
description="List of response filters to iterate on when deciding how to handle an error. When using an array of multiple filters, the filters will be applied sequentially and the response will be selected if it matches any of the filter's predicate.",
title="Response Filters",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DefaultPaginator(BaseModel):
type: Literal["DefaultPaginator"]
- pagination_strategy: Union[
- CursorPagination, CustomPaginationStrategy, OffsetIncrement, PageIncrement
- ] = Field(
+ pagination_strategy: (
+ CursorPagination | CustomPaginationStrategy | OffsetIncrement | PageIncrement
+ ) = Field(
...,
description="Strategy defining how records are paginated.",
title="Pagination Strategy",
)
- page_size_option: Optional[RequestOption] = None
- page_token_option: Optional[Union[RequestOption, RequestPath]] = None
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ page_size_option: RequestOption | None = None
+ page_token_option: RequestOption | RequestPath | None = None
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class SessionTokenRequestApiKeyAuthenticator(BaseModel):
@@ -1612,55 +1611,55 @@ class ListPartitionRouter(BaseModel):
examples=["section", "{{ config['section_key'] }}"],
title="Current Partition Value Identifier",
)
- values: Union[str, List[str]] = Field(
+ values: str | list[str] = Field(
...,
description="The list of attributes being iterated over and used as input for the requests made to the source API.",
examples=[["section_a", "section_b", "section_c"], "{{ config['sections'] }}"],
title="Partition Values",
)
- request_option: Optional[RequestOption] = Field(
+ request_option: RequestOption | None = Field(
None,
description="A request option describing where the list value should be injected into and under what field name if applicable.",
title="Inject Partition Value Into Outgoing HTTP Request",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class RecordSelector(BaseModel):
type: Literal["RecordSelector"]
- extractor: Union[CustomRecordExtractor, DpathExtractor]
- record_filter: Optional[Union[CustomRecordFilter, RecordFilter]] = Field(
+ extractor: CustomRecordExtractor | DpathExtractor
+ record_filter: CustomRecordFilter | RecordFilter | None = Field(
None,
description="Responsible for filtering records to be emitted by the Source.",
title="Record Filter",
)
- schema_normalization: Optional[Union[SchemaNormalization, CustomSchemaNormalization]] = Field(
+ schema_normalization: SchemaNormalization | CustomSchemaNormalization | None = Field(
SchemaNormalization.None_,
description="Responsible for normalization according to the schema.",
title="Schema Normalization",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class GzipParser(BaseModel):
type: Literal["GzipParser"]
- inner_parser: Union[JsonLineParser, CsvParser, JsonParser]
+ inner_parser: JsonLineParser | CsvParser | JsonParser
class Spec(BaseModel):
type: Literal["Spec"]
- connection_specification: Dict[str, Any] = Field(
+ connection_specification: dict[str, Any] = Field(
...,
description="A connection specification describing how a the connector can be configured.",
title="Connection Specification",
)
- documentation_url: Optional[str] = Field(
+ documentation_url: str | None = Field(
None,
description="URL of the connector's documentation page.",
examples=["https://docs.airbyte.com/integrations/sources/dremio"],
title="Documentation URL",
)
- advanced_auth: Optional[AuthFlow] = Field(
+ advanced_auth: AuthFlow | None = Field(
None,
description="Advanced specification for configuring the authentication flow.",
title="Advanced Auth",
@@ -1669,12 +1668,12 @@ class Spec(BaseModel):
class CompositeErrorHandler(BaseModel):
type: Literal["CompositeErrorHandler"]
- error_handlers: List[Union[CompositeErrorHandler, DefaultErrorHandler]] = Field(
+ error_handlers: list[CompositeErrorHandler | DefaultErrorHandler] = Field(
...,
description="List of error handlers to iterate on to determine how to handle a failed response.",
title="Error Handlers",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ZipfileDecoder(BaseModel):
@@ -1682,7 +1681,7 @@ class Config:
extra = Extra.allow
type: Literal["ZipfileDecoder"]
- parser: Union[GzipParser, JsonParser, JsonLineParser, CsvParser] = Field(
+ parser: GzipParser | JsonParser | JsonLineParser | CsvParser = Field(
...,
description="Parser to parse the decompressed data from the zipfile(s).",
title="Parser",
@@ -1691,7 +1690,7 @@ class Config:
class CompositeRawDecoder(BaseModel):
type: Literal["CompositeRawDecoder"]
- parser: Union[GzipParser, JsonParser, JsonLineParser, CsvParser]
+ parser: GzipParser | JsonParser | JsonLineParser | CsvParser
class DeclarativeSource1(BaseModel):
@@ -1699,22 +1698,22 @@ class Config:
extra = Extra.forbid
type: Literal["DeclarativeSource"]
- check: Union[CheckStream, CheckDynamicStream]
- streams: List[DeclarativeStream]
- dynamic_streams: Optional[List[DynamicDeclarativeStream]] = None
+ check: CheckStream | CheckDynamicStream
+ streams: list[DeclarativeStream]
+ dynamic_streams: list[DynamicDeclarativeStream] | None = None
version: str = Field(
...,
description="The version of the Airbyte CDK used to build and test the source.",
)
- schemas: Optional[Schemas] = None
- definitions: Optional[Dict[str, Any]] = None
- spec: Optional[Spec] = None
- concurrency_level: Optional[ConcurrencyLevel] = None
- metadata: Optional[Dict[str, Any]] = Field(
+ schemas: Schemas | None = None
+ definitions: dict[str, Any] | None = None
+ spec: Spec | None = None
+ concurrency_level: ConcurrencyLevel | None = None
+ metadata: dict[str, Any] | None = Field(
None,
description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.",
)
- description: Optional[str] = Field(
+ description: str | None = Field(
None,
description="A description of the connector. It will be presented on the Source documentation page.",
)
@@ -1725,22 +1724,22 @@ class Config:
extra = Extra.forbid
type: Literal["DeclarativeSource"]
- check: Union[CheckStream, CheckDynamicStream]
- streams: Optional[List[DeclarativeStream]] = None
- dynamic_streams: List[DynamicDeclarativeStream]
+ check: CheckStream | CheckDynamicStream
+ streams: list[DeclarativeStream] | None = None
+ dynamic_streams: list[DynamicDeclarativeStream]
version: str = Field(
...,
description="The version of the Airbyte CDK used to build and test the source.",
)
- schemas: Optional[Schemas] = None
- definitions: Optional[Dict[str, Any]] = None
- spec: Optional[Spec] = None
- concurrency_level: Optional[ConcurrencyLevel] = None
- metadata: Optional[Dict[str, Any]] = Field(
+ schemas: Schemas | None = None
+ definitions: dict[str, Any] | None = None
+ spec: Spec | None = None
+ concurrency_level: ConcurrencyLevel | None = None
+ metadata: dict[str, Any] | None = Field(
None,
description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.",
)
- description: Optional[str] = Field(
+ description: str | None = Field(
None,
description="A description of the connector. It will be presented on the Source documentation page.",
)
@@ -1750,7 +1749,7 @@ class DeclarativeSource(BaseModel):
class Config:
extra = Extra.forbid
- __root__: Union[DeclarativeSource1, DeclarativeSource2] = Field(
+ __root__: DeclarativeSource1 | DeclarativeSource2 = Field(
...,
description="An API source that extracts data according to its declarative components.",
title="DeclarativeSource",
@@ -1762,25 +1761,23 @@ class Config:
extra = Extra.allow
type: Literal["SelectiveAuthenticator"]
- authenticator_selection_path: List[str] = Field(
+ authenticator_selection_path: list[str] = Field(
...,
description="Path of the field in config with selected authenticator name",
examples=[["auth"], ["auth", "type"]],
title="Authenticator Selection Path",
)
- authenticators: Dict[
+ authenticators: dict[
str,
- Union[
- ApiKeyAuthenticator,
- BasicHttpAuthenticator,
- BearerAuthenticator,
- CustomAuthenticator,
- OAuthAuthenticator,
- JwtAuthenticator,
- NoAuth,
- SessionTokenAuthenticator,
- LegacySessionTokenAuthenticator,
- ],
+ ApiKeyAuthenticator
+ | BasicHttpAuthenticator
+ | BearerAuthenticator
+ | CustomAuthenticator
+ | OAuthAuthenticator
+ | JwtAuthenticator
+ | NoAuth
+ | SessionTokenAuthenticator
+ | LegacySessionTokenAuthenticator,
] = Field(
...,
description="Authenticators to select from.",
@@ -1795,7 +1792,7 @@ class Config:
],
title="Authenticators",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DeclarativeStream(BaseModel):
@@ -1803,58 +1800,52 @@ class Config:
extra = Extra.allow
type: Literal["DeclarativeStream"]
- retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field(
+ retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field(
...,
description="Component used to coordinate how records are extracted across stream slices and request pages.",
title="Retriever",
)
- incremental_sync: Optional[Union[CustomIncrementalSync, DatetimeBasedCursor]] = Field(
+ incremental_sync: CustomIncrementalSync | DatetimeBasedCursor | None = Field(
None,
description="Component used to fetch data incrementally based on a time field in the data.",
title="Incremental Sync",
)
- name: Optional[str] = Field("", description="The stream name.", example=["Users"], title="Name")
- primary_key: Optional[PrimaryKey] = Field(
+ name: str | None = Field("", description="The stream name.", example=["Users"], title="Name")
+ primary_key: PrimaryKey | None = Field(
"", description="The primary key of the stream.", title="Primary Key"
)
- schema_loader: Optional[
- Union[
- DynamicSchemaLoader,
- InlineSchemaLoader,
- JsonFileSchemaLoader,
- CustomSchemaLoader,
- ]
- ] = Field(
+ schema_loader: (
+ DynamicSchemaLoader | InlineSchemaLoader | JsonFileSchemaLoader | CustomSchemaLoader | None
+ ) = Field(
None,
description="Component used to retrieve the schema for the current stream.",
title="Schema Loader",
)
- transformations: Optional[
- List[
- Union[
- AddFields,
- CustomTransformation,
- RemoveFields,
- KeysToLower,
- KeysToSnakeCase,
- FlattenFields,
- DpathFlattenFields,
- KeysReplace,
- ]
+ transformations: (
+ list[
+ AddFields
+ | CustomTransformation
+ | RemoveFields
+ | KeysToLower
+ | KeysToSnakeCase
+ | FlattenFields
+ | DpathFlattenFields
+ | KeysReplace
]
- ] = Field(
+ | None
+ ) = Field(
None,
description="A list of transformations to be applied to each output record.",
title="Transformations",
)
- state_migrations: Optional[
- List[Union[LegacyToPerPartitionStateMigration, CustomStateMigration]]
- ] = Field(
- [],
- description="Array of state migrations to be applied on the input state",
- title="State Migrations",
+ state_migrations: list[LegacyToPerPartitionStateMigration | CustomStateMigration] | None = (
+ Field(
+ [],
+ description="Array of state migrations to be applied on the input state",
+ title="State Migrations",
+ )
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class SessionTokenAuthenticator(BaseModel):
@@ -1876,29 +1867,29 @@ class SessionTokenAuthenticator(BaseModel):
],
title="Login Requester",
)
- session_token_path: List[str] = Field(
+ session_token_path: list[str] = Field(
...,
description="The path in the response body returned from the login requester to the session token.",
examples=[["access_token"], ["result", "token"]],
title="Session Token Path",
)
- expiration_duration: Optional[str] = Field(
+ expiration_duration: str | None = Field(
None,
description="The duration in ISO 8601 duration notation after which the session token expires, starting from the time it was obtained. Omitting it will result in the session token being refreshed for every request.",
examples=["PT1H", "P1D"],
title="Expiration Duration",
)
- request_authentication: Union[
- SessionTokenRequestApiKeyAuthenticator, SessionTokenRequestBearerAuthenticator
- ] = Field(
+ request_authentication: (
+ SessionTokenRequestApiKeyAuthenticator | SessionTokenRequestBearerAuthenticator
+ ) = Field(
...,
description="Authentication method to use for requests sent to the API, specifying how to inject the session token.",
title="Data Request Authentication",
)
- decoder: Optional[Union[JsonDecoder, XmlDecoder, CompositeRawDecoder]] = Field(
+ decoder: JsonDecoder | XmlDecoder | CompositeRawDecoder | None = Field(
None, description="Component used to decode the response.", title="Decoder"
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class HttpRequester(BaseModel):
@@ -1922,38 +1913,35 @@ class HttpRequester(BaseModel):
],
title="URL Path",
)
- authenticator: Optional[
- Union[
- ApiKeyAuthenticator,
- BasicHttpAuthenticator,
- BearerAuthenticator,
- CustomAuthenticator,
- OAuthAuthenticator,
- JwtAuthenticator,
- NoAuth,
- SessionTokenAuthenticator,
- LegacySessionTokenAuthenticator,
- SelectiveAuthenticator,
- ]
- ] = Field(
+ authenticator: (
+ ApiKeyAuthenticator
+ | BasicHttpAuthenticator
+ | BearerAuthenticator
+ | CustomAuthenticator
+ | OAuthAuthenticator
+ | JwtAuthenticator
+ | NoAuth
+ | SessionTokenAuthenticator
+ | LegacySessionTokenAuthenticator
+ | SelectiveAuthenticator
+ | None
+ ) = Field(
None,
description="Authentication method to use for requests sent to the API.",
title="Authenticator",
)
- error_handler: Optional[
- Union[DefaultErrorHandler, CustomErrorHandler, CompositeErrorHandler]
- ] = Field(
+ error_handler: DefaultErrorHandler | CustomErrorHandler | CompositeErrorHandler | None = Field(
None,
description="Error handler component that defines how to handle errors.",
title="Error Handler",
)
- http_method: Optional[HttpMethod] = Field(
+ http_method: HttpMethod | None = Field(
HttpMethod.GET,
description="The HTTP method used to fetch data from the source (can be GET or POST).",
examples=["GET", "POST"],
title="HTTP Method",
)
- request_body_data: Optional[Union[str, Dict[str, str]]] = Field(
+ request_body_data: str | dict[str, str] | None = Field(
None,
description="Specifies how to populate the body of the request with a non-JSON payload. Plain text will be sent as is, whereas objects will be converted to a urlencoded form.",
examples=[
@@ -1961,7 +1949,7 @@ class HttpRequester(BaseModel):
],
title="Request Body Payload (Non-JSON)",
)
- request_body_json: Optional[Union[str, Dict[str, Any]]] = Field(
+ request_body_json: str | dict[str, Any] | None = Field(
None,
description="Specifies how to populate the body of the request with a JSON payload. Can contain nested objects.",
examples=[
@@ -1971,13 +1959,13 @@ class HttpRequester(BaseModel):
],
title="Request Body JSON Payload",
)
- request_headers: Optional[Union[str, Dict[str, str]]] = Field(
+ request_headers: str | dict[str, str] | None = Field(
None,
description="Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.",
examples=[{"Output-Format": "JSON"}, {"Version": "{{ config['version'] }}"}],
title="Request Headers",
)
- request_parameters: Optional[Union[str, Dict[str, str]]] = Field(
+ request_parameters: str | dict[str, str] | None = Field(
None,
description="Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.",
examples=[
@@ -1990,41 +1978,40 @@ class HttpRequester(BaseModel):
],
title="Query Parameters",
)
- use_cache: Optional[bool] = Field(
- False,
+ use_cache: bool | None = Field(
+ False, # noqa: FBT003
description="Enables stream requests caching. This field is automatically set by the CDK.",
title="Use Cache",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DynamicSchemaLoader(BaseModel):
type: Literal["DynamicSchemaLoader"]
- retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field(
+ retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field(
...,
description="Component used to coordinate how records are extracted across stream slices and request pages.",
title="Retriever",
)
- schema_transformations: Optional[
- List[
- Union[
- AddFields,
- CustomTransformation,
- RemoveFields,
- KeysToLower,
- KeysToSnakeCase,
- FlattenFields,
- DpathFlattenFields,
- KeysReplace,
- ]
+ schema_transformations: (
+ list[
+ AddFields
+ | CustomTransformation
+ | RemoveFields
+ | KeysToLower
+ | KeysToSnakeCase
+ | FlattenFields
+ | DpathFlattenFields
+ | KeysReplace
]
- ] = Field(
+ | None
+ ) = Field(
None,
description="A list of transformations to be applied to the schema.",
title="Schema Transformations",
)
schema_type_identifier: SchemaTypeIdentifier
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class ParentStreamConfig(BaseModel):
@@ -2044,22 +2031,22 @@ class ParentStreamConfig(BaseModel):
examples=["parent_id", "{{ config['parent_partition_field'] }}"],
title="Current Parent Key Value Identifier",
)
- request_option: Optional[RequestOption] = Field(
+ request_option: RequestOption | None = Field(
None,
description="A request option describing where the parent key value should be injected into and under what field name if applicable.",
title="Request Option",
)
- incremental_dependency: Optional[bool] = Field(
- False,
+ incremental_dependency: bool | None = Field(
+ False, # noqa: FBT003
description="Indicates whether the parent stream should be read incrementally based on updates in the child stream.",
title="Incremental Dependency",
)
- extra_fields: Optional[List[List[str]]] = Field(
+ extra_fields: list[list[str]] | None = Field(
None,
description="Array of field paths to include as additional fields in the stream slice. Each path is an array of strings representing keys to access fields in the respective parent record. Accessible via `stream_slice.extra_fields`. Missing fields are set to `None`.",
title="Extra Fields",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class SimpleRetriever(BaseModel):
@@ -2068,47 +2055,45 @@ class SimpleRetriever(BaseModel):
...,
description="Component that describes how to extract records from a HTTP response.",
)
- requester: Union[CustomRequester, HttpRequester] = Field(
+ requester: CustomRequester | HttpRequester = Field(
...,
description="Requester component that describes how to prepare HTTP requests to send to the source API.",
)
- paginator: Optional[Union[DefaultPaginator, NoPagination]] = Field(
+ paginator: DefaultPaginator | NoPagination | None = Field(
None,
description="Paginator component that describes how to navigate through the API's pages.",
)
- ignore_stream_slicer_parameters_on_paginated_requests: Optional[bool] = Field(
- False,
+ ignore_stream_slicer_parameters_on_paginated_requests: bool | None = Field(
+ False, # noqa: FBT003
description="If true, the partition router and incremental request options will be ignored when paginating requests. Request options set directly on the requester will not be ignored.",
)
- partition_router: Optional[
- Union[
- CustomPartitionRouter,
- ListPartitionRouter,
- SubstreamPartitionRouter,
- List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]],
- ]
- ] = Field(
+ partition_router: (
+ CustomPartitionRouter
+ | ListPartitionRouter
+ | SubstreamPartitionRouter
+ | list[CustomPartitionRouter | ListPartitionRouter | SubstreamPartitionRouter]
+ | None
+ ) = Field(
[],
description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.",
title="Partition Router",
)
- decoder: Optional[
- Union[
- CustomDecoder,
- JsonDecoder,
- JsonlDecoder,
- IterableDecoder,
- XmlDecoder,
- GzipJsonDecoder,
- CompositeRawDecoder,
- ZipfileDecoder,
- ]
- ] = Field(
+ decoder: (
+ CustomDecoder
+ | JsonDecoder
+ | JsonlDecoder
+ | IterableDecoder
+ | XmlDecoder
+ | GzipJsonDecoder
+ | CompositeRawDecoder
+ | ZipfileDecoder
+ | None
+ ) = Field(
None,
description="Component decoding the response so records can be extracted.",
title="Decoder",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class AsyncRetriever(BaseModel):
@@ -2120,110 +2105,107 @@ class AsyncRetriever(BaseModel):
status_mapping: AsyncJobStatusMap = Field(
..., description="Async Job Status to Airbyte CDK Async Job Status mapping."
)
- status_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field(
+ status_extractor: CustomRecordExtractor | DpathExtractor = Field(
..., description="Responsible for fetching the actual status of the async job."
)
- urls_extractor: Union[CustomRecordExtractor, DpathExtractor] = Field(
+ urls_extractor: CustomRecordExtractor | DpathExtractor = Field(
...,
description="Responsible for fetching the final result `urls` provided by the completed / finished / ready async job.",
)
- download_extractor: Optional[
- Union[CustomRecordExtractor, DpathExtractor, ResponseToFileExtractor]
- ] = Field(None, description="Responsible for fetching the records from provided urls.")
- creation_requester: Union[CustomRequester, HttpRequester] = Field(
+ download_extractor: CustomRecordExtractor | DpathExtractor | ResponseToFileExtractor | None = (
+ Field(None, description="Responsible for fetching the records from provided urls.")
+ )
+ creation_requester: CustomRequester | HttpRequester = Field(
...,
description="Requester component that describes how to prepare HTTP requests to send to the source API to create the async server-side job.",
)
- polling_requester: Union[CustomRequester, HttpRequester] = Field(
+ polling_requester: CustomRequester | HttpRequester = Field(
...,
description="Requester component that describes how to prepare HTTP requests to send to the source API to fetch the status of the running async job.",
)
- url_requester: Optional[Union[CustomRequester, HttpRequester]] = Field(
+ url_requester: CustomRequester | HttpRequester | None = Field(
None,
description="Requester component that describes how to prepare HTTP requests to send to the source API to extract the url from polling response by the completed async job.",
)
- download_requester: Union[CustomRequester, HttpRequester] = Field(
+ download_requester: CustomRequester | HttpRequester = Field(
...,
description="Requester component that describes how to prepare HTTP requests to send to the source API to download the data provided by the completed async job.",
)
- download_paginator: Optional[Union[DefaultPaginator, NoPagination]] = Field(
+ download_paginator: DefaultPaginator | NoPagination | None = Field(
None,
description="Paginator component that describes how to navigate through the API's pages during download.",
)
- abort_requester: Optional[Union[CustomRequester, HttpRequester]] = Field(
+ abort_requester: CustomRequester | HttpRequester | None = Field(
None,
description="Requester component that describes how to prepare HTTP requests to send to the source API to abort a job once it is timed out from the source's perspective.",
)
- delete_requester: Optional[Union[CustomRequester, HttpRequester]] = Field(
+ delete_requester: CustomRequester | HttpRequester | None = Field(
None,
description="Requester component that describes how to prepare HTTP requests to send to the source API to delete a job once the records are extracted.",
)
- partition_router: Optional[
- Union[
- CustomPartitionRouter,
- ListPartitionRouter,
- SubstreamPartitionRouter,
- List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]],
- ]
- ] = Field(
+ partition_router: (
+ CustomPartitionRouter
+ | ListPartitionRouter
+ | SubstreamPartitionRouter
+ | list[CustomPartitionRouter | ListPartitionRouter | SubstreamPartitionRouter]
+ | None
+ ) = Field(
[],
description="PartitionRouter component that describes how to partition the stream, enabling incremental syncs and checkpointing.",
title="Partition Router",
)
- decoder: Optional[
- Union[
- CustomDecoder,
- JsonDecoder,
- JsonlDecoder,
- IterableDecoder,
- XmlDecoder,
- GzipJsonDecoder,
- CompositeRawDecoder,
- ZipfileDecoder,
- ]
- ] = Field(
+ decoder: (
+ CustomDecoder
+ | JsonDecoder
+ | JsonlDecoder
+ | IterableDecoder
+ | XmlDecoder
+ | GzipJsonDecoder
+ | CompositeRawDecoder
+ | ZipfileDecoder
+ | None
+ ) = Field(
None,
description="Component decoding the response so records can be extracted.",
title="Decoder",
)
- download_decoder: Optional[
- Union[
- CustomDecoder,
- JsonDecoder,
- JsonlDecoder,
- IterableDecoder,
- XmlDecoder,
- GzipJsonDecoder,
- CompositeRawDecoder,
- ZipfileDecoder,
- ]
- ] = Field(
+ download_decoder: (
+ CustomDecoder
+ | JsonDecoder
+ | JsonlDecoder
+ | IterableDecoder
+ | XmlDecoder
+ | GzipJsonDecoder
+ | CompositeRawDecoder
+ | ZipfileDecoder
+ | None
+ ) = Field(
None,
description="Component decoding the download response so records can be extracted.",
title="Download Decoder",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class SubstreamPartitionRouter(BaseModel):
type: Literal["SubstreamPartitionRouter"]
- parent_stream_configs: List[ParentStreamConfig] = Field(
+ parent_stream_configs: list[ParentStreamConfig] = Field(
...,
description="Specifies which parent streams are being iterated over and how parent records should be used to partition the child stream data set.",
title="Parent Stream Configs",
)
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class HttpComponentsResolver(BaseModel):
type: Literal["HttpComponentsResolver"]
- retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field(
+ retriever: AsyncRetriever | CustomRetriever | SimpleRetriever = Field(
...,
description="Component used to coordinate how records are extracted across stream slices and request pages.",
title="Retriever",
)
- components_mapping: List[ComponentMappingDefinition]
- parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
+ components_mapping: list[ComponentMappingDefinition]
+ parameters: dict[str, Any] | None = Field(None, alias="$parameters")
class DynamicDeclarativeStream(BaseModel):
@@ -2231,7 +2213,7 @@ class DynamicDeclarativeStream(BaseModel):
stream_template: DeclarativeStream = Field(
..., description="Reference to the stream template.", title="Stream Template"
)
- components_resolver: Union[HttpComponentsResolver, ConfigComponentsResolver] = Field(
+ components_resolver: HttpComponentsResolver | ConfigComponentsResolver = Field(
...,
description="Component resolve and populates stream templates with components values.",
title="Components Resolver",
diff --git a/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py b/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py
index 8a6638fad..2b805b50e 100644
--- a/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py
+++ b/airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py
@@ -5,9 +5,8 @@
import sys
from collections.abc import Mapping
from types import ModuleType
-from typing import Any, cast
+from typing import Any, Literal, cast
-from typing_extensions import Literal
ChecksumType = Literal["md5", "sha256"]
CHECKSUM_FUNCTIONS = {
@@ -107,10 +106,10 @@ def get_registered_components_module(
# Check for `components` or `source_declarative_manifest.components`.
if SDM_COMPONENTS_MODULE_NAME in sys.modules:
- return cast(ModuleType, sys.modules.get(SDM_COMPONENTS_MODULE_NAME))
+ return cast(ModuleType, sys.modules.get(SDM_COMPONENTS_MODULE_NAME)) # noqa: TC006
if COMPONENTS_MODULE_NAME in sys.modules:
- return cast(ModuleType, sys.modules.get(COMPONENTS_MODULE_NAME))
+ return cast(ModuleType, sys.modules.get(COMPONENTS_MODULE_NAME)) # noqa: TC006
# Could not find module 'components' in `sys.modules`
# and INJECTED_COMPONENTS_PY was not provided in config.
diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py
index f2719bb14..73ce6b526 100644
--- a/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py
+++ b/airbyte_cdk/sources/declarative/parsers/manifest_component_transformer.py
@@ -3,8 +3,9 @@
#
import copy
-import typing
-from typing import Any, Mapping
+from collections.abc import Mapping
+from typing import Any
+
PARAMETERS_STR = "$parameters"
@@ -149,7 +150,7 @@ def propagate_types_and_parameters(
)
if excluded_parameter:
current_parameters[field_name] = excluded_parameter
- elif isinstance(field_value, typing.List):
+ elif isinstance(field_value, list):
# We exclude propagating a parameter that matches the current field name because that would result in an infinite cycle
excluded_parameter = current_parameters.pop(field_name, None)
for i, element in enumerate(field_value):
diff --git a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py
index 045ea9a2c..9fb80d0c4 100644
--- a/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py
+++ b/airbyte_cdk/sources/declarative/parsers/manifest_reference_resolver.py
@@ -3,13 +3,15 @@
#
import re
-from typing import Any, Mapping, Set, Tuple, Union
+from collections.abc import Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.parsers.custom_exceptions import (
CircularReferenceException,
UndefinedReferenceException,
)
+
REF_TAG = "$ref"
@@ -106,7 +108,7 @@ def preprocess_manifest(self, manifest: Mapping[str, Any]) -> Mapping[str, Any]:
"""
return self._evaluate_node(manifest, manifest, set()) # type: ignore[no-any-return]
- def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[Any]) -> Any:
+ def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: set[Any]) -> Any: # noqa: ANN401
if isinstance(node, dict):
evaluated_dict = {
k: self._evaluate_node(v, manifest, visited)
@@ -118,24 +120,21 @@ def _evaluate_node(self, node: Any, manifest: Mapping[str, Any], visited: Set[An
evaluated_ref = self._evaluate_node(node[REF_TAG], manifest, visited)
if not isinstance(evaluated_ref, dict):
return evaluated_ref
- else:
- # The values defined on the component take precedence over the reference values
- return evaluated_ref | evaluated_dict
- else:
- return evaluated_dict
- elif isinstance(node, list):
+ # The values defined on the component take precedence over the reference values
+ return evaluated_ref | evaluated_dict
+ return evaluated_dict
+ if isinstance(node, list):
return [self._evaluate_node(v, manifest, visited) for v in node]
- elif self._is_ref(node):
+ if self._is_ref(node):
if node in visited:
raise CircularReferenceException(node)
visited.add(node)
ret = self._evaluate_node(self._lookup_ref_value(node, manifest), manifest, visited)
visited.remove(node)
return ret
- else:
- return node
+ return node
- def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any:
+ def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any: # noqa: ANN401
ref_match = re.match(r"#/(.*)", ref)
if not ref_match:
raise ValueError(f"Invalid reference format {ref}")
@@ -143,10 +142,10 @@ def _lookup_ref_value(self, ref: str, manifest: Mapping[str, Any]) -> Any:
path = ref_match.groups()[0]
return self._read_ref_value(path, manifest)
except (AttributeError, KeyError, IndexError):
- raise UndefinedReferenceException(path, ref)
+ raise UndefinedReferenceException(path, ref) # noqa: B904
@staticmethod
- def _is_ref(node: Any) -> bool:
+ def _is_ref(node: Any) -> bool: # noqa: ANN401
return isinstance(node, str) and node.startswith("#/")
@staticmethod
@@ -154,7 +153,7 @@ def _is_ref_key(key: str) -> bool:
return bool(key == REF_TAG)
@staticmethod
- def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any:
+ def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any: # noqa: ANN401
"""
Read the value at the referenced location of the manifest.
@@ -185,7 +184,7 @@ def _read_ref_value(ref: str, manifest_node: Mapping[str, Any]) -> Any:
return manifest_node
-def _parse_path(ref: str) -> Tuple[Union[str, int], str]:
+def _parse_path(ref: str) -> tuple[str | int, str]:
"""
Return the next path component, together with the rest of the path.
diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py
index c39ae0a68..dac53a066 100644
--- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py
+++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py
@@ -8,31 +8,23 @@
import importlib
import inspect
import re
-import sys
+from collections.abc import Callable, Mapping, MutableMapping
from functools import partial
from typing import (
Any,
- Callable,
- Dict,
- List,
- Mapping,
- MutableMapping,
- Optional,
- Type,
- Union,
get_args,
get_origin,
get_type_hints,
)
from isodate import parse_duration
-from pydantic.v1 import BaseModel
+from pydantic.v1 import BaseModel # noqa: TC002
from airbyte_cdk.models import FailureType, Level
-from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
+from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager # noqa: TC001
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator
from airbyte_cdk.sources.declarative.async_job.job_tracker import JobTracker
-from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository
+from airbyte_cdk.sources.declarative.async_job.repository import AsyncJobRepository # noqa: TC001
from airbyte_cdk.sources.declarative.async_job.status import AsyncJobStatus
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import (
@@ -189,7 +181,7 @@
CustomRetriever as CustomRetrieverModel,
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
- CustomSchemaLoader as CustomSchemaLoader,
+ CustomSchemaLoader as CustomSchemaLoader, # noqa: PLC0414
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
CustomSchemaNormalization as CustomSchemaNormalizationModel,
@@ -364,10 +356,6 @@
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
ZipfileDecoder as ZipfileDecoderModel,
)
-from airbyte_cdk.sources.declarative.parsers.custom_code_compiler import (
- COMPONENTS_MODULE_NAME,
- SDM_COMPONENTS_MODULE_NAME,
-)
from airbyte_cdk.sources.declarative.partition_routers import (
CartesianProductStreamSlicer,
ListPartitionRouter,
@@ -435,7 +423,7 @@
TypesMap,
)
from airbyte_cdk.sources.declarative.spec import Spec
-from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer
+from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer # noqa: TC001
from airbyte_cdk.sources.declarative.transformations import (
AddFields,
RecordTransformation,
@@ -468,9 +456,10 @@
DateTimeStreamStateConverter,
)
from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction
-from airbyte_cdk.sources.types import Config
+from airbyte_cdk.sources.types import Config # noqa: TC001
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
+
ComponentDefinition = Mapping[str, Any]
SCHEMA_TRANSFORMER_TYPE_MAPPING = {
@@ -479,18 +468,18 @@
}
-class ModelToComponentFactory:
+class ModelToComponentFactory: # noqa: PLR0904
EPOCH_DATETIME_FORMAT = "%s"
- def __init__(
+ def __init__( # noqa: ANN204
self,
- limit_pages_fetched_per_slice: Optional[int] = None,
- limit_slices_fetched: Optional[int] = None,
- emit_connector_builder_messages: bool = False,
- disable_retries: bool = False,
- disable_cache: bool = False,
- disable_resumable_full_refresh: bool = False,
- message_repository: Optional[MessageRepository] = None,
+ limit_pages_fetched_per_slice: int | None = None,
+ limit_slices_fetched: int | None = None,
+ emit_connector_builder_messages: bool = False, # noqa: FBT001, FBT002
+ disable_retries: bool = False, # noqa: FBT001, FBT002
+ disable_cache: bool = False, # noqa: FBT001, FBT002
+ disable_resumable_full_refresh: bool = False, # noqa: FBT001, FBT002
+ message_repository: MessageRepository | None = None,
):
self._init_mappings()
self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice
@@ -504,7 +493,7 @@ def __init__(
)
def _init_mappings(self) -> None:
- self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = {
+ self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[type[BaseModel], Callable[..., Any]] = {
AddedFieldDefinitionModel: self.create_added_field_definition,
AddFieldsModel: self.create_add_fields,
ApiKeyAuthenticatorModel: self.create_api_key_authenticator,
@@ -595,11 +584,11 @@ def _init_mappings(self) -> None:
def create_component(
self,
- model_type: Type[BaseModel],
+ model_type: type[BaseModel],
component_definition: ComponentDefinition,
config: Config,
- **kwargs: Any,
- ) -> Any:
+ **kwargs: Any, # noqa: ANN401
+ ) -> Any: # noqa: ANN401
"""
Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and
subcomponents which will be used at runtime. This is done by first parsing the mapping into a Pydantic model and then creating
@@ -620,7 +609,7 @@ def create_component(
declarative_component_model = model_type.parse_obj(component_definition)
if not isinstance(declarative_component_model, model_type):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"Expected {model_type.__name__} component, but received {declarative_component_model.__class__.__name__}"
)
@@ -628,7 +617,7 @@ def create_component(
model=declarative_component_model, config=config, **kwargs
)
- def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any:
+ def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any: # noqa: ANN401
if model.__class__ not in self.PYDANTIC_MODEL_TO_CONSTRUCTOR:
raise ValueError(
f"{model.__class__} with attributes {model} is not a valid component type"
@@ -640,7 +629,9 @@ def _create_component_from_model(self, model: BaseModel, config: Config, **kwarg
@staticmethod
def create_added_field_definition(
- model: AddedFieldDefinitionModel, config: Config, **kwargs: Any
+ model: AddedFieldDefinitionModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> AddedFieldDefinition:
interpolated_value = InterpolatedString.create(
model.value, parameters=model.parameters or {}
@@ -652,7 +643,7 @@ def create_added_field_definition(
parameters=model.parameters or {},
)
- def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields:
+ def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields: # noqa: ANN401, ARG002
added_field_definitions = [
self._create_component_from_model(
model=added_field_definition_model,
@@ -666,33 +657,48 @@ def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any
return AddFields(fields=added_field_definitions, parameters=model.parameters or {})
def create_keys_to_lower_transformation(
- self, model: KeysToLowerModel, config: Config, **kwargs: Any
+ self,
+ model: KeysToLowerModel, # noqa: ARG002
+ config: Config, # noqa: ARG002
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> KeysToLowerTransformation:
return KeysToLowerTransformation()
def create_keys_to_snake_transformation(
- self, model: KeysToSnakeCaseModel, config: Config, **kwargs: Any
+ self,
+ model: KeysToSnakeCaseModel, # noqa: ARG002
+ config: Config, # noqa: ARG002
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> KeysToSnakeCaseTransformation:
return KeysToSnakeCaseTransformation()
def create_keys_replace_transformation(
- self, model: KeysReplaceModel, config: Config, **kwargs: Any
+ self,
+ model: KeysReplaceModel,
+ config: Config, # noqa: ARG002
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> KeysReplaceTransformation:
return KeysReplaceTransformation(
old=model.old, new=model.new, parameters=model.parameters or {}
)
def create_flatten_fields(
- self, model: FlattenFieldsModel, config: Config, **kwargs: Any
+ self,
+ model: FlattenFieldsModel,
+ config: Config, # noqa: ARG002
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> FlattenFields:
return FlattenFields(
flatten_lists=model.flatten_lists if model.flatten_lists is not None else True
)
def create_dpath_flatten_fields(
- self, model: DpathFlattenFieldsModel, config: Config, **kwargs: Any
+ self,
+ model: DpathFlattenFieldsModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DpathFlattenFields:
- model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path]
+ model_field_path: list[InterpolatedString | str] = [x for x in model.field_path]
return DpathFlattenFields(
config=config,
field_path=model_field_path,
@@ -703,7 +709,7 @@ def create_dpath_flatten_fields(
)
@staticmethod
- def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[Type[Any]]:
+ def _json_schema_type_name_to_type(value_type: ValueType | None) -> type[Any] | None:
if not value_type:
return None
names_to_types = {
@@ -718,8 +724,8 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[
def create_api_key_authenticator(
model: ApiKeyAuthenticatorModel,
config: Config,
- token_provider: Optional[TokenProvider] = None,
- **kwargs: Any,
+ token_provider: TokenProvider | None = None,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> ApiKeyAuthenticator:
if model.inject_into is None and model.header is None:
raise ValueError(
@@ -731,7 +737,7 @@ def create_api_key_authenticator(
"inject_into and header cannot be set both for ApiKeyAuthenticator - remove the deprecated header option"
)
- if token_provider is not None and model.api_token != "":
+ if token_provider is not None and model.api_token != "": # noqa: PLC1901
raise ValueError(
"If token_provider is set, api_token is ignored and has to be set to empty string."
)
@@ -766,20 +772,20 @@ def create_api_key_authenticator(
def create_legacy_to_per_partition_state_migration(
self,
- model: LegacyToPerPartitionStateMigrationModel,
+ model: LegacyToPerPartitionStateMigrationModel, # noqa: ARG002
config: Mapping[str, Any],
declarative_stream: DeclarativeStreamModel,
) -> LegacyToPerPartitionStateMigration:
retriever = declarative_stream.retriever
if not isinstance(retriever, SimpleRetrieverModel):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"LegacyToPerPartitionStateMigrations can only be applied on a DeclarativeStream with a SimpleRetriever. Got {type(retriever)}"
)
partition_router = retriever.partition_router
- if not isinstance(
+ if not isinstance( # noqa: UP038
partition_router, (SubstreamPartitionRouterModel, CustomPartitionRouterModel)
):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"LegacyToPerPartitionStateMigrations can only be applied on a SimpleRetriever with a Substream partition router. Got {type(partition_router)}"
)
if not hasattr(partition_router, "parent_stream_configs"):
@@ -800,8 +806,12 @@ def create_legacy_to_per_partition_state_migration(
)
def create_session_token_authenticator(
- self, model: SessionTokenAuthenticatorModel, config: Config, name: str, **kwargs: Any
- ) -> Union[ApiKeyAuthenticator, BearerAuthenticator]:
+ self,
+ model: SessionTokenAuthenticatorModel,
+ config: Config,
+ name: str,
+ **kwargs: Any, # noqa: ANN401, ARG002
+ ) -> ApiKeyAuthenticator | BearerAuthenticator:
decoder = (
self._create_component_from_model(model=model.decoder, config=config)
if model.decoder
@@ -829,20 +839,21 @@ def create_session_token_authenticator(
config,
token_provider=token_provider,
)
- else:
- return ModelToComponentFactory.create_api_key_authenticator(
- ApiKeyAuthenticatorModel(
- type="ApiKeyAuthenticator",
- api_token="",
- inject_into=model.request_authentication.inject_into,
- ), # type: ignore # $parameters and headers default to None
- config=config,
- token_provider=token_provider,
- )
+ return ModelToComponentFactory.create_api_key_authenticator(
+ ApiKeyAuthenticatorModel(
+ type="ApiKeyAuthenticator",
+ api_token="",
+ inject_into=model.request_authentication.inject_into,
+ ), # type: ignore # $parameters and headers default to None
+ config=config,
+ token_provider=token_provider,
+ )
@staticmethod
def create_basic_http_authenticator(
- model: BasicHttpAuthenticatorModel, config: Config, **kwargs: Any
+ model: BasicHttpAuthenticatorModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> BasicHttpAuthenticator:
return BasicHttpAuthenticator(
password=model.password or "",
@@ -855,10 +866,10 @@ def create_basic_http_authenticator(
def create_bearer_authenticator(
model: BearerAuthenticatorModel,
config: Config,
- token_provider: Optional[TokenProvider] = None,
- **kwargs: Any,
+ token_provider: TokenProvider | None = None,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> BearerAuthenticator:
- if token_provider is not None and model.api_token != "":
+ if token_provider is not None and model.api_token != "": # noqa: PLC1901
raise ValueError(
"If token_provider is set, api_token is ignored and has to be set to empty string."
)
@@ -877,17 +888,22 @@ def create_bearer_authenticator(
)
@staticmethod
- def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream:
+ def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream: # noqa: ANN401, ARG004
return CheckStream(stream_names=model.stream_names, parameters={})
@staticmethod
def create_check_dynamic_stream(
- model: CheckDynamicStreamModel, config: Config, **kwargs: Any
+ model: CheckDynamicStreamModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> CheckDynamicStream:
return CheckDynamicStream(stream_count=model.stream_count, parameters={})
def create_composite_error_handler(
- self, model: CompositeErrorHandlerModel, config: Config, **kwargs: Any
+ self,
+ model: CompositeErrorHandlerModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> CompositeErrorHandler:
error_handlers = [
self._create_component_from_model(model=error_handler_model, config=config)
@@ -899,7 +915,9 @@ def create_composite_error_handler(
@staticmethod
def create_concurrency_level(
- model: ConcurrencyLevelModel, config: Config, **kwargs: Any
+ model: ConcurrencyLevelModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> ConcurrencyLevel:
return ConcurrencyLevel(
default_concurrency=model.default_concurrency,
@@ -908,16 +926,16 @@ def create_concurrency_level(
parameters={},
)
- def create_concurrent_cursor_from_datetime_based_cursor(
+ def create_concurrent_cursor_from_datetime_based_cursor( # noqa: PLR0914
self,
state_manager: ConnectorStateManager,
- model_type: Type[BaseModel],
+ model_type: type[BaseModel],
component_definition: ComponentDefinition,
stream_name: str,
- stream_namespace: Optional[str],
+ stream_namespace: str | None,
config: Config,
stream_state: MutableMapping[str, Any],
- **kwargs: Any,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> ConcurrentCursor:
component_type = component_definition.get("type")
if component_definition.get("type") != model_type.__name__:
@@ -928,7 +946,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
datetime_based_cursor_model = model_type.parse_obj(component_definition)
if not isinstance(datetime_based_cursor_model, DatetimeBasedCursorModel):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"Expected {model_type.__name__} component, but received {datetime_based_cursor_model.__class__.__name__}"
)
@@ -982,7 +1000,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
cursor_granularity=cursor_granularity,
)
- start_date_runtime_value: Union[InterpolatedString, str, MinMaxDatetime]
+ start_date_runtime_value: InterpolatedString | str | MinMaxDatetime
if isinstance(datetime_based_cursor_model.start_datetime, MinMaxDatetimeModel):
start_date_runtime_value = self.create_min_max_datetime(
model=datetime_based_cursor_model.start_datetime, config=config
@@ -990,7 +1008,7 @@ def create_concurrent_cursor_from_datetime_based_cursor(
else:
start_date_runtime_value = datetime_based_cursor_model.start_datetime
- end_date_runtime_value: Optional[Union[InterpolatedString, str, MinMaxDatetime]]
+ end_date_runtime_value: InterpolatedString | str | MinMaxDatetime | None
if isinstance(datetime_based_cursor_model.end_datetime, MinMaxDatetimeModel):
end_date_runtime_value = self.create_min_max_datetime(
model=datetime_based_cursor_model.end_datetime, config=config
@@ -1066,7 +1084,9 @@ def create_concurrent_cursor_from_datetime_based_cursor(
@staticmethod
def create_constant_backoff_strategy(
- model: ConstantBackoffStrategyModel, config: Config, **kwargs: Any
+ model: ConstantBackoffStrategyModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> ConstantBackoffStrategy:
return ConstantBackoffStrategy(
backoff_time_in_seconds=model.backoff_time_in_seconds,
@@ -1075,7 +1095,11 @@ def create_constant_backoff_strategy(
)
def create_cursor_pagination(
- self, model: CursorPaginationModel, config: Config, decoder: Decoder, **kwargs: Any
+ self,
+ model: CursorPaginationModel,
+ config: Config,
+ decoder: Decoder,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> CursorPaginationStrategy:
if isinstance(decoder, PaginationDecoderDecorator):
inner_decoder = decoder.decoder
@@ -1099,7 +1123,7 @@ def create_cursor_pagination(
parameters=model.parameters or {},
)
- def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any:
+ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any: # noqa: ANN401
"""
Generically creates a custom component based on the model type and a class_name reference to the custom Python class being
instantiated. Only the model's additional properties that match the custom class definition are passed to the constructor
@@ -1154,7 +1178,7 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) ->
kwargs = {
class_field: model_args[class_field]
- for class_field in component_fields.keys()
+ for class_field in component_fields.keys() # noqa: SIM118
if class_field in model_args
}
return custom_component_class(**kwargs)
@@ -1162,7 +1186,7 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) ->
@staticmethod
def _get_class_from_fully_qualified_class_name(
full_qualified_class_name: str,
- ) -> Any:
+ ) -> Any: # noqa: ANN401
"""Get a class from its fully qualified name.
If a custom components module is needed, we assume it is already registered - probably
@@ -1194,7 +1218,7 @@ def _get_class_from_fully_qualified_class_name(
) from e
@staticmethod
- def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]:
+ def _derive_component_type_from_type_hints(field_type: Any) -> str | None: # noqa: ANN401
interface = field_type
while True:
origin = get_origin(interface)
@@ -1211,22 +1235,25 @@ def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]:
return None
@staticmethod
- def is_builtin_type(cls: Optional[Type[Any]]) -> bool:
+ def is_builtin_type(cls: type[Any] | None) -> bool: # noqa: PLW0211
if not cls:
return False
return cls.__module__ == "builtins"
@staticmethod
- def _extract_missing_parameters(error: TypeError) -> List[str]:
+ def _extract_missing_parameters(error: TypeError) -> list[str]:
parameter_search = re.search(r"keyword-only.*:\s(.*)", str(error))
if parameter_search:
return re.findall(r"\'(.+?)\'", parameter_search.group(1))
- else:
- return []
+ return []
def _create_nested_component(
- self, model: Any, model_field: str, model_value: Any, config: Config
- ) -> Any:
+ self,
+ model: Any, # noqa: ANN401
+ model_field: str, # noqa: ARG002
+ model_value: Any, # noqa: ANN401
+ config: Config,
+ ) -> Any: # noqa: ANN401
type_name = model_value.get("type", None)
if not type_name:
# If no type is specified, we can assume this is a dictionary object which can be returned instead of a subcomponent
@@ -1256,16 +1283,14 @@ def _create_nested_component(
except TypeError as error:
missing_parameters = self._extract_missing_parameters(error)
if missing_parameters:
- raise ValueError(
+ raise ValueError( # noqa: B904
f"Error creating component '{type_name}' with parent custom component {model.class_name}: Please provide "
+ ", ".join(
- (
- f"{type_name}.$parameters.{parameter}"
- for parameter in missing_parameters
- )
+ f"{type_name}.$parameters.{parameter}"
+ for parameter in missing_parameters
)
)
- raise TypeError(
+ raise TypeError( # noqa: B904
f"Error creating component '{type_name}' with parent custom component {model.class_name}: {error}"
)
else:
@@ -1274,18 +1299,21 @@ def _create_nested_component(
)
@staticmethod
- def _is_component(model_value: Any) -> bool:
+ def _is_component(model_value: Any) -> bool: # noqa: ANN401
return isinstance(model_value, dict) and model_value.get("type") is not None
def create_datetime_based_cursor(
- self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any
+ self,
+ model: DatetimeBasedCursorModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DatetimeBasedCursor:
- start_datetime: Union[str, MinMaxDatetime] = (
+ start_datetime: str | MinMaxDatetime = (
model.start_datetime
if isinstance(model.start_datetime, str)
else self.create_min_max_datetime(model.start_datetime, config)
)
- end_datetime: Union[str, MinMaxDatetime, None] = None
+ end_datetime: str | MinMaxDatetime | None = None
if model.is_data_feed and model.end_datetime:
raise ValueError("Data feed does not support end_datetime")
if model.is_data_feed and model.is_client_side_incremental:
@@ -1320,7 +1348,7 @@ def create_datetime_based_cursor(
return DatetimeBasedCursor(
cursor_field=model.cursor_field,
- cursor_datetime_formats=model.cursor_datetime_formats
+ cursor_datetime_formats=model.cursor_datetime_formats # noqa: FURB110
if model.cursor_datetime_formats
else [],
cursor_granularity=model.cursor_granularity,
@@ -1340,7 +1368,10 @@ def create_datetime_based_cursor(
)
def create_declarative_stream(
- self, model: DeclarativeStreamModel, config: Config, **kwargs: Any
+ self,
+ model: DeclarativeStreamModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DeclarativeStream:
# When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field
# components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the
@@ -1375,7 +1406,7 @@ def create_declarative_stream(
),
"substream_cursor": (
combined_slicers
- if isinstance(
+ if isinstance( # noqa: UP038
combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor)
)
else None
@@ -1418,7 +1449,7 @@ def create_declarative_stream(
transformations = []
if model.transformations:
for transformation_model in model.transformations:
- transformations.append(
+ transformations.append( # noqa: PERF401
self._create_component_from_model(model=transformation_model, config=config)
)
retriever = self._create_component_from_model(
@@ -1465,9 +1496,9 @@ def create_declarative_stream(
def _build_stream_slicer_from_partition_router(
self,
- model: Union[AsyncRetrieverModel, CustomRetrieverModel, SimpleRetrieverModel],
+ model: AsyncRetrieverModel | CustomRetrieverModel | SimpleRetrieverModel,
config: Config,
- ) -> Optional[PartitionRouter]:
+ ) -> PartitionRouter | None:
if (
hasattr(model, "partition_router")
and isinstance(model, SimpleRetrieverModel)
@@ -1483,16 +1514,15 @@ def _build_stream_slicer_from_partition_router(
],
parameters={},
)
- else:
- return self._create_component_from_model(model=stream_slicer_model, config=config) # type: ignore[no-any-return]
- # Will be created PartitionRouter as stream_slicer_model is model.partition_router
+ return self._create_component_from_model(model=stream_slicer_model, config=config) # type: ignore[no-any-return]
+ # Will be created PartitionRouter as stream_slicer_model is model.partition_router
return None
def _build_resumable_cursor_from_paginator(
self,
- model: Union[AsyncRetrieverModel, CustomRetrieverModel, SimpleRetrieverModel],
- stream_slicer: Optional[StreamSlicer],
- ) -> Optional[StreamSlicer]:
+ model: AsyncRetrieverModel | CustomRetrieverModel | SimpleRetrieverModel,
+ stream_slicer: StreamSlicer | None,
+ ) -> StreamSlicer | None:
if hasattr(model, "paginator") and model.paginator and not stream_slicer:
# For the regular Full-Refresh streams, we use the high level `ResumableFullRefreshCursor`
return ResumableFullRefreshCursor(parameters={})
@@ -1500,7 +1530,7 @@ def _build_resumable_cursor_from_paginator(
def _merge_stream_slicers(
self, model: DeclarativeStreamModel, config: Config
- ) -> Optional[StreamSlicer]:
+ ) -> StreamSlicer | None:
stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config)
if model.incremental_sync and stream_slicer:
@@ -1515,28 +1545,27 @@ def _merge_stream_slicers(
return GlobalSubstreamCursor(
stream_cursor=cursor_component, partition_router=stream_slicer
)
- else:
- cursor_component = self._create_component_from_model(
- model=incremental_sync_model, config=config
- )
- return PerPartitionWithGlobalCursor(
- cursor_factory=CursorFactory(
- lambda: self._create_component_from_model(
- model=incremental_sync_model, config=config
- ),
+ cursor_component = self._create_component_from_model(
+ model=incremental_sync_model, config=config
+ )
+ return PerPartitionWithGlobalCursor(
+ cursor_factory=CursorFactory(
+ lambda: self._create_component_from_model(
+ model=incremental_sync_model, config=config
),
- partition_router=stream_slicer,
- stream_cursor=cursor_component,
- )
- elif model.incremental_sync:
+ ),
+ partition_router=stream_slicer,
+ stream_cursor=cursor_component,
+ )
+ if model.incremental_sync:
return (
self._create_component_from_model(model=model.incremental_sync, config=config)
if model.incremental_sync
else None
)
- elif self._disable_resumable_full_refresh:
+ if self._disable_resumable_full_refresh:
return stream_slicer
- elif stream_slicer:
+ if stream_slicer:
# For the Full-Refresh sub-streams, we use the nested `ChildPartitionResumableFullRefreshCursor`
return PerPartitionCursor(
cursor_factory=CursorFactory(
@@ -1547,19 +1576,22 @@ def _merge_stream_slicers(
return self._build_resumable_cursor_from_paginator(model.retriever, stream_slicer)
def create_default_error_handler(
- self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any
+ self,
+ model: DefaultErrorHandlerModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DefaultErrorHandler:
backoff_strategies = []
if model.backoff_strategies:
for backoff_strategy_model in model.backoff_strategies:
- backoff_strategies.append(
+ backoff_strategies.append( # noqa: PERF401
self._create_component_from_model(model=backoff_strategy_model, config=config)
)
response_filters = []
if model.response_filters:
for response_filter_model in model.response_filters:
- response_filters.append(
+ response_filters.append( # noqa: PERF401
self._create_component_from_model(model=response_filter_model, config=config)
)
response_filters.append(
@@ -1580,9 +1612,9 @@ def create_default_paginator(
config: Config,
*,
url_base: str,
- decoder: Optional[Decoder] = None,
- cursor_used_for_stop_condition: Optional[DeclarativeCursor] = None,
- ) -> Union[DefaultPaginator, PaginatorTestReadDecorator]:
+ decoder: Decoder | None = None,
+ cursor_used_for_stop_condition: DeclarativeCursor | None = None,
+ ) -> DefaultPaginator | PaginatorTestReadDecorator:
if decoder:
if self._is_supported_decoder_for_pagination(decoder):
decoder_to_use = PaginationDecoderDecorator(decoder=decoder)
@@ -1624,14 +1656,14 @@ def create_dpath_extractor(
self,
model: DpathExtractorModel,
config: Config,
- decoder: Optional[Decoder] = None,
- **kwargs: Any,
+ decoder: Decoder | None = None,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DpathExtractor:
- if decoder:
+ if decoder: # noqa: SIM108
decoder_to_use = decoder
else:
decoder_to_use = JsonDecoder(parameters={})
- model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path]
+ model_field_path: list[InterpolatedString | str] = [x for x in model.field_path]
return DpathExtractor(
decoder=decoder_to_use,
field_path=model_field_path,
@@ -1642,7 +1674,7 @@ def create_dpath_extractor(
def create_response_to_file_extractor(
self,
model: ResponseToFileExtractorModel,
- **kwargs: Any,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> ResponseToFileExtractor:
return ResponseToFileExtractor(parameters=model.parameters or {})
@@ -1658,7 +1690,7 @@ def create_http_requester(
self,
model: HttpRequesterModel,
config: Config,
- decoder: Decoder = JsonDecoder(parameters={}),
+ decoder: Decoder = JsonDecoder(parameters={}), # noqa: B008
*,
name: str,
) -> HttpRequester:
@@ -1717,9 +1749,11 @@ def create_http_requester(
@staticmethod
def create_http_response_filter(
- model: HttpResponseFilterModel, config: Config, **kwargs: Any
+ model: HttpResponseFilterModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> HttpResponseFilter:
- if model.action:
+ if model.action: # noqa: SIM108
action = ResponseAction(model.action.value)
else:
action = None
@@ -1743,12 +1777,14 @@ def create_http_response_filter(
@staticmethod
def create_inline_schema_loader(
- model: InlineSchemaLoaderModel, config: Config, **kwargs: Any
+ model: InlineSchemaLoaderModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> InlineSchemaLoader:
return InlineSchemaLoader(schema=model.schema_ or {}, parameters={})
@staticmethod
- def create_types_map(model: TypesMapModel, **kwargs: Any) -> TypesMap:
+ def create_types_map(model: TypesMapModel, **kwargs: Any) -> TypesMap: # noqa: ANN401, ARG004
return TypesMap(
target_type=model.target_type,
current_type=model.current_type,
@@ -1756,21 +1792,22 @@ def create_types_map(model: TypesMapModel, **kwargs: Any) -> TypesMap:
)
def create_schema_type_identifier(
- self, model: SchemaTypeIdentifierModel, config: Config, **kwargs: Any
+ self,
+ model: SchemaTypeIdentifierModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> SchemaTypeIdentifier:
types_mapping = []
if model.types_mapping:
- types_mapping.extend(
- [
- self._create_component_from_model(types_map, config=config)
- for types_map in model.types_mapping
- ]
- )
- model_schema_pointer: List[Union[InterpolatedString, str]] = (
+ types_mapping.extend([
+ self._create_component_from_model(types_map, config=config)
+ for types_map in model.types_mapping
+ ])
+ model_schema_pointer: list[InterpolatedString | str] = (
[x for x in model.schema_pointer] if model.schema_pointer else []
)
- model_key_pointer: List[Union[InterpolatedString, str]] = [x for x in model.key_pointer]
- model_type_pointer: Optional[List[Union[InterpolatedString, str]]] = (
+ model_key_pointer: list[InterpolatedString | str] = [x for x in model.key_pointer]
+ model_type_pointer: list[InterpolatedString | str] | None = (
[x for x in model.type_pointer] if model.type_pointer else None
)
@@ -1783,7 +1820,10 @@ def create_schema_type_identifier(
)
def create_dynamic_schema_loader(
- self, model: DynamicSchemaLoaderModel, config: Config, **kwargs: Any
+ self,
+ model: DynamicSchemaLoaderModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DynamicSchemaLoader:
stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config)
combined_slicers = self._build_resumable_cursor_from_paginator(
@@ -1793,7 +1833,7 @@ def create_dynamic_schema_loader(
schema_transformations = []
if model.schema_transformations:
for transformation_model in model.schema_transformations:
- schema_transformations.append(
+ schema_transformations.append( # noqa: PERF401
self._create_component_from_model(model=transformation_model, config=config)
)
@@ -1817,67 +1857,86 @@ def create_dynamic_schema_loader(
)
@staticmethod
- def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) -> JsonDecoder:
+ def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) -> JsonDecoder: # noqa: ANN401, ARG004
return JsonDecoder(parameters={})
@staticmethod
- def create_json_parser(model: JsonParserModel, config: Config, **kwargs: Any) -> JsonParser:
- encoding = model.encoding if model.encoding else "utf-8"
+ def create_json_parser(model: JsonParserModel, config: Config, **kwargs: Any) -> JsonParser: # noqa: ANN401, ARG004
+ encoding = model.encoding if model.encoding else "utf-8" # noqa: FURB110
return JsonParser(encoding=encoding)
@staticmethod
def create_jsonl_decoder(
- model: JsonlDecoderModel, config: Config, **kwargs: Any
+ model: JsonlDecoderModel, # noqa: ARG004
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> JsonlDecoder:
return JsonlDecoder(parameters={})
@staticmethod
def create_json_line_parser(
- model: JsonLineParserModel, config: Config, **kwargs: Any
+ model: JsonLineParserModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> JsonLineParser:
return JsonLineParser(encoding=model.encoding)
@staticmethod
def create_iterable_decoder(
- model: IterableDecoderModel, config: Config, **kwargs: Any
+ model: IterableDecoderModel, # noqa: ARG004
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> IterableDecoder:
return IterableDecoder(parameters={})
@staticmethod
- def create_xml_decoder(model: XmlDecoderModel, config: Config, **kwargs: Any) -> XmlDecoder:
+ def create_xml_decoder(model: XmlDecoderModel, config: Config, **kwargs: Any) -> XmlDecoder: # noqa: ANN401, ARG004
return XmlDecoder(parameters={})
@staticmethod
def create_gzipjson_decoder(
- model: GzipJsonDecoderModel, config: Config, **kwargs: Any
+ model: GzipJsonDecoderModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> GzipJsonDecoder:
return GzipJsonDecoder(parameters={}, encoding=model.encoding)
def create_zipfile_decoder(
- self, model: ZipfileDecoderModel, config: Config, **kwargs: Any
+ self,
+ model: ZipfileDecoderModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> ZipfileDecoder:
parser = self._create_component_from_model(model=model.parser, config=config)
return ZipfileDecoder(parser=parser)
def create_gzip_parser(
- self, model: GzipParserModel, config: Config, **kwargs: Any
+ self,
+ model: GzipParserModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> GzipParser:
inner_parser = self._create_component_from_model(model=model.inner_parser, config=config)
return GzipParser(inner_parser=inner_parser)
@staticmethod
- def create_csv_parser(model: CsvParserModel, config: Config, **kwargs: Any) -> CsvParser:
+ def create_csv_parser(model: CsvParserModel, config: Config, **kwargs: Any) -> CsvParser: # noqa: ANN401, ARG004
return CsvParser(encoding=model.encoding, delimiter=model.delimiter)
def create_composite_raw_decoder(
- self, model: CompositeRawDecoderModel, config: Config, **kwargs: Any
+ self,
+ model: CompositeRawDecoderModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> CompositeRawDecoder:
parser = self._create_component_from_model(model=model.parser, config=config)
return CompositeRawDecoder(parser=parser)
@staticmethod
def create_json_file_schema_loader(
- model: JsonFileSchemaLoaderModel, config: Config, **kwargs: Any
+ model: JsonFileSchemaLoaderModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> JsonFileSchemaLoader:
return JsonFileSchemaLoader(
file_path=model.file_path or "", config=config, parameters=model.parameters or {}
@@ -1885,7 +1944,9 @@ def create_json_file_schema_loader(
@staticmethod
def create_jwt_authenticator(
- model: JwtAuthenticatorModel, config: Config, **kwargs: Any
+ model: JwtAuthenticatorModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> JwtAuthenticator:
jwt_headers = model.jwt_headers or JwtHeadersModel(kid=None, typ="JWT", cty=None)
jwt_payload = model.jwt_payload or JwtPayloadModel(iss=None, sub=None, aud=None)
@@ -1909,7 +1970,9 @@ def create_jwt_authenticator(
@staticmethod
def create_list_partition_router(
- model: ListPartitionRouterModel, config: Config, **kwargs: Any
+ model: ListPartitionRouterModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> ListPartitionRouter:
request_option = (
RequestOption(
@@ -1930,7 +1993,9 @@ def create_list_partition_router(
@staticmethod
def create_min_max_datetime(
- model: MinMaxDatetimeModel, config: Config, **kwargs: Any
+ model: MinMaxDatetimeModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> MinMaxDatetime:
return MinMaxDatetime(
datetime=model.datetime,
@@ -1941,17 +2006,22 @@ def create_min_max_datetime(
)
@staticmethod
- def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth:
+ def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth: # noqa: ANN401, ARG004
return NoAuth(parameters=model.parameters or {})
@staticmethod
def create_no_pagination(
- model: NoPaginationModel, config: Config, **kwargs: Any
+ model: NoPaginationModel, # noqa: ARG004
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> NoPagination:
return NoPagination(parameters={})
def create_oauth_authenticator(
- self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any
+ self,
+ model: OAuthAuthenticatorModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> DeclarativeOauth2Authenticator:
if model.refresh_token_updater:
# ignore type error because fixing it would have a lot of dependencies, revisit later
@@ -2028,7 +2098,11 @@ def create_oauth_authenticator(
)
def create_offset_increment(
- self, model: OffsetIncrementModel, config: Config, decoder: Decoder, **kwargs: Any
+ self,
+ model: OffsetIncrementModel,
+ config: Config,
+ decoder: Decoder,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> OffsetIncrement:
if isinstance(decoder, PaginationDecoderDecorator):
inner_decoder = decoder.decoder
@@ -2053,7 +2127,9 @@ def create_offset_increment(
@staticmethod
def create_page_increment(
- model: PageIncrementModel, config: Config, **kwargs: Any
+ model: PageIncrementModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> PageIncrement:
return PageIncrement(
page_size=model.page_size,
@@ -2064,7 +2140,10 @@ def create_page_increment(
)
def create_parent_stream_config(
- self, model: ParentStreamConfigModel, config: Config, **kwargs: Any
+ self,
+ model: ParentStreamConfigModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> ParentStreamConfig:
declarative_stream = self._create_component_from_model(model.stream, config=config)
request_option = (
@@ -2085,19 +2164,23 @@ def create_parent_stream_config(
@staticmethod
def create_record_filter(
- model: RecordFilterModel, config: Config, **kwargs: Any
+ model: RecordFilterModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> RecordFilter:
return RecordFilter(
condition=model.condition or "", config=config, parameters=model.parameters or {}
)
@staticmethod
- def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath:
+ def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath: # noqa: ANN401, ARG004
return RequestPath(parameters={})
@staticmethod
def create_request_option(
- model: RequestOptionModel, config: Config, **kwargs: Any
+ model: RequestOptionModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> RequestOption:
inject_into = RequestOptionType(model.inject_into.value)
return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={})
@@ -2108,10 +2191,10 @@ def create_record_selector(
config: Config,
*,
name: str,
- transformations: List[RecordTransformation] | None = None,
+ transformations: list[RecordTransformation] | None = None,
decoder: Decoder | None = None,
- client_side_incremental_sync: Dict[str, Any] | None = None,
- **kwargs: Any,
+ client_side_incremental_sync: dict[str, Any] | None = None,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> RecordSelector:
extractor = self._create_component_from_model(
model=model.extractor, decoder=decoder, config=config
@@ -2148,14 +2231,19 @@ def create_record_selector(
@staticmethod
def create_remove_fields(
- model: RemoveFieldsModel, config: Config, **kwargs: Any
+ model: RemoveFieldsModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> RemoveFields:
return RemoveFields(
field_pointers=model.field_pointers, condition=model.condition or "", parameters={}
)
def create_selective_authenticator(
- self, model: SelectiveAuthenticatorModel, config: Config, **kwargs: Any
+ self,
+ model: SelectiveAuthenticatorModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401
) -> DeclarativeAuthenticator:
authenticators = {
name: self._create_component_from_model(model=auth, config=config)
@@ -2171,7 +2259,11 @@ def create_selective_authenticator(
@staticmethod
def create_legacy_session_token_authenticator(
- model: LegacySessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs: Any
+ model: LegacySessionTokenAuthenticatorModel,
+ config: Config,
+ *,
+ url_base: str,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> LegacySessionTokenAuthenticator:
return LegacySessionTokenAuthenticator(
api_url=url_base,
@@ -2186,18 +2278,18 @@ def create_legacy_session_token_authenticator(
parameters=model.parameters or {},
)
- def create_simple_retriever(
+ def create_simple_retriever( # noqa: PLR0913
self,
model: SimpleRetrieverModel,
config: Config,
*,
name: str,
- primary_key: Optional[Union[str, List[str], List[List[str]]]],
- stream_slicer: Optional[StreamSlicer],
- request_options_provider: Optional[RequestOptionsProvider] = None,
+ primary_key: str | list[str] | list[list[str]] | None,
+ stream_slicer: StreamSlicer | None,
+ request_options_provider: RequestOptionsProvider | None = None,
stop_condition_on_cursor: bool = False,
- client_side_incremental_sync: Optional[Dict[str, Any]] = None,
- transformations: List[RecordTransformation],
+ client_side_incremental_sync: dict[str, Any] | None = None,
+ transformations: list[RecordTransformation],
) -> SimpleRetriever:
decoder = (
self._create_component_from_model(model=model.decoder, config=config)
@@ -2285,7 +2377,10 @@ def create_simple_retriever(
)
def _create_async_job_status_mapping(
- self, model: AsyncJobStatusMapModel, config: Config, **kwargs: Any
+ self,
+ model: AsyncJobStatusMapModel,
+ config: Config, # noqa: ARG002
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> Mapping[str, AsyncJobStatus]:
api_status_to_cdk_status = {}
for cdk_status, api_statuses in model.dict().items():
@@ -2314,19 +2409,20 @@ def _get_async_job_status(self, status: str) -> AsyncJobStatus:
case _:
raise ValueError(f"Unsupported CDK status {status}")
- def create_async_retriever(
+ def create_async_retriever( # noqa: PLR0914
self,
model: AsyncRetrieverModel,
config: Config,
*,
name: str,
- primary_key: Optional[
- Union[str, List[str], List[List[str]]]
- ], # this seems to be needed to match create_simple_retriever
- stream_slicer: Optional[StreamSlicer],
- client_side_incremental_sync: Optional[Dict[str, Any]] = None,
- transformations: List[RecordTransformation],
- **kwargs: Any,
+ primary_key: str # noqa: ARG002
+ | list[str]
+ | list[list[str]]
+ | None, # this seems to be needed to match create_simple_retriever
+ stream_slicer: StreamSlicer | None,
+ client_side_incremental_sync: dict[str, Any] | None = None,
+ transformations: list[RecordTransformation],
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> AsyncRetriever:
decoder = (
self._create_component_from_model(model=model.decoder, config=config)
@@ -2457,10 +2553,10 @@ def create_async_retriever(
job_repository,
stream_slices,
JobTracker(1),
- # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1
+ # FIXME eventually make the number of concurrent jobs in the API configurable. Until then, we limit to 1 # noqa: FIX001, TD001, TD004
self._message_repository,
has_bulk_parent=False,
- # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk
+ # FIXME work would need to be done here in order to detect if a stream as a parent stream that is bulk # noqa: FIX001, TD001, TD004
),
stream_slicer=stream_slicer,
config=config,
@@ -2475,7 +2571,7 @@ def create_async_retriever(
)
@staticmethod
- def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
+ def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec: # noqa: ANN401, ARG004
return Spec(
connection_specification=model.connection_specification,
documentation_url=model.documentation_url,
@@ -2484,18 +2580,19 @@ def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
)
def create_substream_partition_router(
- self, model: SubstreamPartitionRouterModel, config: Config, **kwargs: Any
+ self,
+ model: SubstreamPartitionRouterModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG002
) -> SubstreamPartitionRouter:
parent_stream_configs = []
if model.parent_stream_configs:
- parent_stream_configs.extend(
- [
- self._create_message_repository_substream_wrapper(
- model=parent_stream_config, config=config
- )
- for parent_stream_config in model.parent_stream_configs
- ]
- )
+ parent_stream_configs.extend([
+ self._create_message_repository_substream_wrapper(
+ model=parent_stream_config, config=config
+ )
+ for parent_stream_config in model.parent_stream_configs
+ ])
return SubstreamPartitionRouter(
parent_stream_configs=parent_stream_configs,
@@ -2505,7 +2602,7 @@ def create_substream_partition_router(
def _create_message_repository_substream_wrapper(
self, model: ParentStreamConfigModel, config: Config
- ) -> Any:
+ ) -> Any: # noqa: ANN401
substream_factory = ModelToComponentFactory(
limit_pages_fetched_per_slice=self._limit_pages_fetched_per_slice,
limit_slices_fetched=self._limit_slices_fetched,
@@ -2518,11 +2615,13 @@ def _create_message_repository_substream_wrapper(
self._evaluate_log_level(self._emit_connector_builder_messages),
),
)
- return substream_factory._create_component_from_model(model=model, config=config)
+ return substream_factory._create_component_from_model(model=model, config=config) # noqa: SLF001
@staticmethod
def create_wait_time_from_header(
- model: WaitTimeFromHeaderModel, config: Config, **kwargs: Any
+ model: WaitTimeFromHeaderModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> WaitTimeFromHeaderBackoffStrategy:
return WaitTimeFromHeaderBackoffStrategy(
header=model.header,
@@ -2536,7 +2635,9 @@ def create_wait_time_from_header(
@staticmethod
def create_wait_until_time_from_header(
- model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs: Any
+ model: WaitUntilTimeFromHeaderModel,
+ config: Config,
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> WaitUntilTimeFromHeaderBackoffStrategy:
return WaitUntilTimeFromHeaderBackoffStrategy(
header=model.header,
@@ -2549,12 +2650,14 @@ def create_wait_until_time_from_header(
def get_message_repository(self) -> MessageRepository:
return self._message_repository
- def _evaluate_log_level(self, emit_connector_builder_messages: bool) -> Level:
+ def _evaluate_log_level(self, emit_connector_builder_messages: bool) -> Level: # noqa: FBT001
return Level.DEBUG if emit_connector_builder_messages else Level.INFO
@staticmethod
def create_components_mapping_definition(
- model: ComponentMappingDefinitionModel, config: Config, **kwargs: Any
+ model: ComponentMappingDefinitionModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> ComponentMappingDefinition:
interpolated_value = InterpolatedString.create(
model.value, parameters=model.parameters or {}
@@ -2572,7 +2675,7 @@ def create_components_mapping_definition(
def create_http_components_resolver(
self, model: HttpComponentsResolverModel, config: Config
- ) -> Any:
+ ) -> Any: # noqa: ANN401
stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config)
combined_slicers = self._build_resumable_cursor_from_paginator(
model.retriever, stream_slicer
@@ -2583,7 +2686,7 @@ def create_http_components_resolver(
config=config,
name="",
primary_key=None,
- stream_slicer=stream_slicer if stream_slicer else combined_slicers,
+ stream_slicer=stream_slicer if stream_slicer else combined_slicers, # noqa: FURB110
transformations=[],
)
@@ -2607,9 +2710,11 @@ def create_http_components_resolver(
@staticmethod
def create_stream_config(
- model: StreamConfigModel, config: Config, **kwargs: Any
+ model: StreamConfigModel,
+ config: Config, # noqa: ARG004
+ **kwargs: Any, # noqa: ANN401, ARG004
) -> StreamConfig:
- model_configs_pointer: List[Union[InterpolatedString, str]] = (
+ model_configs_pointer: list[InterpolatedString | str] = (
[x for x in model.configs_pointer] if model.configs_pointer else []
)
@@ -2620,7 +2725,7 @@ def create_stream_config(
def create_config_components_resolver(
self, model: ConfigComponentsResolverModel, config: Config
- ) -> Any:
+ ) -> Any: # noqa: ANN401
stream_config = self._create_component_from_model(
model.stream_config, config=config, parameters=model.parameters or {}
)
@@ -2650,17 +2755,15 @@ def create_config_components_resolver(
)
def _is_supported_decoder_for_pagination(self, decoder: Decoder) -> bool:
- if isinstance(decoder, (JsonDecoder, XmlDecoder)):
+ if isinstance(decoder, (JsonDecoder, XmlDecoder)): # noqa: UP038
return True
- elif isinstance(decoder, CompositeRawDecoder):
+ if isinstance(decoder, CompositeRawDecoder):
return self._is_supported_parser_for_pagination(decoder.parser)
- else:
- return False
+ return False
def _is_supported_parser_for_pagination(self, parser: Parser) -> bool:
if isinstance(parser, JsonParser):
return True
- elif isinstance(parser, GzipParser):
+ if isinstance(parser, GzipParser):
return isinstance(parser.inner_parser, JsonParser)
- else:
- return False
+ return False
diff --git a/airbyte_cdk/sources/declarative/partition_routers/__init__.py b/airbyte_cdk/sources/declarative/partition_routers/__init__.py
index f35647402..5a6ecfcd0 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/__init__.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/__init__.py
@@ -19,6 +19,7 @@
SubstreamPartitionRouter,
)
+
__all__ = [
"AsyncJobPartitionRouter",
"CartesianProductStreamSlicer",
diff --git a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py
index 0f11820f7..2eccd8fe4 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py
@@ -1,7 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
+from collections.abc import Callable, Iterable, Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Callable, Iterable, Mapping, Optional
+from typing import Any
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
@@ -33,7 +34,7 @@ class AsyncJobPartitionRouter(StreamSlicer):
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._job_orchestrator_factory = self.job_orchestrator_factory
- self._job_orchestrator: Optional[AsyncJobOrchestrator] = None
+ self._job_orchestrator: AsyncJobOrchestrator | None = None
self._parameters = parameters
def stream_slices(self) -> Iterable[StreamSlice]:
diff --git a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py
index 8718004bf..47bfc36b2 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py
@@ -5,9 +5,9 @@
import itertools
import logging
from collections import ChainMap
-from collections.abc import Callable
+from collections.abc import Callable, Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Iterable, List, Mapping, Optional
+from typing import Any
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
@@ -31,7 +31,7 @@ def check_for_substream_in_slicers(
if isinstance(slicer, SubstreamPartitionRouter):
log_warning("Parent state handling is not supported for CartesianProductStreamSlicer.")
return
- elif isinstance(slicer, CartesianProductStreamSlicer):
+ if isinstance(slicer, CartesianProductStreamSlicer):
# Recursively check sub-slicers within CartesianProductStreamSlicer
check_for_substream_in_slicers(slicer.stream_slicers, log_warning)
@@ -57,7 +57,7 @@ class CartesianProductStreamSlicer(PartitionRouter):
stream_slicers (List[PartitionRouter]): Underlying stream slicers. The RequestOptions (e.g: Request headers, parameters, etc..) returned by this slicer are the combination of the RequestOptions of its input slicers. If there are conflicts e.g: two slicers define the same header or request param, the conflict is resolved by taking the value from the first slicer, where ordering is determined by the order in which slicers were input to this composite slicer.
"""
- stream_slicers: List[PartitionRouter]
+ stream_slicers: list[PartitionRouter]
parameters: InitVar[Mapping[str, Any]]
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@@ -66,81 +66,73 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return dict(
- ChainMap(
- *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
- s.get_request_params(
- stream_state=stream_state,
- stream_slice=stream_slice,
- next_page_token=next_page_token,
- )
- for s in self.stream_slicers
- ]
- )
+ ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
+ s.get_request_params(
+ stream_state=stream_state,
+ stream_slice=stream_slice,
+ next_page_token=next_page_token,
+ )
+ for s in self.stream_slicers
+ ])
)
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return dict(
- ChainMap(
- *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
- s.get_request_headers(
- stream_state=stream_state,
- stream_slice=stream_slice,
- next_page_token=next_page_token,
- )
- for s in self.stream_slicers
- ]
- )
+ ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
+ s.get_request_headers(
+ stream_state=stream_state,
+ stream_slice=stream_slice,
+ next_page_token=next_page_token,
+ )
+ for s in self.stream_slicers
+ ])
)
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return dict(
- ChainMap(
- *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
- s.get_request_body_data(
- stream_state=stream_state,
- stream_slice=stream_slice,
- next_page_token=next_page_token,
- )
- for s in self.stream_slicers
- ]
- )
+ ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
+ s.get_request_body_data(
+ stream_state=stream_state,
+ stream_slice=stream_slice,
+ next_page_token=next_page_token,
+ )
+ for s in self.stream_slicers
+ ])
)
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return dict(
- ChainMap(
- *[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
- s.get_request_body_json(
- stream_state=stream_state,
- stream_slice=stream_slice,
- next_page_token=next_page_token,
- )
- for s in self.stream_slicers
- ]
- )
+ ChainMap(*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
+ s.get_request_body_json(
+ stream_state=stream_state,
+ stream_slice=stream_slice,
+ next_page_token=next_page_token,
+ )
+ for s in self.stream_slicers
+ ])
)
def stream_slices(self) -> Iterable[StreamSlice]:
@@ -153,7 +145,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
raise ValueError(
f"There should only be a single cursor slice. Found {cursor_slices}"
)
- if cursor_slices:
+ if cursor_slices: # noqa: SIM108
cursor_slice = cursor_slices[0]
else:
cursor_slice = {}
@@ -165,7 +157,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""
pass
- def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
+ def get_stream_state(self) -> Mapping[str, StreamState] | None:
"""
Parent stream states are not supported for cartesian product stream slicer
"""
diff --git a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py
index 29b700b04..ed5c127e1 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Iterable, List, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
@@ -27,11 +28,11 @@ class ListPartitionRouter(PartitionRouter):
request_option (Optional[RequestOption]): The request option to configure the HTTP request
"""
- values: Union[str, List[str]]
- cursor_field: Union[InterpolatedString, str]
+ values: str | list[str]
+ cursor_field: InterpolatedString | str
config: Config
parameters: InitVar[Mapping[str, Any]]
- request_option: Optional[RequestOption] = None
+ request_option: RequestOption | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if isinstance(self.values, str):
@@ -48,36 +49,36 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def get_request_params(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.request_parameter, stream_slice)
def get_request_headers(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.header, stream_slice)
def get_request_body_data(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.body_data, stream_slice)
def get_request_body_json(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.body_json, stream_slice)
@@ -91,7 +92,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
]
def _get_request_option(
- self, request_option_type: RequestOptionType, stream_slice: Optional[StreamSlice]
+ self, request_option_type: RequestOptionType, stream_slice: StreamSlice | None
) -> Mapping[str, Any]:
if (
self.request_option
@@ -101,10 +102,8 @@ def _get_request_option(
slice_value = stream_slice.get(self._cursor_field.eval(self.config))
if slice_value:
return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString
- else:
- return {}
- else:
return {}
+ return {}
def set_initial_state(self, stream_state: StreamState) -> None:
"""
@@ -112,7 +111,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""
pass
- def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
+ def get_stream_state(self) -> Mapping[str, StreamState] | None:
"""
ListPartitionRouter doesn't have parent streams
"""
diff --git a/airbyte_cdk/sources/declarative/partition_routers/partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/partition_router.py
index 3a9bc3abf..d0a1c062a 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/partition_router.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/partition_router.py
@@ -3,8 +3,8 @@
#
from abc import abstractmethod
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Mapping, Optional
from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer
from airbyte_cdk.sources.types import StreamState
@@ -41,7 +41,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""
@abstractmethod
- def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
+ def get_stream_state(self) -> Mapping[str, StreamState] | None:
"""
Get the state of the parent streams.
diff --git a/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py
index 32e6a353d..b6f5a3b7f 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Iterable, Mapping, Optional
+from typing import Any
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import StreamSlice, StreamState
@@ -17,33 +18,33 @@ class SinglePartitionRouter(PartitionRouter):
def get_request_params(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_headers(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_body_data(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_body_json(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
@@ -56,7 +57,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
"""
pass
- def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
+ def get_stream_state(self) -> Mapping[str, StreamState] | None:
"""
SinglePartitionRouter doesn't have parent streams
"""
diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py
index 1c7bb6961..ca0c5e61e 100644
--- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py
+++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py
@@ -3,8 +3,9 @@
#
import copy
import logging
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union
+from typing import TYPE_CHECKING, Any
import dpath
@@ -19,6 +20,7 @@
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException
+
if TYPE_CHECKING:
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
@@ -37,14 +39,14 @@ class ParentStreamConfig:
"""
stream: "DeclarativeStream" # Parent streams must be DeclarativeStream because we can't know which part of the stream slice is a partition for regular Stream
- parent_key: Union[InterpolatedString, str]
- partition_field: Union[InterpolatedString, str]
+ parent_key: InterpolatedString | str
+ partition_field: InterpolatedString | str
config: Config
parameters: InitVar[Mapping[str, Any]]
- extra_fields: Optional[Union[List[List[str]], List[List[InterpolatedString]]]] = (
+ extra_fields: list[list[str]] | list[list[InterpolatedString]] | None = (
None # List of field paths (arrays of strings)
)
- request_option: Optional[RequestOption] = None
+ request_option: RequestOption | None = None
incremental_dependency: bool = False
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@@ -70,7 +72,7 @@ class SubstreamPartitionRouter(PartitionRouter):
parent_stream_configs (List[ParentStreamConfig]): parent streams to iterate over and their config
"""
- parent_stream_configs: List[ParentStreamConfig]
+ parent_stream_configs: list[ParentStreamConfig]
config: Config
parameters: InitVar[Mapping[str, Any]]
@@ -81,42 +83,42 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def get_request_params(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.request_parameter, stream_slice)
def get_request_headers(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.header, stream_slice)
def get_request_body_data(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.body_data, stream_slice)
def get_request_body_json(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
# Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response
return self._get_request_option(RequestOptionType.body_json, stream_slice)
def _get_request_option(
- self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]
+ self, option_type: RequestOptionType, stream_slice: StreamSlice | None
) -> Mapping[str, Any]:
params = {}
if stream_slice:
@@ -128,13 +130,11 @@ def _get_request_option(
key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string
value = stream_slice.get(key)
if value:
- params.update(
- {
- parent_config.request_option.field_name.eval( # type: ignore [union-attr]
- config=self.config
- ): value
- }
- )
+ params.update({
+ parent_config.request_option.field_name.eval( # type: ignore [union-attr]
+ config=self.config
+ ): value
+ })
return params
def stream_slices(self) -> Iterable[StreamSlice]:
@@ -176,7 +176,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
f"Parent stream {parent_stream.name} returns records of type AirbyteMessage. This SubstreamPartitionRouter is not able to checkpoint incremental parent state."
)
if parent_record.type == MessageType.RECORD:
- parent_record = parent_record.record.data # type: ignore[union-attr, assignment] # record is always a Record
+ parent_record = parent_record.record.data # type: ignore[union-attr, assignment] # record is always a Record # noqa: PLW2901
else:
continue
elif isinstance(parent_record, Record):
@@ -185,7 +185,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
if parent_record.associated_slice
else {}
)
- parent_record = parent_record.data
+ parent_record = parent_record.data # noqa: PLW2901
elif not isinstance(parent_record, Mapping):
# The parent_record should only take the form of a Record, AirbyteMessage, or Mapping. Anything else is invalid
raise AirbyteTracedException(
@@ -214,7 +214,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
def _extract_extra_fields(
self,
parent_record: Mapping[str, Any] | AirbyteMessage,
- extra_fields: Optional[List[List[str]]] = None,
+ extra_fields: list[list[str]] | None = None,
) -> Mapping[str, Any]:
"""
Extracts additional fields specified by their paths from the parent record.
@@ -289,7 +289,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
# If `parent_state` doesn't exist and at least one parent stream has an incremental dependency,
# copy the child state to parent streams with incremental dependencies.
incremental_dependency = any(
- [parent_config.incremental_dependency for parent_config in self.parent_stream_configs]
+ [parent_config.incremental_dependency for parent_config in self.parent_stream_configs] # noqa: C419
)
if not parent_state and not incremental_dependency:
return
@@ -313,7 +313,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
if parent_config.incremental_dependency:
parent_config.stream.state = parent_state.get(parent_config.stream.name, {})
- def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
+ def get_stream_state(self) -> Mapping[str, StreamState] | None:
"""
Get the state of the parent streams.
diff --git a/airbyte_cdk/sources/declarative/requesters/__init__.py b/airbyte_cdk/sources/declarative/requesters/__init__.py
index e5266ea7c..ae6241428 100644
--- a/airbyte_cdk/sources/declarative/requesters/__init__.py
+++ b/airbyte_cdk/sources/declarative/requesters/__init__.py
@@ -6,4 +6,5 @@
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption
from airbyte_cdk.sources.declarative.requesters.requester import Requester
+
__all__ = ["HttpRequester", "RequestOption", "Requester"]
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py
index 099aa4286..6cbeaec1a 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/__init__.py
@@ -16,6 +16,7 @@
HttpResponseFilter,
)
+
__all__ = [
"BackoffStrategy",
"CompositeErrorHandler",
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py
index 26ecafbde..ce3cc015d 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/__init__.py
@@ -15,6 +15,7 @@
WaitUntilTimeFromHeaderBackoffStrategy,
)
+
__all__ = [
"ConstantBackoffStrategy",
"ExponentialBackoffStrategy",
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py
index 26c7c7673..5c12ada8a 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/constant_backoff_strategy.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import requests
@@ -21,7 +22,7 @@ class ConstantBackoffStrategy(BackoffStrategy):
backoff_time_in_seconds (float): time to backoff before retrying a retryable request.
"""
- backoff_time_in_seconds: Union[float, InterpolatedString, str]
+ backoff_time_in_seconds: float | InterpolatedString | str
parameters: InitVar[Mapping[str, Any]]
config: Config
@@ -39,7 +40,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
- attempt_count: int,
- ) -> Optional[float]:
+ response_or_exception: requests.Response | requests.RequestException | None, # noqa: ARG002
+ attempt_count: int, # noqa: ARG002
+ ) -> float | None:
return self.backoff_time_in_seconds.eval(self.config) # type: ignore # backoff_time_in_seconds is always cast to an interpolated string
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py
index cdd1fe650..475c23782 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/exponential_backoff_strategy.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import requests
@@ -23,7 +24,7 @@ class ExponentialBackoffStrategy(BackoffStrategy):
parameters: InitVar[Mapping[str, Any]]
config: Config
- factor: Union[float, InterpolatedString, str] = 5
+ factor: float | InterpolatedString | str = 5
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if not isinstance(self.factor, InterpolatedString):
@@ -39,7 +40,7 @@ def _retry_factor(self) -> float:
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
+ response_or_exception: requests.Response | requests.RequestException | None, # noqa: ARG002
attempt_count: int,
- ) -> Optional[float]:
+ ) -> float | None:
return self._retry_factor * 2**attempt_count # type: ignore # factor is always cast to an interpolated string
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py
index 60103f343..2a2214793 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/header_helper.py
@@ -4,14 +4,13 @@
import numbers
from re import Pattern
-from typing import Optional
import requests
def get_numeric_value_from_header(
- response: requests.Response, header: str, regex: Optional[Pattern[str]]
-) -> Optional[float]:
+ response: requests.Response, header: str, regex: Pattern[str] | None
+) -> float | None:
"""
Extract a header value from the response as a float
:param response: response the extract header value from
@@ -28,13 +27,12 @@ def get_numeric_value_from_header(
if match:
header_value = match.group()
return _as_float(header_value)
- elif isinstance(header_value, numbers.Number):
+ if isinstance(header_value, numbers.Number):
return float(header_value) # type: ignore[arg-type]
- else:
- return None
+ return None
-def _as_float(s: str) -> Optional[float]:
+def _as_float(s: str) -> float | None:
try:
return float(s)
except ValueError:
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py
index 5cda96a4d..7f42802c9 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_time_from_header_backoff_strategy.py
@@ -3,8 +3,9 @@
#
import re
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import requests
@@ -31,11 +32,11 @@ class WaitTimeFromHeaderBackoffStrategy(BackoffStrategy):
max_waiting_time_in_seconds: (Optional[float]): given the value extracted from the header is greater than this value, stop the stream
"""
- header: Union[InterpolatedString, str]
+ header: InterpolatedString | str
parameters: InitVar[Mapping[str, Any]]
config: Config
- regex: Optional[Union[InterpolatedString, str]] = None
- max_waiting_time_in_seconds: Optional[float] = None
+ regex: InterpolatedString | str | None = None
+ max_waiting_time_in_seconds: float | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.regex = (
@@ -45,9 +46,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
- attempt_count: int,
- ) -> Optional[float]:
+ response_or_exception: requests.Response | requests.RequestException | None,
+ attempt_count: int, # noqa: ARG002
+ ) -> float | None:
header = self.header.eval(config=self.config) # type: ignore # header is always cast to an interpolated stream
if self.regex:
evaled_regex = self.regex.eval(self.config) # type: ignore # header is always cast to an interpolated string
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py
index 1220e198f..2bbad2722 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/backoff_strategies/wait_until_time_from_header_backoff_strategy.py
@@ -5,8 +5,9 @@
import numbers
import re
import time
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import requests
@@ -32,11 +33,11 @@ class WaitUntilTimeFromHeaderBackoffStrategy(BackoffStrategy):
regex (Optional[str]): optional regex to apply on the header to extract its value
"""
- header: Union[InterpolatedString, str]
+ header: InterpolatedString | str
parameters: InitVar[Mapping[str, Any]]
config: Config
- min_wait: Optional[Union[float, InterpolatedString, str]] = None
- regex: Optional[Union[InterpolatedString, str]] = None
+ min_wait: float | InterpolatedString | str | None = None
+ regex: InterpolatedString | str | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.header = InterpolatedString.create(self.header, parameters=parameters)
@@ -48,9 +49,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
- attempt_count: int,
- ) -> Optional[float]:
+ response_or_exception: requests.Response | requests.RequestException | None,
+ attempt_count: int, # noqa: ARG002
+ ) -> float | None:
now = time.time()
header = self.header.eval(self.config) # type: ignore # header is always cast to an interpolated string
if self.regex:
@@ -72,6 +73,6 @@ def backoff_time(
return float(min_wait)
if min_wait:
return float(max(wait_time, min_wait))
- elif wait_time < 0:
+ if wait_time < 0:
return None
return wait_time
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py
index fc4219134..3737b3dd0 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, List, Mapping, Optional, Union
+from typing import Any
import requests
@@ -40,7 +41,7 @@ class CompositeErrorHandler(ErrorHandler):
error_handlers (List[ErrorHandler]): list of error handlers
"""
- error_handlers: List[ErrorHandler]
+ error_handlers: list[ErrorHandler]
parameters: InitVar[Mapping[str, Any]]
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@@ -48,15 +49,15 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
raise ValueError("CompositeErrorHandler expects at least 1 underlying error handler")
@property
- def max_retries(self) -> Optional[int]:
+ def max_retries(self) -> int | None:
return self.error_handlers[0].max_retries
@property
- def max_time(self) -> Optional[int]:
- return max([error_handler.max_time or 0 for error_handler in self.error_handlers])
+ def max_time(self) -> int | None:
+ return max([error_handler.max_time or 0 for error_handler in self.error_handlers]) # noqa: C419
def interpret_response(
- self, response_or_exception: Optional[Union[requests.Response, Exception]]
+ self, response_or_exception: requests.Response | Exception | None
) -> ErrorResolution:
matched_error_resolution = None
for error_handler in self.error_handlers:
@@ -69,7 +70,7 @@ def interpret_response(
return matched_error_resolution
if (
- matched_error_resolution.response_action == ResponseAction.RETRY
+ matched_error_resolution.response_action == ResponseAction.RETRY # noqa: PLR1714
or matched_error_resolution.response_action == ResponseAction.IGNORE
):
return matched_error_resolution
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py
index b70ceaaeb..3ebe3dc95 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_error_handler.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, List, Mapping, MutableMapping, Optional, Union
+from typing import Any
import requests
@@ -95,12 +96,12 @@ class DefaultErrorHandler(ErrorHandler):
parameters: InitVar[Mapping[str, Any]]
config: Config
- response_filters: Optional[List[HttpResponseFilter]] = None
- max_retries: Optional[int] = 5
+ response_filters: list[HttpResponseFilter] | None = None
+ max_retries: int | None = 5
max_time: int = 60 * 10
_max_retries: int = field(init=False, repr=False, default=5)
_max_time: int = field(init=False, repr=False, default=60 * 10)
- backoff_strategies: Optional[List[BackoffStrategy]] = None
+ backoff_strategies: list[BackoffStrategy] | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if not self.response_filters:
@@ -109,7 +110,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._last_request_to_attempt_count: MutableMapping[requests.PreparedRequest, int] = {}
def interpret_response(
- self, response_or_exception: Optional[Union[requests.Response, Exception]]
+ self, response_or_exception: requests.Response | Exception | None
) -> ErrorResolution:
if self.response_filters:
for response_filter in self.response_filters:
@@ -118,7 +119,7 @@ def interpret_response(
)
if matched_error_resolution:
return matched_error_resolution
- if isinstance(response_or_exception, requests.Response):
+ if isinstance(response_or_exception, requests.Response): # noqa: SIM102
if response_or_exception.ok:
return SUCCESS_RESOLUTION
@@ -126,16 +127,16 @@ def interpret_response(
default_response_filter_resolution = default_reponse_filter.matches(response_or_exception)
return (
- default_response_filter_resolution
+ default_response_filter_resolution # noqa: FURB110
if default_response_filter_resolution
else create_fallback_error_resolution(response_or_exception)
)
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
+ response_or_exception: requests.Response | requests.RequestException | None,
attempt_count: int = 0,
- ) -> Optional[float]:
+ ) -> float | None:
backoff = None
if self.backoff_strategies:
for backoff_strategy in self.backoff_strategies:
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py
index 9943a0d6a..fbde89f9e 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/default_http_response_filter.py
@@ -2,7 +2,6 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
-from typing import Optional, Union
import requests
@@ -20,12 +19,12 @@
class DefaultHttpResponseFilter(HttpResponseFilter):
def matches(
- self, response_or_exception: Optional[Union[requests.Response, Exception]]
- ) -> Optional[ErrorResolution]:
+ self, response_or_exception: requests.Response | Exception | None
+ ) -> ErrorResolution | None:
default_mapped_error_resolution = None
- if isinstance(response_or_exception, (requests.Response, Exception)):
- mapped_key: Union[int, type] = (
+ if isinstance(response_or_exception, (requests.Response, Exception)): # noqa: UP038
+ mapped_key: int | type = (
response_or_exception.status_code
if isinstance(response_or_exception, requests.Response)
else response_or_exception.__class__
@@ -34,7 +33,7 @@ def matches(
default_mapped_error_resolution = DEFAULT_ERROR_MAPPING.get(mapped_key)
return (
- default_mapped_error_resolution
+ default_mapped_error_resolution # noqa: FURB110
if default_mapped_error_resolution
else create_fallback_error_resolution(response_or_exception)
)
diff --git a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py
index a2fc80007..34a9273ea 100644
--- a/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py
+++ b/airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Set, Union
+from typing import Any
import requests
@@ -40,12 +41,12 @@ class HttpResponseFilter:
config: Config
parameters: InitVar[Mapping[str, Any]]
- action: Optional[Union[ResponseAction, str]] = None
- failure_type: Optional[Union[FailureType, str]] = None
- http_codes: Optional[Set[int]] = None
- error_message_contains: Optional[str] = None
- predicate: Union[InterpolatedBoolean, str] = ""
- error_message: Union[InterpolatedString, str] = ""
+ action: ResponseAction | str | None = None
+ failure_type: FailureType | str | None = None
+ http_codes: set[int] | None = None
+ error_message_contains: str | None = None
+ predicate: InterpolatedBoolean | str = ""
+ error_message: InterpolatedString | str = ""
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.action is not None:
@@ -57,7 +58,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
raise ValueError(
"HttpResponseFilter requires a filter condition if an action is specified"
)
- elif isinstance(self.action, str):
+ if isinstance(self.action, str):
self.action = ResponseAction[self.action]
self.http_codes = self.http_codes or set()
if isinstance(self.predicate, str):
@@ -70,8 +71,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.failure_type = FailureType[self.failure_type]
def matches(
- self, response_or_exception: Optional[Union[requests.Response, Exception]]
- ) -> Optional[ErrorResolution]:
+ self, response_or_exception: requests.Response | Exception | None
+ ) -> ErrorResolution | None:
filter_action = self._matches_filter(response_or_exception)
mapped_key = (
response_or_exception.status_code
@@ -79,7 +80,7 @@ def matches(
else response_or_exception.__class__
)
- if isinstance(mapped_key, (int, Exception)):
+ if isinstance(mapped_key, (int, Exception)): # noqa: UP038
default_mapped_error_resolution = self._match_default_error_mapping(mapped_key)
else:
default_mapped_error_resolution = None
@@ -118,13 +119,13 @@ def matches(
return None
def _match_default_error_mapping(
- self, mapped_key: Union[int, type[Exception]]
- ) -> Optional[ErrorResolution]:
+ self, mapped_key: int | type[Exception]
+ ) -> ErrorResolution | None:
return DEFAULT_ERROR_MAPPING.get(mapped_key)
def _matches_filter(
- self, response_or_exception: Optional[Union[requests.Response, Exception]]
- ) -> Optional[ResponseAction]:
+ self, response_or_exception: requests.Response | Exception | None
+ ) -> ResponseAction | None:
"""
Apply the HTTP filter on the response and return the action to execute if it matches
:param response: The HTTP response to evaluate
@@ -145,7 +146,7 @@ def _safe_response_json(response: requests.Response) -> dict[str, Any]:
except requests.exceptions.JSONDecodeError:
return {}
- def _create_error_message(self, response: requests.Response) -> Optional[str]:
+ def _create_error_message(self, response: requests.Response) -> str | None:
"""
Construct an error message based on the specified message template of the filter.
:param response: The HTTP response which can be used during interpolation
@@ -172,8 +173,5 @@ def _response_matches_predicate(self, response: requests.Response) -> bool:
def _response_contains_error_message(self, response: requests.Response) -> bool:
if not self.error_message_contains:
return False
- else:
- error_message = self._error_message_parser.parse_response_error_message(
- response=response
- )
- return bool(error_message and self.error_message_contains in error_message)
+ error_message = self._error_message_parser.parse_response_error_message(response=response)
+ return bool(error_message and self.error_message_contains in error_message)
diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py
index fce146fd8..cababa2d2 100644
--- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py
+++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py
@@ -1,9 +1,10 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
import logging
import uuid
+from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from datetime import timedelta
-from typing import Any, Dict, Iterable, Mapping, Optional
+from typing import Any
import requests
from requests import Response
@@ -26,6 +27,7 @@
from airbyte_cdk.sources.types import Record, StreamSlice
from airbyte_cdk.utils import AirbyteTracedException
+
LOGGER = logging.getLogger("airbyte")
@@ -38,23 +40,23 @@ class AsyncHttpJobRepository(AsyncJobRepository):
creation_requester: Requester
polling_requester: Requester
download_retriever: SimpleRetriever
- abort_requester: Optional[Requester]
- delete_requester: Optional[Requester]
+ abort_requester: Requester | None
+ delete_requester: Requester | None
status_extractor: DpathExtractor
status_mapping: Mapping[str, AsyncJobStatus]
urls_extractor: DpathExtractor
- job_timeout: Optional[timedelta] = None
+ job_timeout: timedelta | None = None
record_extractor: RecordExtractor = field(
init=False, repr=False, default_factory=lambda: ResponseToFileExtractor({})
)
- url_requester: Optional[Requester] = (
+ url_requester: Requester | None = (
None # use it in case polling_requester provides some and extra request is needed to obtain list of urls to download from
)
def __post_init__(self) -> None:
- self._create_job_response_by_id: Dict[str, Response] = {}
- self._polling_job_response_by_id: Dict[str, Response] = {}
+ self._create_job_response_by_id: dict[str, Response] = {}
+ self._polling_job_response_by_id: dict[str, Response] = {}
def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests.Response:
"""
@@ -70,7 +72,7 @@ def _get_validated_polling_response(self, stream_slice: StreamSlice) -> requests
AirbyteTracedException: If the polling request returns an empty response.
"""
- polling_response: Optional[requests.Response] = self.polling_requester.send_request(
+ polling_response: requests.Response | None = self.polling_requester.send_request(
stream_slice=stream_slice
)
if polling_response is None:
@@ -117,7 +119,7 @@ def _start_job_and_validate_response(self, stream_slice: StreamSlice) -> request
AirbyteTracedException: If no response is received from the creation requester.
"""
- response: Optional[requests.Response] = self.creation_requester.send_request(
+ response: requests.Response | None = self.creation_requester.send_request(
stream_slice=stream_slice
)
if not response:
@@ -168,13 +170,13 @@ def update_jobs_status(self, jobs: Iterable[AsyncJob]) -> None:
lazy_log(
LOGGER,
logging.DEBUG,
- lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}",
+ lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}", # noqa: B023
)
else:
lazy_log(
LOGGER,
logging.DEBUG,
- lambda: f"Status of job {job.api_job_id()} is still {job.status()}",
+ lambda: f"Status of job {job.api_job_id()} is still {job.status()}", # noqa: B023
)
job.update_status(job_status)
@@ -206,7 +208,7 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]:
elif isinstance(message, AirbyteMessage):
if message.type == Type.RECORD:
yield message.record.data # type: ignore # message.record won't be None here as the message is a record
- elif isinstance(message, (dict, Mapping)):
+ elif isinstance(message, (dict, Mapping)): # noqa: UP038
yield message
else:
raise TypeError(f"Unknown type `{type(message)}` for message")
diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py
index 35d4b0f11..49615b4c1 100644
--- a/airbyte_cdk/sources/declarative/requesters/http_requester.py
+++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py
@@ -4,8 +4,9 @@
import logging
import os
+from collections.abc import Callable, Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Callable, Mapping, MutableMapping, Optional, Union
+from typing import Any
from urllib.parse import urljoin
import requests
@@ -47,16 +48,16 @@ class HttpRequester(Requester):
"""
name: str
- url_base: Union[InterpolatedString, str]
- path: Union[InterpolatedString, str]
+ url_base: InterpolatedString | str
+ path: InterpolatedString | str
config: Config
parameters: InitVar[Mapping[str, Any]]
- authenticator: Optional[DeclarativeAuthenticator] = None
- http_method: Union[str, HttpMethod] = HttpMethod.GET
- request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None
- error_handler: Optional[ErrorHandler] = None
+ authenticator: DeclarativeAuthenticator | None = None
+ http_method: str | HttpMethod = HttpMethod.GET
+ request_options_provider: InterpolatedRequestOptionsProvider | None = None
+ error_handler: ErrorHandler | None = None
disable_retries: bool = False
- message_repository: MessageRepository = NoopMessageRepository()
+ message_repository: MessageRepository = NoopMessageRepository() # noqa: RUF009
use_cache: bool = False
_exit_on_rate_limit: bool = False
stream_response: bool = False
@@ -110,14 +111,14 @@ def get_authenticator(self) -> DeclarativeAuthenticator:
return self._authenticator
def get_url_base(self) -> str:
- return os.path.join(self._url_base.eval(self.config), "")
+ return os.path.join(self._url_base.eval(self.config), "") # noqa: PTH118
def get_path(
self,
*,
- stream_state: Optional[StreamState],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
+ stream_state: StreamState | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
) -> str:
kwargs = {
"stream_state": stream_state,
@@ -133,9 +134,9 @@ def get_method(self) -> HttpMethod:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> MutableMapping[str, Any]:
return self._request_options_provider.get_request_params(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
@@ -144,9 +145,9 @@ def get_request_params(
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._request_options_provider.get_request_headers(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
@@ -156,10 +157,10 @@ def get_request_headers(
def get_request_body_data( # type: ignore
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
return (
self._request_options_provider.get_request_body_data(
stream_state=stream_state,
@@ -173,10 +174,10 @@ def get_request_body_data( # type: ignore
def get_request_body_json( # type: ignore
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Optional[Mapping[str, Any]]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | None:
return self._request_options_provider.get_request_body_json(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
@@ -187,36 +188,34 @@ def logger(self) -> logging.Logger:
def _get_request_options(
self,
- stream_state: Optional[StreamState],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
- requester_method: Callable[..., Optional[Union[Mapping[str, Any], str]]],
- auth_options_method: Callable[..., Optional[Union[Mapping[str, Any], str]]],
- extra_options: Optional[Union[Mapping[str, Any], str]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
+ requester_method: Callable[..., Mapping[str, Any] | str | None],
+ auth_options_method: Callable[..., Mapping[str, Any] | str | None],
+ extra_options: Mapping[str, Any] | str | None = None,
+ ) -> Mapping[str, Any] | str:
"""
Get the request_option from the requester, the authenticator and extra_options passed in.
Raise a ValueError if there's a key collision
Returned merged mapping otherwise
"""
- return combine_mappings(
- [
- requester_method(
- stream_state=stream_state,
- stream_slice=stream_slice,
- next_page_token=next_page_token,
- ),
- auth_options_method(),
- extra_options,
- ]
- )
+ return combine_mappings([
+ requester_method(
+ stream_state=stream_state,
+ stream_slice=stream_slice,
+ next_page_token=next_page_token,
+ ),
+ auth_options_method(),
+ extra_options,
+ ])
def _request_headers(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- extra_headers: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ extra_headers: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies request headers.
@@ -231,15 +230,15 @@ def _request_headers(
extra_headers,
)
if isinstance(headers, str):
- raise ValueError("Request headers cannot be a string")
+ raise ValueError("Request headers cannot be a string") # noqa: TRY004
return {str(k): str(v) for k, v in headers.items()}
def _request_params(
self,
- stream_state: Optional[StreamState],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
- extra_params: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
+ extra_params: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.
@@ -255,11 +254,11 @@ def _request_params(
extra_params,
)
if isinstance(options, str):
- raise ValueError("Request params cannot be a string")
+ raise ValueError("Request params cannot be a string") # noqa: TRY004
for k, v in options.items():
- if isinstance(v, (dict,)):
- raise ValueError(
+ if isinstance(v, (dict,)): # noqa: UP038
+ raise ValueError( # noqa: TRY004
f"Invalid value for `{k}` parameter. The values of request params cannot be an object."
)
@@ -267,11 +266,11 @@ def _request_params(
def _request_body_data(
self,
- stream_state: Optional[StreamState],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
- extra_body_data: Optional[Union[Mapping[str, Any], str]] = None,
- ) -> Optional[Union[Mapping[str, Any], str]]:
+ stream_state: StreamState | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
+ extra_body_data: Mapping[str, Any] | str | None = None,
+ ) -> Mapping[str, Any] | str | None:
"""
Specifies how to populate the body of the request with a non-JSON payload.
@@ -293,11 +292,11 @@ def _request_body_data(
def _request_body_json(
self,
- stream_state: Optional[StreamState],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
- extra_body_json: Optional[Mapping[str, Any]] = None,
- ) -> Optional[Mapping[str, Any]]:
+ stream_state: StreamState | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
+ extra_body_json: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | None:
"""
Specifies how to populate the body of the request with a JSON payload.
@@ -313,26 +312,26 @@ def _request_body_json(
extra_body_json,
)
if isinstance(options, str):
- raise ValueError("Request body json cannot be a string")
+ raise ValueError("Request body json cannot be a string") # noqa: TRY004
return options
@classmethod
def _join_url(cls, url_base: str, path: str) -> str:
return urljoin(url_base, path)
- def send_request(
+ def send_request( # noqa: PLR0913, PLR0917
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- path: Optional[str] = None,
- request_headers: Optional[Mapping[str, Any]] = None,
- request_params: Optional[Mapping[str, Any]] = None,
- request_body_data: Optional[Union[Mapping[str, Any], str]] = None,
- request_body_json: Optional[Mapping[str, Any]] = None,
- log_formatter: Optional[Callable[[requests.Response], Any]] = None,
- ) -> Optional[requests.Response]:
- request, response = self._http_client.send_request(
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ path: str | None = None,
+ request_headers: Mapping[str, Any] | None = None,
+ request_params: Mapping[str, Any] | None = None,
+ request_body_data: Mapping[str, Any] | str | None = None,
+ request_body_json: Mapping[str, Any] | None = None,
+ log_formatter: Callable[[requests.Response], Any] | None = None,
+ ) -> requests.Response | None:
+ request, response = self._http_client.send_request( # noqa: F841
http_method=self.get_method().value,
url=self._join_url(
self.get_url_base(),
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py b/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py
index 3b077ec0c..9c5463752 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/__init__.py
@@ -12,6 +12,7 @@
PaginationStrategy,
)
+
__all__ = [
"DefaultPaginator",
"NoPagination",
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py
index 59255c75b..79fadee6a 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, MutableMapping, Optional, Union
+from typing import Any
import requests
@@ -97,13 +98,13 @@ class DefaultPaginator(Paginator):
pagination_strategy: PaginationStrategy
config: Config
- url_base: Union[InterpolatedString, str]
+ url_base: InterpolatedString | str
parameters: InitVar[Mapping[str, Any]]
decoder: Decoder = field(
default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={}))
)
- page_size_option: Optional[RequestOption] = None
- page_token_option: Optional[Union[RequestPath, RequestOption]] = None
+ page_size_option: RequestOption | None = None
+ page_token_option: RequestPath | RequestOption | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.page_size_option and not self.pagination_strategy.get_page_size():
@@ -113,7 +114,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if isinstance(self.url_base, str):
self.url_base = InterpolatedString(string=self.url_base, parameters=parameters)
- def get_initial_token(self) -> Optional[Any]:
+ def get_initial_token(self) -> Any | None: # noqa: ANN401
"""
Return the page token that should be used for the first request of a stream
@@ -126,9 +127,9 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any] = None,
- ) -> Optional[Mapping[str, Any]]:
+ last_record: Record | None,
+ last_page_token_value: Any | None = None, # noqa: ANN401
+ ) -> Mapping[str, Any] | None:
next_page_token = self.pagination_strategy.next_page_token(
response=response,
last_page_size=last_page_size,
@@ -137,55 +138,53 @@ def next_page_token(
)
if next_page_token:
return {"next_page_token": next_page_token}
- else:
- return None
+ return None
- def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
+ def path(self, next_page_token: Mapping[str, Any] | None) -> str | None:
token = next_page_token.get("next_page_token") if next_page_token else None
if token and self.page_token_option and isinstance(self.page_token_option, RequestPath):
# Replace url base to only return the path
return str(token).replace(self.url_base.eval(self.config), "") # type: ignore # url_base is casted to a InterpolatedString in __post_init__
- else:
- return None
+ return None
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None,
) -> MutableMapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter, next_page_token)
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, str]:
return self._get_request_options(RequestOptionType.header, next_page_token)
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_data, next_page_token)
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_json, next_page_token)
def _get_request_options(
- self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]]
+ self, option_type: RequestOptionType, next_page_token: Mapping[str, Any] | None
) -> MutableMapping[str, Any]:
options = {}
@@ -228,7 +227,7 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No
self._decorated = decorated
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
- def get_initial_token(self) -> Optional[Any]:
+ def get_initial_token(self) -> Any | None: # noqa: ANN401
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
return self._decorated.get_initial_token()
@@ -236,9 +235,9 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any] = None,
- ) -> Optional[Mapping[str, Any]]:
+ last_record: Record | None,
+ last_page_token_value: Any | None = None, # noqa: ANN401
+ ) -> Mapping[str, Any] | None:
if self._page_count >= self._maximum_number_of_pages:
return None
@@ -247,15 +246,15 @@ def next_page_token(
response, last_page_size, last_record, last_page_token_value
)
- def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
+ def path(self, next_page_token: Mapping[str, Any] | None) -> str | None:
return self._decorated.path(next_page_token)
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._decorated.get_request_params(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
@@ -264,9 +263,9 @@ def get_request_params(
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, str]:
return self._decorated.get_request_headers(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
@@ -275,10 +274,10 @@ def get_request_headers(
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
return self._decorated.get_request_body_data(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
@@ -286,9 +285,9 @@ def get_request_body_data(
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._decorated.get_request_body_json(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py
index 7de91f5e9..35973b3ec 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/no_pagination.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping, MutableMapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, MutableMapping, Optional, Union
+from typing import Any
import requests
@@ -19,53 +20,53 @@ class NoPagination(Paginator):
parameters: InitVar[Mapping[str, Any]]
- def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
+ def path(self, next_page_token: Mapping[str, Any] | None) -> str | None: # noqa: ARG002
return None
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> MutableMapping[str, Any]:
return {}
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, str]:
return {}
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Mapping[str, Any] | str:
return {}
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
- def get_initial_token(self) -> Optional[Any]:
+ def get_initial_token(self) -> Any | None: # noqa: ANN401
return None
def next_page_token(
self,
- response: requests.Response,
- last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any],
- ) -> Optional[Mapping[str, Any]]:
+ response: requests.Response, # noqa: ARG002
+ last_page_size: int, # noqa: ARG002
+ last_record: Record | None, # noqa: ARG002
+ last_page_token_value: Any | None, # noqa: ANN401, ARG002
+ ) -> Mapping[str, Any] | None:
return {}
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py
index 8b1fea69b..80b42d41d 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/paginator.py
@@ -3,8 +3,9 @@
#
from abc import ABC, abstractmethod
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Any, Mapping, Optional
+from typing import Any
import requests
@@ -24,7 +25,7 @@ class Paginator(ABC, RequestOptionsProvider):
"""
@abstractmethod
- def get_initial_token(self) -> Optional[Any]:
+ def get_initial_token(self) -> Any | None: # noqa: ANN401
"""
Get the page token that should be included in the request to get the first page of records
"""
@@ -34,9 +35,9 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any],
- ) -> Optional[Mapping[str, Any]]:
+ last_record: Record | None,
+ last_page_token_value: Any | None, # noqa: ANN401
+ ) -> Mapping[str, Any] | None:
"""
Returns the next_page_token to use to fetch the next page of records.
@@ -49,7 +50,7 @@ def next_page_token(
pass
@abstractmethod
- def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
+ def path(self, next_page_token: Mapping[str, Any] | None) -> str | None:
"""
Returns the URL path to hit to fetch the next page of records
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py
index c1f9ff105..ae2663571 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/__init__.py
@@ -16,6 +16,7 @@
StopConditionPaginationStrategyDecorator,
)
+
__all__ = [
"CursorPaginationStrategy",
"CursorStopCondition",
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py
index e35c84c7c..34fd6b9d1 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/cursor_pagination_strategy.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Dict, Mapping, Optional, Union
+from typing import Any
import requests
@@ -33,11 +34,11 @@ class CursorPaginationStrategy(PaginationStrategy):
decoder (Decoder): decoder to decode the response
"""
- cursor_value: Union[InterpolatedString, str]
+ cursor_value: InterpolatedString | str
config: Config
parameters: InitVar[Mapping[str, Any]]
- page_size: Optional[int] = None
- stop_condition: Optional[Union[InterpolatedBoolean, str]] = None
+ page_size: int | None = None
+ stop_condition: InterpolatedBoolean | str | None = None
decoder: Decoder = field(
default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={}))
)
@@ -48,14 +49,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else:
self._cursor_value = self.cursor_value
if isinstance(self.stop_condition, str):
- self._stop_condition: Optional[InterpolatedBoolean] = InterpolatedBoolean(
+ self._stop_condition: InterpolatedBoolean | None = InterpolatedBoolean(
condition=self.stop_condition, parameters=parameters
)
else:
self._stop_condition = self.stop_condition
@property
- def initial_token(self) -> Optional[Any]:
+ def initial_token(self) -> Any | None: # noqa: ANN401
"""
CursorPaginationStrategy does not have an initial value because the next cursor is typically included
in the response of the first request. For Resumable Full Refresh streams that checkpoint the page
@@ -67,14 +68,14 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any] = None,
- ) -> Optional[Any]:
+ last_record: Record | None,
+ last_page_token_value: Any | None = None, # noqa: ANN401, ARG002
+ ) -> Any | None: # noqa: ANN401
decoded_response = next(self.decoder.decode(response))
# The default way that link is presented in requests.Response is a string of various links (last, next, etc). This
# is not indexable or useful for parsing the cursor, so we replace it with the link dictionary from response.links
- headers: Dict[str, Any] = dict(response.headers)
+ headers: dict[str, Any] = dict(response.headers)
headers["link"] = response.links
if self._stop_condition:
should_stop = self._stop_condition.eval(
@@ -93,7 +94,7 @@ def next_page_token(
last_record=last_record,
last_page_size=last_page_size,
)
- return token if token else None
+ return token if token else None # noqa: FURB110
- def get_page_size(self) -> Optional[int]:
+ def get_page_size(self) -> int | None:
return self.page_size
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py
index 512d8143c..b61755257 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/offset_increment.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import requests
@@ -44,7 +45,7 @@ class OffsetIncrement(PaginationStrategy):
"""
config: Config
- page_size: Optional[Union[str, int]]
+ page_size: str | int | None
parameters: InitVar[Mapping[str, Any]]
decoder: Decoder = field(
default_factory=lambda: PaginationDecoderDecorator(decoder=JsonDecoder(parameters={}))
@@ -54,14 +55,14 @@ class OffsetIncrement(PaginationStrategy):
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
page_size = str(self.page_size) if isinstance(self.page_size, int) else self.page_size
if page_size:
- self._page_size: Optional[InterpolatedString] = InterpolatedString(
+ self._page_size: InterpolatedString | None = InterpolatedString(
page_size, parameters=parameters
)
else:
self._page_size = None
@property
- def initial_token(self) -> Optional[Any]:
+ def initial_token(self) -> Any | None: # noqa: ANN401
if self.inject_on_first_request:
return 0
return None
@@ -70,9 +71,9 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any] = None,
- ) -> Optional[Any]:
+ last_record: Record | None, # noqa: ARG002
+ last_page_token_value: Any | None = None, # noqa: ANN401
+ ) -> Any | None: # noqa: ANN401
decoded_response = next(self.decoder.decode(response))
# Stop paginating when there are fewer records than the page size or the current page has no records
@@ -81,22 +82,20 @@ def next_page_token(
and last_page_size < self._page_size.eval(self.config, response=decoded_response)
) or last_page_size == 0:
return None
- elif last_page_token_value is None:
+ if last_page_token_value is None:
# If the OffsetIncrement strategy does not inject on the first request, the incoming last_page_token_value
# will be None. For this case, we assume that None was the first page and progress to the next offset
return 0 + last_page_size
- elif not isinstance(last_page_token_value, int):
- raise ValueError(
+ if not isinstance(last_page_token_value, int):
+ raise ValueError( # noqa: TRY004
f"Last page token value {last_page_token_value} for OffsetIncrement pagination strategy was not an integer"
)
- else:
- return last_page_token_value + last_page_size
+ return last_page_token_value + last_page_size
- def get_page_size(self) -> Optional[int]:
+ def get_page_size(self) -> int | None:
if self._page_size:
page_size = self._page_size.eval(self.config)
if not isinstance(page_size, int):
- raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}")
+ raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") # noqa: TRY002
return page_size
- else:
- return None
+ return None
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py
index 2e1643b56..8e2762f42 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/page_increment.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
import requests
@@ -25,7 +26,7 @@ class PageIncrement(PaginationStrategy):
"""
config: Config
- page_size: Optional[Union[str, int]]
+ page_size: str | int | None
parameters: InitVar[Mapping[str, Any]]
start_from_page: int = 0
inject_on_first_request: bool = False
@@ -36,36 +37,35 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else:
page_size = InterpolatedString(self.page_size, parameters=parameters).eval(self.config)
if not isinstance(page_size, int):
- raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}")
+ raise Exception(f"{page_size} is of type {type(page_size)}. Expected {int}") # noqa: TRY002
self._page_size = page_size
@property
- def initial_token(self) -> Optional[Any]:
+ def initial_token(self) -> Any | None: # noqa: ANN401
if self.inject_on_first_request:
return self.start_from_page
return None
def next_page_token(
self,
- response: requests.Response,
+ response: requests.Response, # noqa: ARG002
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any],
- ) -> Optional[Any]:
+ last_record: Record | None, # noqa: ARG002
+ last_page_token_value: Any | None, # noqa: ANN401
+ ) -> Any | None: # noqa: ANN401
# Stop paginating when there are fewer records than the page size or the current page has no records
if (self._page_size and last_page_size < self._page_size) or last_page_size == 0:
return None
- elif last_page_token_value is None:
+ if last_page_token_value is None:
# If the PageIncrement strategy does not inject on the first request, the incoming last_page_token_value
# may be None. When this is the case, we assume we've already requested the first page specified by
# start_from_page and must now get the next page
return self.start_from_page + 1
- elif not isinstance(last_page_token_value, int):
- raise ValueError(
+ if not isinstance(last_page_token_value, int):
+ raise ValueError( # noqa: TRY004
f"Last page token value {last_page_token_value} for PageIncrement pagination strategy was not an integer"
)
- else:
- return last_page_token_value + 1
+ return last_page_token_value + 1
- def get_page_size(self) -> Optional[int]:
+ def get_page_size(self) -> int | None:
return self._page_size
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py
index dae02ba13..04945ea28 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py
@@ -4,7 +4,7 @@
from abc import abstractmethod
from dataclasses import dataclass
-from typing import Any, Optional
+from typing import Any
import requests
@@ -19,7 +19,7 @@ class PaginationStrategy:
@property
@abstractmethod
- def initial_token(self) -> Optional[Any]:
+ def initial_token(self) -> Any | None: # noqa: ANN401
"""
Return the initial value of the token
"""
@@ -29,9 +29,9 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any],
- ) -> Optional[Any]:
+ last_record: Record | None,
+ last_page_token_value: Any | None, # noqa: ANN401
+ ) -> Any | None: # noqa: ANN401
"""
:param response: response to process
:param last_page_size: the number of records read from the response
@@ -42,7 +42,7 @@ def next_page_token(
pass
@abstractmethod
- def get_page_size(self) -> Optional[int]:
+ def get_page_size(self) -> int | None:
"""
:return: page size: The number of records to fetch in a page. Returns None if unspecified
"""
diff --git a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py
index 7c89ba552..65a241ab4 100644
--- a/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py
+++ b/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py
@@ -3,7 +3,7 @@
#
from abc import ABC, abstractmethod
-from typing import Any, Optional
+from typing import Any
import requests
@@ -23,11 +23,11 @@ def is_met(self, record: Record) -> bool:
:param record: a record used to evaluate the condition
"""
- raise NotImplementedError()
+ raise NotImplementedError
class CursorStopCondition(PaginationStopCondition):
- def __init__(
+ def __init__( # noqa: ANN204
self,
cursor: DeclarativeCursor
| ConcurrentCursor, # migrate to use both old and concurrent versions
@@ -39,7 +39,7 @@ def is_met(self, record: Record) -> bool:
class StopConditionPaginationStrategyDecorator(PaginationStrategy):
- def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStopCondition):
+ def __init__(self, _delegate: PaginationStrategy, stop_condition: PaginationStopCondition): # noqa: ANN204
self._delegate = _delegate
self._stop_condition = stop_condition
@@ -47,9 +47,9 @@ def next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any] = None,
- ) -> Optional[Any]:
+ last_record: Record | None,
+ last_page_token_value: Any | None = None, # noqa: ANN401
+ ) -> Any | None: # noqa: ANN401
# We evaluate in reverse order because the assumption is that most of the APIs using data feed structure
# will return records in descending order. In terms of performance/memory, we return the records lazily
if last_record and self._stop_condition.is_met(last_record):
@@ -58,9 +58,9 @@ def next_page_token(
response, last_page_size, last_record, last_page_token_value
)
- def get_page_size(self) -> Optional[int]:
+ def get_page_size(self) -> int | None:
return self._delegate.get_page_size()
@property
- def initial_token(self) -> Optional[Any]:
+ def initial_token(self) -> Any | None: # noqa: ANN401
return self._delegate.initial_token
diff --git a/airbyte_cdk/sources/declarative/requesters/request_option.py b/airbyte_cdk/sources/declarative/requesters/request_option.py
index d13d20566..9f8699a9e 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_option.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_option.py
@@ -2,9 +2,10 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
from enum import Enum
-from typing import Any, Mapping, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
@@ -30,7 +31,7 @@ class RequestOption:
inject_into (RequestOptionType): Describes where in the HTTP request to inject the parameter
"""
- field_name: Union[InterpolatedString, str]
+ field_name: InterpolatedString | str
inject_into: RequestOptionType
parameters: InitVar[Mapping[str, Any]]
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py b/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py
index a63705832..b77f51ee5 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/__init__.py
@@ -15,6 +15,7 @@
RequestOptionsProvider,
)
+
__all__ = [
"DatetimeBasedRequestOptionsProvider",
"DefaultRequestOptionsProvider",
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py
index 05e06db71..846b5af16 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/datetime_based_request_options_provider.py
@@ -2,8 +2,9 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping, MutableMapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, MutableMapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.requesters.request_option import (
@@ -25,10 +26,10 @@ class DatetimeBasedRequestOptionsProvider(RequestOptionsProvider):
config: Config
parameters: InitVar[Mapping[str, Any]]
- start_time_option: Optional[RequestOption] = None
- end_time_option: Optional[RequestOption] = None
- partition_field_start: Optional[str] = None
- partition_field_end: Optional[str] = None
+ start_time_option: RequestOption | None = None
+ end_time_option: RequestOption | None = None
+ partition_field_start: str | None = None
+ partition_field_end: str | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._partition_field_start = InterpolatedString.create(
@@ -41,41 +42,41 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.request_parameter, stream_slice)
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.header, stream_slice)
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Mapping[str, Any] | str:
return self._get_request_options(RequestOptionType.body_data, stream_slice)
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return self._get_request_options(RequestOptionType.body_json, stream_slice)
def _get_request_options(
- self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]
+ self, option_type: RequestOptionType, stream_slice: StreamSlice | None
) -> Mapping[str, Any]:
options: MutableMapping[str, Any] = {}
if not stream_slice:
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py
index 449da977f..f2d04ac0b 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/default_request_options_provider.py
@@ -2,8 +2,9 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import (
RequestOptionsProvider,
@@ -26,35 +27,35 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Mapping[str, Any] | str:
return {}
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
return {}
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py
index 6403417c9..16cfa0fc8 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_nested_mapping import (
InterpolatedNestedMapping,
@@ -20,14 +21,12 @@ class InterpolatedNestedRequestInputProvider:
"""
parameters: InitVar[Mapping[str, Any]]
- request_inputs: Optional[Union[str, NestedMapping]] = field(default=None)
+ request_inputs: str | NestedMapping | None = field(default=None)
config: Config = field(default_factory=dict)
- _interpolator: Optional[Union[InterpolatedString, InterpolatedNestedMapping]] = field(
- init=False, repr=False, default=None
- )
- _request_inputs: Optional[Union[str, NestedMapping]] = field(
+ _interpolator: InterpolatedString | InterpolatedNestedMapping | None = field(
init=False, repr=False, default=None
)
+ _request_inputs: str | NestedMapping | None = field(init=False, repr=False, default=None)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._request_inputs = self.request_inputs or {}
@@ -42,9 +41,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def eval_request_inputs(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Returns the request inputs to set on an outgoing HTTP request
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py
index 0278df351..aabdbbaba 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, Optional, Tuple, Type, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
@@ -17,14 +18,12 @@ class InterpolatedRequestInputProvider:
"""
parameters: InitVar[Mapping[str, Any]]
- request_inputs: Optional[Union[str, Mapping[str, str]]] = field(default=None)
+ request_inputs: str | Mapping[str, str] | None = field(default=None)
config: Config = field(default_factory=dict)
- _interpolator: Optional[Union[InterpolatedString, InterpolatedMapping]] = field(
- init=False, repr=False, default=None
- )
- _request_inputs: Optional[Union[str, Mapping[str, str]]] = field(
+ _interpolator: InterpolatedString | InterpolatedMapping | None = field(
init=False, repr=False, default=None
)
+ _request_inputs: str | Mapping[str, str] | None = field(init=False, repr=False, default=None)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._request_inputs = self.request_inputs or {}
@@ -37,11 +36,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def eval_request_inputs(
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- valid_key_types: Optional[Tuple[Type[Any]]] = None,
- valid_value_types: Optional[Tuple[Type[Any], ...]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ valid_key_types: tuple[type[Any]] | None = None,
+ valid_value_types: tuple[type[Any], ...] | None = None,
) -> Mapping[str, Any]:
"""
Returns the request inputs to set on an outgoing HTTP request
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py
index c327b83da..a284a542e 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_options_provider.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping, MutableMapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, MutableMapping, Optional, Union
+from typing import Any, Union
from typing_extensions import deprecated
@@ -20,7 +21,8 @@
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
-RequestInput = Union[str, Mapping[str, str]]
+
+RequestInput = Union[str, Mapping[str, str]] # noqa: UP007
ValidRequestTypes = (str, list)
@@ -39,10 +41,10 @@ class InterpolatedRequestOptionsProvider(RequestOptionsProvider):
parameters: InitVar[Mapping[str, Any]]
config: Config = field(default_factory=dict)
- request_parameters: Optional[RequestInput] = None
- request_headers: Optional[RequestInput] = None
- request_body_data: Optional[RequestInput] = None
- request_body_json: Optional[NestedMapping] = None
+ request_parameters: RequestInput | None = None
+ request_headers: RequestInput | None = None
+ request_body_data: RequestInput | None = None
+ request_body_json: NestedMapping | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.request_parameters is None:
@@ -75,9 +77,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> MutableMapping[str, Any]:
interpolated_value = self._parameter_interpolator.eval_request_inputs(
stream_state,
@@ -93,9 +95,9 @@ def get_request_params(
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._headers_interpolator.eval_request_inputs(
stream_state, stream_slice, next_page_token
@@ -104,10 +106,10 @@ def get_request_headers(
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
return self._body_data_interpolator.eval_request_inputs(
stream_state,
stream_slice,
@@ -119,9 +121,9 @@ def get_request_body_data(
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
return self._body_json_interpolator.eval_request_inputs(
stream_state, stream_slice, next_page_token
@@ -147,18 +149,17 @@ def request_options_contain_stream_state(self) -> bool:
@staticmethod
def _check_if_interpolation_uses_stream_state(
- request_input: Optional[Union[RequestInput, NestedMapping]],
+ request_input: RequestInput | NestedMapping | None,
) -> bool:
if not request_input:
return False
- elif isinstance(request_input, str):
+ if isinstance(request_input, str):
return "stream_state" in request_input
- else:
- for key, val in request_input.items():
- # Covers the case of RequestInput in the form of a string or Mapping[str, str]. It also covers the case
- # of a NestedMapping where the value is a string.
- # Note: Doesn't account for nested mappings for request_body_json, but I don't see stream_state used in that way
- # in our code
- if "stream_state" in key or (isinstance(val, str) and "stream_state" in val):
- return True
+ for key, val in request_input.items():
+ # Covers the case of RequestInput in the form of a string or Mapping[str, str]. It also covers the case
+ # of a NestedMapping where the value is a string.
+ # Note: Doesn't account for nested mappings for request_body_json, but I don't see stream_state used in that way
+ # in our code
+ if "stream_state" in key or (isinstance(val, str) and "stream_state" in val):
+ return True
return False
diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py
index f0a94ecb9..f5147d481 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py
@@ -3,8 +3,9 @@
#
from abc import abstractmethod
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Any, Mapping, Optional, Union
+from typing import Any
from airbyte_cdk.sources.types import StreamSlice, StreamState
@@ -25,9 +26,9 @@ class RequestOptionsProvider:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.
@@ -40,9 +41,9 @@ def get_request_params(
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method."""
@@ -50,10 +51,10 @@ def get_request_headers(
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
"""
Specifies how to populate the body of the request with a non-JSON payload.
@@ -68,9 +69,9 @@ def get_request_body_data(
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies how to populate the body of the request with a JSON payload.
diff --git a/airbyte_cdk/sources/declarative/requesters/request_path.py b/airbyte_cdk/sources/declarative/requesters/request_path.py
index 378ea6220..d6b289f9d 100644
--- a/airbyte_cdk/sources/declarative/requesters/request_path.py
+++ b/airbyte_cdk/sources/declarative/requesters/request_path.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping
+from typing import Any
@dataclass
diff --git a/airbyte_cdk/sources/declarative/requesters/requester.py b/airbyte_cdk/sources/declarative/requesters/requester.py
index 604b2faba..030cb5155 100644
--- a/airbyte_cdk/sources/declarative/requesters/requester.py
+++ b/airbyte_cdk/sources/declarative/requesters/requester.py
@@ -3,8 +3,9 @@
#
from abc import abstractmethod
+from collections.abc import Callable, Mapping, MutableMapping
from enum import Enum
-from typing import Any, Callable, Mapping, MutableMapping, Optional, Union
+from typing import Any
import requests
@@ -44,9 +45,9 @@ def get_url_base(self) -> str:
def get_path(
self,
*,
- stream_state: Optional[StreamState],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
+ stream_state: StreamState | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
) -> str:
"""
Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity"
@@ -62,9 +63,9 @@ def get_method(self) -> HttpMethod:
def get_request_params(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> MutableMapping[str, Any]:
"""
Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.
@@ -76,9 +77,9 @@ def get_request_params(
def get_request_headers(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.
@@ -88,10 +89,10 @@ def get_request_headers(
def get_request_body_data(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
"""
Specifies how to populate the body of the request with a non-JSON payload.
@@ -106,9 +107,9 @@ def get_request_body_data(
def get_request_body_json(
self,
*,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies how to populate the body of the request with a JSON payload.
@@ -117,18 +118,18 @@ def get_request_body_json(
"""
@abstractmethod
- def send_request(
+ def send_request( # noqa: PLR0913, PLR0917
self,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- path: Optional[str] = None,
- request_headers: Optional[Mapping[str, Any]] = None,
- request_params: Optional[Mapping[str, Any]] = None,
- request_body_data: Optional[Union[Mapping[str, Any], str]] = None,
- request_body_json: Optional[Mapping[str, Any]] = None,
- log_formatter: Optional[Callable[[requests.Response], Any]] = None,
- ) -> Optional[requests.Response]:
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ path: str | None = None,
+ request_headers: Mapping[str, Any] | None = None,
+ request_params: Mapping[str, Any] | None = None,
+ request_body_data: Mapping[str, Any] | str | None = None,
+ request_body_json: Mapping[str, Any] | None = None,
+ log_formatter: Callable[[requests.Response], Any] | None = None,
+ ) -> requests.Response | None:
"""
Sends a request and returns the response. Might return no response if the error handler chooses to ignore the response or throw an exception in case of an error.
If path is set, the path configured on the requester itself is ignored.
diff --git a/airbyte_cdk/sources/declarative/resolvers/__init__.py b/airbyte_cdk/sources/declarative/resolvers/__init__.py
index dba2f60b8..ba2de5695 100644
--- a/airbyte_cdk/sources/declarative/resolvers/__init__.py
+++ b/airbyte_cdk/sources/declarative/resolvers/__init__.py
@@ -2,7 +2,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
-from typing import Mapping
+from collections.abc import Mapping
from pydantic.v1 import BaseModel
@@ -25,6 +25,7 @@
HttpComponentsResolver,
)
+
COMPONENTS_RESOLVER_TYPE_MAPPING: Mapping[str, type[BaseModel]] = {
"HttpComponentsResolver": HttpComponentsResolverModel,
"ConfigComponentsResolver": ConfigComponentsResolverModel,
diff --git a/airbyte_cdk/sources/declarative/resolvers/components_resolver.py b/airbyte_cdk/sources/declarative/resolvers/components_resolver.py
index 5975b3082..6507b99d3 100644
--- a/airbyte_cdk/sources/declarative/resolvers/components_resolver.py
+++ b/airbyte_cdk/sources/declarative/resolvers/components_resolver.py
@@ -3,12 +3,13 @@
#
from abc import ABC, abstractmethod
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, Union
+from typing import Any, Union
from typing_extensions import deprecated
-from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
+from airbyte_cdk.sources.declarative.interpolation import InterpolatedString # noqa: TC001
from airbyte_cdk.sources.source import ExperimentalClassWarning
@@ -18,9 +19,9 @@ class ComponentMappingDefinition:
what field in the stream template should be updated with value, supporting dynamic interpolation
and type enforcement."""
- field_path: List["InterpolatedString"]
+ field_path: list["InterpolatedString"]
value: Union["InterpolatedString", str]
- value_type: Optional[Type[Any]]
+ value_type: type[Any] | None
parameters: InitVar[Mapping[str, Any]]
@@ -30,9 +31,9 @@ class ResolvedComponentMappingDefinition:
what field in the stream template should be updated with value, supporting dynamic interpolation
and type enforcement."""
- field_path: List["InterpolatedString"]
+ field_path: list["InterpolatedString"]
value: "InterpolatedString"
- value_type: Optional[Type[Any]]
+ value_type: type[Any] | None
parameters: InitVar[Mapping[str, Any]]
@@ -45,8 +46,8 @@ class ComponentsResolver(ABC):
@abstractmethod
def resolve_components(
- self, stream_template_config: Dict[str, Any]
- ) -> Iterable[Dict[str, Any]]:
+ self, stream_template_config: dict[str, Any]
+ ) -> Iterable[dict[str, Any]]:
"""
Maps and populates values into a stream template configuration.
:param stream_template_config: The stream template with placeholders for components.
diff --git a/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py b/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py
index 0308ea5da..c0b276ab3 100644
--- a/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py
+++ b/airbyte_cdk/sources/declarative/resolvers/config_components_resolver.py
@@ -2,9 +2,10 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping
from copy import deepcopy
from dataclasses import InitVar, dataclass, field
-from typing import Any, Dict, Iterable, List, Mapping, Union
+from typing import Any
import dpath
from typing_extensions import deprecated
@@ -26,7 +27,7 @@ class StreamConfig:
Identifies stream config details for dynamic schema extraction and processing.
"""
- configs_pointer: List[Union[InterpolatedString, str]]
+ configs_pointer: list[InterpolatedString | str]
parameters: InitVar[Mapping[str, Any]]
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@@ -50,9 +51,9 @@ class ConfigComponentsResolver(ComponentsResolver):
stream_config: StreamConfig
config: Config
- components_mapping: List[ComponentMappingDefinition]
+ components_mapping: list[ComponentMappingDefinition]
parameters: InitVar[Mapping[str, Any]]
- _resolved_components: List[ResolvedComponentMappingDefinition] = field(
+ _resolved_components: list[ResolvedComponentMappingDefinition] = field(
init=False, repr=False, default_factory=list
)
@@ -65,7 +66,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
"""
for component_mapping in self.components_mapping:
- if isinstance(component_mapping.value, (str, InterpolatedString)):
+ if isinstance(component_mapping.value, (str, InterpolatedString)): # noqa: UP038
interpolated_value = (
InterpolatedString.create(component_mapping.value, parameters=parameters)
if isinstance(component_mapping.value, str)
@@ -86,7 +87,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
)
)
else:
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"Expected a string or InterpolatedString for value in mapping: {component_mapping}"
)
@@ -104,8 +105,8 @@ def _stream_config(self) -> Iterable[Mapping[str, Any]]:
return stream_config
def resolve_components(
- self, stream_template_config: Dict[str, Any]
- ) -> Iterable[Dict[str, Any]]:
+ self, stream_template_config: dict[str, Any]
+ ) -> Iterable[dict[str, Any]]:
"""
Resolves components in the stream template configuration by populating values.
diff --git a/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py b/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py
index 6e85fc578..8213c9af0 100644
--- a/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py
+++ b/airbyte_cdk/sources/declarative/resolvers/http_components_resolver.py
@@ -2,9 +2,10 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Iterable, Mapping
from copy import deepcopy
from dataclasses import InitVar, dataclass, field
-from typing import Any, Dict, Iterable, List, Mapping
+from typing import Any
import dpath
from typing_extensions import deprecated
@@ -35,9 +36,9 @@ class HttpComponentsResolver(ComponentsResolver):
retriever: Retriever
config: Config
- components_mapping: List[ComponentMappingDefinition]
+ components_mapping: list[ComponentMappingDefinition]
parameters: InitVar[Mapping[str, Any]]
- _resolved_components: List[ResolvedComponentMappingDefinition] = field(
+ _resolved_components: list[ResolvedComponentMappingDefinition] = field(
init=False, repr=False, default_factory=list
)
@@ -49,7 +50,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
parameters (Mapping[str, Any]): Parameters for interpolation.
"""
for component_mapping in self.components_mapping:
- if isinstance(component_mapping.value, (str, InterpolatedString)):
+ if isinstance(component_mapping.value, (str, InterpolatedString)): # noqa: UP038
interpolated_value = (
InterpolatedString.create(component_mapping.value, parameters=parameters)
if isinstance(component_mapping.value, str)
@@ -70,13 +71,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
)
)
else:
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"Expected a string or InterpolatedString for value in mapping: {component_mapping}"
)
def resolve_components(
- self, stream_template_config: Dict[str, Any]
- ) -> Iterable[Dict[str, Any]]:
+ self, stream_template_config: dict[str, Any]
+ ) -> Iterable[dict[str, Any]]:
"""
Resolves components in the stream template configuration by populating values.
diff --git a/airbyte_cdk/sources/declarative/retrievers/__init__.py b/airbyte_cdk/sources/declarative/retrievers/__init__.py
index 177d141a3..ce6ab5842 100644
--- a/airbyte_cdk/sources/declarative/retrievers/__init__.py
+++ b/airbyte_cdk/sources/declarative/retrievers/__init__.py
@@ -9,4 +9,5 @@
SimpleRetrieverTestReadDecorator,
)
+
__all__ = ["Retriever", "SimpleRetriever", "SimpleRetrieverTestReadDecorator", "AsyncRetriever"]
diff --git a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py
index 1b8860289..f362ce606 100644
--- a/airbyte_cdk/sources/declarative/retrievers/async_retriever.py
+++ b/airbyte_cdk/sources/declarative/retrievers/async_retriever.py
@@ -1,8 +1,9 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
+from collections.abc import Iterable, Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Iterable, Mapping, Optional
+from typing import Any
from typing_extensions import deprecated
@@ -58,7 +59,7 @@ def _get_stream_state(self) -> StreamState:
return self.state
def _validate_and_get_stream_slice_partition(
- self, stream_slice: Optional[StreamSlice] = None
+ self, stream_slice: StreamSlice | None = None
) -> AsyncPartition:
"""
Validates the stream_slice argument and returns the partition from it.
@@ -80,13 +81,13 @@ def _validate_and_get_stream_slice_partition(
)
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices
- def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
+ def stream_slices(self) -> Iterable[StreamSlice | None]:
return self.stream_slicer.stream_slices()
def read_records(
self,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
+ stream_slice: StreamSlice | None = None,
) -> Iterable[StreamData]:
stream_state: StreamState = self._get_stream_state()
partition: AsyncPartition = self._validate_and_get_stream_slice_partition(stream_slice)
diff --git a/airbyte_cdk/sources/declarative/retrievers/retriever.py b/airbyte_cdk/sources/declarative/retrievers/retriever.py
index 155de5782..f4ba620b0 100644
--- a/airbyte_cdk/sources/declarative/retrievers/retriever.py
+++ b/airbyte_cdk/sources/declarative/retrievers/retriever.py
@@ -3,7 +3,8 @@
#
from abc import abstractmethod
-from typing import Any, Iterable, Mapping, Optional
+from collections.abc import Iterable, Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice
from airbyte_cdk.sources.streams.core import StreamData
@@ -19,7 +20,7 @@ class Retriever:
def read_records(
self,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
+ stream_slice: StreamSlice | None = None,
) -> Iterable[StreamData]:
"""
Fetch a stream's records from an HTTP API source
@@ -30,7 +31,7 @@ def read_records(
"""
@abstractmethod
- def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
+ def stream_slices(self) -> Iterable[StreamSlice | None]:
"""Returns the stream slices"""
@property
diff --git a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py
index d167a84bc..1413faf23 100644
--- a/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py
+++ b/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py
@@ -3,10 +3,11 @@
#
import json
+from collections.abc import Callable, Iterable, Mapping
from dataclasses import InitVar, dataclass, field
from functools import partial
from itertools import islice
-from typing import Any, Callable, Iterable, List, Mapping, Optional, Set, Tuple, Union
+from typing import Any
import requests
@@ -32,6 +33,7 @@
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.utils.mapping_helpers import combine_mappings
+
FULL_REFRESH_SYNC_COMPLETE_KEY = "__ab_full_refresh_sync_complete"
@@ -64,17 +66,17 @@ class SimpleRetriever(Retriever):
config: Config
parameters: InitVar[Mapping[str, Any]]
name: str
- _name: Union[InterpolatedString, str] = field(init=False, repr=False, default="")
- primary_key: Optional[Union[str, List[str], List[List[str]]]]
+ _name: InterpolatedString | str = field(init=False, repr=False, default="")
+ primary_key: str | list[str] | list[list[str]] | None
_primary_key: str = field(init=False, repr=False, default="")
- paginator: Optional[Paginator] = None
+ paginator: Paginator | None = None
stream_slicer: StreamSlicer = field(
default_factory=lambda: SinglePartitionRouter(parameters={})
)
request_option_provider: RequestOptionsProvider = field(
default_factory=lambda: DefaultRequestOptionsProvider(parameters={})
)
- cursor: Optional[DeclarativeCursor] = None
+ cursor: DeclarativeCursor | None = None
ignore_stream_slicer_parameters_on_paginated_requests: bool = False
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@@ -103,8 +105,10 @@ def name(self, value: str) -> None:
self._name = value
def _get_mapping(
- self, method: Callable[..., Optional[Union[Mapping[str, Any], str]]], **kwargs: Any
- ) -> Tuple[Union[Mapping[str, Any], str], Set[str]]:
+ self,
+ method: Callable[..., Mapping[str, Any] | str | None],
+ **kwargs: Any, # noqa: ANN401
+ ) -> tuple[Mapping[str, Any] | str, set[str]]:
"""
Get mapping from the provided method, and get the keys of the mapping.
If the method returns a string, it will return the string and an empty set.
@@ -116,18 +120,18 @@ def _get_mapping(
def _get_request_options(
self,
- stream_state: Optional[StreamData],
- stream_slice: Optional[StreamSlice],
- next_page_token: Optional[Mapping[str, Any]],
- paginator_method: Callable[..., Optional[Union[Mapping[str, Any], str]]],
- stream_slicer_method: Callable[..., Optional[Union[Mapping[str, Any], str]]],
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamData | None,
+ stream_slice: StreamSlice | None,
+ next_page_token: Mapping[str, Any] | None,
+ paginator_method: Callable[..., Mapping[str, Any] | str | None],
+ stream_slicer_method: Callable[..., Mapping[str, Any] | str | None],
+ ) -> Mapping[str, Any] | str:
"""
Get the request_option from the paginator and the stream slicer.
Raise a ValueError if there's a key collision
Returned merged mapping otherwise
"""
- # FIXME we should eventually remove the usage of stream_state as part of the interpolation
+ # FIXME we should eventually remove the usage of stream_state as part of the interpolation # noqa: FIX001, TD001, TD004
mappings = [
paginator_method(
stream_state=stream_state,
@@ -147,9 +151,9 @@ def _get_request_options(
def _request_headers(
self,
- stream_state: Optional[StreamData] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamData | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies request headers.
@@ -163,14 +167,14 @@ def _request_headers(
self.stream_slicer.get_request_headers,
)
if isinstance(headers, str):
- raise ValueError("Request headers cannot be a string")
+ raise ValueError("Request headers cannot be a string") # noqa: TRY004
return {str(k): str(v) for k, v in headers.items()}
def _request_params(
self,
- stream_state: Optional[StreamData] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: StreamData | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Specifies the query parameters that should be set on an outgoing HTTP request given the inputs.
@@ -185,15 +189,15 @@ def _request_params(
self.request_option_provider.get_request_params,
)
if isinstance(params, str):
- raise ValueError("Request params cannot be a string")
+ raise ValueError("Request params cannot be a string") # noqa: TRY004
return params
def _request_body_data(
self,
- stream_state: Optional[StreamData] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Union[Mapping[str, Any], str]:
+ stream_state: StreamData | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | str:
"""
Specifies how to populate the body of the request with a non-JSON payload.
@@ -213,10 +217,10 @@ def _request_body_data(
def _request_body_json(
self,
- stream_state: Optional[StreamData] = None,
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Optional[Mapping[str, Any]]:
+ stream_state: StreamData | None = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> Mapping[str, Any] | None:
"""
Specifies how to populate the body of the request with a JSON payload.
@@ -230,10 +234,10 @@ def _request_body_json(
self.request_option_provider.get_request_body_json,
)
if isinstance(body_json, str):
- raise ValueError("Request body json cannot be a string")
+ raise ValueError("Request body json cannot be a string") # noqa: TRY004
return body_json
- def _paginator_path(self, next_page_token: Optional[Mapping[str, Any]] = None) -> Optional[str]:
+ def _paginator_path(self, next_page_token: Mapping[str, Any] | None = None) -> str | None:
"""
If the paginator points to a path, follow it, else return nothing so the requester is used.
:param next_page_token:
@@ -243,11 +247,11 @@ def _paginator_path(self, next_page_token: Optional[Mapping[str, Any]] = None) -
def _parse_response(
self,
- response: Optional[requests.Response],
+ response: requests.Response | None,
stream_state: StreamState,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: StreamSlice | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Record]:
if not response:
yield from []
@@ -261,7 +265,7 @@ def _parse_response(
)
@property # type: ignore
- def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
+ def primary_key(self) -> str | list[str] | list[list[str]] | None:
"""The stream's primary key"""
return self._primary_key
@@ -274,9 +278,9 @@ def _next_page_token(
self,
response: requests.Response,
last_page_size: int,
- last_record: Optional[Record],
- last_page_token_value: Optional[Any],
- ) -> Optional[Mapping[str, Any]]:
+ last_record: Record | None,
+ last_page_token_value: Any | None, # noqa: ANN401
+ ) -> Mapping[str, Any] | None:
"""
Specifies a pagination strategy.
@@ -295,8 +299,8 @@ def _fetch_next_page(
self,
stream_state: Mapping[str, Any],
stream_slice: StreamSlice,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Optional[requests.Response]:
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> requests.Response | None:
return self.requester.send_request(
path=self._paginator_path(next_page_token=next_page_token),
stream_state=stream_state,
@@ -327,20 +331,20 @@ def _fetch_next_page(
# This logic is similar to _read_pages in the HttpStream class. When making changes here, consider making changes there as well.
def _read_pages(
self,
- records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]],
+ records_generator_fn: Callable[[requests.Response | None], Iterable[Record]],
stream_state: Mapping[str, Any],
stream_slice: StreamSlice,
) -> Iterable[Record]:
pagination_complete = False
initial_token = self._paginator.get_initial_token()
- next_page_token: Optional[Mapping[str, Any]] = (
+ next_page_token: Mapping[str, Any] | None = (
{"next_page_token": initial_token} if initial_token else None
)
while not pagination_complete:
response = self._fetch_next_page(stream_state, stream_slice, next_page_token)
last_page_size = 0
- last_record: Optional[Record] = None
+ last_record: Record | None = None
for record in records_generator_fn(response):
last_page_size += 1
last_record = record
@@ -366,21 +370,21 @@ def _read_pages(
def _read_single_page(
self,
- records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]],
+ records_generator_fn: Callable[[requests.Response | None], Iterable[Record]],
stream_state: Mapping[str, Any],
stream_slice: StreamSlice,
) -> Iterable[StreamData]:
initial_token = stream_state.get("next_page_token")
if initial_token is None:
initial_token = self._paginator.get_initial_token()
- next_page_token: Optional[Mapping[str, Any]] = (
+ next_page_token: Mapping[str, Any] | None = (
{"next_page_token": initial_token} if initial_token else None
)
response = self._fetch_next_page(stream_state, stream_slice, next_page_token)
last_page_size = 0
- last_record: Optional[Record] = None
+ last_record: Record | None = None
for record in records_generator_fn(response):
last_page_size += 1
last_record = record
@@ -410,7 +414,7 @@ def _read_single_page(
def read_records(
self,
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice] = None,
+ stream_slice: StreamSlice | None = None,
) -> Iterable[StreamData]:
"""
Fetch a stream's records from an HTTP API source
@@ -419,13 +423,16 @@ def read_records(
:param stream_slice: The stream slice to read data for
:return: The records read from the API source
"""
- _slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check
+ slice_ = stream_slice or StreamSlice(
+ partition={},
+ cursor_slice={},
+ )
most_recent_record_from_slice = None
record_generator = partial(
self._parse_records,
stream_state=self.state or {},
- stream_slice=_slice,
+ stream_slice=slice_,
records_schema=records_schema,
)
@@ -438,46 +445,42 @@ def read_records(
if stream_state.get(FULL_REFRESH_SYNC_COMPLETE_KEY):
return
- yield from self._read_single_page(record_generator, stream_state, _slice)
+ yield from self._read_single_page(record_generator, stream_state, slice_)
else:
- for stream_data in self._read_pages(record_generator, self.state, _slice):
- current_record = self._extract_record(stream_data, _slice)
+ for stream_data in self._read_pages(record_generator, self.state, slice_):
+ current_record = self._extract_record(stream_data, slice_)
if self.cursor and current_record:
- self.cursor.observe(_slice, current_record)
+ self.cursor.observe(slice_, current_record)
# Latest record read, not necessarily within slice boundaries.
- # TODO Remove once all custom components implement `observe` method.
+ # TODO Remove once all custom components implement `observe` method. # noqa: TD004
# https://github.com/airbytehq/airbyte-internal-issues/issues/6955
most_recent_record_from_slice = self._get_most_recent_record(
- most_recent_record_from_slice, current_record, _slice
+ most_recent_record_from_slice, current_record, slice_
)
yield stream_data
if self.cursor:
- self.cursor.close_slice(_slice, most_recent_record_from_slice)
+ self.cursor.close_slice(slice_, most_recent_record_from_slice)
return
def _get_most_recent_record(
self,
- current_most_recent: Optional[Record],
- current_record: Optional[Record],
- stream_slice: StreamSlice,
- ) -> Optional[Record]:
+ current_most_recent: Record | None,
+ current_record: Record | None,
+ stream_slice: StreamSlice, # noqa: ARG002
+ ) -> Record | None:
if self.cursor and current_record:
if not current_most_recent:
return current_record
- else:
- return (
- current_most_recent
- if self.cursor.is_greater_than_or_equal(current_most_recent, current_record)
- else current_record
- )
- else:
- return None
+ return (
+ current_most_recent
+ if self.cursor.is_greater_than_or_equal(current_most_recent, current_record)
+ else current_record
+ )
+ return None
- def _extract_record(
- self, stream_data: StreamData, stream_slice: StreamSlice
- ) -> Optional[Record]:
+ def _extract_record(self, stream_data: StreamData, stream_slice: StreamSlice) -> Record | None:
"""
As we allow the output of _read_pages to be StreamData, it can be multiple things. Therefore, we need to filter out and normalize
to data to streamline the rest of the process.
@@ -485,11 +488,11 @@ def _extract_record(
if isinstance(stream_data, Record):
# Record is not part of `StreamData` but is the most common implementation of `Mapping[str, Any]` which is part of `StreamData`
return stream_data
- elif isinstance(stream_data, (dict, Mapping)):
+ if isinstance(stream_data, (dict, Mapping)): # noqa: UP038
return Record(
data=dict(stream_data), associated_slice=stream_slice, stream_name=self.name
)
- elif isinstance(stream_data, AirbyteMessage) and stream_data.record:
+ if isinstance(stream_data, AirbyteMessage) and stream_data.record:
return Record(
data=stream_data.record.data, # type:ignore # AirbyteMessage always has record.data
associated_slice=stream_slice,
@@ -498,7 +501,7 @@ def _extract_record(
return None
# stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever
- def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore
+ def stream_slices(self) -> Iterable[StreamSlice | None]: # type: ignore
"""
Specifies the slices for this stream. See the stream slicing section of the docs for more information.
@@ -521,10 +524,10 @@ def state(self, value: StreamState) -> None:
def _parse_records(
self,
- response: Optional[requests.Response],
+ response: requests.Response | None,
stream_state: Mapping[str, Any],
records_schema: Mapping[str, Any],
- stream_slice: Optional[StreamSlice],
+ stream_slice: StreamSlice | None,
) -> Iterable[Record]:
yield from self._parse_response(
response,
@@ -537,7 +540,7 @@ def must_deduplicate_query_params(self) -> bool:
return True
@staticmethod
- def _to_partition_key(to_serialize: Any) -> str:
+ def _to_partition_key(to_serialize: Any) -> str: # noqa: ANN401
# separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value
return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True)
@@ -559,15 +562,15 @@ def __post_init__(self, options: Mapping[str, Any]) -> None:
)
# stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever
- def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore
+ def stream_slices(self) -> Iterable[StreamSlice | None]: # type: ignore
return islice(super().stream_slices(), self.maximum_number_of_slices)
def _fetch_next_page(
self,
stream_state: Mapping[str, Any],
stream_slice: StreamSlice,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Optional[requests.Response]:
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> requests.Response | None:
return self.requester.send_request(
path=self._paginator_path(next_page_token=next_page_token),
stream_state=stream_state,
diff --git a/airbyte_cdk/sources/declarative/schema/__init__.py b/airbyte_cdk/sources/declarative/schema/__init__.py
index b5b6a7d31..e41407845 100644
--- a/airbyte_cdk/sources/declarative/schema/__init__.py
+++ b/airbyte_cdk/sources/declarative/schema/__init__.py
@@ -12,6 +12,7 @@
from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import JsonFileSchemaLoader
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
+
__all__ = [
"JsonFileSchemaLoader",
"DefaultSchemaLoader",
diff --git a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py
index a9b625e7d..7f3ba5611 100644
--- a/airbyte_cdk/sources/declarative/schema/default_schema_loader.py
+++ b/airbyte_cdk/sources/declarative/schema/default_schema_loader.py
@@ -3,8 +3,9 @@
#
import logging
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.schema.json_file_schema_loader import JsonFileSchemaLoader
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
@@ -41,7 +42,7 @@ def get_json_schema(self) -> Mapping[str, Any]:
# A slight hack since we don't directly have the stream name. However, when building the default filepath we assume the
# runtime options stores stream name 'name' so we'll do the same here
stream_name = self._parameters.get("name", "")
- logging.info(
+ logging.info( # noqa: LOG015
f"Could not find schema for stream {stream_name}, defaulting to the empty schema"
)
return {}
diff --git a/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py b/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py
index d65890b70..68cb2efaf 100644
--- a/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py
+++ b/airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py
@@ -3,9 +3,10 @@
#
+from collections.abc import Mapping, MutableMapping
from copy import deepcopy
from dataclasses import InitVar, dataclass, field
-from typing import Any, List, Mapping, MutableMapping, Optional, Union
+from typing import Any
import dpath
from typing_extensions import deprecated
@@ -18,6 +19,7 @@
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
+
AIRBYTE_DATA_TYPES: Mapping[str, Mapping[str, Any]] = {
"string": {"type": ["null", "string"]},
"boolean": {"type": ["null", "boolean"]},
@@ -52,9 +54,9 @@ class TypesMap:
Represents a mapping between a current type and its corresponding target type.
"""
- target_type: Union[List[str], str]
- current_type: Union[List[str], str]
- condition: Optional[str]
+ target_type: list[str] | str
+ current_type: list[str] | str
+ condition: str | None
@deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning)
@@ -64,11 +66,11 @@ class SchemaTypeIdentifier:
Identifies schema details for dynamic schema extraction and processing.
"""
- key_pointer: List[Union[InterpolatedString, str]]
+ key_pointer: list[InterpolatedString | str]
parameters: InitVar[Mapping[str, Any]]
- type_pointer: Optional[List[Union[InterpolatedString, str]]] = None
- types_mapping: Optional[List[TypesMap]] = None
- schema_pointer: Optional[List[Union[InterpolatedString, str]]] = None
+ type_pointer: list[InterpolatedString | str] | None = None
+ types_mapping: list[TypesMap] | None = None
+ schema_pointer: list[InterpolatedString | str] | None = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.schema_pointer = (
@@ -81,8 +83,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
@staticmethod
def _update_pointer(
- pointer: Optional[List[Union[InterpolatedString, str]]], parameters: Mapping[str, Any]
- ) -> Optional[List[Union[InterpolatedString, str]]]:
+ pointer: list[InterpolatedString | str] | None, parameters: Mapping[str, Any]
+ ) -> list[InterpolatedString | str] | None:
return (
[
InterpolatedString.create(path, parameters=parameters)
@@ -106,7 +108,7 @@ class DynamicSchemaLoader(SchemaLoader):
config: Config
parameters: InitVar[Mapping[str, Any]]
schema_type_identifier: SchemaTypeIdentifier
- schema_transformations: List[RecordTransformation] = field(default_factory=lambda: [])
+ schema_transformations: list[RecordTransformation] = field(default_factory=list)
def get_json_schema(self) -> Mapping[str, Any]:
"""
@@ -143,8 +145,8 @@ def get_json_schema(self) -> Mapping[str, Any]:
def _transform(
self,
properties: Mapping[str, Any],
- stream_state: StreamState,
- stream_slice: Optional[StreamSlice] = None,
+ stream_state: StreamState, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
for transformation in self.schema_transformations:
transformation.transform(
@@ -156,21 +158,21 @@ def _transform(
def _get_key(
self,
raw_schema: MutableMapping[str, Any],
- field_key_path: List[Union[InterpolatedString, str]],
+ field_key_path: list[InterpolatedString | str],
) -> str:
"""
Extracts the key field from the schema using the specified path.
"""
field_key = self._extract_data(raw_schema, field_key_path)
if not isinstance(field_key, str):
- raise ValueError(f"Expected key to be a string. Got {field_key}")
+ raise ValueError(f"Expected key to be a string. Got {field_key}") # noqa: TRY004
return field_key
def _get_type(
self,
raw_schema: MutableMapping[str, Any],
- field_type_path: Optional[List[Union[InterpolatedString, str]]],
- ) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
+ field_type_path: list[InterpolatedString | str] | None,
+ ) -> Mapping[str, Any] | list[Mapping[str, Any]]:
"""
Determines the JSON Schema type for a field, supporting nullable and combined types.
"""
@@ -182,24 +184,23 @@ def _get_type(
mapped_field_type = self._replace_type_if_not_valid(raw_field_type, raw_schema)
if (
isinstance(mapped_field_type, list)
- and len(mapped_field_type) == 2
+ and len(mapped_field_type) == 2 # noqa: PLR2004
and all(isinstance(item, str) for item in mapped_field_type)
):
first_type = self._get_airbyte_type(mapped_field_type[0])
second_type = self._get_airbyte_type(mapped_field_type[1])
return {"oneOf": [first_type, second_type]}
- elif isinstance(mapped_field_type, str):
+ if isinstance(mapped_field_type, str):
return self._get_airbyte_type(mapped_field_type)
- else:
- raise ValueError(
- f"Invalid data type. Available string or two items list of string. Got {mapped_field_type}."
- )
+ raise ValueError(
+ f"Invalid data type. Available string or two items list of string. Got {mapped_field_type}."
+ )
def _replace_type_if_not_valid(
self,
- field_type: Union[List[str], str],
+ field_type: list[str] | str,
raw_schema: MutableMapping[str, Any],
- ) -> Union[List[str], str]:
+ ) -> list[str] | str:
"""
Replaces a field type if it matches a type mapping in `types_map`.
"""
@@ -228,9 +229,9 @@ def _get_airbyte_type(field_type: str) -> Mapping[str, Any]:
def _extract_data(
self,
body: Mapping[str, Any],
- extraction_path: Optional[List[Union[InterpolatedString, str]]] = None,
- default: Any = None,
- ) -> Any:
+ extraction_path: list[InterpolatedString | str] | None = None,
+ default: Any = None, # noqa: ANN401
+ ) -> Any: # noqa: ANN401
"""
Extracts data from the body based on the provided extraction path.
"""
diff --git a/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py b/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py
index 72a46b7e5..6675badde 100644
--- a/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py
+++ b/airbyte_cdk/sources/declarative/schema/inline_schema_loader.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
@@ -12,7 +13,7 @@
class InlineSchemaLoader(SchemaLoader):
"""Describes a stream's schema"""
- schema: Dict[str, Any]
+ schema: dict[str, Any]
parameters: InitVar[Mapping[str, Any]]
def get_json_schema(self) -> Mapping[str, Any]:
diff --git a/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py b/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py
index af51fe5db..486223800 100644
--- a/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py
+++ b/airbyte_cdk/sources/declarative/schema/json_file_schema_loader.py
@@ -5,8 +5,9 @@
import json
import pkgutil
import sys
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Mapping, Tuple, Union
+from typing import Any
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
@@ -43,7 +44,7 @@ class JsonFileSchemaLoader(ResourceSchemaLoader, SchemaLoader):
config: Config
parameters: InitVar[Mapping[str, Any]]
- file_path: Union[InterpolatedString, str] = field(default="")
+ file_path: InterpolatedString | str = field(default="")
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if not self.file_path:
@@ -51,14 +52,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.file_path = InterpolatedString.create(self.file_path, parameters=parameters)
def get_json_schema(self) -> Mapping[str, Any]:
- # todo: It is worth revisiting if we can replace file_path with just file_name if every schema is in the /schemas directory
+ # TODO: It is worth revisiting if we can replace file_path with just file_name if every schema is in the /schemas directory
# this would require that we find a creative solution to store or retrieve source_name in here since the files are mounted there
json_schema_path = self._get_json_filepath()
resource, schema_path = self.extract_resource_and_schema_path(json_schema_path)
raw_json_file = pkgutil.get_data(resource, schema_path)
if not raw_json_file:
- raise IOError(f"Cannot find file {json_schema_path}")
+ raise OSError(f"Cannot find file {json_schema_path}")
try:
raw_schema = json.loads(raw_json_file)
except ValueError as err:
@@ -66,11 +67,11 @@ def get_json_schema(self) -> Mapping[str, Any]:
self.package_name = resource
return self._resolve_schema_references(raw_schema)
- def _get_json_filepath(self) -> Any:
+ def _get_json_filepath(self) -> Any: # noqa: ANN401
return self.file_path.eval(self.config) # type: ignore # file_path is always cast to an interpolated string
@staticmethod
- def extract_resource_and_schema_path(json_schema_path: str) -> Tuple[str, str]:
+ def extract_resource_and_schema_path(json_schema_path: str) -> tuple[str, str]:
"""
When the connector is running on a docker container, package_data is accessible from the resource (source_), so we extract
the resource from the first part of the schema path and the remaining path is used to find the schema file. This is a slight
@@ -80,7 +81,7 @@ def extract_resource_and_schema_path(json_schema_path: str) -> Tuple[str, str]:
"""
split_path = json_schema_path.split("/")
- if split_path[0] == "" or split_path[0] == ".":
+ if split_path[0] == "" or split_path[0] == ".": # noqa: PLC1901
split_path = split_path[1:]
if len(split_path) == 0:
diff --git a/airbyte_cdk/sources/declarative/schema/schema_loader.py b/airbyte_cdk/sources/declarative/schema/schema_loader.py
index a6beb70ae..fb7f45cb6 100644
--- a/airbyte_cdk/sources/declarative/schema/schema_loader.py
+++ b/airbyte_cdk/sources/declarative/schema/schema_loader.py
@@ -3,8 +3,9 @@
#
from abc import abstractmethod
+from collections.abc import Mapping
from dataclasses import dataclass
-from typing import Any, Mapping
+from typing import Any
@dataclass
diff --git a/airbyte_cdk/sources/declarative/spec/__init__.py b/airbyte_cdk/sources/declarative/spec/__init__.py
index 1c13ed67c..4302b8749 100644
--- a/airbyte_cdk/sources/declarative/spec/__init__.py
+++ b/airbyte_cdk/sources/declarative/spec/__init__.py
@@ -4,4 +4,5 @@
from airbyte_cdk.sources.declarative.spec.spec import Spec
+
__all__ = ["Spec"]
diff --git a/airbyte_cdk/sources/declarative/spec/spec.py b/airbyte_cdk/sources/declarative/spec/spec.py
index 914e99e93..7892e3c45 100644
--- a/airbyte_cdk/sources/declarative/spec/spec.py
+++ b/airbyte_cdk/sources/declarative/spec/spec.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Mapping, Optional
+from typing import Any
from airbyte_cdk.models import (
AdvancedAuth,
@@ -25,8 +26,8 @@ class Spec:
connection_specification: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
- documentation_url: Optional[str] = None
- advanced_auth: Optional[AuthFlow] = None
+ documentation_url: str | None = None
+ advanced_auth: AuthFlow | None = None
def generate_spec(self) -> ConnectorSpecification:
"""
diff --git a/airbyte_cdk/sources/declarative/stream_slicers/__init__.py b/airbyte_cdk/sources/declarative/stream_slicers/__init__.py
index 7bacc3ca8..d2df979f5 100644
--- a/airbyte_cdk/sources/declarative/stream_slicers/__init__.py
+++ b/airbyte_cdk/sources/declarative/stream_slicers/__init__.py
@@ -4,4 +4,5 @@
from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer
+
__all__ = ["StreamSlicer"]
diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py
index 91ce28e7a..d6e5dd1b1 100644
--- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py
+++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py
@@ -1,6 +1,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
-from typing import Any, Iterable, Mapping, Optional
+from collections.abc import Iterable, Mapping
+from typing import Any
from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.message import MessageRepository
@@ -40,7 +41,7 @@ def create(self, stream_slice: StreamSlice) -> Partition:
class DeclarativePartition(Partition):
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream_name: str,
json_schema: Mapping[str, Any],
@@ -66,7 +67,7 @@ def read(self) -> Iterable[Record]:
else:
self._message_repository.emit_message(stream_data)
- def to_slice(self) -> Optional[Mapping[str, Any]]:
+ def to_slice(self) -> Mapping[str, Any] | None:
return self._stream_slice
def stream_name(self) -> str:
diff --git a/airbyte_cdk/sources/declarative/transformations/__init__.py b/airbyte_cdk/sources/declarative/transformations/__init__.py
index e18712a01..9bd1e758e 100644
--- a/airbyte_cdk/sources/declarative/transformations/__init__.py
+++ b/airbyte_cdk/sources/declarative/transformations/__init__.py
@@ -10,8 +10,10 @@
# so we add the split directive below to tell isort to sort imports while keeping RecordTransformation as the first import
from .transformation import RecordTransformation
+
# isort: split
from .add_fields import AddFields
from .remove_fields import RemoveFields
+
__all__ = ["AddFields", "RecordTransformation", "RemoveFields"]
diff --git a/airbyte_cdk/sources/declarative/transformations/add_fields.py b/airbyte_cdk/sources/declarative/transformations/add_fields.py
index 4c9d5366c..b4c13c7de 100644
--- a/airbyte_cdk/sources/declarative/transformations/add_fields.py
+++ b/airbyte_cdk/sources/declarative/transformations/add_fields.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass, field
-from typing import Any, Dict, List, Mapping, Optional, Type, Union
+from typing import Any
import dpath
@@ -17,8 +18,8 @@ class AddedFieldDefinition:
"""Defines the field to add on a record"""
path: FieldPointer
- value: Union[InterpolatedString, str]
- value_type: Optional[Type[Any]]
+ value: InterpolatedString | str
+ value_type: type[Any] | None
parameters: InitVar[Mapping[str, Any]]
@@ -28,12 +29,12 @@ class ParsedAddFieldDefinition:
path: FieldPointer
value: InterpolatedString
- value_type: Optional[Type[Any]]
+ value_type: type[Any] | None
parameters: InitVar[Mapping[str, Any]]
@dataclass
-class AddFields(RecordTransformation):
+class AddFields(RecordTransformation): # noqa: PLW1641
"""
Transformation which adds field to an output record. The path of the added field can be nested. Adding nested fields will create all
necessary parent objects (like mkdir -p). Adding fields to an array will extend the array to that index (filling intermediate
@@ -84,9 +85,9 @@ class AddFields(RecordTransformation):
fields (List[AddedFieldDefinition]): A list of transformations (path and corresponding value) that will be added to the record
"""
- fields: List[AddedFieldDefinition]
+ fields: list[AddedFieldDefinition]
parameters: InitVar[Mapping[str, Any]]
- _parsed_fields: List[ParsedAddFieldDefinition] = field(
+ _parsed_fields: list[ParsedAddFieldDefinition] = field(
init=False, repr=False, default_factory=list
)
@@ -100,15 +101,14 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if not isinstance(add_field.value, InterpolatedString):
if not isinstance(add_field.value, str):
raise f"Expected a string value for the AddFields transformation: {add_field}"
- else:
- self._parsed_fields.append(
- ParsedAddFieldDefinition(
- add_field.path,
- InterpolatedString.create(add_field.value, parameters=parameters),
- value_type=add_field.value_type,
- parameters=parameters,
- )
+ self._parsed_fields.append(
+ ParsedAddFieldDefinition(
+ add_field.path,
+ InterpolatedString.create(add_field.value, parameters=parameters),
+ value_type=add_field.value_type,
+ parameters=parameters,
)
+ )
else:
self._parsed_fields.append(
ParsedAddFieldDefinition(
@@ -121,10 +121,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
) -> None:
if config is None:
config = {}
@@ -134,5 +134,5 @@ def transform(
value = parsed_field.value.eval(config, valid_types=valid_types, **kwargs)
dpath.new(record, parsed_field.path, value)
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
return bool(self.__dict__ == other.__dict__)
diff --git a/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py b/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py
index 73162d848..4fd5efc04 100644
--- a/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py
+++ b/airbyte_cdk/sources/declarative/transformations/dpath_flatten_fields.py
@@ -1,5 +1,6 @@
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, List, Mapping, Optional, Union
+from typing import Any
import dpath
@@ -19,7 +20,7 @@ class DpathFlattenFields(RecordTransformation):
"""
config: Config
- field_path: List[Union[InterpolatedString, str]]
+ field_path: list[InterpolatedString | str]
parameters: InitVar[Mapping[str, Any]]
delete_origin_value: bool = False
@@ -35,10 +36,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None, # noqa: ARG002
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
) -> None:
path = [path.eval(self.config) for path in self._field_path]
if "*" in path:
diff --git a/airbyte_cdk/sources/declarative/transformations/flatten_fields.py b/airbyte_cdk/sources/declarative/transformations/flatten_fields.py
index 24bfba660..4976918c1 100644
--- a/airbyte_cdk/sources/declarative/transformations/flatten_fields.py
+++ b/airbyte_cdk/sources/declarative/transformations/flatten_fields.py
@@ -3,7 +3,7 @@
#
from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Any
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
@@ -15,18 +15,18 @@ class FlattenFields(RecordTransformation):
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None, # noqa: ARG002
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
) -> None:
transformed_record = self.flatten_record(record)
record.clear()
record.update(transformed_record)
- def flatten_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
+ def flatten_record(self, record: dict[str, Any]) -> dict[str, Any]:
stack = [(record, "_")]
- transformed_record: Dict[str, Any] = {}
+ transformed_record: dict[str, Any] = {}
force_with_parent_name = False
while stack:
diff --git a/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py
index 8fe0bbffb..00deb5130 100644
--- a/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py
+++ b/airbyte_cdk/sources/declarative/transformations/keys_replace_transformation.py
@@ -2,8 +2,9 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, Mapping, Optional
+from typing import Any
from airbyte_cdk import InterpolatedString
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
@@ -34,10 +35,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
) -> None:
if config is None:
config = {}
@@ -46,7 +47,7 @@ def transform(
old_key = str(self._old.eval(config, **kwargs))
new_key = str(self._new.eval(config, **kwargs))
- def _transform(data: Dict[str, Any]) -> Dict[str, Any]:
+ def _transform(data: dict[str, Any]) -> dict[str, Any]:
result = {}
for key, value in data.items():
updated_key = key.replace(old_key, new_key)
diff --git a/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py
index 53db3d49a..cabc6761d 100644
--- a/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py
+++ b/airbyte_cdk/sources/declarative/transformations/keys_to_lower_transformation.py
@@ -3,7 +3,7 @@
#
from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Any
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
@@ -13,10 +13,10 @@
class KeysToLowerTransformation(RecordTransformation):
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None, # noqa: ARG002
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
) -> None:
for key in set(record.keys()):
record[key.lower()] = record.pop(key)
diff --git a/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py b/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py
index 86e25c399..c6e645706 100644
--- a/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py
+++ b/airbyte_cdk/sources/declarative/transformations/keys_to_snake_transformation.py
@@ -4,7 +4,7 @@
import re
from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
+from typing import Any
import unidecode
@@ -20,16 +20,16 @@ class KeysToSnakeCaseTransformation(RecordTransformation):
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None, # noqa: ARG002
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
) -> None:
transformed_record = self._transform_record(record)
record.clear()
record.update(transformed_record)
- def _transform_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
+ def _transform_record(self, record: dict[str, Any]) -> dict[str, Any]:
transformed_record = {}
for key, value in record.items():
transformed_key = self.process_key(key)
@@ -50,19 +50,19 @@ def process_key(self, key: str) -> str:
def normalize_key(self, key: str) -> str:
return unidecode.unidecode(key)
- def tokenize_key(self, key: str) -> List[str]:
+ def tokenize_key(self, key: str) -> list[str]:
tokens = []
for match in self.token_pattern.finditer(key):
token = match.group(0) if match.group("NoToken") is None else ""
tokens.append(token)
return tokens
- def filter_tokens(self, tokens: List[str]) -> List[str]:
- if len(tokens) >= 3:
+ def filter_tokens(self, tokens: list[str]) -> list[str]:
+ if len(tokens) >= 3: # noqa: PLR2004
tokens = tokens[:1] + [t for t in tokens[1:-1] if t] + tokens[-1:]
if tokens and tokens[0].isdigit():
tokens.insert(0, "")
return tokens
- def tokens_to_snake_case(self, tokens: List[str]) -> str:
+ def tokens_to_snake_case(self, tokens: list[str]) -> str:
return "_".join(token.lower() for token in tokens)
diff --git a/airbyte_cdk/sources/declarative/transformations/remove_fields.py b/airbyte_cdk/sources/declarative/transformations/remove_fields.py
index f5d8164df..1aee83bb8 100644
--- a/airbyte_cdk/sources/declarative/transformations/remove_fields.py
+++ b/airbyte_cdk/sources/declarative/transformations/remove_fields.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from dataclasses import InitVar, dataclass
-from typing import Any, Dict, List, Mapping, Optional
+from typing import Any
import dpath
import dpath.exceptions
@@ -40,7 +41,7 @@ class RemoveFields(RecordTransformation):
field_pointers (List[FieldPointer]): pointers to the fields that should be removed
"""
- field_pointers: List[FieldPointer]
+ field_pointers: list[FieldPointer]
parameters: InitVar[Mapping[str, Any]]
condition: str = ""
@@ -51,10 +52,10 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None,
+ stream_state: StreamState | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None, # noqa: ARG002
) -> None:
"""
:param record: The record to be transformed
@@ -62,7 +63,7 @@ def transform(
"""
for pointer in self.field_pointers:
# the dpath library by default doesn't delete fields from arrays
- try:
+ try: # noqa: SIM105
dpath.delete(
record,
pointer,
diff --git a/airbyte_cdk/sources/declarative/transformations/transformation.py b/airbyte_cdk/sources/declarative/transformations/transformation.py
index f5b226429..15b67b0b9 100644
--- a/airbyte_cdk/sources/declarative/transformations/transformation.py
+++ b/airbyte_cdk/sources/declarative/transformations/transformation.py
@@ -4,13 +4,13 @@
from abc import abstractmethod
from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Any
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
@dataclass
-class RecordTransformation:
+class RecordTransformation: # noqa: PLW1641
"""
Implementations of this class define transformations that can be applied to records of a stream.
"""
@@ -18,10 +18,10 @@ class RecordTransformation:
@abstractmethod
def transform(
self,
- record: Dict[str, Any],
- config: Optional[Config] = None,
- stream_state: Optional[StreamState] = None,
- stream_slice: Optional[StreamSlice] = None,
+ record: dict[str, Any],
+ config: Config | None = None,
+ stream_state: StreamState | None = None,
+ stream_slice: StreamSlice | None = None,
) -> None:
"""
Transform a record by adding, deleting, or mutating fields directly from the record reference passed in argument.
diff --git a/airbyte_cdk/sources/declarative/types.py b/airbyte_cdk/sources/declarative/types.py
index a4d0aeb1d..80f297e7f 100644
--- a/airbyte_cdk/sources/declarative/types.py
+++ b/airbyte_cdk/sources/declarative/types.py
@@ -1,4 +1,4 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
@@ -13,13 +13,14 @@
StreamState,
)
+
# Note: This package originally contained class definitions for low-code CDK types, but we promoted them into the Python CDK.
# We've migrated connectors in the repository to reference the new location, but these assignments are used to retain backwards
# compatibility for sources created by OSS customers or on forks. This can be removed when we start bumping major versions.
-FieldPointer = FieldPointer
-Config = Config
-ConnectionDefinition = ConnectionDefinition
-StreamState = StreamState
-Record = Record
-StreamSlice = StreamSlice
+FieldPointer = FieldPointer # noqa: PLW0127
+Config = Config # noqa: PLW0127
+ConnectionDefinition = ConnectionDefinition # noqa: PLW0127
+StreamState = StreamState # noqa: PLW0127
+Record = Record # noqa: PLW0127
+StreamSlice = StreamSlice # noqa: PLW0127
diff --git a/airbyte_cdk/sources/declarative/yaml_declarative_source.py b/airbyte_cdk/sources/declarative/yaml_declarative_source.py
index 04ccda4cf..b94b387a3 100644
--- a/airbyte_cdk/sources/declarative/yaml_declarative_source.py
+++ b/airbyte_cdk/sources/declarative/yaml_declarative_source.py
@@ -3,7 +3,8 @@
#
import pkgutil
-from typing import Any, List, Mapping, Optional
+from collections.abc import Mapping
+from typing import Any
import yaml
@@ -14,16 +15,16 @@
from airbyte_cdk.sources.types import ConnectionDefinition
-class YamlDeclarativeSource(ConcurrentDeclarativeSource[List[AirbyteStateMessage]]):
+class YamlDeclarativeSource(ConcurrentDeclarativeSource[list[AirbyteStateMessage]]):
"""Declarative source defined by a yaml file"""
def __init__(
self,
path_to_yaml: str,
- debug: bool = False,
- catalog: Optional[ConfiguredAirbyteCatalog] = None,
- config: Optional[Mapping[str, Any]] = None,
- state: Optional[List[AirbyteStateMessage]] = None,
+ debug: bool = False, # noqa: FBT001, FBT002, ARG002
+ catalog: ConfiguredAirbyteCatalog | None = None,
+ config: Mapping[str, Any] | None = None,
+ state: list[AirbyteStateMessage] | None = None,
) -> None:
"""
:param path_to_yaml: Path to the yaml file describing the source
@@ -45,8 +46,7 @@ def _read_and_parse_yaml_file(self, path_to_yaml_file: str) -> ConnectionDefinit
if yaml_config:
decoded_yaml = yaml_config.decode()
return self._parse(decoded_yaml)
- else:
- return {}
+ return {}
def _emit_manifest_debug_message(self, extra_args: dict[str, Any]) -> None:
extra_args["path_to_yaml"] = self._path_to_yaml
diff --git a/airbyte_cdk/sources/embedded/base_integration.py b/airbyte_cdk/sources/embedded/base_integration.py
index 77917b0a1..f16a03ab2 100644
--- a/airbyte_cdk/sources/embedded/base_integration.py
+++ b/airbyte_cdk/sources/embedded/base_integration.py
@@ -3,7 +3,8 @@
#
from abc import ABC, abstractmethod
-from typing import Generic, Iterable, Optional, TypeVar
+from collections.abc import Iterable
+from typing import Generic, TypeVar
from airbyte_cdk.connector import TConfig
from airbyte_cdk.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type
@@ -16,27 +17,28 @@
from airbyte_cdk.sources.embedded.tools import get_defined_id
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit
+
TOutput = TypeVar("TOutput")
class BaseEmbeddedIntegration(ABC, Generic[TConfig, TOutput]):
- def __init__(self, runner: SourceRunner[TConfig], config: TConfig):
+ def __init__(self, runner: SourceRunner[TConfig], config: TConfig): # noqa: ANN204
check_config_against_spec_or_exit(config, runner.spec())
self.source = runner
self.config = config
- self.last_state: Optional[AirbyteStateMessage] = None
+ self.last_state: AirbyteStateMessage | None = None
@abstractmethod
- def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Optional[TOutput]:
+ def _handle_record(self, record: AirbyteRecordMessage, id: str | None) -> TOutput | None: # noqa: A002
"""
Turn an Airbyte record into the appropriate output type for the integration.
"""
pass
def _load_data(
- self, stream_name: str, state: Optional[AirbyteStateMessage] = None
+ self, stream_name: str, state: AirbyteStateMessage | None = None
) -> Iterable[TOutput]:
catalog = self.source.discover(self.config)
stream = get_stream(catalog, stream_name)
diff --git a/airbyte_cdk/sources/embedded/catalog.py b/airbyte_cdk/sources/embedded/catalog.py
index 62c7a623d..935bfca97 100644
--- a/airbyte_cdk/sources/embedded/catalog.py
+++ b/airbyte_cdk/sources/embedded/catalog.py
@@ -2,7 +2,6 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import List, Optional
from airbyte_cdk.models import (
AirbyteCatalog,
@@ -15,11 +14,11 @@
from airbyte_cdk.sources.embedded.tools import get_first
-def get_stream(catalog: AirbyteCatalog, stream_name: str) -> Optional[AirbyteStream]:
+def get_stream(catalog: AirbyteCatalog, stream_name: str) -> AirbyteStream | None:
return get_first(catalog.streams, lambda s: s.name == stream_name)
-def get_stream_names(catalog: AirbyteCatalog) -> List[str]:
+def get_stream_names(catalog: AirbyteCatalog) -> list[str]:
return [stream.name for stream in catalog.streams]
@@ -27,8 +26,8 @@ def to_configured_stream(
stream: AirbyteStream,
sync_mode: SyncMode = SyncMode.full_refresh,
destination_sync_mode: DestinationSyncMode = DestinationSyncMode.append,
- cursor_field: Optional[List[str]] = None,
- primary_key: Optional[List[List[str]]] = None,
+ cursor_field: list[str] | None = None,
+ primary_key: list[list[str]] | None = None,
) -> ConfiguredAirbyteStream:
return ConfiguredAirbyteStream(
stream=stream,
@@ -40,7 +39,7 @@ def to_configured_stream(
def to_configured_catalog(
- configured_streams: List[ConfiguredAirbyteStream],
+ configured_streams: list[ConfiguredAirbyteStream],
) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog(streams=configured_streams)
diff --git a/airbyte_cdk/sources/embedded/runner.py b/airbyte_cdk/sources/embedded/runner.py
index 43217f156..da45b6e21 100644
--- a/airbyte_cdk/sources/embedded/runner.py
+++ b/airbyte_cdk/sources/embedded/runner.py
@@ -5,7 +5,8 @@
import logging
from abc import ABC, abstractmethod
-from typing import Generic, Iterable, Optional
+from collections.abc import Iterable
+from typing import Generic
from airbyte_cdk.connector import TConfig
from airbyte_cdk.models import (
@@ -32,13 +33,13 @@ def read(
self,
config: TConfig,
catalog: ConfiguredAirbyteCatalog,
- state: Optional[AirbyteStateMessage],
+ state: AirbyteStateMessage | None,
) -> Iterable[AirbyteMessage]:
pass
class CDKRunner(SourceRunner[TConfig]):
- def __init__(self, source: Source, name: str):
+ def __init__(self, source: Source, name: str): # noqa: ANN204
self._source = source
self._logger = logging.getLogger(name)
@@ -52,6 +53,6 @@ def read(
self,
config: TConfig,
catalog: ConfiguredAirbyteCatalog,
- state: Optional[AirbyteStateMessage],
+ state: AirbyteStateMessage | None,
) -> Iterable[AirbyteMessage]:
return self._source.read(self._logger, config, catalog, state=[state] if state else [])
diff --git a/airbyte_cdk/sources/embedded/tools.py b/airbyte_cdk/sources/embedded/tools.py
index 1ddb29b3a..dbfd1de07 100644
--- a/airbyte_cdk/sources/embedded/tools.py
+++ b/airbyte_cdk/sources/embedded/tools.py
@@ -2,7 +2,8 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, Callable, Dict, Iterable, Optional
+from collections.abc import Callable, Iterable
+from typing import Any
import dpath
@@ -10,12 +11,13 @@
def get_first(
- iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True
-) -> Optional[Any]:
+ iterable: Iterable[Any],
+ predicate: Callable[[Any], bool] = lambda m: True, # noqa: ARG005
+) -> Any | None: # noqa: ANN401
return next(filter(predicate, iterable), None)
-def get_defined_id(stream: AirbyteStream, data: Dict[str, Any]) -> Optional[str]:
+def get_defined_id(stream: AirbyteStream, data: dict[str, Any]) -> str | None:
if not stream.source_defined_primary_key:
return None
primary_key = []
diff --git a/airbyte_cdk/sources/file_based/__init__.py b/airbyte_cdk/sources/file_based/__init__.py
index 6ea0ca31e..81bd9d8ec 100644
--- a/airbyte_cdk/sources/file_based/__init__.py
+++ b/airbyte_cdk/sources/file_based/__init__.py
@@ -3,11 +3,12 @@
from .config.file_based_stream_config import FileBasedStreamConfig
from .config.jsonl_format import JsonlFormat
from .exceptions import CustomFileBasedException, ErrorListingFiles, FileBasedSourceError
-from .file_based_source import DEFAULT_CONCURRENCY, FileBasedSource
+from .file_based_source import DEFAULT_CONCURRENCY, FileBasedSource # noqa: F401
from .file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from .remote_file import RemoteFile
from .stream.cursor import DefaultFileBasedCursor
+
__all__ = [
"AbstractFileBasedSpec",
"AbstractFileBasedStreamReader",
diff --git a/airbyte_cdk/sources/file_based/availability_strategy/__init__.py b/airbyte_cdk/sources/file_based/availability_strategy/__init__.py
index 8134a89e0..2356a9a2d 100644
--- a/airbyte_cdk/sources/file_based/availability_strategy/__init__.py
+++ b/airbyte_cdk/sources/file_based/availability_strategy/__init__.py
@@ -4,6 +4,7 @@
)
from .default_file_based_availability_strategy import DefaultFileBasedAvailabilityStrategy
+
__all__ = [
"AbstractFileBasedAvailabilityStrategy",
"AbstractFileBasedAvailabilityStrategyWrapper",
diff --git a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py
index 12e1740b6..ee83a23d0 100644
--- a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py
+++ b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py
@@ -6,9 +6,9 @@
import logging
from abc import abstractmethod
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING
-from airbyte_cdk.sources import Source
+from airbyte_cdk.sources import Source # noqa: TC001
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AbstractAvailabilityStrategy,
@@ -16,7 +16,8 @@
StreamAvailable,
StreamUnavailable,
)
-from airbyte_cdk.sources.streams.core import Stream
+from airbyte_cdk.sources.streams.core import Stream # noqa: TC001
+
if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
@@ -28,8 +29,8 @@ def check_availability( # type: ignore[override] # Signature doesn't match bas
self,
stream: Stream,
logger: logging.Logger,
- _: Optional[Source],
- ) -> Tuple[bool, Optional[str]]:
+ _: Source | None,
+ ) -> tuple[bool, str | None]:
"""
Perform a connection check for the stream.
@@ -42,8 +43,8 @@ def check_availability_and_parsability(
self,
stream: AbstractFileBasedStream,
logger: logging.Logger,
- _: Optional[Source],
- ) -> Tuple[bool, Optional[str]]:
+ _: Source | None,
+ ) -> tuple[bool, str | None]:
"""
Performs a connection check for the stream, as well as additional checks that
verify that the connection is working as expected.
@@ -65,9 +66,7 @@ def check_availability(self, logger: logging.Logger) -> StreamAvailability:
return StreamAvailable()
return StreamUnavailable(reason or "")
- def check_availability_and_parsability(
- self, logger: logging.Logger
- ) -> Tuple[bool, Optional[str]]:
+ def check_availability_and_parsability(self, logger: logging.Logger) -> tuple[bool, str | None]:
return self.stream.availability_strategy.check_availability_and_parsability(
self.stream, logger, None
)
diff --git a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py
index c9d416a72..b6304a50a 100644
--- a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py
+++ b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py
@@ -2,14 +2,14 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from __future__ import annotations
+from __future__ import annotations # noqa: I001
import logging
import traceback
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING
from airbyte_cdk import AirbyteTracedException
-from airbyte_cdk.sources import Source
+from airbyte_cdk.sources import Source # noqa: TC001
from airbyte_cdk.sources.file_based.availability_strategy import (
AbstractFileBasedAvailabilityStrategy,
)
@@ -18,10 +18,11 @@
CustomFileBasedException,
FileBasedSourceError,
)
-from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader
-from airbyte_cdk.sources.file_based.remote_file import RemoteFile
+from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader # noqa: TC001
+from airbyte_cdk.sources.file_based.remote_file import RemoteFile # noqa: TC001
from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema
+
if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
@@ -33,9 +34,9 @@ def __init__(self, stream_reader: AbstractFileBasedStreamReader) -> None:
def check_availability( # type: ignore[override] # Signature doesn't match base class
self,
stream: AbstractFileBasedStream,
- logger: logging.Logger,
- _: Optional[Source],
- ) -> Tuple[bool, Optional[str]]:
+ logger: logging.Logger, # noqa: ARG002
+ _: Source | None,
+ ) -> tuple[bool, str | None]:
"""
Perform a connection check for the stream (verify that we can list files from the stream).
@@ -52,8 +53,8 @@ def check_availability_and_parsability(
self,
stream: AbstractFileBasedStream,
logger: logging.Logger,
- _: Optional[Source],
- ) -> Tuple[bool, Optional[str]]:
+ _: Source | None,
+ ) -> tuple[bool, str | None]:
"""
Perform a connection check for the stream.
@@ -77,14 +78,14 @@ def check_availability_and_parsability(
return False, config_check_error_message
try:
file = self._check_list_files(stream)
- if not parser.parser_max_n_files_for_parsability == 0:
+ if parser.parser_max_n_files_for_parsability != 0:
self._check_parse_record(stream, file, logger)
else:
# If the parser is set to not check parsability, we still want to check that we can open the file.
handle = stream.stream_reader.open_file(file, parser.file_read_mode, None, logger)
handle.close()
except AirbyteTracedException as ate:
- raise ate
+ raise ate # noqa: TRY201
except CheckAvailabilityError:
return False, "".join(traceback.format_exc())
@@ -99,7 +100,7 @@ def _check_list_files(self, stream: AbstractFileBasedStream) -> RemoteFile:
try:
file = next(iter(stream.get_files()))
except StopIteration:
- raise CheckAvailabilityError(FileBasedSourceError.EMPTY_STREAM, stream=stream.name)
+ raise CheckAvailabilityError(FileBasedSourceError.EMPTY_STREAM, stream=stream.name) # noqa: B904
except CustomFileBasedException as exc:
raise CheckAvailabilityError(str(exc), stream=stream.name) from exc
except Exception as exc:
@@ -131,14 +132,14 @@ def _check_parse_record(
# we skip the schema validation check.
return
except AirbyteTracedException as ate:
- raise ate
+ raise ate # noqa: TRY201
except Exception as exc:
raise CheckAvailabilityError(
FileBasedSourceError.ERROR_READING_FILE, stream=stream.name, file=file.uri
) from exc
schema = stream.catalog_schema or stream.config.input_schema
- if schema and stream.validation_policy.validate_schema_before_sync:
+ if schema and stream.validation_policy.validate_schema_before_sync: # noqa: SIM102
if not conforms_to_schema(record, schema): # type: ignore
raise CheckAvailabilityError(
FileBasedSourceError.ERROR_VALIDATING_RECORD,
@@ -146,4 +147,4 @@ def _check_parse_record(
file=file.uri,
)
- return None
+ return
diff --git a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py
index 626d50fef..26fd8da14 100644
--- a/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py
+++ b/airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py
@@ -4,7 +4,7 @@
import copy
from abc import abstractmethod
-from typing import Any, Dict, List, Literal, Optional, Union
+from typing import Any, Literal
import dpath
from pydantic.v1 import AnyUrl, BaseModel, Field
@@ -49,7 +49,7 @@ class AbstractFileBasedSpec(BaseModel):
that are needed when users configure a file-based source.
"""
- start_date: Optional[str] = Field(
+ start_date: str | None = Field(
title="Start Date",
description="UTC date and time in the format 2017-01-25T00:00:00.000000Z. Any file modified before this date will not be replicated.",
examples=["2021-01-01T00:00:00.000000Z"],
@@ -59,13 +59,13 @@ class AbstractFileBasedSpec(BaseModel):
order=1,
)
- streams: List[FileBasedStreamConfig] = Field(
+ streams: list[FileBasedStreamConfig] = Field(
title="The list of streams to sync",
description='Each instance of this configuration defines a stream. Use this to define which files belong in the stream, their format, and how they should be parsed and validated. When sending data to warehouse destination such as Snowflake or BigQuery, each stream is a separate table.',
order=10,
)
- delivery_method: Union[DeliverRecords, DeliverRawFiles] = Field(
+ delivery_method: DeliverRecords | DeliverRawFiles = Field(
title="Delivery Method",
discriminator="delivery_type",
type="object",
@@ -84,12 +84,12 @@ def documentation_url(cls) -> AnyUrl:
"""
@classmethod
- def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
+ def schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
"""
Generates the mapping comprised of the config fields
"""
schema = super().schema(*args, **kwargs)
- transformed_schema: Dict[str, Any] = copy.deepcopy(schema)
+ transformed_schema: dict[str, Any] = copy.deepcopy(schema)
schema_helpers.expand_refs(transformed_schema)
cls.replace_enum_allOf_and_anyOf(transformed_schema)
cls.remove_discriminator(transformed_schema)
@@ -97,12 +97,12 @@ def schema(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return transformed_schema
@staticmethod
- def remove_discriminator(schema: Dict[str, Any]) -> None:
+ def remove_discriminator(schema: dict[str, Any]) -> None:
"""pydantic adds "discriminator" to the schema for oneOfs, which is not treated right by the platform as we inline all references"""
dpath.delete(schema, "properties/**/discriminator")
@staticmethod
- def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]:
+ def replace_enum_allOf_and_anyOf(schema: dict[str, Any]) -> dict[str, Any]: # noqa: N802
"""
allOfs are not supported by the UI, but pydantic is automatically writing them for enums.
Unpacks the enums under allOf and moves them up a level under the enum key
@@ -112,7 +112,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]:
objects_to_check = schema["properties"]["streams"]["items"]["properties"]["format"]
objects_to_check["type"] = "object"
objects_to_check["oneOf"] = objects_to_check.pop("anyOf", [])
- for format in objects_to_check["oneOf"]:
+ for format in objects_to_check["oneOf"]: # noqa: A001
for key in format["properties"]:
object_property = format["properties"][key]
AbstractFileBasedSpec.move_enum_to_root(object_property)
@@ -133,7 +133,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]:
csv_format_schemas = list(
filter(
- lambda format: format["properties"]["filetype"]["default"] == "csv",
+ lambda format: format["properties"]["filetype"]["default"] == "csv", # noqa: A006
schema["properties"]["streams"]["items"]["properties"]["format"]["oneOf"],
)
)
@@ -146,7 +146,7 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]:
return schema
@staticmethod
- def move_enum_to_root(object_property: Dict[str, Any]) -> None:
+ def move_enum_to_root(object_property: dict[str, Any]) -> None:
if "allOf" in object_property and "enum" in object_property["allOf"][0]:
object_property["enum"] = object_property["allOf"][0]["enum"]
object_property.pop("allOf")
diff --git a/airbyte_cdk/sources/file_based/config/csv_format.py b/airbyte_cdk/sources/file_based/config/csv_format.py
index 1441d8411..9ce100892 100644
--- a/airbyte_cdk/sources/file_based/config/csv_format.py
+++ b/airbyte_cdk/sources/file_based/config/csv_format.py
@@ -4,7 +4,7 @@
import codecs
from enum import Enum
-from typing import Any, Dict, List, Optional, Set, Union
+from typing import Any
from pydantic.v1 import BaseModel, Field, root_validator, validator
from pydantic.v1.error_wrappers import ValidationError
@@ -60,7 +60,7 @@ class Config(OneOfOptionConfig):
CsvHeaderDefinitionType.USER_PROVIDED.value,
const=True,
)
- column_names: List[str] = Field(
+ column_names: list[str] = Field(
title="Column Names",
description="The column names that will be used while emitting the CSV records",
)
@@ -69,7 +69,7 @@ def has_header_row(self) -> bool:
return False
@validator("column_names")
- def validate_column_names(cls, v: List[str]) -> List[str]:
+ def validate_column_names(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError(
"At least one column name needs to be provided when using user provided headers"
@@ -100,12 +100,12 @@ class Config(OneOfOptionConfig):
default='"',
description="The character used for quoting CSV values. To disallow quoting, make this field blank.",
)
- escape_char: Optional[str] = Field(
+ escape_char: str | None = Field(
title="Escape Character",
default=None,
description="The character used for escaping special characters. To disallow escaping, leave this field blank.",
)
- encoding: Optional[str] = Field(
+ encoding: str | None = Field(
default="utf8",
description='The character encoding of the CSV data. Leave blank to default to UTF8. See list of python encodings for allowable options.',
)
@@ -114,7 +114,7 @@ class Config(OneOfOptionConfig):
default=True,
description="Whether two quotes in a quoted CSV value denote a single quote in the data.",
)
- null_values: Set[str] = Field(
+ null_values: set[str] = Field(
title="Null Values",
default=[],
description="A set of case-sensitive strings that should be interpreted as null values. For example, if the value 'NA' should be interpreted as null, enter 'NA' in this field.",
@@ -134,19 +134,17 @@ class Config(OneOfOptionConfig):
default=0,
description="The number of rows to skip after the header row.",
)
- header_definition: Union[CsvHeaderFromCsv, CsvHeaderAutogenerated, CsvHeaderUserProvided] = (
- Field(
- title="CSV Header Definition",
- default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value),
- description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.",
- )
+ header_definition: CsvHeaderFromCsv | CsvHeaderAutogenerated | CsvHeaderUserProvided = Field(
+ title="CSV Header Definition",
+ default=CsvHeaderFromCsv(header_definition_type=CsvHeaderDefinitionType.FROM_CSV.value),
+ description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.",
)
- true_values: Set[str] = Field(
+ true_values: set[str] = Field(
title="True Values",
default=DEFAULT_TRUE_VALUES,
description="A set of case-sensitive strings that should be interpreted as true values.",
)
- false_values: Set[str] = Field(
+ false_values: set[str] = Field(
title="False Values",
default=DEFAULT_FALSE_VALUES,
description="A set of case-sensitive strings that should be interpreted as false values.",
@@ -190,11 +188,11 @@ def validate_encoding(cls, v: str) -> str:
try:
codecs.lookup(v)
except LookupError:
- raise ValueError(f"invalid encoding format: {v}")
+ raise ValueError(f"invalid encoding format: {v}") # noqa: B904
return v
@root_validator
- def validate_optional_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ def validate_optional_args(cls, values: dict[str, Any]) -> dict[str, Any]:
definition_type = values.get("header_definition_type")
column_names = values.get("user_provided_column_names")
if definition_type == CsvHeaderDefinitionType.USER_PROVIDED and not column_names:
diff --git a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py
index eb592a4aa..7b481c9df 100644
--- a/airbyte_cdk/sources/file_based/config/file_based_stream_config.py
+++ b/airbyte_cdk/sources/file_based/config/file_based_stream_config.py
@@ -2,8 +2,9 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
+from collections.abc import Mapping
from enum import Enum
-from typing import Any, List, Mapping, Optional, Union
+from typing import Any
from pydantic.v1 import BaseModel, Field, validator
@@ -16,7 +17,8 @@
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError
from airbyte_cdk.sources.file_based.schema_helpers import type_mapping_to_jsonschema
-PrimaryKeyType = Optional[Union[str, List[str]]]
+
+PrimaryKeyType = str | list[str] | None
class ValidationPolicy(Enum):
@@ -27,13 +29,13 @@ class ValidationPolicy(Enum):
class FileBasedStreamConfig(BaseModel):
name: str = Field(title="Name", description="The name of the stream.")
- globs: Optional[List[str]] = Field(
+ globs: list[str] | None = Field(
default=["**"],
title="Globs",
description='The pattern used to specify which files should be selected from the file system. For more information on glob pattern matching look here.',
order=1,
)
- legacy_prefix: Optional[str] = Field(
+ legacy_prefix: str | None = Field(
title="Legacy Prefix",
description="The path prefix configured in v3 versions of the S3 connector. This option is deprecated in favor of a single glob.",
airbyte_hidden=True,
@@ -43,11 +45,11 @@ class FileBasedStreamConfig(BaseModel):
description="The name of the validation policy that dictates sync behavior when a record does not adhere to the stream schema.",
default=ValidationPolicy.emit_record,
)
- input_schema: Optional[str] = Field(
+ input_schema: str | None = Field(
title="Input Schema",
description="The schema that will be used to validate records extracted from the file. This will override the stream schema that is auto-detected from incoming files.",
)
- primary_key: Optional[str] = Field(
+ primary_key: str | None = Field(
title="Primary Key",
description="The column or columns (for a composite key) that serves as the unique identifier of a record. If empty, the primary key will default to the parser's default primary key.",
airbyte_hidden=True, # Users can create/modify primary keys in the connection configuration so we shouldn't duplicate it here.
@@ -57,9 +59,9 @@ class FileBasedStreamConfig(BaseModel):
description="When the state history of the file store is full, syncs will only read files that were last modified in the provided day range.",
default=3,
)
- format: Union[
- AvroFormat, CsvFormat, JsonlFormat, ParquetFormat, UnstructuredFormat, ExcelFormat
- ] = Field(
+ format: (
+ AvroFormat | CsvFormat | JsonlFormat | ParquetFormat | UnstructuredFormat | ExcelFormat
+ ) = Field(
title="Format",
description="The configuration options that are used to alter how to read incoming files that deviate from the standard formatting.",
)
@@ -68,7 +70,7 @@ class FileBasedStreamConfig(BaseModel):
description="When enabled, syncs will not validate or structure records against the stream's schema.",
default=False,
)
- recent_n_files_to_read_for_schema_discovery: Optional[int] = Field(
+ recent_n_files_to_read_for_schema_discovery: int | None = Field(
title="Files To Read For Schema Discover",
description="The number of resent files which will be used to discover the schema for this stream.",
default=None,
@@ -76,15 +78,14 @@ class FileBasedStreamConfig(BaseModel):
)
@validator("input_schema", pre=True)
- def validate_input_schema(cls, v: Optional[str]) -> Optional[str]:
+ def validate_input_schema(cls, v: str | None) -> str | None:
if v:
if type_mapping_to_jsonschema(v):
return v
- else:
- raise ConfigValidationError(FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA)
+ raise ConfigValidationError(FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA)
return None
- def get_input_schema(self) -> Optional[Mapping[str, Any]]:
+ def get_input_schema(self) -> Mapping[str, Any] | None:
"""
User defined input_schema is defined as a string in the config. This method takes the string representation
and converts it into a Mapping[str, Any] which is used by file-based CDK components.
diff --git a/airbyte_cdk/sources/file_based/config/unstructured_format.py b/airbyte_cdk/sources/file_based/config/unstructured_format.py
index c03540ce6..7f8ae9369 100644
--- a/airbyte_cdk/sources/file_based/config/unstructured_format.py
+++ b/airbyte_cdk/sources/file_based/config/unstructured_format.py
@@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import List, Literal, Optional, Union
+from typing import Literal
from pydantic.v1 import BaseModel, Field
@@ -50,7 +50,7 @@ class APIProcessingConfigModel(BaseModel):
examples=["https://api.unstructured.com"],
)
- parameters: Optional[List[APIParameterConfigModel]] = Field(
+ parameters: list[APIParameterConfigModel] | None = Field(
default=[],
always_show=True,
title="Additional URL Parameters",
@@ -90,10 +90,7 @@ class Config(OneOfOptionConfig):
description="The strategy used to parse documents. `fast` extracts text directly from the document which doesn't work for all files. `ocr_only` is more reliable, but slower. `hi_res` is the most reliable, but requires an API key and a hosted instance of unstructured and can't be used with local mode. See the unstructured.io documentation for more details: https://unstructured-io.github.io/unstructured/core/partition.html#partition-pdf",
)
- processing: Union[
- LocalProcessingConfigModel,
- APIProcessingConfigModel,
- ] = Field(
+ processing: LocalProcessingConfigModel | APIProcessingConfigModel = Field(
default=LocalProcessingConfigModel(mode="local"),
title="Processing",
description="Processing configuration",
diff --git a/airbyte_cdk/sources/file_based/discovery_policy/__init__.py b/airbyte_cdk/sources/file_based/discovery_policy/__init__.py
index 6d0f231a3..7d5bdd887 100644
--- a/airbyte_cdk/sources/file_based/discovery_policy/__init__.py
+++ b/airbyte_cdk/sources/file_based/discovery_policy/__init__.py
@@ -5,4 +5,5 @@
DefaultDiscoveryPolicy,
)
+
__all__ = ["AbstractDiscoveryPolicy", "DefaultDiscoveryPolicy"]
diff --git a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py
index f651c2ce1..24a0d9158 100644
--- a/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py
+++ b/airbyte_cdk/sources/file_based/discovery_policy/default_discovery_policy.py
@@ -7,6 +7,7 @@
)
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
+
DEFAULT_N_CONCURRENT_REQUESTS = 10
DEFAULT_MAX_N_FILES_FOR_STREAM_SCHEMA_INFERENCE = 10
diff --git a/airbyte_cdk/sources/file_based/exceptions.py b/airbyte_cdk/sources/file_based/exceptions.py
index b0d38947f..f2af7c2a8 100644
--- a/airbyte_cdk/sources/file_based/exceptions.py
+++ b/airbyte_cdk/sources/file_based/exceptions.py
@@ -3,7 +3,7 @@
#
from enum import Enum
-from typing import Any, List, Union
+from typing import Any
from airbyte_cdk.models import AirbyteMessage, FailureType
from airbyte_cdk.utils import AirbyteTracedException
@@ -43,9 +43,9 @@ class FileBasedErrorsCollector:
The placeholder for all errors collected.
"""
- errors: List[AirbyteMessage] = []
+ errors: list[AirbyteMessage] = [] # noqa: RUF012
- def yield_and_raise_collected(self) -> Any:
+ def yield_and_raise_collected(self) -> Any: # noqa: ANN401
if self.errors:
# emit collected logged messages
yield from self.errors
@@ -63,7 +63,7 @@ def collect(self, logged_error: AirbyteMessage) -> None:
class BaseFileBasedSourceError(Exception):
- def __init__(self, error: Union[FileBasedSourceError, str], **kwargs): # type: ignore # noqa
+ def __init__(self, error: FileBasedSourceError | str, **kwargs) -> None: # type: ignore
if isinstance(error, FileBasedSourceError):
error = FileBasedSourceError(error).value
super().__init__(
@@ -112,7 +112,7 @@ class ErrorListingFiles(BaseFileBasedSourceError):
class DuplicatedFilesError(BaseFileBasedSourceError):
- def __init__(self, duplicated_files_names: List[dict[str, List[str]]], **kwargs: Any):
+ def __init__(self, duplicated_files_names: list[dict[str, list[str]]], **kwargs) -> None:
self._duplicated_files_names = duplicated_files_names
self._stream_name: str = kwargs["stream"]
super().__init__(self._format_duplicate_files_error_message(), **kwargs)
diff --git a/airbyte_cdk/sources/file_based/file_based_source.py b/airbyte_cdk/sources/file_based/file_based_source.py
index 0eb90ac24..99fa3bcc1 100644
--- a/airbyte_cdk/sources/file_based/file_based_source.py
+++ b/airbyte_cdk/sources/file_based/file_based_source.py
@@ -6,7 +6,8 @@
import traceback
from abc import ABC
from collections import Counter
-from typing import Any, Iterator, List, Mapping, Optional, Tuple, Type, Union
+from collections.abc import Iterator, Mapping
+from typing import Any
from pydantic.v1.error_wrappers import ValidationError
@@ -63,6 +64,7 @@
from airbyte_cdk.utils.analytics_message import create_analytics_message
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
DEFAULT_CONCURRENCY = 100
MAX_CONCURRENCY = 100
INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2
@@ -72,21 +74,21 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
# We make each source override the concurrency level to give control over when they are upgraded.
_concurrency_level = None
- def __init__(
+ def __init__( # noqa: ANN204, PLR0913, PLR0917
self,
stream_reader: AbstractFileBasedStreamReader,
- spec_class: Type[AbstractFileBasedSpec],
- catalog: Optional[ConfiguredAirbyteCatalog],
- config: Optional[Mapping[str, Any]],
- state: Optional[List[AirbyteStateMessage]],
- availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
- discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
- parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
+ spec_class: type[AbstractFileBasedSpec],
+ catalog: ConfiguredAirbyteCatalog | None,
+ config: Mapping[str, Any] | None,
+ state: list[AirbyteStateMessage] | None,
+ availability_strategy: AbstractFileBasedAvailabilityStrategy | None = None,
+ discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(), # noqa: B008
+ parsers: Mapping[type[Any], FileTypeParser] = default_parsers,
validation_policies: Mapping[
ValidationPolicy, AbstractSchemaValidationPolicy
] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
- cursor_cls: Type[
- Union[AbstractConcurrentFileBasedCursor, AbstractFileBasedCursor]
+ cursor_cls: type[
+ AbstractConcurrentFileBasedCursor | AbstractFileBasedCursor
] = FileBasedConcurrentCursor,
):
self.stream_reader = stream_reader
@@ -106,7 +108,7 @@ def __init__(
self.cursor_cls = cursor_cls
self.logger = init_logger(f"airbyte.{self.name}")
self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector()
- self._message_repository: Optional[MessageRepository] = None
+ self._message_repository: MessageRepository | None = None
concurrent_source = ConcurrentSource.create(
MAX_CONCURRENCY,
INITIAL_N_PARTITIONS,
@@ -127,7 +129,7 @@ def message_repository(self) -> MessageRepository:
def check_connection(
self, logger: logging.Logger, config: Mapping[str, Any]
- ) -> Tuple[bool, Optional[Any]]:
+ ) -> tuple[bool, Any | None]:
"""
Check that the source can be accessed using the user-provided configuration.
@@ -140,7 +142,7 @@ def check_connection(
try:
streams = self.streams(config)
except Exception as config_exception:
- raise AirbyteTracedException(
+ raise AirbyteTracedException( # noqa: B904
internal_message="Please check the logged errors for more information.",
message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value,
exception=AirbyteTracedException(exception=config_exception),
@@ -158,7 +160,7 @@ def check_connection(
tracebacks = []
for stream in streams:
if not isinstance(stream, AbstractFileBasedStream):
- raise ValueError(f"Stream {stream} is not a file-based stream.")
+ raise ValueError(f"Stream {stream} is not a file-based stream.") # noqa: TRY004
try:
parsed_config = self._get_parsed_config(config)
availability_method = (
@@ -191,7 +193,7 @@ def check_connection(
message=f"{errors[0]}",
failure_type=FailureType.config_error,
)
- elif len(errors) > 1:
+ if len(errors) > 1:
raise AirbyteTracedException(
internal_message="\n".join(tracebacks),
message=f"{len(errors)} streams with errors: {', '.join(error for error in errors)}",
@@ -200,12 +202,12 @@ def check_connection(
return not bool(errors), (errors or None)
- def streams(self, config: Mapping[str, Any]) -> List[Stream]:
+ def streams(self, config: Mapping[str, Any]) -> list[Stream]:
"""
Return a list of this source's streams.
"""
- if self.catalog:
+ if self.catalog: # noqa: SIM108
state_manager = ConnectorStateManager(state=self.state)
else:
# During `check` operations we don't have a catalog so cannot create a state manager.
@@ -215,7 +217,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
try:
parsed_config = self._get_parsed_config(config)
self.stream_reader.config = parsed_config
- streams: List[Stream] = []
+ streams: list[Stream] = []
for stream_config in parsed_config.streams:
# Like state_manager, `catalog_stream` may be None during `check`
catalog_stream = self._get_stream_from_catalog(stream_config)
@@ -256,9 +258,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
and hasattr(self, "_concurrency_level")
and self._concurrency_level is not None
):
- assert (
- state_manager is not None
- ), "No ConnectorStateManager was created, but it is required for incremental syncs. This is unexpected. Please contact Support."
+ assert state_manager is not None, (
+ "No ConnectorStateManager was created, but it is required for incremental syncs. This is unexpected. Please contact Support."
+ )
cursor = self.cursor_cls(
stream_config,
@@ -289,7 +291,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
)
streams.append(stream)
- return streams
+ return streams # noqa: TRY300
except ValidationError as exc:
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc
@@ -297,7 +299,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
def _make_default_stream(
self,
stream_config: FileBasedStreamConfig,
- cursor: Optional[AbstractFileBasedCursor],
+ cursor: AbstractFileBasedCursor | None,
parsed_config: AbstractFileBasedSpec,
) -> AbstractFileBasedStream:
return DefaultFileBasedStream(
@@ -316,14 +318,14 @@ def _make_default_stream(
def _get_stream_from_catalog(
self, stream_config: FileBasedStreamConfig
- ) -> Optional[AirbyteStream]:
+ ) -> AirbyteStream | None:
if self.catalog:
for stream in self.catalog.streams or []:
if stream.stream.name == stream_config.name:
return stream.stream
return None
- def _get_sync_mode_from_catalog(self, stream_name: str) -> Optional[SyncMode]:
+ def _get_sync_mode_from_catalog(self, stream_name: str) -> SyncMode | None:
if self.catalog:
for catalog_stream in self.catalog.streams:
if stream_name == catalog_stream.stream.name:
@@ -336,7 +338,7 @@ def read(
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
- state: Optional[List[AirbyteStateMessage]] = None,
+ state: list[AirbyteStateMessage] | None = None,
) -> Iterator[AirbyteMessage]:
yield from super().read(logger, config, catalog, state)
# emit all the errors collected
@@ -348,7 +350,7 @@ def read(
).items():
yield create_analytics_message(f"file-cdk-{parser}-stream-count", count)
- def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification:
+ def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification: # noqa: ANN401, ARG002
"""
Returns the specification describing what fields can be configured by a user when setting up a file-based source.
"""
diff --git a/airbyte_cdk/sources/file_based/file_based_stream_reader.py b/airbyte_cdk/sources/file_based/file_based_stream_reader.py
index 065125621..fcb6e0a31 100644
--- a/airbyte_cdk/sources/file_based/file_based_stream_reader.py
+++ b/airbyte_cdk/sources/file_based/file_based_stream_reader.py
@@ -4,11 +4,12 @@
import logging
from abc import ABC, abstractmethod
+from collections.abc import Iterable
from datetime import datetime
from enum import Enum
from io import IOBase
from os import makedirs, path
-from typing import Any, Dict, Iterable, List, Optional, Set
+from typing import Any
from wcmatch.glob import GLOBSTAR, globmatch
@@ -28,7 +29,7 @@ def __init__(self) -> None:
self._config = None
@property
- def config(self) -> Optional[AbstractFileBasedSpec]:
+ def config(self) -> AbstractFileBasedSpec | None:
return self._config
@config.setter
@@ -47,7 +48,7 @@ def config(self, value: AbstractFileBasedSpec) -> None:
@abstractmethod
def open_file(
- self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger
+ self, file: RemoteFile, mode: FileReadMode, encoding: str | None, logger: logging.Logger
) -> IOBase:
"""
Return a file handle for reading.
@@ -63,8 +64,8 @@ def open_file(
@abstractmethod
def get_matching_files(
self,
- globs: List[str],
- prefix: Optional[str],
+ globs: list[str],
+ prefix: str | None,
logger: logging.Logger,
) -> Iterable[RemoteFile]:
"""
@@ -84,7 +85,7 @@ def get_matching_files(
...
def filter_files_by_globs_and_start_date(
- self, files: List[RemoteFile], globs: List[str]
+ self, files: list[RemoteFile], globs: list[str]
) -> Iterable[RemoteFile]:
"""
Utility method for filtering files based on globs.
@@ -97,7 +98,7 @@ def filter_files_by_globs_and_start_date(
seen = set()
for file in files:
- if self.file_matches_globs(file, globs):
+ if self.file_matches_globs(file, globs): # noqa: SIM102
if file.uri not in seen and (not start_date or file.last_modified >= start_date):
seen.add(file.uri)
yield file
@@ -113,13 +114,13 @@ def file_size(self, file: RemoteFile) -> int:
...
@staticmethod
- def file_matches_globs(file: RemoteFile, globs: List[str]) -> bool:
+ def file_matches_globs(file: RemoteFile, globs: list[str]) -> bool:
# Use the GLOBSTAR flag to enable recursive ** matching
# (https://facelessuser.github.io/wcmatch/wcmatch/#globstar)
return any(globmatch(file.uri, g, flags=GLOBSTAR) for g in globs)
@staticmethod
- def get_prefixes_from_globs(globs: List[str]) -> Set[str]:
+ def get_prefixes_from_globs(globs: list[str]) -> set[str]:
"""
Utility method for extracting prefixes from the globs.
"""
@@ -149,7 +150,7 @@ def preserve_directory_structure(self) -> bool:
@abstractmethod
def get_file(
self, file: RemoteFile, local_directory: str, logger: logging.Logger
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
This is required for connectors that will support writing to
files. It will handle the logic to download,get,read,acquire or
@@ -170,16 +171,16 @@ def get_file(
"""
...
- def _get_file_transfer_paths(self, file: RemoteFile, local_directory: str) -> List[str]:
+ def _get_file_transfer_paths(self, file: RemoteFile, local_directory: str) -> list[str]:
preserve_directory_structure = self.preserve_directory_structure()
if preserve_directory_structure:
# Remove left slashes from source path format to make relative path for writing locally
file_relative_path = file.uri.lstrip("/")
else:
- file_relative_path = path.basename(file.uri)
- local_file_path = path.join(local_directory, file_relative_path)
+ file_relative_path = path.basename(file.uri) # noqa: PTH119
+ local_file_path = path.join(local_directory, file_relative_path) # noqa: PTH118
# Ensure the local directory exists
- makedirs(path.dirname(local_file_path), exist_ok=True)
- absolute_file_path = path.abspath(local_file_path)
+ makedirs(path.dirname(local_file_path), exist_ok=True) # noqa: PTH103, PTH120
+ absolute_file_path = path.abspath(local_file_path) # noqa: PTH100
return [file_relative_path, local_file_path, absolute_file_path]
diff --git a/airbyte_cdk/sources/file_based/file_types/__init__.py b/airbyte_cdk/sources/file_based/file_types/__init__.py
index b9d8f1d52..50446f893 100644
--- a/airbyte_cdk/sources/file_based/file_types/__init__.py
+++ b/airbyte_cdk/sources/file_based/file_types/__init__.py
@@ -1,11 +1,5 @@
-from typing import Any, Mapping, Type
-
-from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
-from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
-from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat
-from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat
-from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat
-from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat
+from collections.abc import Mapping
+from typing import Any, Type # noqa: F401, UP035
from .avro_parser import AvroParser
from .csv_parser import CsvParser
@@ -15,8 +9,15 @@
from .jsonl_parser import JsonlParser
from .parquet_parser import ParquetParser
from .unstructured_parser import UnstructuredParser
+from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
+from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
+from airbyte_cdk.sources.file_based.config.excel_format import ExcelFormat
+from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat
+from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat
+from airbyte_cdk.sources.file_based.config.unstructured_format import UnstructuredFormat
+
-default_parsers: Mapping[Type[Any], FileTypeParser] = {
+default_parsers: Mapping[type[Any], FileTypeParser] = {
AvroFormat: AvroParser(),
CsvFormat: CsvParser(),
ExcelFormat: ExcelParser(),
diff --git a/airbyte_cdk/sources/file_based/file_types/avro_parser.py b/airbyte_cdk/sources/file_based/file_types/avro_parser.py
index e1aa2c4cb..6da06900b 100644
--- a/airbyte_cdk/sources/file_based/file_types/avro_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/avro_parser.py
@@ -3,7 +3,8 @@
#
import logging
-from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, cast
+from collections.abc import Iterable, Mapping
+from typing import Any, cast
import fastavro
@@ -18,6 +19,7 @@
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
+
AVRO_TYPE_TO_JSON_TYPE = {
"null": "null",
"boolean": "boolean",
@@ -46,7 +48,7 @@
class AvroParser(FileTypeParser):
ENCODING = None
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002
"""
AvroParser does not require config checks, implicit pydantic validation is enough.
"""
@@ -61,12 +63,12 @@ async def infer_schema(
) -> SchemaType:
avro_format = config.format
if not isinstance(avro_format, AvroFormat):
- raise ValueError(f"Expected ParquetFormat, got {avro_format}")
+ raise ValueError(f"Expected ParquetFormat, got {avro_format}") # noqa: TRY004
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
avro_reader = fastavro.reader(fp) # type: ignore [arg-type]
avro_schema = avro_reader.writer_schema
- if not avro_schema["type"] == "record": # type: ignore [index, call-overload]
+ if avro_schema["type"] != "record": # type: ignore [index, call-overload]
unsupported_type = avro_schema["type"] # type: ignore [index, call-overload]
raise ValueError(
f"Only record based avro files are supported. Found {unsupported_type}"
@@ -82,7 +84,7 @@ async def infer_schema(
return json_schema
@classmethod
- def _convert_avro_type_to_json(
+ def _convert_avro_type_to_json( # noqa: PLR0911, PLR0912
cls, avro_format: AvroFormat, field_name: str, avro_field: str
) -> Mapping[str, Any]:
if isinstance(avro_field, str) and avro_field in AVRO_TYPE_TO_JSON_TYPE:
@@ -101,7 +103,7 @@ def _convert_avro_type_to_json(
for object_field in avro_field["fields"]
},
}
- elif avro_field["type"] == "array":
+ if avro_field["type"] == "array":
if "items" not in avro_field:
raise ValueError(
f"{field_name} array type does not have a required field items"
@@ -112,7 +114,7 @@ def _convert_avro_type_to_json(
avro_format, "", avro_field["items"]
),
}
- elif avro_field["type"] == "enum":
+ if avro_field["type"] == "enum":
if "symbols" not in avro_field:
raise ValueError(
f"{field_name} enum type does not have a required field symbols"
@@ -120,7 +122,7 @@ def _convert_avro_type_to_json(
if "name" not in avro_field:
raise ValueError(f"{field_name} enum type does not have a required field name")
return {"type": "string", "enum": avro_field["symbols"]}
- elif avro_field["type"] == "map":
+ if avro_field["type"] == "map":
if "values" not in avro_field:
raise ValueError(f"{field_name} map type does not have a required field values")
return {
@@ -129,7 +131,7 @@ def _convert_avro_type_to_json(
avro_format, "", avro_field["values"]
),
}
- elif avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration":
+ if avro_field["type"] == "fixed" and avro_field.get("logicalType") != "duration":
if "size" not in avro_field:
raise ValueError(f"{field_name} fixed type does not have a required field size")
if not isinstance(avro_field["size"], int):
@@ -138,7 +140,7 @@ def _convert_avro_type_to_json(
"type": "string",
"pattern": f"^[0-9A-Fa-f]{{{avro_field['size'] * 2}}}$",
}
- elif avro_field.get("logicalType") == "decimal":
+ if avro_field.get("logicalType") == "decimal":
if "precision" not in avro_field:
raise ValueError(
f"{field_name} decimal type does not have a required field precision"
@@ -154,18 +156,16 @@ def _convert_avro_type_to_json(
# For example: ^-?\d{1,5}(?:\.\d{1,3})?$ would accept 12345.123 and 123456.12345 would be rejected
return {
"type": "string",
- "pattern": f"^-?\\d{{{1,max_whole_number_range}}}(?:\\.\\d{1,decimal_range})?$",
+ "pattern": f"^-?\\d{{{1, max_whole_number_range}}}(?:\\.\\d{1, decimal_range})?$",
}
- elif "logicalType" in avro_field:
+ if "logicalType" in avro_field:
if avro_field["logicalType"] not in AVRO_LOGICAL_TYPE_TO_JSON:
raise ValueError(
f"{avro_field['logicalType']} is not a valid Avro logical type"
)
return AVRO_LOGICAL_TYPE_TO_JSON[avro_field["logicalType"]]
- else:
- raise ValueError(f"Unsupported avro type: {avro_field}")
- else:
raise ValueError(f"Unsupported avro type: {avro_field}")
+ raise ValueError(f"Unsupported avro type: {avro_field}")
def parse_records(
self,
@@ -173,11 +173,11 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]],
- ) -> Iterable[Dict[str, Any]]:
+ discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002
+ ) -> Iterable[dict[str, Any]]:
avro_format = config.format or AvroFormat(filetype="avro")
if not isinstance(avro_format, AvroFormat):
- raise ValueError(f"Expected ParquetFormat, got {avro_format}")
+ raise ValueError(f"Expected ParquetFormat, got {avro_format}") # noqa: TRY004
line_no = 0
try:
@@ -185,7 +185,7 @@ def parse_records(
avro_reader = fastavro.reader(fp) # type: ignore [arg-type]
schema = avro_reader.writer_schema
schema_field_name_to_type = {
- field["name"]: cast(dict[str, Any], field["type"]) # type: ignore [index]
+ field["name"]: cast(dict[str, Any], field["type"]) # type: ignore [index] # noqa: TC006
for field in schema["fields"] # type: ignore [index, call-overload] # If schema is not dict, it is not subscriptable by strings
}
for record in avro_reader:
@@ -193,7 +193,7 @@ def parse_records(
yield {
record_field: self._to_output_value(
avro_format,
- schema_field_name_to_type[record_field], # type: ignore [index] # Any not subscriptable
+ schema_field_name_to_type[record_field], # type: ignore [index] # Any not subscriptable # noqa: PLR1733
record[record_field], # type: ignore [index] # Any not subscriptable
)
for record_field, record_value in schema_field_name_to_type.items()
@@ -208,26 +208,27 @@ def file_read_mode(self) -> FileReadMode:
return FileReadMode.READ_BINARY
@staticmethod
- def _to_output_value(
- avro_format: AvroFormat, record_type: Mapping[str, Any], record_value: Any
- ) -> Any:
+ def _to_output_value( # noqa: PLR0911
+ avro_format: AvroFormat,
+ record_type: Mapping[str, Any],
+ record_value: Any, # noqa: ANN401
+ ) -> Any: # noqa: ANN401
if isinstance(record_value, bytes):
return record_value.decode()
- elif not isinstance(record_type, Mapping):
+ if not isinstance(record_type, Mapping):
if record_type == "double" and avro_format.double_as_string:
return str(record_value)
return record_value
if record_type.get("logicalType") in ("decimal", "uuid"):
return str(record_value)
- elif record_type.get("logicalType") == "date":
+ if record_type.get("logicalType") == "date":
return record_value.isoformat()
- elif record_type.get("logicalType") == "timestamp-millis":
+ if record_type.get("logicalType") == "timestamp-millis":
return record_value.isoformat(sep="T", timespec="milliseconds")
- elif record_type.get("logicalType") == "timestamp-micros":
+ if record_type.get("logicalType") == "timestamp-micros":
return record_value.isoformat(sep="T", timespec="microseconds")
- elif record_type.get("logicalType") == "local-timestamp-millis":
+ if record_type.get("logicalType") == "local-timestamp-millis":
return record_value.isoformat(sep="T", timespec="milliseconds")
- elif record_type.get("logicalType") == "local-timestamp-micros":
+ if record_type.get("logicalType") == "local-timestamp-micros":
return record_value.isoformat(sep="T", timespec="microseconds")
- else:
- return record_value
+ return record_value
diff --git a/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte_cdk/sources/file_based/file_types/csv_parser.py
index e3010690e..cc56e29ee 100644
--- a/airbyte_cdk/sources/file_based/file_types/csv_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/csv_parser.py
@@ -7,9 +7,10 @@
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
+from collections.abc import Callable, Generator, Iterable, Mapping
from functools import partial
from io import IOBase
-from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple
+from typing import Any
from uuid import uuid4
import orjson
@@ -32,6 +33,7 @@
from airbyte_cdk.sources.file_based.schema_helpers import TYPE_PYTHON_MAPPING, SchemaType
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
DIALECT_NAME = "_config_dialect"
@@ -43,7 +45,7 @@ def read_data(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
file_read_mode: FileReadMode,
- ) -> Generator[Dict[str, Any], None, None]:
+ ) -> Generator[dict[str, Any], None, None]:
config_format = _extract_format(config)
lineno = 0
@@ -52,7 +54,7 @@ def read_data(
# Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up
# with a race condition where a thread attempts to use a dialect before a separate thread has finished
# registering it.
- dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}"
+ dialect_name = f"{config.name}_{uuid4()!s}_{DIALECT_NAME}"
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
@@ -65,7 +67,7 @@ def read_data(
try:
headers = self._get_headers(fp, config_format, dialect_name)
except UnicodeError:
- raise AirbyteTracedException(
+ raise AirbyteTracedException( # noqa: B904
message=f"{FileBasedSourceError.ENCODING_ERROR.value} Expected encoding: {config_format.encoding}",
)
@@ -111,7 +113,7 @@ def read_data(
# due to RecordParseError or GeneratorExit
csv.unregister_dialect(dialect_name)
- def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]:
+ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> list[str]:
"""
Assumes the fp is pointing to the beginning of the files and will reset it as such
"""
@@ -133,7 +135,7 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str)
fp.seek(0)
return headers
- def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> List[str]:
+ def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> list[str]:
"""
Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True.
See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html
@@ -154,14 +156,14 @@ def _skip_rows(fp: IOBase, rows_to_skip: int) -> None:
class CsvParser(FileTypeParser):
_MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000
- def __init__(self, csv_reader: Optional[_CsvReader] = None, csv_field_max_bytes: int = 2**31):
+ def __init__(self, csv_reader: _CsvReader | None = None, csv_field_max_bytes: int = 2**31): # noqa: ANN204
# Increase the maximum length of data that can be parsed in a single CSV field. The default is 128k, which is typically sufficient
# but given the use of Airbyte in loading a large variety of data it is best to allow for a larger maximum field size to avoid
# skipping data on load. https://stackoverflow.com/questions/15063936/csv-error-field-larger-than-field-limit-131072
csv.field_size_limit(csv_field_max_bytes)
- self._csv_reader = csv_reader if csv_reader else _CsvReader()
+ self._csv_reader = csv_reader if csv_reader else _CsvReader() # noqa: FURB110
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002
"""
CsvParser does not require config checks, implicit pydantic validation is enough.
"""
@@ -178,10 +180,10 @@ async def infer_schema(
if input_schema:
return input_schema
- # todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
+ # TODO: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
config_format = _extract_format(config)
- type_inferrer_by_field: Dict[str, _TypeInferrer] = defaultdict(
+ type_inferrer_by_field: dict[str, _TypeInferrer] = defaultdict(
lambda: _JsonTypeInferrer(
config_format.true_values, config_format.false_values, config_format.null_values
)
@@ -221,8 +223,8 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]],
- ) -> Iterable[Dict[str, Any]]:
+ discovered_schema: Mapping[str, SchemaType] | None,
+ ) -> Iterable[dict[str, Any]]:
line_no = 0
try:
config_format = _extract_format(config)
@@ -263,7 +265,7 @@ def _get_cast_function(
deduped_property_types: Mapping[str, str],
config_format: CsvFormat,
logger: logging.Logger,
- schemaless: bool,
+ schemaless: bool, # noqa: FBT001
) -> Callable[[Mapping[str, str]], Mapping[str, str]]:
# Only cast values if the schema is provided
if deduped_property_types and not schemaless:
@@ -273,17 +275,16 @@ def _get_cast_function(
config_format=config_format,
logger=logger,
)
- else:
- # If no schema is provided, yield the rows as they are
- return _no_cast
+ # If no schema is provided, yield the rows as they are
+ return _no_cast
@staticmethod
def _to_nullable(
row: Mapping[str, str],
deduped_property_types: Mapping[str, str],
- null_values: Set[str],
- strings_can_be_null: bool,
- ) -> Dict[str, Optional[str]]:
+ null_values: set[str],
+ strings_can_be_null: bool, # noqa: FBT001
+ ) -> dict[str, str | None]:
nullable = {
k: None
if CsvParser._value_is_none(
@@ -296,15 +297,15 @@ def _to_nullable(
@staticmethod
def _value_is_none(
- value: Any,
- deduped_property_type: Optional[str],
- null_values: Set[str],
- strings_can_be_null: bool,
+ value: Any, # noqa: ANN401
+ deduped_property_type: str | None,
+ null_values: set[str],
+ strings_can_be_null: bool, # noqa: FBT001
) -> bool:
return value in null_values and (strings_can_be_null or deduped_property_type != "string")
@staticmethod
- def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str, str]:
+ def _pre_propcess_property_types(property_types: dict[str, Any]) -> Mapping[str, str]:
"""
Transform the property types to be non-nullable and remove duplicate types if any.
Sample input:
@@ -335,11 +336,11 @@ def _pre_propcess_property_types(property_types: Dict[str, Any]) -> Mapping[str,
@staticmethod
def _cast_types(
- row: Dict[str, str],
+ row: dict[str, str],
deduped_property_types: Mapping[str, str],
config_format: CsvFormat,
logger: logging.Logger,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Casts the values in the input 'row' dictionary according to the types defined in the JSON schema.
@@ -358,7 +359,7 @@ def _cast_types(
_, python_type = TYPE_PYTHON_MAPPING[prop_type]
if python_type is None:
- if value == "":
+ if value == "": # noqa: PLC1901
cast_value = None
else:
warnings.append(_format_warning(key, value, prop_type))
@@ -401,7 +402,7 @@ def _cast_types(
class _TypeInferrer(ABC):
@abstractmethod
- def add_value(self, value: Any) -> None:
+ def add_value(self, value: Any) -> None: # noqa: ANN401
pass
@abstractmethod
@@ -410,7 +411,7 @@ def infer(self) -> str:
class _DisabledTypeInferrer(_TypeInferrer):
- def add_value(self, value: Any) -> None:
+ def add_value(self, value: Any) -> None: # noqa: ANN401
pass
def infer(self) -> str:
@@ -425,14 +426,14 @@ class _JsonTypeInferrer(_TypeInferrer):
_STRING_TYPE = "string"
def __init__(
- self, boolean_trues: Set[str], boolean_falses: Set[str], null_values: Set[str]
+ self, boolean_trues: set[str], boolean_falses: set[str], null_values: set[str]
) -> None:
self._boolean_trues = boolean_trues
self._boolean_falses = boolean_falses
self._null_values = null_values
- self._values: Set[str] = set()
+ self._values: set[str] = set()
- def add_value(self, value: Any) -> None:
+ def add_value(self, value: Any) -> None: # noqa: ANN401
self._values.add(value)
def infer(self) -> str:
@@ -447,13 +448,13 @@ def infer(self) -> str:
types = set.intersection(*types_excluding_null_values)
if self._BOOLEAN_TYPE in types:
return self._BOOLEAN_TYPE
- elif self._INTEGER_TYPE in types:
+ if self._INTEGER_TYPE in types:
return self._INTEGER_TYPE
- elif self._NUMBER_TYPE in types:
+ if self._NUMBER_TYPE in types:
return self._NUMBER_TYPE
return self._STRING_TYPE
- def _infer_type(self, value: str) -> Set[str]:
+ def _infer_type(self, value: str) -> set[str]:
inferred_types = set()
if value in self._null_values:
@@ -472,7 +473,7 @@ def _infer_type(self, value: str) -> Set[str]:
def _is_boolean(self, value: str) -> bool:
try:
_value_to_bool(value, self._boolean_trues, self._boolean_falses)
- return True
+ return True # noqa: TRY300
except ValueError:
return False
@@ -480,7 +481,7 @@ def _is_boolean(self, value: str) -> bool:
def _is_integer(value: str) -> bool:
try:
_value_to_python_type(value, int)
- return True
+ return True # noqa: TRY300
except ValueError:
return False
@@ -488,12 +489,12 @@ def _is_integer(value: str) -> bool:
def _is_number(value: str) -> bool:
try:
_value_to_python_type(value, float)
- return True
+ return True # noqa: TRY300
except ValueError:
return False
-def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> bool:
+def _value_to_bool(value: str, true_values: set[str], false_values: set[str]) -> bool:
if value in true_values:
return True
if value in false_values:
@@ -501,18 +502,18 @@ def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) ->
raise ValueError(f"Value {value} is not a valid boolean value")
-def _value_to_list(value: str) -> List[Any]:
+def _value_to_list(value: str) -> list[Any]:
parsed_value = json.loads(value)
if isinstance(parsed_value, list):
return parsed_value
raise ValueError(f"Value {parsed_value} is not a valid list value")
-def _value_to_python_type(value: str, python_type: type) -> Any:
+def _value_to_python_type(value: str, python_type: type) -> Any: # noqa: ANN401
return python_type(value)
-def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str:
+def _format_warning(key: str, value: str, expected_type: Any | None) -> str: # noqa: ANN401
return f"{key}: value={value},expected_type={expected_type}"
@@ -523,5 +524,5 @@ def _no_cast(row: Mapping[str, str]) -> Mapping[str, str]:
def _extract_format(config: FileBasedStreamConfig) -> CsvFormat:
config_format = config.format
if not isinstance(config_format, CsvFormat):
- raise ValueError(f"Invalid format config: {config_format}")
+ raise ValueError(f"Invalid format config: {config_format}") # noqa: TRY004
return config_format
diff --git a/airbyte_cdk/sources/file_based/file_types/excel_parser.py b/airbyte_cdk/sources/file_based/file_types/excel_parser.py
index 5a0332171..dd3117b7c 100644
--- a/airbyte_cdk/sources/file_based/file_types/excel_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/excel_parser.py
@@ -3,9 +3,10 @@
#
import logging
+from collections.abc import Iterable, Mapping
from io import IOBase
from pathlib import Path
-from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union
+from typing import Any
import orjson
import pandas as pd
@@ -34,7 +35,7 @@
class ExcelParser(FileTypeParser):
ENCODING = None
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002
"""
ExcelParser does not require config checks, implicit pydantic validation is enough.
"""
@@ -63,7 +64,7 @@ async def infer_schema(
# Validate the format of the config
self.validate_format(config.format, logger)
- fields: Dict[str, str] = {}
+ fields: dict[str, str] = {}
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
df = self.open_and_parse_file(fp)
@@ -91,8 +92,8 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]] = None,
- ) -> Iterable[Dict[str, Any]]:
+ discovered_schema: Mapping[str, SchemaType] | None = None, # noqa: ARG002
+ ) -> Iterable[dict[str, Any]]:
"""
Parses records from an Excel file based on the provided configuration.
@@ -140,7 +141,7 @@ def file_read_mode(self) -> FileReadMode:
@staticmethod
def dtype_to_json_type(
- current_type: Optional[str],
+ current_type: str | None,
dtype: dtype_, # type: ignore [type-arg]
) -> str:
"""
@@ -183,7 +184,7 @@ def validate_format(excel_format: BaseModel, logger: logging.Logger) -> None:
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR)
@staticmethod
- def open_and_parse_file(fp: Union[IOBase, str, Path]) -> pd.DataFrame:
+ def open_and_parse_file(fp: IOBase | str | Path) -> pd.DataFrame:
"""
Opens and parses the Excel file.
diff --git a/airbyte_cdk/sources/file_based/file_types/file_transfer.py b/airbyte_cdk/sources/file_based/file_types/file_transfer.py
index 154b6ff44..e8ca27e4a 100644
--- a/airbyte_cdk/sources/file_based/file_types/file_transfer.py
+++ b/airbyte_cdk/sources/file_based/file_types/file_transfer.py
@@ -3,12 +3,14 @@
#
import logging
import os
-from typing import Any, Dict, Iterable
+from collections.abc import Iterable
+from typing import Any
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
+
AIRBYTE_STAGING_DIRECTORY = os.getenv("AIRBYTE_STAGING_DIRECTORY", "/staging/files")
DEFAULT_LOCAL_DIRECTORY = "/tmp/airbyte-file-transfer"
@@ -17,21 +19,21 @@ class FileTransfer:
def __init__(self) -> None:
self._local_directory = (
AIRBYTE_STAGING_DIRECTORY
- if os.path.exists(AIRBYTE_STAGING_DIRECTORY)
+ if os.path.exists(AIRBYTE_STAGING_DIRECTORY) # noqa: PTH110
else DEFAULT_LOCAL_DIRECTORY
)
def get_file(
self,
- config: FileBasedStreamConfig,
+ config: FileBasedStreamConfig, # noqa: ARG002
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- ) -> Iterable[Dict[str, Any]]:
+ ) -> Iterable[dict[str, Any]]:
try:
yield stream_reader.get_file(
file=file, local_directory=self._local_directory, logger=logger
)
except Exception as ex:
logger.error("An error has occurred while getting file: %s", str(ex))
- raise ex
+ raise ex # noqa: TRY201
diff --git a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py
index e6a9c5cb1..e27e27870 100644
--- a/airbyte_cdk/sources/file_based/file_types/file_type_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/file_type_parser.py
@@ -4,7 +4,8 @@
import logging
from abc import ABC, abstractmethod
-from typing import Any, Dict, Iterable, Mapping, Optional, Tuple
+from collections.abc import Iterable, Mapping
+from typing import Any
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.file_based_stream_reader import (
@@ -14,7 +15,8 @@
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
-Record = Dict[str, Any]
+
+Record = dict[str, Any]
class FileTypeParser(ABC):
@@ -24,27 +26,27 @@ class FileTypeParser(ABC):
"""
@property
- def parser_max_n_files_for_schema_inference(self) -> Optional[int]:
+ def parser_max_n_files_for_schema_inference(self) -> int | None:
"""
The discovery policy decides how many files are loaded for schema inference. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used.
"""
return None
@property
- def parser_max_n_files_for_parsability(self) -> Optional[int]:
+ def parser_max_n_files_for_parsability(self) -> int | None:
"""
The availability policy decides how many files are loaded for checking whether parsing works correctly. This method can provide a parser-specific override. If it's defined, the smaller of the two values will be used.
"""
return None
- def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> Optional[str]:
+ def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> str | None: # noqa: ARG002
"""
The parser can define a primary key. If no user-defined primary key is provided, this will be used.
"""
return None
@abstractmethod
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]:
"""
Check whether the config is valid for this file type. If it is, return True and None. If it's not, return False and an error message explaining why it's invalid.
"""
@@ -70,7 +72,7 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]],
+ discovered_schema: Mapping[str, SchemaType] | None,
) -> Iterable[Record]:
"""
Parse and emit each record.
diff --git a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py
index 722ad329b..c7c5f7b4b 100644
--- a/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py
@@ -4,7 +4,8 @@
import json
import logging
-from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union
+from collections.abc import Iterable, Mapping
+from typing import Any
import orjson
@@ -27,7 +28,7 @@ class JsonlParser(FileTypeParser):
MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000
ENCODING = "utf8"
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002
"""
JsonlParser does not require config checks, implicit pydantic validation is enough.
"""
@@ -35,7 +36,7 @@ def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[st
async def infer_schema(
self,
- config: FileBasedStreamConfig,
+ config: FileBasedStreamConfig, # noqa: ARG002
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
@@ -54,12 +55,12 @@ async def infer_schema(
def parse_records(
self,
- config: FileBasedStreamConfig,
+ config: FileBasedStreamConfig, # noqa: ARG002
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]],
- ) -> Iterable[Dict[str, Any]]:
+ discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002
+ ) -> Iterable[dict[str, Any]]:
"""
This code supports parsing json objects over multiple lines even though this does not align with the JSONL format. This is for
backward compatibility reasons i.e. the previous source-s3 parser did support this. The drawback is:
@@ -73,7 +74,7 @@ def parse_records(
yield from self._parse_jsonl_entries(file, stream_reader, logger)
@classmethod
- def _infer_schema_for_record(cls, record: Dict[str, Any]) -> Dict[str, Any]:
+ def _infer_schema_for_record(cls, record: dict[str, Any]) -> dict[str, Any]:
record_schema = {}
for key, value in record.items():
if value is None:
@@ -92,8 +93,8 @@ def _parse_jsonl_entries(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- read_limit: bool = False,
- ) -> Iterable[Dict[str, Any]]:
+ read_limit: bool = False, # noqa: FBT001, FBT002
+ ) -> Iterable[dict[str, Any]]:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
read_bytes = 0
@@ -137,9 +138,9 @@ def _parse_jsonl_entries(
FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line
)
- @staticmethod
- def _instantiate_accumulator(line: Union[bytes, str]) -> Union[bytes, str]:
+ @staticmethod # noqa: RET503
+ def _instantiate_accumulator(line: bytes | str) -> bytes | str:
if isinstance(line, bytes):
return bytes("", json.detect_encoding(line))
- elif isinstance(line, str):
+ if isinstance(line, str): # noqa: RET503, RUF100
return ""
diff --git a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py
index 28cfb14c9..b3391d866 100644
--- a/airbyte_cdk/sources/file_based/file_types/parquet_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/parquet_parser.py
@@ -5,7 +5,8 @@
import json
import logging
import os
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
+from collections.abc import Iterable, Mapping
+from typing import Any
from urllib.parse import unquote
import pyarrow as pa
@@ -33,7 +34,7 @@
class ParquetParser(FileTypeParser):
ENCODING = None
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]: # noqa: ARG002
"""
ParquetParser does not require config checks, implicit pydantic validation is enough.
"""
@@ -48,7 +49,7 @@ async def infer_schema(
) -> SchemaType:
parquet_format = config.format
if not isinstance(parquet_format, ParquetFormat):
- raise ValueError(f"Expected ParquetFormat, got {parquet_format}")
+ raise ValueError(f"Expected ParquetFormat, got {parquet_format}") # noqa: TRY004
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
parquet_file = pq.ParquetFile(fp)
@@ -74,8 +75,8 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]],
- ) -> Iterable[Dict[str, Any]]:
+ discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002
+ ) -> Iterable[dict[str, Any]]:
parquet_format = config.format
if not isinstance(parquet_format, ParquetFormat):
logger.info(f"Expected ParquetFormat, got {parquet_format}")
@@ -109,8 +110,8 @@ def parse_records(
) from exc
@staticmethod
- def _extract_partitions(filepath: str) -> List[str]:
- return [unquote(partition) for partition in filepath.split(os.sep) if "=" in partition]
+ def _extract_partitions(filepath: str) -> list[str]:
+ return [unquote(partition) for partition in filepath.split(os.sep) if "=" in partition] # noqa: PTH206
@property
def file_read_mode(self) -> FileReadMode:
@@ -118,18 +119,17 @@ def file_read_mode(self) -> FileReadMode:
@staticmethod
def _to_output_value(
- parquet_value: Union[Scalar, DictionaryArray], parquet_format: ParquetFormat
- ) -> Any:
+ parquet_value: Scalar | DictionaryArray, parquet_format: ParquetFormat
+ ) -> Any: # noqa: ANN401
"""
Convert an entry in a pyarrow table to a value that can be output by the source.
"""
if isinstance(parquet_value, DictionaryArray):
return ParquetParser._dictionary_array_to_python_value(parquet_value)
- else:
- return ParquetParser._scalar_to_python_value(parquet_value, parquet_format)
+ return ParquetParser._scalar_to_python_value(parquet_value, parquet_format)
@staticmethod
- def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat) -> Any:
+ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat) -> Any: # noqa: ANN401, PLR0911
"""
Convert a pyarrow scalar to a value that can be output by the source.
"""
@@ -155,8 +155,7 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat
if pa.types.is_decimal(parquet_value.type):
if parquet_format.decimal_as_float:
return float(parquet_value.as_py())
- else:
- return str(parquet_value.as_py())
+ return str(parquet_value.as_py())
if pa.types.is_map(parquet_value.type):
return {k: v for k, v in parquet_value.as_py()}
@@ -170,19 +169,17 @@ def _scalar_to_python_value(parquet_value: Scalar, parquet_format: ParquetFormat
duration_seconds = duration.total_seconds()
if parquet_value.type.unit == "s":
return duration_seconds
- elif parquet_value.type.unit == "ms":
+ if parquet_value.type.unit == "ms":
return duration_seconds * 1000
- elif parquet_value.type.unit == "us":
+ if parquet_value.type.unit == "us":
return duration_seconds * 1_000_000
- elif parquet_value.type.unit == "ns":
+ if parquet_value.type.unit == "ns":
return duration_seconds * 1_000_000_000 + duration.nanoseconds
- else:
- raise ValueError(f"Unknown duration unit: {parquet_value.type.unit}")
- else:
- return parquet_value.as_py()
+ raise ValueError(f"Unknown duration unit: {parquet_value.type.unit}")
+ return parquet_value.as_py()
@staticmethod
- def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[str, Any]:
+ def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> dict[str, Any]:
"""
Convert a pyarrow dictionary array to a value that can be output by the source.
@@ -196,7 +193,7 @@ def _dictionary_array_to_python_value(parquet_value: DictionaryArray) -> Dict[st
}
@staticmethod
- def parquet_type_to_schema_type(
+ def parquet_type_to_schema_type( # noqa: PLR0911
parquet_type: pa.DataType, parquet_format: ParquetFormat
) -> Mapping[str, str]:
"""
@@ -206,24 +203,23 @@ def parquet_type_to_schema_type(
if pa.types.is_timestamp(parquet_type):
return {"type": "string", "format": "date-time"}
- elif pa.types.is_date(parquet_type):
+ if pa.types.is_date(parquet_type):
return {"type": "string", "format": "date"}
- elif ParquetParser._is_string(parquet_type, parquet_format):
+ if ParquetParser._is_string(parquet_type, parquet_format):
return {"type": "string"}
- elif pa.types.is_boolean(parquet_type):
+ if pa.types.is_boolean(parquet_type):
return {"type": "boolean"}
- elif ParquetParser._is_integer(parquet_type):
+ if ParquetParser._is_integer(parquet_type):
return {"type": "integer"}
- elif ParquetParser._is_float(parquet_type, parquet_format):
+ if ParquetParser._is_float(parquet_type, parquet_format):
return {"type": "number"}
- elif ParquetParser._is_object(parquet_type):
+ if ParquetParser._is_object(parquet_type):
return {"type": "object"}
- elif ParquetParser._is_list(parquet_type):
+ if ParquetParser._is_list(parquet_type):
return {"type": "array"}
- elif pa.types.is_null(parquet_type):
+ if pa.types.is_null(parquet_type):
return {"type": "null"}
- else:
- raise ValueError(f"Unsupported parquet type: {parquet_type}")
+ raise ValueError(f"Unsupported parquet type: {parquet_type}")
@staticmethod
def _is_binary(parquet_type: pa.DataType) -> bool:
@@ -241,22 +237,20 @@ def _is_integer(parquet_type: pa.DataType) -> bool:
def _is_float(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool:
if pa.types.is_decimal(parquet_type):
return parquet_format.decimal_as_float
- else:
- return bool(pa.types.is_floating(parquet_type))
+ return bool(pa.types.is_floating(parquet_type))
@staticmethod
def _is_string(parquet_type: pa.DataType, parquet_format: ParquetFormat) -> bool:
if pa.types.is_decimal(parquet_type):
return not parquet_format.decimal_as_float
- else:
- return bool(
- pa.types.is_time(parquet_type)
- or pa.types.is_string(parquet_type)
- or pa.types.is_large_string(parquet_type)
- or ParquetParser._is_binary(
- parquet_type
- ) # Best we can do is return as a string since we do not support binary
- )
+ return bool(
+ pa.types.is_time(parquet_type)
+ or pa.types.is_string(parquet_type)
+ or pa.types.is_large_string(parquet_type)
+ or ParquetParser._is_binary(
+ parquet_type
+ ) # Best we can do is return as a string since we do not support binary
+ )
@staticmethod
def _is_object(parquet_type: pa.DataType) -> bool:
diff --git a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py
index f55675e0a..f1ad81a09 100644
--- a/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py
+++ b/airbyte_cdk/sources/file_based/file_types/unstructured_parser.py
@@ -4,9 +4,10 @@
import logging
import os
import traceback
+from collections.abc import Iterable, Mapping
from datetime import datetime
from io import BytesIO, IOBase
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
+from typing import Any
import backoff
import dpath
@@ -39,6 +40,7 @@
from airbyte_cdk.utils import is_cloud_environment
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
unstructured_partition_pdf = None
unstructured_partition_docx = None
unstructured_partition_pptx = None
@@ -54,10 +56,10 @@ def get_nltk_temp_folder() -> str:
"""
try:
nltk_data_dir = AIRBYTE_NLTK_DATA_DIR
- os.makedirs(nltk_data_dir, exist_ok=True)
+ os.makedirs(nltk_data_dir, exist_ok=True) # noqa: PTH103
except OSError:
nltk_data_dir = TMP_NLTK_DATA_DIR
- os.makedirs(nltk_data_dir, exist_ok=True)
+ os.makedirs(nltk_data_dir, exist_ok=True) # noqa: PTH103
return nltk_data_dir
@@ -73,7 +75,7 @@ def get_nltk_temp_folder() -> str:
nltk.download("averaged_perceptron_tagger_eng", download_dir=nltk_data_dir, quiet=True)
-def optional_decode(contents: Union[str, bytes]) -> str:
+def optional_decode(contents: str | bytes) -> str:
if isinstance(contents, bytes):
return contents.decode("utf-8")
return contents
@@ -81,12 +83,12 @@ def optional_decode(contents: Union[str, bytes]) -> str:
def _import_unstructured() -> None:
"""Dynamically imported as needed, due to slow import speed."""
- global unstructured_partition_pdf
+ global unstructured_partition_pdf # noqa: FURB154
global unstructured_partition_docx
global unstructured_partition_pptx
- from unstructured.partition.docx import partition_docx
- from unstructured.partition.pdf import partition_pdf
- from unstructured.partition.pptx import partition_pptx
+ from unstructured.partition.docx import partition_docx # noqa: PLC0415
+ from unstructured.partition.pdf import partition_pdf # noqa: PLC0415
+ from unstructured.partition.pptx import partition_pptx # noqa: PLC0415
# separate global variables to properly propagate typing
unstructured_partition_pdf = partition_pdf
@@ -102,7 +104,7 @@ def user_error(e: Exception) -> bool:
return False
if not isinstance(e, requests.exceptions.RequestException):
return False
- return bool(e.response and 400 <= e.response.status_code < 500)
+ return bool(e.response and 400 <= e.response.status_code < 500) # noqa: PLR2004
CLOUD_DEPLOYMENT_MODE = "cloud"
@@ -110,20 +112,20 @@ def user_error(e: Exception) -> bool:
class UnstructuredParser(FileTypeParser):
@property
- def parser_max_n_files_for_schema_inference(self) -> Optional[int]:
+ def parser_max_n_files_for_schema_inference(self) -> int | None:
"""
Just check one file as the schema is static
"""
return 1
@property
- def parser_max_n_files_for_parsability(self) -> Optional[int]:
+ def parser_max_n_files_for_parsability(self) -> int | None:
"""
Do not check any files for parsability because it might be an expensive operation and doesn't give much confidence whether the sync will succeed.
"""
return 0
- def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> Optional[str]:
+ def get_parser_defined_primary_key(self, config: FileBasedStreamConfig) -> str | None: # noqa: ARG002
"""
Return the document_key field as the primary key.
@@ -138,7 +140,7 @@ async def infer_schema(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> SchemaType:
- format = _extract_format(config)
+ format = _extract_format(config) # noqa: A001
with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle:
filetype = self._get_filetype(file_handle, file)
if filetype not in self._supported_file_types() and not format.skip_unprocessable_files:
@@ -168,9 +170,9 @@ def parse_records(
file: RemoteFile,
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
- discovered_schema: Optional[Mapping[str, SchemaType]],
- ) -> Iterable[Dict[str, Any]]:
- format = _extract_format(config)
+ discovered_schema: Mapping[str, SchemaType] | None, # noqa: ARG002
+ ) -> Iterable[dict[str, Any]]:
+ format = _extract_format(config) # noqa: A001
with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle:
try:
markdown = self._read_file(file_handle, file, format, logger)
@@ -193,18 +195,18 @@ def parse_records(
}
logger.warn(f"File {file.uri} cannot be parsed. Skipping it.")
else:
- raise e
+ raise e # noqa: TRY201
except Exception as e:
exception_str = str(e)
logger.error(f"File {file.uri} caused an error during parsing: {exception_str}.")
- raise e
+ raise e # noqa: TRY201
- def _read_file(
+ def _read_file( # noqa: RET503
self,
file_handle: IOBase,
remote_file: RemoteFile,
- format: UnstructuredFormat,
- logger: logging.Logger,
+ format: UnstructuredFormat, # noqa: A002
+ logger: logging.Logger, # noqa: ARG002
) -> str:
_import_unstructured()
if (
@@ -213,7 +215,7 @@ def _read_file(
or (not unstructured_partition_pptx)
):
# check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point)
- raise Exception("unstructured library is not available")
+ raise Exception("unstructured library is not available") # noqa: TRY002
filetype: FileType | None = self._get_filetype(file_handle, remote_file)
@@ -233,7 +235,7 @@ def _read_file(
format.strategy,
remote_file,
)
- elif format.processing.mode == "api":
+ if format.processing.mode == "api": # noqa: RET503, RUF100
try:
result: str = self._read_file_remotely_with_retries(
file_handle,
@@ -248,17 +250,17 @@ def _read_file(
# For other exceptions, re-throw as config error so the sync is stopped as problems with the external API need to be resolved by the user and are not considered part of the SLA.
# Once this parser leaves experimental stage, we should consider making this a system error instead for issues that might be transient.
if isinstance(e, RecordParseError):
- raise e
- raise AirbyteTracedException.from_exception(
+ raise e # noqa: TRY201
+ raise AirbyteTracedException.from_exception( # noqa: B904
e, failure_type=FailureType.config_error
)
return result
def _params_to_dict(
- self, params: Optional[List[APIParameterConfigModel]], strategy: str
- ) -> Dict[str, Union[str, List[str]]]:
- result_dict: Dict[str, Union[str, List[str]]] = {"strategy": strategy}
+ self, params: list[APIParameterConfigModel] | None, strategy: str
+ ) -> dict[str, str | list[str]]:
+ result_dict: dict[str, str | list[str]] = {"strategy": strategy}
if params is None:
return result_dict
for item in params:
@@ -277,7 +279,7 @@ def _params_to_dict(
return result_dict
- def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[str]]:
+ def check_config(self, config: FileBasedStreamConfig) -> tuple[bool, str | None]:
"""
Perform a connection check for the parser config:
- Verify that encryption is enabled if the API is hosted on a cloud instance.
@@ -313,7 +315,7 @@ def check_config(self, config: FileBasedStreamConfig) -> Tuple[bool, Optional[st
def _read_file_remotely_with_retries(
self,
file_handle: IOBase,
- format: APIProcessingConfigModel,
+ format: APIProcessingConfigModel, # noqa: A002
filetype: FileType,
strategy: str,
remote_file: RemoteFile,
@@ -326,7 +328,7 @@ def _read_file_remotely_with_retries(
def _read_file_remotely(
self,
file_handle: IOBase,
- format: APIProcessingConfigModel,
+ format: APIProcessingConfigModel, # noqa: A002
filetype: FileType,
strategy: str,
remote_file: RemoteFile,
@@ -341,12 +343,11 @@ def _read_file_remotely(
f"{format.api_url}/general/v0/general", headers=headers, data=data, files=file_data
)
- if response.status_code == 422:
+ if response.status_code == 422: # noqa: PLR2004
# 422 means the file couldn't be processed, but the API is working. Treat this as a parsing error (passing an error record to the destination).
raise self._create_parse_error(remote_file, response.json())
- else:
- # Other error statuses are raised as requests exceptions (retry everything except user errors)
- response.raise_for_status()
+ # Other error statuses are raised as requests exceptions (retry everything except user errors)
+ response.raise_for_status()
json_response = response.json()
@@ -362,7 +363,7 @@ def _read_file_locally(
or (not unstructured_partition_pptx)
):
# check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point)
- raise Exception("unstructured library is not available")
+ raise Exception("unstructured library is not available") # noqa: TRY002
file: Any = file_handle
@@ -383,7 +384,7 @@ def _read_file_locally(
elif filetype == FileType.PPTX:
elements = unstructured_partition_pptx(file=file)
except Exception as e:
- raise self._create_parse_error(remote_file, str(e))
+ raise self._create_parse_error(remote_file, str(e)) # noqa: B904
return self._render_markdown([element.to_dict() for element in elements])
@@ -396,7 +397,7 @@ def _create_parse_error(
FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri, message=message
)
- def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileType]:
+ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> FileType | None:
"""
Detect the file type based on the file name and the file content.
@@ -416,7 +417,7 @@ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileT
# if possible, try to leverage the file name to detect the file type
# if the file name is not available, use the file content
file_type: FileType | None = None
- try:
+ try: # noqa: SIM105
file_type = detect_filetype(
filename=remote_file.uri,
)
@@ -439,34 +440,33 @@ def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileT
return None
- def _supported_file_types(self) -> List[Any]:
+ def _supported_file_types(self) -> list[Any]:
return [FileType.MD, FileType.PDF, FileType.DOCX, FileType.PPTX, FileType.TXT]
def _get_file_type_error_message(
self,
file_type: FileType | None,
) -> str:
- supported_file_types = ", ".join([str(type) for type in self._supported_file_types()])
+ supported_file_types = ", ".join([str(type) for type in self._supported_file_types()]) # noqa: A001
return f"File type {file_type or 'None'!s} is not supported. Supported file types are {supported_file_types}"
- def _render_markdown(self, elements: List[Any]) -> str:
- return "\n\n".join((self._convert_to_markdown(el) for el in elements))
+ def _render_markdown(self, elements: list[Any]) -> str:
+ return "\n\n".join(self._convert_to_markdown(el) for el in elements)
- def _convert_to_markdown(self, el: Dict[str, Any]) -> str:
+ def _convert_to_markdown(self, el: dict[str, Any]) -> str:
if dpath.get(el, "type") == "Title":
category_depth = dpath.get(el, "metadata/category_depth", default=1) or 1
if not isinstance(category_depth, int):
category_depth = (
- int(category_depth) if isinstance(category_depth, (str, float)) else 1
+ int(category_depth) if isinstance(category_depth, (str, float)) else 1 # noqa: UP038
)
heading_str = "#" * category_depth
return f"{heading_str} {dpath.get(el, 'text')}"
- elif dpath.get(el, "type") == "ListItem":
+ if dpath.get(el, "type") == "ListItem":
return f"- {dpath.get(el, 'text')}"
- elif dpath.get(el, "type") == "Formula":
+ if dpath.get(el, "type") == "Formula":
return f"```\n{dpath.get(el, 'text')}\n```"
- else:
- return str(dpath.get(el, "text", default=""))
+ return str(dpath.get(el, "text", default=""))
@property
def file_read_mode(self) -> FileReadMode:
@@ -476,5 +476,5 @@ def file_read_mode(self) -> FileReadMode:
def _extract_format(config: FileBasedStreamConfig) -> UnstructuredFormat:
config_format = config.format
if not isinstance(config_format, UnstructuredFormat):
- raise ValueError(f"Invalid format config: {config_format}")
+ raise ValueError(f"Invalid format config: {config_format}") # noqa: TRY004
return config_format
diff --git a/airbyte_cdk/sources/file_based/remote_file.py b/airbyte_cdk/sources/file_based/remote_file.py
index 0197a35fd..48d4e2513 100644
--- a/airbyte_cdk/sources/file_based/remote_file.py
+++ b/airbyte_cdk/sources/file_based/remote_file.py
@@ -3,7 +3,6 @@
#
from datetime import datetime
-from typing import Optional
from pydantic.v1 import BaseModel
@@ -15,4 +14,4 @@ class RemoteFile(BaseModel):
uri: str
last_modified: datetime
- mime_type: Optional[str] = None
+ mime_type: str | None = None
diff --git a/airbyte_cdk/sources/file_based/schema_helpers.py b/airbyte_cdk/sources/file_based/schema_helpers.py
index 1b653db67..98ff86ac8 100644
--- a/airbyte_cdk/sources/file_based/schema_helpers.py
+++ b/airbyte_cdk/sources/file_based/schema_helpers.py
@@ -3,10 +3,11 @@
#
import json
+from collections.abc import Mapping
from copy import deepcopy
from enum import Enum
from functools import total_ordering
-from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union
+from typing import Any, Literal, Union
from airbyte_cdk.sources.file_based.exceptions import (
ConfigValidationError,
@@ -14,7 +15,8 @@
SchemaInferenceError,
)
-JsonSchemaSupportedType = Union[List[str], Literal["string"], str]
+
+JsonSchemaSupportedType = Union[list[str], Literal["string"], str] # noqa: PYI051, UP007
SchemaType = Mapping[str, Mapping[str, JsonSchemaSupportedType]]
schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}}
@@ -33,14 +35,13 @@ class ComparableType(Enum):
STRING = 4
OBJECT = 5
- def __lt__(self, other: Any) -> bool:
+ def __lt__(self, other: Any) -> bool: # noqa: ANN401
if self.__class__ is other.__class__:
return self.value < other.value # type: ignore
- else:
- return NotImplemented
+ return NotImplemented
-TYPE_PYTHON_MAPPING: Mapping[str, Tuple[str, Optional[Type[Any]]]] = {
+TYPE_PYTHON_MAPPING: Mapping[str, tuple[str, type[Any] | None]] = {
"null": ("null", None),
"array": ("array", list),
"boolean": ("boolean", bool),
@@ -53,7 +54,7 @@ def __lt__(self, other: Any) -> bool:
PYTHON_TYPE_MAPPING = {t: k for k, (_, t) in TYPE_PYTHON_MAPPING.items()}
-def get_comparable_type(value: Any) -> Optional[ComparableType]:
+def get_comparable_type(value: Any) -> ComparableType | None: # noqa: ANN401, PLR0911
if value == "null":
return ComparableType.NULL
if value == "boolean":
@@ -66,11 +67,10 @@ def get_comparable_type(value: Any) -> Optional[ComparableType]:
return ComparableType.STRING
if value == "object":
return ComparableType.OBJECT
- else:
- return None
+ return None
-def get_inferred_type(value: Any) -> Optional[ComparableType]:
+def get_inferred_type(value: Any) -> ComparableType | None: # noqa: ANN401, PLR0911
if value is None:
return ComparableType.NULL
if isinstance(value, bool):
@@ -83,8 +83,7 @@ def get_inferred_type(value: Any) -> Optional[ComparableType]:
return ComparableType.STRING
if isinstance(value, dict):
return ComparableType.OBJECT
- else:
- return None
+ return None
def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
@@ -107,7 +106,7 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]):
raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t)
- merged_schema: Dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping
+ merged_schema: dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping
for k2, t2 in schema2.items():
t1 = merged_schema.get(k2)
if t1 is None:
@@ -136,7 +135,7 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) -
detected_types=f"{t1},{t2}",
)
# Schemas can still be merged if a key contains a null value in either t1 or t2, but it is still an object
- elif (
+ if (
(t1_type == "object" or t2_type == "object")
and t1_type != "null"
and t2_type != "null"
@@ -148,24 +147,23 @@ def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) -
key=key,
detected_types=f"{t1},{t2}",
)
- else:
- comparable_t1 = get_comparable_type(
- TYPE_PYTHON_MAPPING[t1_type][0]
- ) # accessing the type_mapping value
- comparable_t2 = get_comparable_type(
- TYPE_PYTHON_MAPPING[t2_type][0]
- ) # accessing the type_mapping value
- if not comparable_t1 and comparable_t2:
- raise SchemaInferenceError(
- FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}"
- )
- return max(
- [t1, t2],
- key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])),
- ) # accessing the type_mapping value
+ comparable_t1 = get_comparable_type(
+ TYPE_PYTHON_MAPPING[t1_type][0]
+ ) # accessing the type_mapping value
+ comparable_t2 = get_comparable_type(
+ TYPE_PYTHON_MAPPING[t2_type][0]
+ ) # accessing the type_mapping value
+ if not comparable_t1 and comparable_t2:
+ raise SchemaInferenceError(
+ FileBasedSourceError.UNRECOGNIZED_TYPE, key=key, detected_types=f"{t1},{t2}"
+ )
+ return max(
+ [t1, t2],
+ key=lambda x: ComparableType(get_comparable_type(TYPE_PYTHON_MAPPING[x["type"]][0])),
+ ) # accessing the type_mapping value
-def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool:
+def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool: # noqa: ANN401
if isinstance(value, list):
# We do not compare lists directly; the individual items are compared.
# If we hit this condition, it means that the expected type is not
@@ -180,7 +178,7 @@ def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool:
return ComparableType(inferred_type) <= ComparableType(get_comparable_type(expected_type))
-def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
+def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool: # noqa: PLR0911
"""
Return true iff the record conforms to the supplied schema.
@@ -202,9 +200,9 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) ->
if value is not None:
if isinstance(expected_type, list):
return any(is_equal_or_narrower_type(value, e) for e in expected_type)
- elif expected_type == "object":
+ if expected_type == "object":
return isinstance(value, dict)
- elif expected_type == "array":
+ if expected_type == "array":
if not isinstance(value, list):
return False
array_type = definition.get("items", {}).get("type")
@@ -216,7 +214,7 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) ->
return True
-def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[Mapping[str, str]]:
+def _parse_json_input(input_schema: str | Mapping[str, str]) -> Mapping[str, str] | None:
try:
if isinstance(input_schema, str):
schema: Mapping[str, str] = json.loads(input_schema)
@@ -235,8 +233,8 @@ def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[M
def type_mapping_to_jsonschema(
- input_schema: Optional[Union[str, Mapping[str, str]]],
-) -> Optional[Mapping[str, Any]]:
+ input_schema: str | Mapping[str, str] | None,
+) -> Mapping[str, Any] | None:
"""
Return the user input schema (type mapping), transformed to JSON Schema format.
@@ -252,14 +250,14 @@ def type_mapping_to_jsonschema(
json_mapping = _parse_json_input(input_schema) or {}
for col_name, type_name in json_mapping.items():
- col_name, type_name = col_name.strip(), type_name.strip()
+ col_name, type_name = col_name.strip(), type_name.strip() # noqa: PLW2901
if not (col_name and type_name):
raise ConfigValidationError(
FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA,
details=f"Invalid input schema; expected mapping in the format column_name: type, got {input_schema}.",
)
- _json_schema_type = TYPE_PYTHON_MAPPING.get(type_name.casefold())
+ _json_schema_type = TYPE_PYTHON_MAPPING.get(type_name.casefold()) # noqa: RUF052
if not _json_schema_type:
raise ConfigValidationError(
diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py b/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py
index e687bd5b3..e044bd709 100644
--- a/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py
+++ b/airbyte_cdk/sources/file_based/schema_validation_policies/__init__.py
@@ -8,6 +8,7 @@
WaitForDiscoverPolicy,
)
+
__all__ = [
"DEFAULT_SCHEMA_VALIDATION_POLICIES",
"AbstractSchemaValidationPolicy",
diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py
index 139511a98..be3d47005 100644
--- a/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py
+++ b/airbyte_cdk/sources/file_based/schema_validation_policies/abstract_schema_validation_policy.py
@@ -3,7 +3,8 @@
#
from abc import ABC, abstractmethod
-from typing import Any, Mapping, Optional
+from collections.abc import Mapping
+from typing import Any
class AbstractSchemaValidationPolicy(ABC):
@@ -12,9 +13,9 @@ class AbstractSchemaValidationPolicy(ABC):
@abstractmethod
def record_passes_validation_policy(
- self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]
+ self, record: Mapping[str, Any], schema: Mapping[str, Any] | None
) -> bool:
"""
Return True if the record passes the user's validation policy.
"""
- raise NotImplementedError()
+ raise NotImplementedError
diff --git a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py
index 261b0fabd..0e4821e29 100644
--- a/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py
+++ b/airbyte_cdk/sources/file_based/schema_validation_policies/default_schema_validation_policies.py
@@ -2,7 +2,8 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, Mapping, Optional
+from collections.abc import Mapping
+from typing import Any
from airbyte_cdk.sources.file_based.config.file_based_stream_config import ValidationPolicy
from airbyte_cdk.sources.file_based.exceptions import (
@@ -17,7 +18,9 @@ class EmitRecordPolicy(AbstractSchemaValidationPolicy):
name = "emit_record"
def record_passes_validation_policy(
- self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]
+ self,
+ record: Mapping[str, Any], # noqa: ARG002
+ schema: Mapping[str, Any] | None, # noqa: ARG002
) -> bool:
return True
@@ -26,7 +29,7 @@ class SkipRecordPolicy(AbstractSchemaValidationPolicy):
name = "skip_record"
def record_passes_validation_policy(
- self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]
+ self, record: Mapping[str, Any], schema: Mapping[str, Any] | None
) -> bool:
return schema is not None and conforms_to_schema(record, schema)
@@ -36,7 +39,7 @@ class WaitForDiscoverPolicy(AbstractSchemaValidationPolicy):
validate_schema_before_sync = True
def record_passes_validation_policy(
- self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]
+ self, record: Mapping[str, Any], schema: Mapping[str, Any] | None
) -> bool:
if schema is None or not conforms_to_schema(record, schema):
raise StopSyncPerValidationPolicy(
diff --git a/airbyte_cdk/sources/file_based/stream/__init__.py b/airbyte_cdk/sources/file_based/stream/__init__.py
index 4b5c4bc2e..07e5544f1 100644
--- a/airbyte_cdk/sources/file_based/stream/__init__.py
+++ b/airbyte_cdk/sources/file_based/stream/__init__.py
@@ -1,4 +1,7 @@
-from airbyte_cdk.sources.file_based.stream.abstract_file_based_stream import AbstractFileBasedStream
+from airbyte_cdk.sources.file_based.stream.abstract_file_based_stream import (
+ AbstractFileBasedStream,
+)
from airbyte_cdk.sources.file_based.stream.default_file_based_stream import DefaultFileBasedStream
+
__all__ = ["AbstractFileBasedStream", "DefaultFileBasedStream"]
diff --git a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py
index ef258b34d..b69461531 100644
--- a/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py
+++ b/airbyte_cdk/sources/file_based/stream/abstract_file_based_stream.py
@@ -3,8 +3,9 @@
#
from abc import abstractmethod
-from functools import cache, cached_property, lru_cache
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Type
+from collections.abc import Iterable, Mapping
+from functools import cache, cached_property
+from typing import Any
from typing_extensions import deprecated
@@ -50,14 +51,14 @@ class AbstractFileBasedStream(Stream):
by the stream.
"""
- def __init__(
+ def __init__( # noqa: ANN204, PLR0913, PLR0917
self,
config: FileBasedStreamConfig,
- catalog_schema: Optional[Mapping[str, Any]],
+ catalog_schema: Mapping[str, Any] | None,
stream_reader: AbstractFileBasedStreamReader,
availability_strategy: AbstractFileBasedAvailabilityStrategy,
discovery_policy: AbstractDiscoveryPolicy,
- parsers: Dict[Type[Any], FileTypeParser],
+ parsers: dict[type[Any], FileTypeParser],
validation_policy: AbstractSchemaValidationPolicy,
errors_collector: FileBasedErrorsCollector,
cursor: AbstractFileBasedCursor,
@@ -77,8 +78,8 @@ def __init__(
@abstractmethod
def primary_key(self) -> PrimaryKeyType: ...
- @cache
- def list_files(self) -> List[RemoteFile]:
+ @cache # noqa: B019
+ def list_files(self) -> list[RemoteFile]:
"""
List all files that belong to the stream.
@@ -97,10 +98,10 @@ def get_files(self) -> Iterable[RemoteFile]:
def read_records(
self,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_slice: Optional[StreamSlice] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_slice: StreamSlice | None = None,
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Iterable[Mapping[str, Any] | AirbyteMessage]:
"""
Yield all records from all remote files in `list_files_for_this_sync`.
@@ -123,10 +124,10 @@ def read_records_from_slice(
def stream_slices(
self,
*,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
- ) -> Iterable[Optional[Mapping[str, Any]]]:
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Iterable[Mapping[str, Any] | None]:
"""
This method acts as an adapter between the generic Stream interface and the file-based's
stream since file-based streams manage their own states.
@@ -134,7 +135,7 @@ def stream_slices(
return self.compute_slices()
@abstractmethod
- def compute_slices(self) -> Iterable[Optional[StreamSlice]]:
+ def compute_slices(self) -> Iterable[StreamSlice | None]:
"""
Return a list of slices that will be used to read files in the current sync.
:return: The slices to use for the current sync.
@@ -142,7 +143,7 @@ def compute_slices(self) -> Iterable[Optional[StreamSlice]]:
...
@abstractmethod
- @lru_cache(maxsize=None)
+ @cache # noqa: B019
def get_json_schema(self) -> Mapping[str, Any]:
"""
Return the JSON Schema for a stream.
@@ -150,7 +151,7 @@ def get_json_schema(self) -> Mapping[str, Any]:
...
@abstractmethod
- def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
+ def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]:
"""
Infer the schema for files in the stream.
"""
@@ -160,7 +161,7 @@ def get_parser(self) -> FileTypeParser:
try:
return self._parsers[type(self.config.format)]
except KeyError:
- raise UndefinedParserError(
+ raise UndefinedParserError( # noqa: B904
FileBasedSourceError.UNDEFINED_PARSER,
stream=self.name,
format=type(self.config.format),
@@ -171,12 +172,11 @@ def record_passes_validation_policy(self, record: Mapping[str, Any]) -> bool:
return self.validation_policy.record_passes_validation_policy(
record=record, schema=self.catalog_schema
)
- else:
- raise RecordParseError(
- FileBasedSourceError.UNDEFINED_VALIDATION_POLICY,
- stream=self.name,
- validation_policy=self.config.validation_policy,
- )
+ raise RecordParseError(
+ FileBasedSourceError.UNDEFINED_VALIDATION_POLICY,
+ stream=self.name,
+ validation_policy=self.config.validation_policy,
+ )
@cached_property
@deprecated("Deprecated as of CDK version 3.7.0.")
@@ -187,7 +187,7 @@ def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy:
def name(self) -> str:
return self.config.name
- def get_cursor(self) -> Optional[Cursor]:
+ def get_cursor(self) -> Cursor | None:
"""
This is a temporary hack. Because file-based, declarative, and concurrent have _slightly_ different cursor implementations
the file-based cursor isn't compatible with the cursor-based iteration flow in core.py top-level CDK. By setting this to
diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
index fb0efc82c..1f18e3314 100644
--- a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
+++ b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
@@ -4,8 +4,9 @@
import copy
import logging
-from functools import cache, lru_cache
-from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union
+from collections.abc import Iterable, Mapping, MutableMapping
+from functools import cache
+from typing import TYPE_CHECKING, Any
from typing_extensions import deprecated
@@ -46,6 +47,7 @@
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
+
if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import (
AbstractConcurrentFileBasedCursor,
@@ -67,7 +69,7 @@ def create_from_stream(
stream: AbstractFileBasedStream,
source: AbstractSource,
logger: logging.Logger,
- state: Optional[MutableMapping[str, Any]],
+ state: MutableMapping[str, Any] | None,
cursor: "AbstractConcurrentFileBasedCursor",
) -> "FileBasedStreamFacade":
"""
@@ -75,7 +77,7 @@ def create_from_stream(
"""
pk = get_primary_key_from_stream(stream.primary_key)
cursor_field = get_cursor_field_from_stream(stream)
- stream._cursor = cursor
+ stream._cursor = cursor # noqa: SLF001
if not source.message_repository:
raise ValueError(
@@ -107,10 +109,10 @@ def create_from_stream(
stream,
cursor,
logger=logger,
- slice_logger=source._slice_logger,
+ slice_logger=source._slice_logger, # noqa: SLF001
)
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream: DefaultStream,
legacy_stream: AbstractFileBasedStream,
@@ -131,11 +133,10 @@ def __init__(
self.validation_policy = legacy_stream.validation_policy
@property
- def cursor_field(self) -> Union[str, List[str]]:
+ def cursor_field(self) -> str | list[str]:
if self._abstract_stream.cursor_field is None:
return []
- else:
- return self._abstract_stream.cursor_field
+ return self._abstract_stream.cursor_field
@property
def name(self) -> str:
@@ -150,7 +151,7 @@ def supports_incremental(self) -> bool:
def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy:
return self._legacy_stream.availability_strategy
- @lru_cache(maxsize=None)
+ @cache # noqa: B019
def get_json_schema(self) -> Mapping[str, Any]:
return self._abstract_stream.get_json_schema()
@@ -170,10 +171,10 @@ def get_files(self) -> Iterable[RemoteFile]:
def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any]]:
yield from self._legacy_stream.read_records_from_slice(stream_slice) # type: ignore[misc] # Only Mapping[str, Any] is expected for legacy streams, not AirbyteMessage
- def compute_slices(self) -> Iterable[Optional[StreamSlice]]:
+ def compute_slices(self) -> Iterable[StreamSlice | None]:
return self._legacy_stream.compute_slices()
- def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
+ def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]:
return self._legacy_stream.infer_schema(files)
def get_underlying_stream(self) -> DefaultStream:
@@ -181,21 +182,21 @@ def get_underlying_stream(self) -> DefaultStream:
def read(
self,
- configured_stream: ConfiguredAirbyteStream,
- logger: logging.Logger,
- slice_logger: SliceLogger,
- stream_state: MutableMapping[str, Any],
- state_manager: ConnectorStateManager,
- internal_config: InternalConfig,
+ configured_stream: ConfiguredAirbyteStream, # noqa: ARG002
+ logger: logging.Logger, # noqa: ARG002
+ slice_logger: SliceLogger, # noqa: ARG002
+ stream_state: MutableMapping[str, Any], # noqa: ARG002
+ state_manager: ConnectorStateManager, # noqa: ARG002
+ internal_config: InternalConfig, # noqa: ARG002
) -> Iterable[StreamData]:
yield from self._read_records()
def read_records(
self,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Iterable[StreamData]:
try:
yield from self._read_records()
@@ -211,7 +212,7 @@ def read_records(
level=Level.ERROR, message=f"Cursor State at time of exception: {state}"
),
)
- raise exc
+ raise exc # noqa: TRY201
def _read_records(self) -> Iterable[StreamData]:
for partition in self._abstract_stream.generate_partitions():
@@ -222,14 +223,14 @@ def _read_records(self) -> Iterable[StreamData]:
class FileBasedStreamPartition(Partition):
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream: AbstractFileBasedStream,
- _slice: Optional[Mapping[str, Any]],
+ _slice: Mapping[str, Any] | None,
message_repository: MessageRepository,
sync_mode: SyncMode,
- cursor_field: Optional[List[str]],
- state: Optional[MutableMapping[str, Any]],
+ cursor_field: list[str] | None,
+ state: MutableMapping[str, Any] | None,
):
self._stream = stream
self._slice = _slice
@@ -265,7 +266,7 @@ def read(self) -> Iterable[Record]:
else record_data.record.data
)
if not record_message_data:
- raise ExceptionWithDisplayMessage("A record without data was found")
+ raise ExceptionWithDisplayMessage("A record without data was found") # noqa: TRY301
else:
yield Record(
data=record_message_data,
@@ -279,14 +280,14 @@ def read(self) -> Iterable[Record]:
if display_message:
raise ExceptionWithDisplayMessage(display_message) from e
else:
- raise e
+ raise e # noqa: TRY201
- def to_slice(self) -> Optional[Mapping[str, Any]]:
+ def to_slice(self) -> Mapping[str, Any] | None:
if self._slice is None:
return None
- assert (
- len(self._slice["files"]) == 1
- ), f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}"
+ assert len(self._slice["files"]) == 1, (
+ f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}"
+ )
file = self._slice["files"][0]
return {"files": [file]}
@@ -297,16 +298,14 @@ def __hash__(self) -> int:
raise ValueError(
f"Slices for file-based streams should be of length 1, but got {len(self._slice['files'])}. This is unexpected. Please contact Support."
)
- else:
- s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}"
+ s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}"
return hash((self._stream.name, s))
- else:
- return hash(self._stream.name)
+ return hash(self._stream.name)
def stream_name(self) -> str:
return self._stream.name
- @cache
+ @cache # noqa: B019
def _use_file_transfer(self) -> bool:
return hasattr(self._stream, "use_file_transfer") and self._stream.use_file_transfer
@@ -315,13 +314,13 @@ def __repr__(self) -> str:
class FileBasedStreamPartitionGenerator(PartitionGenerator):
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream: AbstractFileBasedStream,
message_repository: MessageRepository,
sync_mode: SyncMode,
- cursor_field: Optional[List[str]],
- state: Optional[MutableMapping[str, Any]],
+ cursor_field: list[str] | None,
+ state: MutableMapping[str, Any] | None,
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
@@ -338,7 +337,7 @@ def generate(self) -> Iterable[FileBasedStreamPartition]:
):
if _slice is not None:
for file in _slice.get("files", []):
- pending_partitions.append(
+ pending_partitions.append( # noqa: PERF401
FileBasedStreamPartition(
self._stream,
{"files": [copy.deepcopy(file)]},
diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py
index 089cae0ad..6498e5c8c 100644
--- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py
+++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/__init__.py
@@ -2,6 +2,7 @@
from .file_based_concurrent_cursor import FileBasedConcurrentCursor
from .file_based_final_state_cursor import FileBasedFinalStateCursor
+
__all__ = [
"AbstractConcurrentFileBasedCursor",
"FileBasedConcurrentCursor",
diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py
index 5c30fda4a..988134600 100644
--- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py
+++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/abstract_concurrent_file_based_cursor.py
@@ -4,8 +4,9 @@
import logging
from abc import ABC, abstractmethod
+from collections.abc import Iterable, MutableMapping
from datetime import datetime
-from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping
+from typing import TYPE_CHECKING, Any
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
@@ -14,12 +15,13 @@
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.types import Record
+
if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition
class AbstractConcurrentFileBasedCursor(Cursor, AbstractFileBasedCursor, ABC):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
+ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
pass
@property
@@ -33,7 +35,7 @@ def observe(self, record: Record) -> None: ...
def close_partition(self, partition: Partition) -> None: ...
@abstractmethod
- def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None: ...
+ def set_pending_partitions(self, partitions: list["FileBasedStreamPartition"]) -> None: ...
@abstractmethod
def add_file(self, file: RemoteFile) -> None: ...
diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py
index a70169197..e143dda89 100644
--- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py
+++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_concurrent_cursor.py
@@ -3,9 +3,10 @@
#
import logging
+from collections.abc import Iterable, MutableMapping
from datetime import datetime, timedelta
from threading import RLock
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
+from typing import TYPE_CHECKING, Any
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
@@ -21,6 +22,7 @@
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.types import Record
+
if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition
@@ -34,14 +36,14 @@ class FileBasedConcurrentCursor(AbstractConcurrentFileBasedCursor):
)
DEFAULT_MAX_HISTORY_SIZE = 10_000
DATE_TIME_FORMAT = DefaultFileBasedCursor.DATE_TIME_FORMAT
- zero_value = datetime.min
+ zero_value = datetime.min # noqa: DTZ901
zero_cursor_value = f"0001-01-01T00:00:00.000000Z_{_NULL_FILE}"
def __init__(
self,
stream_config: FileBasedStreamConfig,
stream_name: str,
- stream_namespace: Optional[str],
+ stream_namespace: str | None,
stream_state: MutableMapping[str, Any],
message_repository: MessageRepository,
connector_state_manager: ConnectorStateManager,
@@ -60,7 +62,7 @@ def __init__(
)
self._state_lock = RLock()
self._pending_files_lock = RLock()
- self._pending_files: Optional[Dict[str, RemoteFile]] = None
+ self._pending_files: dict[str, RemoteFile] | None = None
self._file_to_datetime_history = stream_state.get("history", {}) if stream_state else {}
self._prev_cursor_value = self._compute_prev_sync_cursor(stream_state)
self._sync_start = self._compute_start_time()
@@ -72,28 +74,28 @@ def state(self) -> MutableMapping[str, Any]:
def observe(self, record: Record) -> None:
pass
- def close_partition(self, partition: Partition) -> None:
+ def close_partition(self, partition: Partition) -> None: # noqa: ARG002
with self._pending_files_lock:
if self._pending_files is None:
raise RuntimeError(
"Expected pending partitions to be set but it was not. This is unexpected. Please contact Support."
)
- def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None:
+ def set_pending_partitions(self, partitions: list["FileBasedStreamPartition"]) -> None:
with self._pending_files_lock:
self._pending_files = {}
for partition in partitions:
- _slice = partition.to_slice()
+ _slice = partition.to_slice() # noqa: RUF052
if _slice is None:
continue
for file in _slice["files"]:
- if file.uri in self._pending_files.keys():
+ if file.uri in self._pending_files.keys(): # noqa: SIM118
raise RuntimeError(
f"Already found file {_slice} in pending files. This is unexpected. Please contact Support."
)
self._pending_files.update({file.uri: file})
- def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datetime, str]:
+ def _compute_prev_sync_cursor(self, value: StreamState | None) -> tuple[datetime, str]:
if not value:
return self.zero_value, ""
prev_cursor_str = value.get(self._cursor_field.cursor_field_key) or self.zero_cursor_value
@@ -112,23 +114,23 @@ def _compute_prev_sync_cursor(self, value: Optional[StreamState]) -> Tuple[datet
cursor_dt, cursor_uri = cursor_str.split("_", 1)
return datetime.strptime(cursor_dt, self.DATE_TIME_FORMAT), cursor_uri
- def _get_cursor_key_from_file(self, file: Optional[RemoteFile]) -> str:
+ def _get_cursor_key_from_file(self, file: RemoteFile | None) -> str:
if file:
return f"{datetime.strftime(file.last_modified, self.DATE_TIME_FORMAT)}_{file.uri}"
return self.zero_cursor_value
- def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]:
+ def _compute_earliest_file_in_history(self) -> RemoteFile | None:
with self._state_lock:
if self._file_to_datetime_history:
filename, last_modified = min(
- self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0])
+ self._file_to_datetime_history.items(),
+ key=lambda f: (f[1], f[0]), # noqa: FURB118
)
return RemoteFile(
uri=filename,
last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT),
)
- else:
- return None
+ return None
def add_file(self, file: RemoteFile) -> None:
"""
@@ -139,7 +141,7 @@ def add_file(self, file: RemoteFile) -> None:
raise RuntimeError(
"Expected pending partitions to be set but it was not. This is unexpected. Please contact Support."
)
- with self._pending_files_lock:
+ with self._pending_files_lock: # noqa: SIM117
with self._state_lock:
if file.uri not in self._pending_files:
self._message_repository.emit_message(
@@ -162,7 +164,7 @@ def add_file(self, file: RemoteFile) -> None:
if oldest_file:
del self._file_to_datetime_history[oldest_file.uri]
else:
- raise Exception(
+ raise Exception( # noqa: TRY002
"The history is full but there is no files in the history. This should never happen and might be indicative of a bug in the CDK."
)
self.emit_state_message()
@@ -181,7 +183,7 @@ def emit_state_message(self) -> None:
self._message_repository.emit_message(state_message)
def _get_new_cursor_value(self) -> str:
- with self._pending_files_lock:
+ with self._pending_files_lock: # noqa: SIM117
with self._state_lock:
if self._pending_files:
# If there are partitions that haven't been synced, we don't know whether the files that have been synced
@@ -189,31 +191,29 @@ def _get_new_cursor_value(self) -> str:
# To avoid missing files, we only increment the cursor up to the oldest pending file, because we know
# that all older files have been synced.
return self._get_cursor_key_from_file(self._compute_earliest_pending_file())
- elif self._file_to_datetime_history:
+ if self._file_to_datetime_history:
# If all partitions have been synced, we know that the sync is up-to-date and so can advance
# the cursor to the newest file in history.
return self._get_cursor_key_from_file(self._compute_latest_file_in_history())
- else:
- return f"{self.zero_value.strftime(self.DATE_TIME_FORMAT)}_"
+ return f"{self.zero_value.strftime(self.DATE_TIME_FORMAT)}_"
- def _compute_earliest_pending_file(self) -> Optional[RemoteFile]:
+ def _compute_earliest_pending_file(self) -> RemoteFile | None:
if self._pending_files:
return min(self._pending_files.values(), key=lambda x: x.last_modified)
- else:
- return None
+ return None
- def _compute_latest_file_in_history(self) -> Optional[RemoteFile]:
+ def _compute_latest_file_in_history(self) -> RemoteFile | None:
with self._state_lock:
if self._file_to_datetime_history:
filename, last_modified = max(
- self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0])
+ self._file_to_datetime_history.items(),
+ key=lambda f: (f[1], f[0]), # noqa: FURB118
)
return RemoteFile(
uri=filename,
last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT),
)
- else:
- return None
+ return None
def get_files_to_sync(
self, all_files: Iterable[RemoteFile], logger: logging.Logger
@@ -235,7 +235,7 @@ def get_files_to_sync(
if self._should_sync_file(f, logger):
yield f
- def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool:
+ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # noqa: ARG002
with self._state_lock:
if file.uri in self._file_to_datetime_history:
# If the file's uri is in the history, we should sync the file if it has been modified since it was synced
@@ -253,23 +253,20 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool:
)
)
return False
- else:
- return file.last_modified > updated_at_from_history
+ return file.last_modified > updated_at_from_history
prev_cursor_timestamp, prev_cursor_uri = self._prev_cursor_value
if self._is_history_full():
if file.last_modified > prev_cursor_timestamp:
# If the history is partial and the file's datetime is strictly greater than the cursor, we should sync it
return True
- elif file.last_modified == prev_cursor_timestamp:
+ if file.last_modified == prev_cursor_timestamp:
# If the history is partial and the file's datetime is equal to the earliest file in the history,
# we should sync it if its uri is greater than or equal to the cursor value.
return file.uri > prev_cursor_uri
- else:
- return file.last_modified >= self._sync_start
- else:
- # The file is not in the history and the history is complete. We know we need to sync the file
- return True
+ return file.last_modified >= self._sync_start
+ # The file is not in the history and the history is complete. We know we need to sync the file
+ return True
def _is_history_full(self) -> bool:
"""
@@ -284,14 +281,13 @@ def _is_history_full(self) -> bool:
def _compute_start_time(self) -> datetime:
if not self._file_to_datetime_history:
- return datetime.min
- else:
- earliest = min(self._file_to_datetime_history.values())
- earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT)
- if self._is_history_full():
- time_window = datetime.now() - self._time_window_if_history_is_full
- earliest_dt = min(earliest_dt, time_window)
- return earliest_dt
+ return datetime.min # noqa: DTZ901
+ earliest = min(self._file_to_datetime_history.values())
+ earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT)
+ if self._is_history_full():
+ time_window = datetime.now() - self._time_window_if_history_is_full
+ earliest_dt = min(earliest_dt, time_window)
+ return earliest_dt
def get_start_time(self) -> datetime:
return self._sync_start
diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py
index e219292d1..04c308b3e 100644
--- a/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py
+++ b/airbyte_cdk/sources/file_based/stream/concurrent/cursor/file_based_final_state_cursor.py
@@ -3,8 +3,9 @@
#
import logging
+from collections.abc import Iterable, MutableMapping
from datetime import datetime
-from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping, Optional
+from typing import TYPE_CHECKING, Any
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
@@ -18,6 +19,7 @@
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.types import Record
+
if TYPE_CHECKING:
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamPartition
@@ -25,12 +27,12 @@
class FileBasedFinalStateCursor(AbstractConcurrentFileBasedCursor):
"""Cursor that is used to guarantee at least one state message is emitted for a concurrent file-based stream."""
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream_config: FileBasedStreamConfig,
message_repository: MessageRepository,
- stream_namespace: Optional[str],
- **kwargs: Any,
+ stream_namespace: str | None,
+ **kwargs: Any, # noqa: ANN401, ARG002
):
self._stream_name = stream_config.name
self._stream_namespace = stream_namespace
@@ -50,25 +52,27 @@ def observe(self, record: Record) -> None:
def close_partition(self, partition: Partition) -> None:
pass
- def set_pending_partitions(self, partitions: List["FileBasedStreamPartition"]) -> None:
+ def set_pending_partitions(self, partitions: list["FileBasedStreamPartition"]) -> None:
pass
def add_file(self, file: RemoteFile) -> None:
pass
def get_files_to_sync(
- self, all_files: Iterable[RemoteFile], logger: logging.Logger
+ self,
+ all_files: Iterable[RemoteFile],
+ logger: logging.Logger, # noqa: ARG002
) -> Iterable[RemoteFile]:
return all_files
def get_state(self) -> MutableMapping[str, Any]:
return {}
- def set_initial_state(self, value: StreamState) -> None:
+ def set_initial_state(self, value: StreamState) -> None: # noqa: ARG002
return None
def get_start_time(self) -> datetime:
- return datetime.min
+ return datetime.min # noqa: DTZ901
def emit_state_message(self) -> None:
pass
diff --git a/airbyte_cdk/sources/file_based/stream/cursor/__init__.py b/airbyte_cdk/sources/file_based/stream/cursor/__init__.py
index c1bf15a5d..34d4899fc 100644
--- a/airbyte_cdk/sources/file_based/stream/cursor/__init__.py
+++ b/airbyte_cdk/sources/file_based/stream/cursor/__init__.py
@@ -1,4 +1,5 @@
from .abstract_file_based_cursor import AbstractFileBasedCursor
from .default_file_based_cursor import DefaultFileBasedCursor
+
__all__ = ["AbstractFileBasedCursor", "DefaultFileBasedCursor"]
diff --git a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py
index 4a5eadb4e..e1094df5b 100644
--- a/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py
+++ b/airbyte_cdk/sources/file_based/stream/cursor/abstract_file_based_cursor.py
@@ -4,8 +4,9 @@
import logging
from abc import ABC, abstractmethod
+from collections.abc import Iterable, MutableMapping
from datetime import datetime
-from typing import Any, Iterable, MutableMapping
+from typing import Any
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
@@ -18,7 +19,7 @@ class AbstractFileBasedCursor(ABC):
"""
@abstractmethod
- def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any):
+ def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any): # noqa: ANN204, ANN401
"""
Common interface for all cursors.
"""
diff --git a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py
index 08ad8c3ae..fa876a24a 100644
--- a/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py
+++ b/airbyte_cdk/sources/file_based/stream/cursor/default_file_based_cursor.py
@@ -3,8 +3,9 @@
#
import logging
+from collections.abc import Iterable, MutableMapping
from datetime import datetime, timedelta
-from typing import Any, Iterable, MutableMapping, Optional
+from typing import Any
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
@@ -20,7 +21,7 @@ class DefaultFileBasedCursor(AbstractFileBasedCursor):
DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
CURSOR_FIELD = "_ab_source_file_last_modified"
- def __init__(self, stream_config: FileBasedStreamConfig, **_: Any):
+ def __init__(self, stream_config: FileBasedStreamConfig, **_: Any): # noqa: ANN204, ANN401
super().__init__(stream_config) # type: ignore [safe-super]
self._file_to_datetime_history: MutableMapping[str, str] = {}
self._time_window_if_history_is_full = timedelta(
@@ -34,7 +35,7 @@ def __init__(self, stream_config: FileBasedStreamConfig, **_: Any):
)
self._start_time = self._compute_start_time()
- self._initial_earliest_file_in_history: Optional[RemoteFile] = None
+ self._initial_earliest_file_in_history: RemoteFile | None = None
def set_initial_state(self, value: StreamState) -> None:
self._file_to_datetime_history = value.get("history", {})
@@ -51,7 +52,7 @@ def add_file(self, file: RemoteFile) -> None:
if oldest_file:
del self._file_to_datetime_history[oldest_file.uri]
else:
- raise Exception(
+ raise Exception( # noqa: TRY002
"The history is full but there is no files in the history. This should never happen and might be indicative of a bug in the CDK."
)
@@ -59,7 +60,7 @@ def get_state(self) -> StreamState:
state = {"history": self._file_to_datetime_history, self.CURSOR_FIELD: self._get_cursor()}
return state
- def _get_cursor(self) -> Optional[str]:
+ def _get_cursor(self) -> str | None:
"""
Returns the cursor value.
@@ -68,7 +69,8 @@ def _get_cursor(self) -> Optional[str]:
"""
if self._file_to_datetime_history.items():
filename, timestamp = max(
- self._file_to_datetime_history.items(), key=lambda x: (x[1], x[0])
+ self._file_to_datetime_history.items(),
+ key=lambda x: (x[1], x[0]), # noqa: FURB118
)
return f"{timestamp}_{filename}"
return None
@@ -79,7 +81,7 @@ def _is_history_full(self) -> bool:
"""
return len(self._file_to_datetime_history) >= self.DEFAULT_MAX_HISTORY_SIZE
- def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool:
+ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool: # noqa: PLR0911
if file.uri in self._file_to_datetime_history:
# If the file's uri is in the history, we should sync the file if it has been modified since it was synced
updated_at_from_history = datetime.strptime(
@@ -99,16 +101,14 @@ def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool:
# If the history is partial and the file's datetime is strictly greater than the earliest file in the history,
# we should sync it
return True
- elif file.last_modified == self._initial_earliest_file_in_history.last_modified:
+ if file.last_modified == self._initial_earliest_file_in_history.last_modified:
# If the history is partial and the file's datetime is equal to the earliest file in the history,
# we should sync it if its uri is strictly greater than the earliest file in the history
return file.uri > self._initial_earliest_file_in_history.uri
- else:
- # Otherwise, only sync the file if it has been modified since the start of the time window
- return file.last_modified >= self.get_start_time()
- else:
- # The file is not in the history and the history is complete. We know we need to sync the file
- return True
+ # Otherwise, only sync the file if it has been modified since the start of the time window
+ return file.last_modified >= self.get_start_time()
+ # The file is not in the history and the history is complete. We know we need to sync the file
+ return True
def get_files_to_sync(
self, all_files: Iterable[RemoteFile], logger: logging.Logger
@@ -126,24 +126,23 @@ def get_files_to_sync(
def get_start_time(self) -> datetime:
return self._start_time
- def _compute_earliest_file_in_history(self) -> Optional[RemoteFile]:
+ def _compute_earliest_file_in_history(self) -> RemoteFile | None:
if self._file_to_datetime_history:
filename, last_modified = min(
- self._file_to_datetime_history.items(), key=lambda f: (f[1], f[0])
+ self._file_to_datetime_history.items(),
+ key=lambda f: (f[1], f[0]), # noqa: FURB118
)
return RemoteFile(
uri=filename, last_modified=datetime.strptime(last_modified, self.DATE_TIME_FORMAT)
)
- else:
- return None
+ return None
def _compute_start_time(self) -> datetime:
if not self._file_to_datetime_history:
- return datetime.min
- else:
- earliest = min(self._file_to_datetime_history.values())
- earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT)
- if self._is_history_full():
- time_window = datetime.now() - self._time_window_if_history_is_full
- earliest_dt = min(earliest_dt, time_window)
- return earliest_dt
+ return datetime.min # noqa: DTZ901
+ earliest = min(self._file_to_datetime_history.values())
+ earliest_dt = datetime.strptime(earliest, self.DATE_TIME_FORMAT)
+ if self._is_history_full():
+ time_window = datetime.now() - self._time_window_if_history_is_full
+ earliest_dt = min(earliest_dt, time_window)
+ return earliest_dt
diff --git a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py
index 604322549..6d414d5ad 100644
--- a/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py
+++ b/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py
@@ -6,10 +6,11 @@
import itertools
import traceback
from collections import defaultdict
+from collections.abc import Iterable, Mapping, MutableMapping
from copy import deepcopy
from functools import cache
from os import path
-from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
+from typing import Any
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, FailureType, Level
from airbyte_cdk.models import Type as MessageType
@@ -53,11 +54,11 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
ab_file_name_col = "_ab_source_file_url"
modified = "modified"
source_file_url = "source_file_url"
- airbyte_columns = [ab_last_mod_col, ab_file_name_col]
+ airbyte_columns = [ab_last_mod_col, ab_file_name_col] # noqa: RUF012
use_file_transfer = False
preserve_directory_structure = True
- def __init__(self, **kwargs: Any):
+ def __init__(self, **kwargs: Any): # noqa: ANN204, ANN401
if self.FILE_TRANSFER_KW in kwargs:
self.use_file_transfer = kwargs.pop(self.FILE_TRANSFER_KW, False)
if self.PRESERVE_DIRECTORY_STRUCTURE_KW in kwargs:
@@ -76,7 +77,7 @@ def state(self, value: MutableMapping[str, Any]) -> None:
self._cursor.set_initial_state(value)
@property # type: ignore # mypy complains wrong type, but AbstractFileBasedCursor is parent of file-based cursors
- def cursor(self) -> Optional[AbstractFileBasedCursor]:
+ def cursor(self) -> AbstractFileBasedCursor | None:
return self._cursor
@cursor.setter
@@ -94,8 +95,8 @@ def primary_key(self) -> PrimaryKeyType:
)
def _filter_schema_invalid_properties(
- self, configured_catalog_json_schema: Dict[str, Any]
- ) -> Dict[str, Any]:
+ self, configured_catalog_json_schema: dict[str, Any]
+ ) -> dict[str, Any]:
if self.use_file_transfer:
return {
"type": "object",
@@ -105,22 +106,21 @@ def _filter_schema_invalid_properties(
self.ab_file_name_col: {"type": "string"},
},
}
- else:
- return super()._filter_schema_invalid_properties(configured_catalog_json_schema)
+ return super()._filter_schema_invalid_properties(configured_catalog_json_schema)
def _duplicated_files_names(
- self, slices: List[dict[str, List[RemoteFile]]]
- ) -> List[dict[str, List[str]]]:
- seen_file_names: Dict[str, List[str]] = defaultdict(list)
+ self, slices: list[dict[str, list[RemoteFile]]]
+ ) -> list[dict[str, list[str]]]:
+ seen_file_names: dict[str, list[str]] = defaultdict(list)
for file_slice in slices:
for file_found in file_slice[self.FILES_KEY]:
- file_name = path.basename(file_found.uri)
+ file_name = path.basename(file_found.uri) # noqa: PTH119
seen_file_names[file_name].append(file_found.uri)
return [
{file_name: paths} for file_name, paths in seen_file_names.items() if len(paths) > 1
]
- def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]:
+ def compute_slices(self) -> Iterable[Mapping[str, Any] | None]:
# Sort files by last_modified, uri and return them grouped by last_modified
all_files = self.list_files()
files_to_read = self._cursor.get_files_to_sync(all_files, self.logger)
@@ -174,7 +174,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte
try:
if self.use_file_transfer:
self.logger.info(f"{self.name}: {file} file-based syncing")
- # todo: complete here the code to not rely on local parser
+ # TODO: complete here the code to not rely on local parser
file_transfer = FileTransfer()
for record in file_transfer.get_file(
self.config, file, self.stream_reader, self.logger
@@ -183,7 +183,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte
if not self.record_passes_validation_policy(record):
n_skipped += 1
continue
- record = self.transform_record_for_file_transfer(record, file)
+ record = self.transform_record_for_file_transfer(record, file) # noqa: PLW2901
yield stream_data_to_airbyte_message(
self.name, record, is_file_transfer_message=True
)
@@ -193,11 +193,11 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte
):
line_no += 1
if self.config.schemaless:
- record = {"data": record}
+ record = {"data": record} # noqa: PLW2901
elif not self.record_passes_validation_policy(record):
n_skipped += 1
continue
- record = self.transform_record(record, file, file_datetime_string)
+ record = self.transform_record(record, file, file_datetime_string) # noqa: PLW2901
yield stream_data_to_airbyte_message(self.name, record)
self._cursor.add_file(file)
@@ -227,7 +227,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte
except AirbyteTracedException as exc:
# Re-raise the exception to stop the whole sync immediately as this is a fatal error
- raise exc
+ raise exc # noqa: TRY201
except Exception:
yield AirbyteMessage(
@@ -250,14 +250,14 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte
)
@property
- def cursor_field(self) -> Union[str, List[str]]:
+ def cursor_field(self) -> str | list[str]:
"""
Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field.
:return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor.
"""
return self.ab_last_mod_col
- @cache
+ @cache # noqa: B019
def get_json_schema(self) -> JsonSchema:
extra_fields = {
self.ab_last_mod_col: {"type": "string"},
@@ -266,14 +266,14 @@ def get_json_schema(self) -> JsonSchema:
try:
schema = self._get_raw_json_schema()
except InvalidSchemaError as config_exception:
- raise AirbyteTracedException(
+ raise AirbyteTracedException( # noqa: B904
internal_message="Please check the logged errors for more information.",
message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value,
exception=AirbyteTracedException(exception=config_exception),
failure_type=FailureType.config_error,
)
except AirbyteTracedException as ate:
- raise ate
+ raise ate # noqa: TRY201
except Exception as exc:
raise SchemaInferenceError(
FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name
@@ -284,22 +284,21 @@ def get_json_schema(self) -> JsonSchema:
def _get_raw_json_schema(self) -> JsonSchema:
if self.use_file_transfer:
return file_transfer_schema
- elif self.config.input_schema:
+ if self.config.input_schema:
return self.config.get_input_schema() # type: ignore
- elif self.config.schemaless:
+ if self.config.schemaless:
return schemaless_schema
- else:
- files = self.list_files()
- first_n_files = len(files)
-
- if self.config.recent_n_files_to_read_for_schema_discovery:
- self.logger.info(
- msg=(
- f"Only first {self.config.recent_n_files_to_read_for_schema_discovery} files will be used to infer schema "
- f"for stream {self.name} due to limitation in config."
- )
+ files = self.list_files()
+ first_n_files = len(files)
+
+ if self.config.recent_n_files_to_read_for_schema_discovery:
+ self.logger.info(
+ msg=(
+ f"Only first {self.config.recent_n_files_to_read_for_schema_discovery} files will be used to infer schema "
+ f"for stream {self.name} due to limitation in config."
)
- first_n_files = self.config.recent_n_files_to_read_for_schema_discovery
+ )
+ first_n_files = self.config.recent_n_files_to_read_for_schema_discovery
if first_n_files == 0:
self.logger.warning(
@@ -341,7 +340,7 @@ def get_files(self) -> Iterable[RemoteFile]:
self.config.globs or [], self.config.legacy_prefix, self.logger
)
- def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
+ def infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]:
loop = asyncio.get_event_loop()
schema = loop.run_until_complete(self._infer_schema(files))
# as infer schema returns a Mapping that is assumed to be immutable, we need to create a deepcopy to avoid modifying the reference
@@ -354,7 +353,7 @@ def _fill_nulls(schema: Mapping[str, Any]) -> Mapping[str, Any]:
if k == "type":
if isinstance(v, list):
if "null" not in v:
- schema[k] = ["null"] + v
+ schema[k] = ["null"] + v # noqa: RUF005
elif v != "null":
schema[k] = ["null", v]
else:
@@ -364,7 +363,7 @@ def _fill_nulls(schema: Mapping[str, Any]) -> Mapping[str, Any]:
DefaultFileBasedStream._fill_nulls(item)
return schema
- async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
+ async def _infer_schema(self, files: list[RemoteFile]) -> Mapping[str, Any]:
"""
Infer the schema for a stream.
@@ -372,7 +371,7 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
Dispatch on file type.
"""
base_schema: SchemaType = {}
- pending_tasks: Set[asyncio.tasks.Task[SchemaType]] = set()
+ pending_tasks: set[asyncio.tasks.Task[SchemaType]] = set()
n_started, n_files = 0, len(files)
files_iterator = iter(files)
@@ -391,7 +390,7 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
try:
base_schema = merge_schemas(base_schema, task.result())
except AirbyteTracedException as ate:
- raise ate
+ raise ate # noqa: TRY201
except Exception as exc:
self.logger.error(
f"An error occurred inferring the schema. \n {traceback.format_exc()}",
@@ -406,7 +405,7 @@ async def _infer_file_schema(self, file: RemoteFile) -> SchemaType:
self.config, file, self.stream_reader, self.logger
)
except AirbyteTracedException as ate:
- raise ate
+ raise ate # noqa: TRY201
except Exception as exc:
raise SchemaInferenceError(
FileBasedSourceError.SCHEMA_INFERENCE_ERROR,
diff --git a/airbyte_cdk/sources/file_based/types.py b/airbyte_cdk/sources/file_based/types.py
index b83bf37a3..fa05978f7 100644
--- a/airbyte_cdk/sources/file_based/types.py
+++ b/airbyte_cdk/sources/file_based/types.py
@@ -1,10 +1,12 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from __future__ import annotations
-from typing import Any, Mapping, MutableMapping
+from collections.abc import Mapping, MutableMapping
+from typing import Any
+
StreamSlice = Mapping[str, Any]
StreamState = MutableMapping[str, Any]
diff --git a/airbyte_cdk/sources/http_logger.py b/airbyte_cdk/sources/http_logger.py
index 33ccc68ac..91b6078cf 100644
--- a/airbyte_cdk/sources/http_logger.py
+++ b/airbyte_cdk/sources/http_logger.py
@@ -2,7 +2,6 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Optional, Union
import requests
@@ -13,8 +12,8 @@ def format_http_message(
response: requests.Response,
title: str,
description: str,
- stream_name: Optional[str],
- is_auxiliary: bool | None = None,
+ stream_name: str | None,
+ is_auxiliary: bool | None = None, # noqa: FBT001
) -> LogMessage:
request = response.request
log_message = {
@@ -48,5 +47,5 @@ def format_http_message(
return log_message # type: ignore [return-value] # got "dict[str, object]", expected "dict[str, JsonType]"
-def _normalize_body_string(body_str: Optional[Union[str, bytes]]) -> Optional[str]:
- return body_str.decode() if isinstance(body_str, (bytes, bytearray)) else body_str
+def _normalize_body_string(body_str: str | bytes | None) -> str | None:
+ return body_str.decode() if isinstance(body_str, (bytes, bytearray)) else body_str # noqa: UP038
diff --git a/airbyte_cdk/sources/message/__init__.py b/airbyte_cdk/sources/message/__init__.py
index 31c484ab5..fc43482f4 100644
--- a/airbyte_cdk/sources/message/__init__.py
+++ b/airbyte_cdk/sources/message/__init__.py
@@ -10,6 +10,7 @@
NoopMessageRepository,
)
+
__all__ = [
"InMemoryMessageRepository",
"LogAppenderMessageRepositoryDecorator",
diff --git a/airbyte_cdk/sources/message/repository.py b/airbyte_cdk/sources/message/repository.py
index 2fc156e8c..952b0b882 100644
--- a/airbyte_cdk/sources/message/repository.py
+++ b/airbyte_cdk/sources/message/repository.py
@@ -6,12 +6,13 @@
import logging
from abc import ABC, abstractmethod
from collections import deque
-from typing import Callable, Deque, Iterable, List, Optional
+from collections.abc import Callable, Iterable
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
+
_LOGGER = logging.getLogger("MessageRepository")
_SUPPORTED_MESSAGE_TYPES = {Type.CONTROL, Type.LOG}
LogMessage = dict[str, JsonType]
@@ -45,7 +46,7 @@ def _is_severe_enough(threshold: Level, level: Level) -> bool:
class MessageRepository(ABC):
@abstractmethod
def emit_message(self, message: AirbyteMessage) -> None:
- raise NotImplementedError()
+ raise NotImplementedError
@abstractmethod
def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None:
@@ -53,11 +54,11 @@ def log_message(self, level: Level, message_provider: Callable[[], LogMessage])
Computing messages can be resource consuming. This method is specialized for logging because we want to allow for lazy evaluation if
the log level is less severe than what is configured
"""
- raise NotImplementedError()
+ raise NotImplementedError
@abstractmethod
def consume_queue(self) -> Iterable[AirbyteMessage]:
- raise NotImplementedError()
+ raise NotImplementedError
class NoopMessageRepository(MessageRepository):
@@ -73,7 +74,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]:
class InMemoryMessageRepository(MessageRepository):
def __init__(self, log_level: Level = Level.INFO) -> None:
- self._message_queue: Deque[AirbyteMessage] = deque()
+ self._message_queue: deque[AirbyteMessage] = deque()
self._log_level = log_level
def emit_message(self, message: AirbyteMessage) -> None:
@@ -96,7 +97,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]:
class LogAppenderMessageRepositoryDecorator(MessageRepository):
- def __init__(
+ def __init__( # noqa: ANN204
self,
dict_to_append: LogMessage,
decorated: MessageRepository,
@@ -119,7 +120,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]:
return self._decorated.consume_queue()
def _append_second_to_first(
- self, first: LogMessage, second: LogMessage, path: Optional[List[str]] = None
+ self, first: LogMessage, second: LogMessage, path: list[str] | None = None
) -> LogMessage:
if path is None:
path = []
@@ -127,10 +128,10 @@ def _append_second_to_first(
for key in second:
if key in first:
if isinstance(first[key], dict) and isinstance(second[key], dict):
- self._append_second_to_first(first[key], second[key], path + [str(key)]) # type: ignore # type is verified above
+ self._append_second_to_first(first[key], second[key], path + [str(key)]) # type: ignore # type is verified above # noqa: RUF005
else:
if first[key] != second[key]:
- _LOGGER.warning("Conflict at %s" % ".".join(path + [str(key)]))
+ _LOGGER.warning("Conflict at %s" % ".".join(path + [str(key)])) # noqa: UP031, RUF005
first[key] = second[key]
else:
first[key] = second[key]
diff --git a/airbyte_cdk/sources/source.py b/airbyte_cdk/sources/source.py
index 2958d82ca..8336b7d1b 100644
--- a/airbyte_cdk/sources/source.py
+++ b/airbyte_cdk/sources/source.py
@@ -5,7 +5,8 @@
import logging
from abc import ABC, abstractmethod
-from typing import Any, Generic, Iterable, List, Mapping, Optional, TypeVar
+from collections.abc import Iterable, Mapping
+from typing import Any, Generic, TypeVar
from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig
from airbyte_cdk.models import (
@@ -17,6 +18,7 @@
ConfiguredAirbyteCatalogSerializer,
)
+
TState = TypeVar("TState")
TCatalog = TypeVar("TCatalog")
@@ -38,7 +40,7 @@ def read(
logger: logging.Logger,
config: TConfig,
catalog: TCatalog,
- state: Optional[TState] = None,
+ state: TState | None = None,
) -> Iterable[AirbyteMessage]:
"""
Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state.
@@ -54,12 +56,12 @@ def discover(self, logger: logging.Logger, config: TConfig) -> AirbyteCatalog:
class Source(
DefaultConnectorMixin,
- BaseSource[Mapping[str, Any], List[AirbyteStateMessage], ConfiguredAirbyteCatalog],
+ BaseSource[Mapping[str, Any], list[AirbyteStateMessage], ConfiguredAirbyteCatalog],
ABC,
):
# can be overridden to change an input state.
@classmethod
- def read_state(cls, state_path: str) -> List[AirbyteStateMessage]:
+ def read_state(cls, state_path: str) -> list[AirbyteStateMessage]:
"""
Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either
a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the
@@ -69,7 +71,7 @@ def read_state(cls, state_path: str) -> List[AirbyteStateMessage]:
"""
parsed_state_messages = []
if state_path:
- state_obj = BaseConnector._read_json_file(state_path)
+ state_obj = BaseConnector._read_json_file(state_path) # noqa: SLF001
if state_obj:
for state in state_obj: # type: ignore # `isinstance(state_obj, List)` ensures that this is a list
parsed_message = AirbyteStateMessageSerializer.load(state)
diff --git a/airbyte_cdk/sources/streams/__init__.py b/airbyte_cdk/sources/streams/__init__.py
index dc735b617..eec23592e 100644
--- a/airbyte_cdk/sources/streams/__init__.py
+++ b/airbyte_cdk/sources/streams/__init__.py
@@ -5,4 +5,5 @@
# Initialize Streams Package
from .core import NO_CURSOR_STATE_KEY, CheckpointMixin, IncrementalMixin, Stream
+
__all__ = ["NO_CURSOR_STATE_KEY", "IncrementalMixin", "CheckpointMixin", "Stream"]
diff --git a/airbyte_cdk/sources/streams/availability_strategy.py b/airbyte_cdk/sources/streams/availability_strategy.py
index 312ddae19..efdeb5837 100644
--- a/airbyte_cdk/sources/streams/availability_strategy.py
+++ b/airbyte_cdk/sources/streams/availability_strategy.py
@@ -5,11 +5,13 @@
import logging
import typing
from abc import ABC, abstractmethod
-from typing import Any, Mapping, Optional, Tuple
+from collections.abc import Mapping
+from typing import Any, Optional
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.streams.core import Stream, StreamData
+
if typing.TYPE_CHECKING:
from airbyte_cdk.sources import Source
@@ -22,7 +24,7 @@ class AvailabilityStrategy(ABC):
@abstractmethod
def check_availability(
self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None
- ) -> Tuple[bool, Optional[str]]:
+ ) -> tuple[bool, str | None]:
"""
Checks stream availability.
@@ -36,7 +38,7 @@ def check_availability(
"""
@staticmethod
- def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]:
+ def get_first_stream_slice(stream: Stream) -> Mapping[str, Any] | None:
"""
Gets the first stream_slice from a given stream's stream_slices.
:param stream: stream
@@ -55,7 +57,7 @@ def get_first_stream_slice(stream: Stream) -> Optional[Mapping[str, Any]]:
@staticmethod
def get_first_record_for_slice(
- stream: Stream, stream_slice: Optional[Mapping[str, Any]]
+ stream: Stream, stream_slice: Mapping[str, Any] | None
) -> StreamData:
"""
Gets the first record for a stream_slice of a stream.
diff --git a/airbyte_cdk/sources/streams/call_rate.py b/airbyte_cdk/sources/streams/call_rate.py
index 81ebac78e..bc1c62a7a 100644
--- a/airbyte_cdk/sources/streams/call_rate.py
+++ b/airbyte_cdk/sources/streams/call_rate.py
@@ -7,9 +7,10 @@
import datetime
import logging
import time
+from collections.abc import Mapping
from datetime import timedelta
from threading import RLock
-from typing import TYPE_CHECKING, Any, Mapping, Optional
+from typing import TYPE_CHECKING, Any
from urllib import parse
import requests
@@ -18,6 +19,7 @@
from pyrate_limiter import Rate as PyRateRate
from pyrate_limiter.exceptions import BucketFullException
+
# prevents mypy from complaining about missing session attributes in LimiterMixin
if TYPE_CHECKING:
MIXIN_BASE = requests.Session
@@ -36,7 +38,7 @@ class Rate:
class CallRateLimitHit(Exception):
- def __init__(self, error: str, item: Any, weight: int, rate: str, time_to_wait: timedelta):
+ def __init__(self, error: str, item: Any, weight: int, rate: str, time_to_wait: timedelta): # noqa: ANN204, ANN401
"""Constructor
:param error: error message
@@ -58,7 +60,7 @@ class AbstractCallRatePolicy(abc.ABC):
"""
@abc.abstractmethod
- def matches(self, request: Any) -> bool:
+ def matches(self, request: Any) -> bool: # noqa: ANN401
"""Tells if this policy matches specific request and should apply to it
:param request:
@@ -66,7 +68,7 @@ def matches(self, request: Any) -> bool:
"""
@abc.abstractmethod
- def try_acquire(self, request: Any, weight: int) -> None:
+ def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401
"""Try to acquire request
:param request: a request object representing a single call to API
@@ -75,9 +77,7 @@ def try_acquire(self, request: Any, weight: int) -> None:
"""
@abc.abstractmethod
- def update(
- self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]
- ) -> None:
+ def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None:
"""Update call rate counting with current values
:param available_calls:
@@ -89,7 +89,7 @@ class RequestMatcher(abc.ABC):
"""Callable that help to match a request object with call rate policies."""
@abc.abstractmethod
- def __call__(self, request: Any) -> bool:
+ def __call__(self, request: Any) -> bool: # noqa: ANN401
"""
:param request:
@@ -100,12 +100,12 @@ def __call__(self, request: Any) -> bool:
class HttpRequestMatcher(RequestMatcher):
"""Simple implementation of RequestMatcher for http requests case"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
- method: Optional[str] = None,
- url: Optional[str] = None,
- params: Optional[Mapping[str, Any]] = None,
- headers: Optional[Mapping[str, Any]] = None,
+ method: str | None = None,
+ url: str | None = None,
+ params: Mapping[str, Any] | None = None,
+ headers: Mapping[str, Any] | None = None,
):
"""Constructor
@@ -129,7 +129,7 @@ def _match_dict(obj: Mapping[str, Any], pattern: Mapping[str, Any]) -> bool:
"""
return pattern.items() <= obj.items()
- def __call__(self, request: Any) -> bool:
+ def __call__(self, request: Any) -> bool: # noqa: ANN401
"""
:param request:
@@ -142,7 +142,7 @@ def __call__(self, request: Any) -> bool:
else:
return False
- if self._method is not None:
+ if self._method is not None: # noqa: SIM102
if prepared_request.method != self._method:
return False
if self._url is not None and prepared_request.url is not None:
@@ -154,17 +154,17 @@ def __call__(self, request: Any) -> bool:
params = dict(parse.parse_qsl(str(parsed_url.query)))
if not self._match_dict(params, self._params):
return False
- if self._headers is not None:
+ if self._headers is not None: # noqa: SIM102
if not self._match_dict(prepared_request.headers, self._headers):
return False
return True
class BaseCallRatePolicy(AbstractCallRatePolicy, abc.ABC):
- def __init__(self, matchers: list[RequestMatcher]):
+ def __init__(self, matchers: list[RequestMatcher]): # noqa: ANN204
self._matchers = matchers
- def matches(self, request: Any) -> bool:
+ def matches(self, request: Any) -> bool: # noqa: ANN401
"""Tell if this policy matches specific request and should apply to it
:param request:
@@ -200,17 +200,15 @@ class UnlimitedCallRatePolicy(BaseCallRatePolicy):
The code above will limit all calls to /some/method except calls that have header sandbox=True
"""
- def try_acquire(self, request: Any, weight: int) -> None:
+ def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401
"""Do nothing"""
- def update(
- self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]
- ) -> None:
+ def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None:
"""Do nothing"""
class FixedWindowCallRatePolicy(BaseCallRatePolicy):
- def __init__(
+ def __init__( # noqa: ANN204
self,
next_reset_ts: datetime.datetime,
period: timedelta,
@@ -232,7 +230,7 @@ def __init__(
self._lock = RLock()
super().__init__(matchers=matchers)
- def try_acquire(self, request: Any, weight: int) -> None:
+ def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401
if weight > self._call_limit:
raise ValueError("Weight can not exceed the call limit")
if not self.matches(request):
@@ -257,9 +255,7 @@ def try_acquire(self, request: Any, weight: int) -> None:
self._calls_num += weight
- def update(
- self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]
- ) -> None:
+ def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None:
"""Update call rate counters, by default, only reacts to decreasing updates of available_calls and changes to call_reset_ts.
We ignore updates with available_calls > current_available_calls to support call rate limits that are lower than API limits.
@@ -290,7 +286,7 @@ def _update_current_window(self) -> None:
now = datetime.datetime.now()
if now > self._next_reset_ts:
logger.debug("started new window, %s calls available now", self._call_limit)
- self._next_reset_ts = self._next_reset_ts + self._offset
+ self._next_reset_ts = self._next_reset_ts + self._offset # noqa: PLR6104
self._calls_num = 0
@@ -302,7 +298,7 @@ class MovingWindowCallRatePolicy(BaseCallRatePolicy):
This strategy requires saving of timestamps of all requests within a window.
"""
- def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]):
+ def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]): # noqa: ANN204
"""Constructor
:param rates: list of rates, the order is important and must be ascending
@@ -319,7 +315,7 @@ def __init__(self, rates: list[Rate], matchers: list[RequestMatcher]):
self._limiter = Limiter(self._bucket)
super().__init__(matchers=matchers)
- def try_acquire(self, request: Any, weight: int) -> None:
+ def try_acquire(self, request: Any, weight: int) -> None: # noqa: ANN401
if not self.matches(request):
raise ValueError("Request does not match the policy")
@@ -333,7 +329,7 @@ def try_acquire(self, request: Any, weight: int) -> None:
time_to_wait = self._bucket.waiting(item)
assert isinstance(time_to_wait, int)
- raise CallRateLimitHit(
+ raise CallRateLimitHit( # noqa: B904
error=str(exc.meta_info["error"]),
item=request,
weight=int(exc.meta_info["weight"]),
@@ -341,16 +337,14 @@ def try_acquire(self, request: Any, weight: int) -> None:
time_to_wait=timedelta(milliseconds=time_to_wait),
)
- def update(
- self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime]
- ) -> None:
+ def update(self, available_calls: int | None, call_reset_ts: datetime.datetime | None) -> None:
"""Adjust call bucket to reflect the state of the API server
:param available_calls:
:param call_reset_ts:
:return:
"""
- if (
+ if ( # noqa: SIM102
available_calls is not None and call_reset_ts is None
): # we do our best to sync buckets with API
if available_calls == 0:
@@ -376,7 +370,11 @@ class AbstractAPIBudget(abc.ABC):
@abc.abstractmethod
def acquire_call(
- self, request: Any, block: bool = True, timeout: Optional[float] = None
+ self,
+ request: Any, # noqa: ANN401
+ *,
+ block: bool = True,
+ timeout: float | None = None,
) -> None:
"""Try to get a call from budget, will block by default
@@ -387,11 +385,11 @@ def acquire_call(
"""
@abc.abstractmethod
- def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]:
+ def get_matching_policy(self, request: Any) -> AbstractCallRatePolicy | None: # noqa: ANN401
"""Find matching call rate policy for specific request"""
@abc.abstractmethod
- def update_from_response(self, request: Any, response: Any) -> None:
+ def update_from_response(self, request: Any, response: Any) -> None: # noqa: ANN401
"""Update budget information based on response from API
:param request: the initial request that triggered this response
@@ -415,14 +413,18 @@ def __init__(
self._policies = policies
self._maximum_attempts_to_acquire = maximum_attempts_to_acquire
- def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]:
+ def get_matching_policy(self, request: Any) -> AbstractCallRatePolicy | None: # noqa: ANN401
for policy in self._policies:
if policy.matches(request):
return policy
return None
def acquire_call(
- self, request: Any, block: bool = True, timeout: Optional[float] = None
+ self,
+ request: Any, # noqa: ANN401
+ *,
+ block: bool = True,
+ timeout: float | None = None,
) -> None:
"""Try to get a call from budget, will block by default.
Matchers will be called sequentially in the same order they were added.
@@ -440,7 +442,7 @@ def acquire_call(
elif self._policies:
logger.info("no policies matched with requests, allow call by default")
- def update_from_response(self, request: Any, response: Any) -> None:
+ def update_from_response(self, request: Any, response: Any) -> None: # noqa: ANN401
"""Update budget information based on response from API
:param request: the initial request that triggered this response
@@ -449,7 +451,12 @@ def update_from_response(self, request: Any, response: Any) -> None:
pass
def _do_acquire(
- self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: Optional[float]
+ self,
+ request: Any, # noqa: ANN401
+ policy: AbstractCallRatePolicy,
+ *,
+ block: bool,
+ timeout: float | None,
) -> None:
"""Internal method to try to acquire a call credit
@@ -460,10 +467,10 @@ def _do_acquire(
"""
last_exception = None
# sometimes we spend all budget before a second attempt, so we have few more here
- for attempt in range(1, self._maximum_attempts_to_acquire):
+ for attempt in range(1, self._maximum_attempts_to_acquire): # noqa: B007
try:
policy.try_acquire(request, weight=1)
- return
+ return # noqa: TRY300
except CallRateLimitHit as exc:
last_exception = exc
if block:
@@ -492,12 +499,12 @@ def _do_acquire(
class HttpAPIBudget(APIBudget):
"""Implementation of AbstractAPIBudget for HTTP"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
ratelimit_reset_header: str = "ratelimit-reset",
ratelimit_remaining_header: str = "ratelimit-remaining",
status_codes_for_ratelimit_hit: tuple[int] = (429,),
- **kwargs: Any,
+ **kwargs: Any, # noqa: ANN401
):
"""Constructor
@@ -510,7 +517,7 @@ def __init__(
self._status_codes_for_ratelimit_hit = status_codes_for_ratelimit_hit
super().__init__(**kwargs)
- def update_from_response(self, request: Any, response: Any) -> None:
+ def update_from_response(self, request: Any, response: Any) -> None: # noqa: ANN401
policy = self.get_matching_policy(request)
if not policy:
return
@@ -520,16 +527,14 @@ def update_from_response(self, request: Any, response: Any) -> None:
reset_ts = self.get_reset_ts_from_response(response)
policy.update(available_calls=available_calls, call_reset_ts=reset_ts)
- def get_reset_ts_from_response(
- self, response: requests.Response
- ) -> Optional[datetime.datetime]:
+ def get_reset_ts_from_response(self, response: requests.Response) -> datetime.datetime | None:
if response.headers.get(self._ratelimit_reset_header):
- return datetime.datetime.fromtimestamp(
+ return datetime.datetime.fromtimestamp( # noqa: DTZ006
int(response.headers[self._ratelimit_reset_header])
)
return None
- def get_calls_left_from_response(self, response: requests.Response) -> Optional[int]:
+ def get_calls_left_from_response(self, response: requests.Response) -> int | None:
if response.headers.get(self._ratelimit_remaining_header):
return int(response.headers[self._ratelimit_remaining_header])
@@ -542,15 +547,15 @@ def get_calls_left_from_response(self, response: requests.Response) -> Optional[
class LimiterMixin(MIXIN_BASE):
"""Mixin class that adds rate-limiting behavior to requests."""
- def __init__(
+ def __init__( # noqa: ANN204
self,
api_budget: AbstractAPIBudget,
- **kwargs: Any,
+ **kwargs: Any, # noqa: ANN401
):
self._api_budget = api_budget
super().__init__(**kwargs) # type: ignore # Base Session doesn't take any kwargs
- def send(self, request: requests.PreparedRequest, **kwargs: Any) -> requests.Response:
+ def send(self, request: requests.PreparedRequest, **kwargs: Any) -> requests.Response: # noqa: ANN401
"""Send a request with rate-limiting."""
self._api_budget.acquire_call(request)
response = super().send(request, **kwargs)
diff --git a/airbyte_cdk/sources/streams/checkpoint/__init__.py b/airbyte_cdk/sources/streams/checkpoint/__init__.py
index ae4e0e46f..866df6875 100644
--- a/airbyte_cdk/sources/streams/checkpoint/__init__.py
+++ b/airbyte_cdk/sources/streams/checkpoint/__init__.py
@@ -13,6 +13,7 @@
from .cursor import Cursor
from .resumable_full_refresh_cursor import ResumableFullRefreshCursor
+
__all__ = [
"CheckpointMode",
"CheckpointReader",
diff --git a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py
index 6e4ef98d7..20ff03708 100644
--- a/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py
+++ b/airbyte_cdk/sources/streams/checkpoint/checkpoint_reader.py
@@ -1,12 +1,12 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from abc import ABC, abstractmethod
+from collections.abc import Iterable, Mapping
from enum import Enum
-from typing import Any, Iterable, Mapping, Optional
-
-from airbyte_cdk.sources.types import StreamSlice
+from typing import Any
from .cursor import Cursor
+from airbyte_cdk.sources.types import StreamSlice
class CheckpointMode(Enum):
@@ -25,7 +25,7 @@ class CheckpointReader(ABC):
"""
@abstractmethod
- def next(self) -> Optional[Mapping[str, Any]]:
+ def next(self) -> Mapping[str, Any] | None:
"""
Returns the next slice that will be used to fetch the next group of records. Returning None indicates that the reader
has finished iterating over all slices.
@@ -41,7 +41,7 @@ def observe(self, new_state: Mapping[str, Any]) -> None:
"""
@abstractmethod
- def get_checkpoint(self) -> Optional[Mapping[str, Any]]:
+ def get_checkpoint(self) -> Mapping[str, Any] | None:
"""
Retrieves the current state value of the stream. The connector does not emit state messages if the checkpoint value is None.
"""
@@ -53,18 +53,18 @@ class IncrementalCheckpointReader(CheckpointReader):
before syncing data.
"""
- def __init__(
- self, stream_state: Mapping[str, Any], stream_slices: Iterable[Optional[Mapping[str, Any]]]
+ def __init__( # noqa: ANN204
+ self, stream_state: Mapping[str, Any], stream_slices: Iterable[Mapping[str, Any] | None]
):
- self._state: Optional[Mapping[str, Any]] = stream_state
+ self._state: Mapping[str, Any] | None = stream_state
self._stream_slices = iter(stream_slices)
self._has_slices = False
- def next(self) -> Optional[Mapping[str, Any]]:
+ def next(self) -> Mapping[str, Any] | None:
try:
next_slice = next(self._stream_slices)
self._has_slices = True
- return next_slice
+ return next_slice # noqa: TRY300
except StopIteration:
# This is used to avoid sending a duplicate state message at the end of a sync since the stream has already
# emitted state at the end of each slice. If we want to avoid this extra complexity, we can also just accept
@@ -76,7 +76,7 @@ def next(self) -> Optional[Mapping[str, Any]]:
def observe(self, new_state: Mapping[str, Any]) -> None:
self._state = new_state
- def get_checkpoint(self) -> Optional[Mapping[str, Any]]:
+ def get_checkpoint(self) -> Mapping[str, Any] | None:
return self._state
@@ -89,25 +89,25 @@ class CursorBasedCheckpointReader(CheckpointReader):
that belongs to the Concurrent CDK.
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
cursor: Cursor,
- stream_slices: Iterable[Optional[Mapping[str, Any]]],
- read_state_from_cursor: bool = False,
+ stream_slices: Iterable[Mapping[str, Any] | None],
+ read_state_from_cursor: bool = False, # noqa: FBT001, FBT002
):
self._cursor = cursor
self._stream_slices = iter(stream_slices)
# read_state_from_cursor is used to delineate that partitions should determine when to stop syncing dynamically according
# to the value of the state at runtime. This currently only applies to streams that use resumable full refresh.
self._read_state_from_cursor = read_state_from_cursor
- self._current_slice: Optional[StreamSlice] = None
+ self._current_slice: StreamSlice | None = None
self._finished_sync = False
- self._previous_state: Optional[Mapping[str, Any]] = None
+ self._previous_state: Mapping[str, Any] | None = None
- def next(self) -> Optional[Mapping[str, Any]]:
+ def next(self) -> Mapping[str, Any] | None:
try:
self.current_slice = self._find_next_slice()
- return self.current_slice
+ return self.current_slice # noqa: TRY300
except StopIteration:
self._finished_sync = True
return None
@@ -117,14 +117,13 @@ def observe(self, new_state: Mapping[str, Any]) -> None:
# while processing records
pass
- def get_checkpoint(self) -> Optional[Mapping[str, Any]]:
+ def get_checkpoint(self) -> Mapping[str, Any] | None:
# This is used to avoid sending a duplicate state messages
new_state = self._cursor.get_stream_state()
if new_state != self._previous_state:
self._previous_state = new_state
return new_state
- else:
- return None
+ return None
def _find_next_slice(self) -> StreamSlice:
"""
@@ -165,36 +164,34 @@ def _find_next_slice(self) -> StreamSlice:
partition=next_slice.partition,
extra_fields=next_slice.extra_fields,
)
- else:
- state_for_slice = self._cursor.select_state(self.current_slice)
- if state_for_slice == FULL_REFRESH_COMPLETE_STATE:
- # If the current slice is is complete, move to the next slice and skip the next slices that already
- # have the terminal complete value indicating that a previous attempt was successfully read.
- # Dummy initialization for mypy since we'll iterate at least once to get the next slice
- next_candidate_slice = StreamSlice(cursor_slice={}, partition={})
- has_more = True
- while has_more:
- next_candidate_slice = self.read_and_convert_slice()
- state_for_slice = self._cursor.select_state(next_candidate_slice)
- has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE
- return StreamSlice(
- cursor_slice=state_for_slice or {},
- partition=next_candidate_slice.partition,
- extra_fields=next_candidate_slice.extra_fields,
- )
- # The reader continues to process the current partition if it's state is still in progress
+ state_for_slice = self._cursor.select_state(self.current_slice)
+ if state_for_slice == FULL_REFRESH_COMPLETE_STATE:
+ # If the current slice is is complete, move to the next slice and skip the next slices that already
+ # have the terminal complete value indicating that a previous attempt was successfully read.
+ # Dummy initialization for mypy since we'll iterate at least once to get the next slice
+ next_candidate_slice = StreamSlice(cursor_slice={}, partition={})
+ has_more = True
+ while has_more:
+ next_candidate_slice = self.read_and_convert_slice()
+ state_for_slice = self._cursor.select_state(next_candidate_slice)
+ has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE
return StreamSlice(
cursor_slice=state_for_slice or {},
- partition=self.current_slice.partition,
- extra_fields=self.current_slice.extra_fields,
+ partition=next_candidate_slice.partition,
+ extra_fields=next_candidate_slice.extra_fields,
)
- else:
- # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate
- # on a fixed set of slices determined before reading records. They just iterate to the next slice
- return self.read_and_convert_slice()
+ # The reader continues to process the current partition if it's state is still in progress
+ return StreamSlice(
+ cursor_slice=state_for_slice or {},
+ partition=self.current_slice.partition,
+ extra_fields=self.current_slice.extra_fields,
+ )
+ # Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate
+ # on a fixed set of slices determined before reading records. They just iterate to the next slice
+ return self.read_and_convert_slice()
@property
- def current_slice(self) -> Optional[StreamSlice]:
+ def current_slice(self) -> StreamSlice | None:
return self._current_slice
@current_slice.setter
@@ -204,7 +201,7 @@ def current_slice(self, value: StreamSlice) -> None:
def read_and_convert_slice(self) -> StreamSlice:
next_slice = next(self._stream_slices)
if not isinstance(next_slice, StreamSlice):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"{self.current_slice} should be of type StreamSlice. This is likely a bug in the CDK, please contact Airbyte support"
)
return next_slice
@@ -231,11 +228,11 @@ class LegacyCursorBasedCheckpointReader(CursorBasedCheckpointReader):
}
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
cursor: Cursor,
- stream_slices: Iterable[Optional[Mapping[str, Any]]],
- read_state_from_cursor: bool = False,
+ stream_slices: Iterable[Mapping[str, Any] | None],
+ read_state_from_cursor: bool = False, # noqa: FBT001, FBT002
):
super().__init__(
cursor=cursor,
@@ -243,13 +240,13 @@ def __init__(
read_state_from_cursor=read_state_from_cursor,
)
- def next(self) -> Optional[Mapping[str, Any]]:
+ def next(self) -> Mapping[str, Any] | None:
try:
self.current_slice = self._find_next_slice()
if "partition" in dict(self.current_slice):
raise ValueError("Stream is configured to use invalid stream slice key 'partition'")
- elif "cursor_slice" in dict(self.current_slice):
+ if "cursor_slice" in dict(self.current_slice):
raise ValueError(
"Stream is configured to use invalid stream slice key 'cursor_slice'"
)
@@ -268,7 +265,7 @@ def next(self) -> Optional[Mapping[str, Any]]:
def read_and_convert_slice(self) -> StreamSlice:
next_mapping_slice = next(self._stream_slices)
if not isinstance(next_mapping_slice, Mapping):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"{self.current_slice} should be of type Mapping. This is likely a bug in the CDK, please contact Airbyte support"
)
@@ -287,25 +284,24 @@ class ResumableFullRefreshCheckpointReader(CheckpointReader):
fetching more pages or stopping the sync.
"""
- def __init__(self, stream_state: Mapping[str, Any]):
+ def __init__(self, stream_state: Mapping[str, Any]): # noqa: ANN204
# The first attempt of an RFR stream has an empty {} incoming state, but should still make a first attempt to read records
# from the first page in next().
self._first_page = bool(stream_state == {})
self._state: Mapping[str, Any] = stream_state
- def next(self) -> Optional[Mapping[str, Any]]:
+ def next(self) -> Mapping[str, Any] | None:
if self._first_page:
self._first_page = False
return self._state
- elif self._state == FULL_REFRESH_COMPLETE_STATE:
+ if self._state == FULL_REFRESH_COMPLETE_STATE:
return None
- else:
- return self._state
+ return self._state
def observe(self, new_state: Mapping[str, Any]) -> None:
self._state = new_state
- def get_checkpoint(self) -> Optional[Mapping[str, Any]]:
+ def get_checkpoint(self) -> Mapping[str, Any] | None:
return self._state or {}
@@ -315,11 +311,11 @@ class FullRefreshCheckpointReader(CheckpointReader):
is not capable of managing state. At the end of a sync, a final state message is emitted to signal completion.
"""
- def __init__(self, stream_slices: Iterable[Optional[Mapping[str, Any]]]):
+ def __init__(self, stream_slices: Iterable[Mapping[str, Any] | None]): # noqa: ANN204
self._stream_slices = iter(stream_slices)
self._final_checkpoint = False
- def next(self) -> Optional[Mapping[str, Any]]:
+ def next(self) -> Mapping[str, Any] | None:
try:
return next(self._stream_slices)
except StopIteration:
@@ -329,7 +325,7 @@ def next(self) -> Optional[Mapping[str, Any]]:
def observe(self, new_state: Mapping[str, Any]) -> None:
pass
- def get_checkpoint(self) -> Optional[Mapping[str, Any]]:
+ def get_checkpoint(self) -> Mapping[str, Any] | None:
if self._final_checkpoint:
return {"__ab_no_cursor_state_message": True}
return None
diff --git a/airbyte_cdk/sources/streams/checkpoint/cursor.py b/airbyte_cdk/sources/streams/checkpoint/cursor.py
index 6d758bf4e..1f57b0482 100644
--- a/airbyte_cdk/sources/streams/checkpoint/cursor.py
+++ b/airbyte_cdk/sources/streams/checkpoint/cursor.py
@@ -3,7 +3,7 @@
#
from abc import ABC, abstractmethod
-from typing import Any, Optional
+from typing import Any
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
@@ -23,7 +23,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
:param stream_state: The state of the stream as returned by get_stream_state
"""
- def observe(self, stream_slice: StreamSlice, record: Record) -> None:
+ def observe(self, stream_slice: StreamSlice, record: Record) -> None: # noqa: B027
"""
Register a record with the cursor; the cursor instance can then use it to manage the state of the in-progress stream read.
@@ -34,7 +34,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
pass
@abstractmethod
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401
"""
Update state based on the stream slice. Note that `stream_slice.cursor_slice` and `most_recent_record.associated_slice` are expected
to be the same but we make it explicit here that `stream_slice` should be leveraged to update the state. We do not pass in the
@@ -69,7 +69,7 @@ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
"""
@abstractmethod
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None:
"""
Get the state value of a specific stream_slice. For incremental or resumable full refresh cursors which only manage state in
a single dimension this is the entire state object. For per-partition cursors used by substreams, this returns the state of
diff --git a/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py b/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py
index e0dee4a92..3acfee945 100644
--- a/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py
+++ b/airbyte_cdk/sources/streams/checkpoint/per_partition_key_serializer.py
@@ -1,7 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
import json
-from typing import Any, Mapping
+from collections.abc import Mapping
+from typing import Any
class PerPartitionKeySerializer:
@@ -13,10 +14,10 @@ class PerPartitionKeySerializer:
"""
@staticmethod
- def to_partition_key(to_serialize: Any) -> str:
+ def to_partition_key(to_serialize: Any) -> str: # noqa: ANN401
# separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value
return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True)
@staticmethod
- def to_partition(to_deserialize: Any) -> Mapping[str, Any]:
+ def to_partition(to_deserialize: Any) -> Mapping[str, Any]: # noqa: ANN401
return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any
diff --git a/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py
index 86abd253f..23fa9f31a 100644
--- a/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py
+++ b/airbyte_cdk/sources/streams/checkpoint/resumable_full_refresh_cursor.py
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from dataclasses import dataclass
-from typing import Any, Optional
+from typing import Any
from airbyte_cdk.sources.streams.checkpoint import Cursor
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
@@ -30,22 +30,22 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
"""
pass
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002
self._cursor = stream_slice.cursor_slice
- def should_be_synced(self, record: Record) -> bool:
+ def should_be_synced(self, record: Record) -> bool: # noqa: ARG002
"""
Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages
that don't have filterable bounds. We should always return them.
"""
return True
- def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
+ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: # noqa: ARG002
"""
RFR record don't have ordering to be compared between one another.
"""
return False
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None: # noqa: ARG002
# A top-level RFR cursor only manages the state of a single partition
return self._cursor
diff --git a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py
index 9966959f0..75bf976dd 100644
--- a/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py
+++ b/airbyte_cdk/sources/streams/checkpoint/substream_resumable_full_refresh_cursor.py
@@ -1,7 +1,8 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
+from collections.abc import Mapping, MutableMapping
from dataclasses import dataclass
-from typing import Any, Mapping, MutableMapping, Optional
+from typing import Any
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.streams.checkpoint import Cursor
@@ -11,6 +12,7 @@
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException
+
FULL_REFRESH_COMPLETE_STATE: Mapping[str, Any] = {"__ab_full_refresh_sync_complete": True}
@@ -76,26 +78,26 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
"""
pass
- def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
+ def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None: # noqa: ANN401, ARG002
self._per_partition_state[self._to_partition_key(stream_slice.partition)] = {
"partition": stream_slice.partition,
"cursor": FULL_REFRESH_COMPLETE_STATE,
}
- def should_be_synced(self, record: Record) -> bool:
+ def should_be_synced(self, record: Record) -> bool: # noqa: ARG002
"""
Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages
that don't have filterable bounds. We should always return them.
"""
return True
- def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
+ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: # noqa: ARG002
"""
RFR record don't have ordering to be compared between one another.
"""
return False
- def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
+ def select_state(self, stream_slice: StreamSlice | None = None) -> StreamState | None:
if not stream_slice:
raise ValueError("A partition needs to be provided in order to extract a state")
diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py
index 26e6f09d4..4e5d321d8 100644
--- a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py
+++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py
@@ -3,7 +3,8 @@
#
from abc import ABC, abstractmethod
-from typing import Any, Iterable, Mapping, Optional
+from collections.abc import Iterable, Mapping
+from typing import Any
from typing_extensions import deprecated
@@ -58,7 +59,7 @@ def name(self) -> str:
@property
@abstractmethod
- def cursor_field(self) -> Optional[str]:
+ def cursor_field(self) -> str | None:
"""
Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field.
:return: The name of the field used as a cursor. Nested cursor fields are not supported.
diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py
index 18cacbc50..e67a6a461 100644
--- a/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py
+++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream_facade.py
@@ -1,14 +1,15 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from abc import ABC, abstractmethod
-from typing import Generic, Optional, TypeVar
+from typing import Generic, TypeVar
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
+
StreamType = TypeVar("StreamType")
-class AbstractStreamFacade(Generic[StreamType], ABC):
+class AbstractStreamFacade(Generic[StreamType], ABC): # noqa: PYI059
@abstractmethod
def get_underlying_stream(self) -> StreamType:
"""
@@ -21,7 +22,7 @@ def source_defined_cursor(self) -> bool:
# Streams must be aware of their cursor at instantiation time
return True
- def get_error_display_message(self, exception: BaseException) -> Optional[str]:
+ def get_error_display_message(self, exception: BaseException) -> str | None:
"""
Retrieves the user-friendly display message that corresponds to an exception.
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
@@ -33,5 +34,4 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]:
"""
if isinstance(exception, ExceptionWithDisplayMessage):
return exception.display_message
- else:
- return None
+ return None
diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py
index f304bfb21..135d29f86 100644
--- a/airbyte_cdk/sources/streams/concurrent/adapters.py
+++ b/airbyte_cdk/sources/streams/concurrent/adapters.py
@@ -5,8 +5,9 @@
import copy
import json
import logging
-from functools import lru_cache
-from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
+from collections.abc import Iterable, Mapping, MutableMapping
+from functools import cache
+from typing import Any, Optional
from typing_extensions import deprecated
@@ -45,6 +46,7 @@
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.utils.slice_hasher import SliceHasher
+
"""
This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream
"""
@@ -68,7 +70,7 @@ def create_from_stream(
stream: Stream,
source: AbstractSource,
logger: logging.Logger,
- state: Optional[MutableMapping[str, Any]],
+ state: MutableMapping[str, Any] | None,
cursor: Cursor,
) -> Stream:
"""
@@ -109,7 +111,7 @@ def create_from_stream(
),
stream,
cursor,
- slice_logger=source._slice_logger,
+ slice_logger=source._slice_logger, # noqa: SLF001
logger=logger,
)
@@ -124,7 +126,7 @@ def state(self, value: Mapping[str, Any]) -> None:
if "state" in dir(self._legacy_stream):
self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream: DefaultStream,
legacy_stream: Stream,
@@ -143,21 +145,21 @@ def __init__(
def read(
self,
- configured_stream: ConfiguredAirbyteStream,
- logger: logging.Logger,
- slice_logger: SliceLogger,
- stream_state: MutableMapping[str, Any],
- state_manager: ConnectorStateManager,
- internal_config: InternalConfig,
+ configured_stream: ConfiguredAirbyteStream, # noqa: ARG002
+ logger: logging.Logger, # noqa: ARG002
+ slice_logger: SliceLogger, # noqa: ARG002
+ stream_state: MutableMapping[str, Any], # noqa: ARG002
+ state_manager: ConnectorStateManager, # noqa: ARG002
+ internal_config: InternalConfig, # noqa: ARG002
) -> Iterable[StreamData]:
yield from self._read_records()
def read_records(
self,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Iterable[StreamData]:
try:
yield from self._read_records()
@@ -173,7 +175,7 @@ def read_records(
level=Level.ERROR, message=f"Cursor State at time of exception: {state}"
),
)
- raise exc
+ raise exc # noqa: TRY201
def _read_records(self) -> Iterable[StreamData]:
for partition in self._abstract_stream.generate_partitions():
@@ -187,22 +189,21 @@ def name(self) -> str:
return self._abstract_stream.name
@property
- def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
+ def primary_key(self) -> str | list[str] | list[list[str]] | None:
# This method is not expected to be called directly. It is only implemented for backward compatibility with the old interface
return self.as_airbyte_stream().source_defined_primary_key # type: ignore # source_defined_primary_key is known to be an Optional[List[List[str]]]
@property
- def cursor_field(self) -> Union[str, List[str]]:
+ def cursor_field(self) -> str | list[str]:
if self._abstract_stream.cursor_field is None:
return []
- else:
- return self._abstract_stream.cursor_field
+ return self._abstract_stream.cursor_field
@property
- def cursor(self) -> Optional[Cursor]: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor
+ def cursor(self) -> Cursor | None: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor
return self._cursor
- @lru_cache(maxsize=None)
+ @cache # noqa: B019
def get_json_schema(self) -> Mapping[str, Any]:
return self._abstract_stream.get_json_schema()
@@ -211,8 +212,10 @@ def supports_incremental(self) -> bool:
return self._legacy_stream.supports_incremental
def check_availability(
- self, logger: logging.Logger, source: Optional["Source"] = None
- ) -> Tuple[bool, Optional[str]]:
+ self,
+ logger: logging.Logger, # noqa: ARG002
+ source: Optional["Source"] = None, # noqa: ARG002
+ ) -> tuple[bool, str | None]:
"""
Verifies the stream is available. Delegates to the underlying AbstractStream and ignores the parameters
:param logger: (ignored)
@@ -233,7 +236,7 @@ def get_underlying_stream(self) -> DefaultStream:
class SliceEncoder(json.JSONEncoder):
- def default(self, obj: Any) -> Any:
+ def default(self, obj: Any) -> Any: # noqa: ANN401
if hasattr(obj, "__json_serializable__"):
return obj.__json_serializable__()
@@ -251,14 +254,14 @@ class StreamPartition(Partition):
In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time.
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream: Stream,
- _slice: Optional[Mapping[str, Any]],
+ _slice: Mapping[str, Any] | None,
message_repository: MessageRepository,
sync_mode: SyncMode,
- cursor_field: Optional[List[str]],
- state: Optional[MutableMapping[str, Any]],
+ cursor_field: list[str] | None,
+ state: MutableMapping[str, Any] | None,
):
"""
:param stream: The stream to delegate to
@@ -309,9 +312,9 @@ def read(self) -> Iterable[Record]:
if display_message:
raise ExceptionWithDisplayMessage(display_message) from e
else:
- raise e
+ raise e # noqa: TRY201
- def to_slice(self) -> Optional[Mapping[str, Any]]:
+ def to_slice(self) -> Mapping[str, Any] | None:
return self._slice
def __hash__(self) -> int:
@@ -332,13 +335,13 @@ class StreamPartitionGenerator(PartitionGenerator):
In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time.
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
stream: Stream,
message_repository: MessageRepository,
sync_mode: SyncMode,
- cursor_field: Optional[List[str]],
- state: Optional[MutableMapping[str, Any]],
+ cursor_field: list[str] | None,
+ state: MutableMapping[str, Any] | None,
):
"""
:param stream: The stream to delegate to
@@ -369,12 +372,15 @@ def generate(self) -> Iterable[Partition]:
category=ExperimentalClassWarning,
)
class AvailabilityStrategyFacade(AvailabilityStrategy):
- def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy):
+ def __init__(self, abstract_availability_strategy: AbstractAvailabilityStrategy): # noqa: ANN204
self._abstract_availability_strategy = abstract_availability_strategy
def check_availability(
- self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None
- ) -> Tuple[bool, Optional[str]]:
+ self,
+ stream: Stream, # noqa: ARG002
+ logger: logging.Logger,
+ source: Optional["Source"] = None, # noqa: ARG002
+ ) -> tuple[bool, str | None]:
"""
Checks stream availability.
diff --git a/airbyte_cdk/sources/streams/concurrent/availability_strategy.py b/airbyte_cdk/sources/streams/concurrent/availability_strategy.py
index 118a7d0bb..9722a1b27 100644
--- a/airbyte_cdk/sources/streams/concurrent/availability_strategy.py
+++ b/airbyte_cdk/sources/streams/concurrent/availability_strategy.py
@@ -4,7 +4,6 @@
import logging
from abc import ABC, abstractmethod
-from typing import Optional
from typing_extensions import deprecated
@@ -19,7 +18,7 @@ def is_available(self) -> bool:
"""
@abstractmethod
- def message(self) -> Optional[str]:
+ def message(self) -> str | None:
"""
:return: A message describing why the stream is not available. If the stream is available, this should return None.
"""
@@ -29,18 +28,18 @@ class StreamAvailable(StreamAvailability):
def is_available(self) -> bool:
return True
- def message(self) -> Optional[str]:
+ def message(self) -> str | None:
return None
class StreamUnavailable(StreamAvailability):
- def __init__(self, message: str):
+ def __init__(self, message: str): # noqa: ANN204
self._message = message
def is_available(self) -> bool:
return False
- def message(self) -> Optional[str]:
+ def message(self) -> str | None:
return self._message
@@ -84,7 +83,7 @@ class AlwaysAvailableAvailabilityStrategy(AbstractAvailabilityStrategy):
without disrupting existing functionality.
"""
- def check_availability(self, logger: logging.Logger) -> StreamAvailability:
+ def check_availability(self, logger: logging.Logger) -> StreamAvailability: # noqa: ARG002
"""
Checks stream availability.
diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py
index cbce82a94..3900388e2 100644
--- a/airbyte_cdk/sources/streams/concurrent/cursor.py
+++ b/airbyte_cdk/sources/streams/concurrent/cursor.py
@@ -5,17 +5,10 @@
import functools
import logging
from abc import ABC, abstractmethod
+from collections.abc import Callable, Iterable, Mapping, MutableMapping
from typing import (
Any,
- Callable,
- Iterable,
- List,
- Mapping,
- MutableMapping,
- Optional,
Protocol,
- Tuple,
- Union,
)
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
@@ -28,10 +21,11 @@
)
from airbyte_cdk.sources.types import Record, StreamSlice
+
LOGGER = logging.getLogger("airbyte")
-def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
+def _extract_value(mapping: Mapping[str, Any], path: list[str]) -> Any: # noqa: ANN401
return functools.reduce(lambda a, b: a[b], path, mapping)
@@ -86,14 +80,14 @@ def observe(self, record: Record) -> None:
"""
Indicate to the cursor that the record has been emitted
"""
- raise NotImplementedError()
+ raise NotImplementedError
@abstractmethod
def close_partition(self, partition: Partition) -> None:
"""
Indicate to the cursor that the partition has been successfully processed
"""
- raise NotImplementedError()
+ raise NotImplementedError
@abstractmethod
def ensure_at_least_one_state_emitted(self) -> None:
@@ -101,7 +95,7 @@ def ensure_at_least_one_state_emitted(self) -> None:
State messages are emitted when a partition is closed. However, the platform expects at least one state to be emitted per sync per
stream. Hence, if no partitions are generated, this method needs to be called.
"""
- raise NotImplementedError()
+ raise NotImplementedError
def stream_slices(self) -> Iterable[StreamSlice]:
"""
@@ -117,7 +111,7 @@ class FinalStateCursor(Cursor):
def __init__(
self,
stream_name: str,
- stream_namespace: Optional[str],
+ stream_namespace: str | None,
message_repository: MessageRepository,
) -> None:
self._stream_name = stream_name
@@ -157,21 +151,21 @@ class ConcurrentCursor(Cursor):
_START_BOUNDARY = 0
_END_BOUNDARY = 1
- def __init__(
+ def __init__( # noqa: PLR0913, PLR0917
self,
stream_name: str,
- stream_namespace: Optional[str],
- stream_state: Any,
+ stream_namespace: str | None,
+ stream_state: Any, # noqa: ANN401
message_repository: MessageRepository,
connector_state_manager: ConnectorStateManager,
connector_state_converter: AbstractStreamStateConverter,
cursor_field: CursorField,
- slice_boundary_fields: Optional[Tuple[str, str]],
- start: Optional[CursorValueType],
+ slice_boundary_fields: tuple[str, str] | None,
+ start: CursorValueType | None,
end_provider: Callable[[], CursorValueType],
- lookback_window: Optional[GapType] = None,
- slice_range: Optional[GapType] = None,
- cursor_granularity: Optional[GapType] = None,
+ lookback_window: GapType | None = None,
+ slice_range: GapType | None = None,
+ cursor_granularity: GapType | None = None,
) -> None:
self._stream_name = stream_name
self._stream_namespace = stream_namespace
@@ -187,7 +181,7 @@ def __init__(
self._lookback_window = lookback_window
self._slice_range = slice_range
self._most_recent_cursor_value_per_partition: MutableMapping[
- Union[StreamSlice, Mapping[str, Any], None], Any
+ StreamSlice | Mapping[str, Any] | None, Any
] = {}
self._has_closed_at_least_one_slice = False
self._cursor_granularity = cursor_granularity
@@ -203,9 +197,9 @@ def cursor_field(self) -> CursorField:
return self._cursor_field
@property
- def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]:
+ def _slice_boundary_fields_wrapper(self) -> tuple[str, str]:
return (
- self._slice_boundary_fields
+ self._slice_boundary_fields # noqa: FURB110
if self._slice_boundary_fields
else (
self._connector_state_converter.START_KEY,
@@ -215,7 +209,7 @@ def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]:
def _get_concurrent_state(
self, state: MutableMapping[str, Any]
- ) -> Tuple[CursorValueType, MutableMapping[str, Any]]:
+ ) -> tuple[CursorValueType, MutableMapping[str, Any]]:
if self._connector_state_converter.is_state_message_compatible(state):
return (
self._start or self._connector_state_converter.zero_value,
@@ -237,7 +231,7 @@ def observe(self, record: Record) -> None:
except ValueError:
self._log_for_record_without_cursor_value()
- def _extract_cursor_value(self, record: Record) -> Any:
+ def _extract_cursor_value(self, record: Record) -> Any: # noqa: ANN401
return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
def close_partition(self, partition: Partition) -> None:
@@ -260,17 +254,15 @@ def _add_slice_to_state(self, partition: Partition) -> None:
raise RuntimeError(
f"The state for stream {self._stream_name} should have at least one slice to delineate the sync start time, but no slices are present. This is unexpected. Please contact Support."
)
- self.state["slices"].append(
- {
- self._connector_state_converter.START_KEY: self._extract_from_slice(
- partition, self._slice_boundary_fields[self._START_BOUNDARY]
- ),
- self._connector_state_converter.END_KEY: self._extract_from_slice(
- partition, self._slice_boundary_fields[self._END_BOUNDARY]
- ),
- self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value,
- }
- )
+ self.state["slices"].append({
+ self._connector_state_converter.START_KEY: self._extract_from_slice(
+ partition, self._slice_boundary_fields[self._START_BOUNDARY]
+ ),
+ self._connector_state_converter.END_KEY: self._extract_from_slice(
+ partition, self._slice_boundary_fields[self._END_BOUNDARY]
+ ),
+ self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value,
+ })
elif most_recent_cursor_value:
if self._has_closed_at_least_one_slice:
# If we track state value using records cursor field, we can only do that if there is one partition. This is because we save
@@ -288,13 +280,11 @@ def _add_slice_to_state(self, partition: Partition) -> None:
"expected. Please contact the Airbyte team."
)
- self.state["slices"].append(
- {
- self._connector_state_converter.START_KEY: self.start,
- self._connector_state_converter.END_KEY: most_recent_cursor_value,
- self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value,
- }
- )
+ self.state["slices"].append({
+ self._connector_state_converter.START_KEY: self.start,
+ self._connector_state_converter.END_KEY: most_recent_cursor_value,
+ self._connector_state_converter.MOST_RECENT_RECORD_KEY: most_recent_cursor_value,
+ })
def _emit_state_message(self) -> None:
self._connector_state_manager.update_state_for_stream(
@@ -314,9 +304,9 @@ def _merge_partitions(self) -> None:
def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType:
try:
- _slice = partition.to_slice()
+ _slice = partition.to_slice() # noqa: RUF052
if not _slice:
- raise KeyError(f"Could not find key `{key}` in empty slice")
+ raise KeyError(f"Could not find key `{key}` in empty slice") # noqa: TRY301
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType
except KeyError as exception:
raise KeyError(
@@ -348,7 +338,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
yield from self._split_per_slice_range(
self._start,
self.state["slices"][0][self._connector_state_converter.START_KEY],
- False,
+ False, # noqa: FBT003
)
if len(self.state["slices"]) == 1:
@@ -357,7 +347,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
self.state["slices"][0][self._connector_state_converter.END_KEY]
),
self._end_provider(),
- True,
+ True, # noqa: FBT003
)
elif len(self.state["slices"]) > 1:
for i in range(len(self.state["slices"]) - 1):
@@ -366,20 +356,20 @@ def stream_slices(self) -> Iterable[StreamSlice]:
self.state["slices"][i][self._connector_state_converter.END_KEY]
+ self._cursor_granularity,
self.state["slices"][i + 1][self._connector_state_converter.START_KEY],
- False,
+ False, # noqa: FBT003
)
else:
yield from self._split_per_slice_range(
self.state["slices"][i][self._connector_state_converter.END_KEY],
self.state["slices"][i + 1][self._connector_state_converter.START_KEY],
- False,
+ False, # noqa: FBT003
)
yield from self._split_per_slice_range(
self._calculate_lower_boundary_of_last_slice(
self.state["slices"][-1][self._connector_state_converter.END_KEY]
),
self._end_provider(),
- True,
+ True, # noqa: FBT003
)
else:
raise ValueError("Expected at least one slice")
@@ -398,7 +388,10 @@ def _calculate_lower_boundary_of_last_slice(
return lower_boundary
def _split_per_slice_range(
- self, lower: CursorValueType, upper: CursorValueType, upper_is_end: bool
+ self,
+ lower: CursorValueType,
+ upper: CursorValueType,
+ upper_is_end: bool, # noqa: FBT001
) -> Iterable[StreamSlice]:
if lower >= upper:
return
diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py
index 7679a1eb6..79aa79404 100644
--- a/airbyte_cdk/sources/streams/concurrent/default_stream.py
+++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py
@@ -2,9 +2,10 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from functools import lru_cache
+from collections.abc import Iterable, Mapping
+from functools import cache
from logging import Logger
-from typing import Any, Iterable, List, Mapping, Optional
+from typing import Any
from airbyte_cdk.models import AirbyteStream, SyncMode
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
@@ -18,17 +19,17 @@
class DefaultStream(AbstractStream):
- def __init__(
+ def __init__( # noqa: PLR0913, PLR0917
self,
partition_generator: PartitionGenerator,
name: str,
json_schema: Mapping[str, Any],
availability_strategy: AbstractAvailabilityStrategy,
- primary_key: List[str],
- cursor_field: Optional[str],
+ primary_key: list[str],
+ cursor_field: str | None,
logger: Logger,
cursor: Cursor,
- namespace: Optional[str] = None,
+ namespace: str | None = None,
) -> None:
self._stream_partition_generator = partition_generator
self._name = name
@@ -48,17 +49,17 @@ def name(self) -> str:
return self._name
@property
- def namespace(self) -> Optional[str]:
+ def namespace(self) -> str | None:
return self._namespace
def check_availability(self) -> StreamAvailability:
return self._availability_strategy.check_availability(self._logger)
@property
- def cursor_field(self) -> Optional[str]:
+ def cursor_field(self) -> str | None:
return self._cursor_field
- @lru_cache(maxsize=None)
+ @cache # noqa: B019
def get_json_schema(self) -> Mapping[str, Any]:
return self._json_schema
diff --git a/airbyte_cdk/sources/streams/concurrent/exceptions.py b/airbyte_cdk/sources/streams/concurrent/exceptions.py
index a0cf699a4..691ba077b 100644
--- a/airbyte_cdk/sources/streams/concurrent/exceptions.py
+++ b/airbyte_cdk/sources/streams/concurrent/exceptions.py
@@ -10,7 +10,7 @@ class ExceptionWithDisplayMessage(Exception):
Exception that can be used to display a custom message to the user.
"""
- def __init__(self, display_message: str, **kwargs: Any):
+ def __init__(self, display_message: str, **kwargs: Any): # noqa: ANN204, ANN401
super().__init__(**kwargs)
self.display_message = display_message
diff --git a/airbyte_cdk/sources/streams/concurrent/helpers.py b/airbyte_cdk/sources/streams/concurrent/helpers.py
index 5e2edf055..8ea0e3201 100644
--- a/airbyte_cdk/sources/streams/concurrent/helpers.py
+++ b/airbyte_cdk/sources/streams/concurrent/helpers.py
@@ -1,18 +1,17 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
-from typing import List, Optional, Union
from airbyte_cdk.sources.streams import Stream
def get_primary_key_from_stream(
- stream_primary_key: Optional[Union[str, List[str], List[List[str]]]],
-) -> List[str]:
+ stream_primary_key: str | list[str] | list[list[str]] | None,
+) -> list[str]:
if stream_primary_key is None:
return []
- elif isinstance(stream_primary_key, str):
+ if isinstance(stream_primary_key, str):
return [stream_primary_key]
- elif isinstance(stream_primary_key, list):
+ if isinstance(stream_primary_key, list):
are_all_elements_str = all(isinstance(k, str) for k in stream_primary_key)
are_all_elements_list_of_size_one = all(
isinstance(k, list) and len(k) == 1 for k in stream_primary_key
@@ -20,23 +19,19 @@ def get_primary_key_from_stream(
if are_all_elements_str:
return stream_primary_key # type: ignore # We verified all items in the list are strings
- elif are_all_elements_list_of_size_one:
- return list(map(lambda x: x[0], stream_primary_key))
- else:
- raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}")
- else:
- raise ValueError(f"Invalid type for primary key: {stream_primary_key}")
+ if are_all_elements_list_of_size_one:
+ return list(map(lambda x: x[0], stream_primary_key)) # noqa: C417, FURB118
+ raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}")
+ raise ValueError(f"Invalid type for primary key: {stream_primary_key}")
-def get_cursor_field_from_stream(stream: Stream) -> Optional[str]:
+def get_cursor_field_from_stream(stream: Stream) -> str | None:
if isinstance(stream.cursor_field, list):
if len(stream.cursor_field) > 1:
raise ValueError(
f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}"
)
- elif len(stream.cursor_field) == 0:
+ if len(stream.cursor_field) == 0:
return None
- else:
- return stream.cursor_field[0]
- else:
- return stream.cursor_field
+ return stream.cursor_field[0]
+ return stream.cursor_field
diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py
index 8391a5a2b..b3f1cda50 100644
--- a/airbyte_cdk/sources/streams/concurrent/partitions/partition.py
+++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition.py
@@ -3,7 +3,8 @@
#
from abc import ABC, abstractmethod
-from typing import Any, Iterable, Mapping, Optional
+from collections.abc import Iterable, Mapping
+from typing import Any
from airbyte_cdk.sources.types import Record
@@ -22,7 +23,7 @@ def read(self) -> Iterable[Record]:
pass
@abstractmethod
- def to_slice(self) -> Optional[Mapping[str, Any]]:
+ def to_slice(self) -> Mapping[str, Any] | None:
"""
Converts the partition to a slice that can be serialized and deserialized.
diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py b/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py
index eff978564..f4bd77bd9 100644
--- a/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py
+++ b/airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py
@@ -3,7 +3,7 @@
#
from abc import ABC, abstractmethod
-from typing import Iterable
+from collections.abc import Iterable
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py b/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py
index 98ac04ed7..e10fada49 100644
--- a/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py
+++ b/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from abc import ABC, abstractmethod
-from typing import Iterable
+from collections.abc import Iterable
from airbyte_cdk.sources.types import StreamSlice
diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py
index 77644c6b9..c645f3b46 100644
--- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py
+++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py
@@ -1,8 +1,8 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, Union
+from typing import Union
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import (
PartitionGenerationCompletedSentinel,
@@ -11,20 +11,20 @@
from airbyte_cdk.sources.types import Record
-class PartitionCompleteSentinel:
+class PartitionCompleteSentinel: # noqa: PLW1641
"""
A sentinel object indicating all records for a partition were produced.
Includes a pointer to the partition that was processed.
"""
- def __init__(self, partition: Partition, is_successful: bool = True):
+ def __init__(self, partition: Partition, is_successful: bool = True): # noqa: ANN204, FBT001, FBT002
"""
:param partition: The partition that was processed
"""
self.partition = partition
self.is_successful = is_successful
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
if isinstance(other, PartitionCompleteSentinel):
return self.partition == other.partition
return False
@@ -33,6 +33,6 @@ def __eq__(self, other: Any) -> bool:
"""
Typedef representing the items that can be added to the ThreadBasedConcurrentStream
"""
-QueueItem = Union[
+QueueItem = Union[ # noqa: UP007
Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception
]
diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py
index 987915317..78787964b 100644
--- a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py
+++ b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py
@@ -3,8 +3,10 @@
#
from abc import ABC, abstractmethod
+from collections.abc import MutableMapping
from enum import Enum
-from typing import TYPE_CHECKING, Any, List, MutableMapping, Optional, Tuple
+from typing import TYPE_CHECKING, Any
+
if TYPE_CHECKING:
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
@@ -20,14 +22,14 @@ class AbstractStreamStateConverter(ABC):
MOST_RECENT_RECORD_KEY = "most_recent_cursor_value"
@abstractmethod
- def _from_state_message(self, value: Any) -> Any:
+ def _from_state_message(self, value: Any) -> Any: # noqa: ANN401
pass
@abstractmethod
- def _to_state_message(self, value: Any) -> Any:
+ def _to_state_message(self, value: Any) -> Any: # noqa: ANN401
pass
- def __init__(self, is_sequential_state: bool = True):
+ def __init__(self, is_sequential_state: bool = True): # noqa: ANN204, FBT001, FBT002
self._is_sequential_state = is_sequential_state
def convert_to_state_message(
@@ -43,14 +45,13 @@ def convert_to_state_message(
legacy_state = stream_state.get("legacy", {})
latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", []))
if latest_complete_time is not None:
- legacy_state.update(
- {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)}
- )
+ legacy_state.update({
+ cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)
+ })
return legacy_state or {}
- else:
- return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range)
+ return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range)
- def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Any:
+ def _get_latest_complete_time(self, slices: list[MutableMapping[str, Any]]) -> Any: # noqa: ANN401
"""
Get the latest time before which all records have been processed.
"""
@@ -102,8 +103,8 @@ def convert_from_sequential_state(
self,
cursor_field: "CursorField", # to deprecate as it is only needed for sequential state
stream_state: MutableMapping[str, Any],
- start: Optional[Any],
- ) -> Tuple[Any, MutableMapping[str, Any]]:
+ start: Any | None, # noqa: ANN401
+ ) -> tuple[Any, MutableMapping[str, Any]]:
"""
Convert the state message to the format required by the ConcurrentCursor.
@@ -118,22 +119,22 @@ def convert_from_sequential_state(
...
@abstractmethod
- def increment(self, value: Any) -> Any:
+ def increment(self, value: Any) -> Any: # noqa: ANN401
"""
Increment a timestamp by a single unit.
"""
...
@abstractmethod
- def output_format(self, value: Any) -> Any:
+ def output_format(self, value: Any) -> Any: # noqa: ANN401
"""
Convert the cursor value type to a JSON valid type.
"""
...
def merge_intervals(
- self, intervals: List[MutableMapping[str, Any]]
- ) -> List[MutableMapping[str, Any]]:
+ self, intervals: list[MutableMapping[str, Any]]
+ ) -> list[MutableMapping[str, Any]]:
"""
Compute and return a list of merged intervals.
@@ -144,7 +145,8 @@ def merge_intervals(
return []
sorted_intervals = sorted(
- intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY])
+ intervals,
+ key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY]), # noqa: FURB118
)
merged_intervals = [sorted_intervals[0]]
@@ -170,7 +172,7 @@ def merge_intervals(
return merged_intervals
@abstractmethod
- def parse_value(self, value: Any) -> Any:
+ def parse_value(self, value: Any) -> Any: # noqa: ANN401
"""
Parse the value of the cursor field into a comparable value.
"""
@@ -178,4 +180,4 @@ def parse_value(self, value: Any) -> Any:
@property
@abstractmethod
- def zero_value(self) -> Any: ...
+ def zero_value(self) -> Any: ... # noqa: ANN401
diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py
index 3f53a9234..d3a47d89f 100644
--- a/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py
+++ b/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py
@@ -3,13 +3,14 @@
#
from abc import abstractmethod
+from collections.abc import Callable, MutableMapping
from datetime import datetime, timedelta, timezone
-from typing import Any, Callable, List, MutableMapping, Optional, Tuple
+from typing import Any
import pendulum
from pendulum.datetime import DateTime
-# FIXME We would eventually like the Concurrent package do be agnostic of the declarative package. However, this is a breaking change and
+# FIXME We would eventually like the Concurrent package do be agnostic of the declarative package. However, this is a breaking change and # noqa: FIX001, TD001, TD004
# the goal in the short term is only to fix the issue we are seeing for source-declarative-manifest.
from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
@@ -20,15 +21,15 @@
class DateTimeStreamStateConverter(AbstractStreamStateConverter):
- def _from_state_message(self, value: Any) -> Any:
+ def _from_state_message(self, value: Any) -> Any: # noqa: ANN401
return self.parse_timestamp(value)
- def _to_state_message(self, value: Any) -> Any:
+ def _to_state_message(self, value: Any) -> Any: # noqa: ANN401
return self.output_format(value)
@property
@abstractmethod
- def _zero_value(self) -> Any: ...
+ def _zero_value(self) -> Any: ... # noqa: ANN401
@property
def zero_value(self) -> datetime:
@@ -42,26 +43,26 @@ def get_end_provider(cls) -> Callable[[], datetime]:
def increment(self, timestamp: datetime) -> datetime: ...
@abstractmethod
- def parse_timestamp(self, timestamp: Any) -> datetime: ...
+ def parse_timestamp(self, timestamp: Any) -> datetime: ... # noqa: ANN401
@abstractmethod
- def output_format(self, timestamp: datetime) -> Any: ...
+ def output_format(self, timestamp: datetime) -> Any: ... # noqa: ANN401
- def parse_value(self, value: Any) -> Any:
+ def parse_value(self, value: Any) -> Any: # noqa: ANN401
"""
Parse the value of the cursor field into a comparable value.
"""
return self.parse_timestamp(value)
- def _compare_intervals(self, end_time: Any, start_time: Any) -> bool:
+ def _compare_intervals(self, end_time: Any, start_time: Any) -> bool: # noqa: ANN401
return bool(self.increment(end_time) >= start_time)
def convert_from_sequential_state(
self,
cursor_field: CursorField,
stream_state: MutableMapping[str, Any],
- start: Optional[datetime],
- ) -> Tuple[datetime, MutableMapping[str, Any]]:
+ start: datetime | None,
+ ) -> tuple[datetime, MutableMapping[str, Any]]:
"""
Convert the state message to the format required by the ConcurrentCursor.
@@ -99,7 +100,7 @@ def _get_sync_start(
self,
cursor_field: CursorField,
stream_state: MutableMapping[str, Any],
- start: Optional[datetime],
+ start: datetime | None,
) -> datetime:
sync_start = start if start is not None else self.zero_value
prev_sync_low_water_mark = (
@@ -109,8 +110,7 @@ def _get_sync_start(
)
if prev_sync_low_water_mark and prev_sync_low_water_mark >= sync_start:
return prev_sync_low_water_mark
- else:
- return sync_start
+ return sync_start
class EpochValueConcurrentStreamStateConverter(DateTimeStreamStateConverter):
@@ -138,7 +138,7 @@ def output_format(self, timestamp: datetime) -> int:
def parse_timestamp(self, timestamp: int) -> datetime:
dt_object = pendulum.from_timestamp(timestamp)
if not isinstance(dt_object, DateTime):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})"
)
return dt_object
@@ -160,8 +160,11 @@ class IsoMillisConcurrentStreamStateConverter(DateTimeStreamStateConverter):
_zero_value = "0001-01-01T00:00:00.000Z"
- def __init__(
- self, is_sequential_state: bool = True, cursor_granularity: Optional[timedelta] = None
+ def __init__( # noqa: ANN204
+ self,
+ *,
+ is_sequential_state: bool = True,
+ cursor_granularity: timedelta | None = None,
):
super().__init__(is_sequential_state=is_sequential_state)
self._cursor_granularity = cursor_granularity or timedelta(milliseconds=1)
@@ -169,13 +172,13 @@ def __init__(
def increment(self, timestamp: datetime) -> datetime:
return timestamp + self._cursor_granularity
- def output_format(self, timestamp: datetime) -> Any:
+ def output_format(self, timestamp: datetime) -> Any: # noqa: ANN401
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def parse_timestamp(self, timestamp: str) -> datetime:
dt_object = pendulum.parse(timestamp)
if not isinstance(dt_object, DateTime):
- raise ValueError(
+ raise ValueError( # noqa: TRY004
f"DateTime object was expected but got {type(dt_object)} from pendulum.parse({timestamp})"
)
return dt_object
@@ -187,18 +190,18 @@ class CustomFormatConcurrentStreamStateConverter(IsoMillisConcurrentStreamStateC
incoming state in any valid datetime format via Pendulum.
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
datetime_format: str,
- input_datetime_formats: Optional[List[str]] = None,
- is_sequential_state: bool = True,
- cursor_granularity: Optional[timedelta] = None,
+ input_datetime_formats: list[str] | None = None,
+ is_sequential_state: bool = True, # noqa: FBT001, FBT002
+ cursor_granularity: timedelta | None = None,
):
super().__init__(
is_sequential_state=is_sequential_state, cursor_granularity=cursor_granularity
)
self._datetime_format = datetime_format
- self._input_datetime_formats = input_datetime_formats if input_datetime_formats else []
+ self._input_datetime_formats = input_datetime_formats if input_datetime_formats else [] # noqa: FURB110
self._input_datetime_formats += [self._datetime_format]
self._parser = DatetimeParser()
diff --git a/airbyte_cdk/sources/streams/core.py b/airbyte_cdk/sources/streams/core.py
index a9aa8550a..dc814a2a7 100644
--- a/airbyte_cdk/sources/streams/core.py
+++ b/airbyte_cdk/sources/streams/core.py
@@ -6,13 +6,13 @@
import itertools
import logging
from abc import ABC, abstractmethod
+from collections.abc import Iterable, Iterator, Mapping, MutableMapping
from dataclasses import dataclass
-from functools import cached_property, lru_cache
-from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Union
+from functools import cache, cached_property
+from typing import Any, Union
from typing_extensions import deprecated
-import airbyte_cdk.sources.utils.casing as casing
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStream,
@@ -32,16 +32,18 @@
ResumableFullRefreshCheckpointReader,
)
from airbyte_cdk.sources.types import StreamSlice
+from airbyte_cdk.sources.utils import casing
# list of all possible HTTP methods which can be used for sending of request bodies
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, ResourceSchemaLoader
from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
+
# A stream's read method can return one of the following types:
# Mapping[str, Any]: The content of an AirbyteRecordMessage
# AirbyteMessage: An AirbyteMessage. Could be of any type
-StreamData = Union[Mapping[str, Any], AirbyteMessage]
+StreamData = Union[Mapping[str, Any], AirbyteMessage] # noqa: UP007
JsonSchema = Mapping[str, Any]
@@ -53,8 +55,7 @@ def package_name_from_class(cls: object) -> str:
module = inspect.getmodule(cls)
if module is not None:
return module.__name__.split(".")[0]
- else:
- raise ValueError(f"Could not find package name for class {cls}")
+ raise ValueError(f"Could not find package name for class {cls}")
class CheckpointMixin(ABC):
@@ -115,12 +116,12 @@ class StreamClassification:
has_multiple_slices: bool
-class Stream(ABC):
+class Stream(ABC): # noqa: PLR0904
"""
Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol.
"""
- _configured_json_schema: Optional[Dict[str, Any]] = None
+ _configured_json_schema: dict[str, Any] | None = None
_exit_on_rate_limit: bool = False
# Use self.logger in subclasses to log any messages
@@ -131,7 +132,7 @@ def logger(self) -> logging.Logger:
# TypeTransformer object to perform output data transformation
transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform)
- cursor: Optional[Cursor] = None
+ cursor: Cursor | None = None
has_multiple_slices = False
@@ -142,7 +143,7 @@ def name(self) -> str:
"""
return casing.camel_to_snake(self.__class__.__name__)
- def get_error_display_message(self, exception: BaseException) -> Optional[str]:
+ def get_error_display_message(self, exception: BaseException) -> str | None: # noqa: ARG002
"""
Retrieves the user-friendly display message that corresponds to an exception.
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
@@ -160,7 +161,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
logger: logging.Logger,
slice_logger: SliceLogger,
stream_state: MutableMapping[str, Any],
- state_manager,
+ state_manager, # noqa: ANN001
internal_config: InternalConfig,
) -> Iterable[StreamData]:
sync_mode = configured_stream.sync_mode
@@ -171,7 +172,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
# opposed to the incoming stream_state value. Because some connectors like ones using the file-based CDK modify
# state before setting the value on the Stream attribute, the most up-to-date state is derived from Stream.state
# instead of the stream_state parameter. This does not apply to legacy connectors using get_updated_state().
- try:
+ try: # noqa: SIM105
stream_state = self.state # type: ignore # we know the field might not exist...
except AttributeError:
pass
@@ -188,7 +189,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
if slice_logger.should_log_slice_message(logger):
yield slice_logger.create_slice_log_message(next_slice)
records = self.read_records(
- sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior
+ sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior
stream_slice=next_slice,
stream_state=stream_state,
cursor_field=cursor_field or None,
@@ -252,7 +253,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
airbyte_state_message = self._checkpoint_state(checkpoint, state_manager=state_manager)
yield airbyte_state_message
- def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterable[StreamData]:
+ def read_only_records(self, state: Mapping[str, Any] | None = None) -> Iterable[StreamData]:
"""
Helper method that performs a read on a stream with an optional state and emits records. If the parent stream supports
incremental, this operation does not update the stream's internal state (if it uses the modern state setter/getter)
@@ -284,15 +285,15 @@ def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterab
def read_records(
self,
sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ cursor_field: list[str] | None = None,
+ stream_slice: Mapping[str, Any] | None = None,
+ stream_state: Mapping[str, Any] | None = None,
) -> Iterable[StreamData]:
"""
This method should be overridden by subclasses to read records based on the inputs
"""
- @lru_cache(maxsize=None)
+ @cache # noqa: B019
def get_json_schema(self) -> Mapping[str, Any]:
"""
:return: A dict of the JSON schema representing this stream.
@@ -300,7 +301,7 @@ def get_json_schema(self) -> Mapping[str, Any]:
The default implementation of this method looks for a JSONSchema file with the same name as this stream's "name" property.
Override as needed.
"""
- # TODO show an example of using pydantic to define the JSON schema, or reading an OpenAPI spec
+ # TODO show an example of using pydantic to define the JSON schema, or reading an OpenAPI spec # noqa: TD004
return ResourceSchemaLoader(package_name_from_class(self.__class__)).get_schema(self.name)
def as_airbyte_stream(self) -> AirbyteStream:
@@ -348,19 +349,18 @@ def is_resumable(self) -> bool:
# to structure stream state in a very specific way. We also can't check for issubclass(HttpSubStream) because
# not all substreams implement the interface and it would be a circular dependency so we use parent as a surrogate
return False
- elif hasattr(type(self), "state") and getattr(type(self), "state").fset is not None:
+ if hasattr(type(self), "state") and type(self).state.fset is not None:
# Modern case where a stream manages state using getter/setter
return True
- else:
- # Legacy case where the CDK manages state via the get_updated_state() method. This is determined by checking if
- # the stream's get_updated_state() differs from the Stream class and therefore has been overridden
- return type(self).get_updated_state != Stream.get_updated_state
+ # Legacy case where the CDK manages state via the get_updated_state() method. This is determined by checking if
+ # the stream's get_updated_state() differs from the Stream class and therefore has been overridden
+ return type(self).get_updated_state != Stream.get_updated_state
- def _wrapped_cursor_field(self) -> List[str]:
+ def _wrapped_cursor_field(self) -> list[str]:
return [self.cursor_field] if isinstance(self.cursor_field, str) else self.cursor_field
@property
- def cursor_field(self) -> Union[str, List[str]]:
+ def cursor_field(self) -> str | list[str]:
"""
Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field.
:return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor.
@@ -368,7 +368,7 @@ def cursor_field(self) -> Union[str, List[str]]:
return []
@property
- def namespace(self) -> Optional[str]:
+ def namespace(self) -> str | None:
"""
Override to return the namespace of this stream, e.g. the Postgres schema which this stream will emit records for.
:return: A string containing the name of the namespace.
@@ -394,7 +394,7 @@ def exit_on_rate_limit(self, value: bool) -> None:
@property
@abstractmethod
- def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
+ def primary_key(self) -> str | list[str] | list[list[str]] | None:
"""
:return: string if single primary key, list of strings if composite primary key, list of list of strings if composite primary key consisting of nested fields.
If the stream has no primary keys, return None.
@@ -403,10 +403,10 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
def stream_slices(
self,
*,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
- ) -> Iterable[Optional[Mapping[str, Any]]]:
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_state: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Iterable[Mapping[str, Any] | None]:
"""
Override to define the slices for this stream. See the stream slicing section of the docs for more information.
@@ -418,7 +418,7 @@ def stream_slices(
yield StreamSlice(partition={}, cursor_slice={})
@property
- def state_checkpoint_interval(self) -> Optional[int]:
+ def state_checkpoint_interval(self) -> int | None:
"""
Decides how often to checkpoint state (i.e: emit a STATE message). E.g: if this returns a value of 100, then state is persisted after reading
100 records, then 200, 300, etc.. A good default value is 1000 although your mileage may vary depending on the underlying data source.
@@ -438,7 +438,9 @@ def state_checkpoint_interval(self) -> Optional[int]:
# "Please use explicit state property instead, see `IncrementalMixin` docs."
# )
def get_updated_state(
- self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]
+ self,
+ current_stream_state: MutableMapping[str, Any], # noqa: ARG002
+ latest_record: Mapping[str, Any], # noqa: ARG002
) -> MutableMapping[str, Any]:
"""DEPRECATED. Please use explicit state property instead, see `IncrementalMixin` docs.
@@ -455,7 +457,7 @@ def get_updated_state(
"""
return {}
- def get_cursor(self) -> Optional[Cursor]:
+ def get_cursor(self) -> Cursor | None:
"""
A Cursor is an interface that a stream can implement to manage how its internal state is read and updated while
reading records. Historically, Python connectors had no concept of a cursor to manage state. Python streams need
@@ -465,14 +467,14 @@ def get_cursor(self) -> Optional[Cursor]:
def _get_checkpoint_reader(
self,
- logger: logging.Logger,
- cursor_field: Optional[List[str]],
+ logger: logging.Logger, # noqa: ARG002
+ cursor_field: list[str] | None,
sync_mode: SyncMode,
stream_state: MutableMapping[str, Any],
) -> CheckpointReader:
mappings_or_slices = self.stream_slices(
cursor_field=cursor_field,
- sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior
+ sync_mode=sync_mode, # TODO: change this interface to no longer rely on sync_mode for behavior
stream_state=stream_state,
)
@@ -505,35 +507,33 @@ def _get_checkpoint_reader(
return LegacyCursorBasedCheckpointReader(
stream_slices=slices_iterable_copy, cursor=cursor, read_state_from_cursor=True
)
- elif cursor:
+ if cursor:
return CursorBasedCheckpointReader(
stream_slices=slices_iterable_copy,
cursor=cursor,
read_state_from_cursor=checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH,
)
- elif checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH:
+ if checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH:
# Resumable full refresh readers rely on the stream state dynamically being updated during pagination and does
# not iterate over a static set of slices.
return ResumableFullRefreshCheckpointReader(stream_state=stream_state)
- elif checkpoint_mode == CheckpointMode.INCREMENTAL:
+ if checkpoint_mode == CheckpointMode.INCREMENTAL:
return IncrementalCheckpointReader(
stream_slices=slices_iterable_copy, stream_state=stream_state
)
- else:
- return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy)
+ return FullRefreshCheckpointReader(stream_slices=slices_iterable_copy)
@property
def _checkpoint_mode(self) -> CheckpointMode:
if self.is_resumable and len(self._wrapped_cursor_field()) > 0:
return CheckpointMode.INCREMENTAL
- elif self.is_resumable:
+ if self.is_resumable:
return CheckpointMode.RESUMABLE_FULL_REFRESH
- else:
- return CheckpointMode.FULL_REFRESH
+ return CheckpointMode.FULL_REFRESH
@staticmethod
def _classify_stream(
- mappings_or_slices: Iterator[Optional[Union[Mapping[str, Any], StreamSlice]]],
+ mappings_or_slices: Iterator[Mapping[str, Any] | StreamSlice | None],
) -> StreamClassification:
"""
This is a bit of a crazy solution, but also the only way we can detect certain attributes about the stream since Python
@@ -601,8 +601,8 @@ def log_stream_sync_configuration(self) -> None:
@staticmethod
def _wrapped_primary_key(
- keys: Optional[Union[str, List[str], List[List[str]]]],
- ) -> Optional[List[List[str]]]:
+ keys: str | list[str] | list[list[str]] | None,
+ ) -> list[list[str]] | None:
"""
:return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object.
"""
@@ -611,7 +611,7 @@ def _wrapped_primary_key(
if isinstance(keys, str):
return [[keys]]
- elif isinstance(keys, list):
+ if isinstance(keys, list):
wrapped_keys = []
for component in keys:
if isinstance(component, str):
@@ -619,13 +619,12 @@ def _wrapped_primary_key(
elif isinstance(component, list):
wrapped_keys.append(component)
else:
- raise ValueError(f"Element must be either list or str. Got: {type(component)}")
+ raise ValueError(f"Element must be either list or str. Got: {type(component)}") # noqa: TRY004
return wrapped_keys
- else:
- raise ValueError(f"Element must be either list or str. Got: {type(keys)}")
+ raise ValueError(f"Element must be either list or str. Got: {type(keys)}")
def _observe_state(
- self, checkpoint_reader: CheckpointReader, stream_state: Optional[Mapping[str, Any]] = None
+ self, checkpoint_reader: CheckpointReader, stream_state: Mapping[str, Any] | None = None
) -> None:
"""
Convenience method that attempts to read the Stream's state using the recommended way of connector's managing their
@@ -652,15 +651,15 @@ def _observe_state(
def _checkpoint_state( # type: ignore # ignoring typing for ConnectorStateManager because of circular dependencies
self,
stream_state: Mapping[str, Any],
- state_manager,
+ state_manager, # noqa: ANN001
) -> AirbyteMessage:
- # todo: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want
+ # TODO: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want
# to reduce changes right now and this would span concurrent as well
state_manager.update_state_for_stream(self.name, self.namespace, stream_state)
return state_manager.create_state_message(self.name, self.namespace) # type: ignore [no-any-return]
@property
- def configured_json_schema(self) -> Optional[Dict[str, Any]]:
+ def configured_json_schema(self) -> dict[str, Any] | None:
"""
This property is set from the read method.
@@ -669,12 +668,12 @@ def configured_json_schema(self) -> Optional[Dict[str, Any]]:
return self._configured_json_schema
@configured_json_schema.setter
- def configured_json_schema(self, json_schema: Dict[str, Any]) -> None:
+ def configured_json_schema(self, json_schema: dict[str, Any]) -> None:
self._configured_json_schema = self._filter_schema_invalid_properties(json_schema)
def _filter_schema_invalid_properties(
- self, configured_catalog_json_schema: Dict[str, Any]
- ) -> Dict[str, Any]:
+ self, configured_catalog_json_schema: dict[str, Any]
+ ) -> dict[str, Any]:
"""
Filters the properties in json_schema that are not present in the stream schema.
Configured Schemas can have very old fields, so we need to housekeeping ourselves.
diff --git a/airbyte_cdk/sources/streams/http/__init__.py b/airbyte_cdk/sources/streams/http/__init__.py
index 74804614c..70342f5d0 100644
--- a/airbyte_cdk/sources/streams/http/__init__.py
+++ b/airbyte_cdk/sources/streams/http/__init__.py
@@ -7,4 +7,5 @@
from .http import HttpStream, HttpSubStream
from .http_client import HttpClient
+
__all__ = ["HttpClient", "HttpStream", "HttpSubStream", "UserDefinedBackoffException"]
diff --git a/airbyte_cdk/sources/streams/http/availability_strategy.py b/airbyte_cdk/sources/streams/http/availability_strategy.py
index 494fcf151..8adcc89e1 100644
--- a/airbyte_cdk/sources/streams/http/availability_strategy.py
+++ b/airbyte_cdk/sources/streams/http/availability_strategy.py
@@ -4,20 +4,24 @@
import logging
import typing
-from typing import Optional, Tuple
+from typing import Optional
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
if typing.TYPE_CHECKING:
from airbyte_cdk.sources import Source
class HttpAvailabilityStrategy(AvailabilityStrategy):
def check_availability(
- self, stream: Stream, logger: logging.Logger, source: Optional["Source"] = None
- ) -> Tuple[bool, Optional[str]]:
+ self,
+ stream: Stream,
+ logger: logging.Logger,
+ source: Optional["Source"] = None, # noqa: ARG002
+ ) -> tuple[bool, str | None]:
"""
Check stream availability by attempting to read the first record of the
stream.
@@ -30,7 +34,7 @@ def check_availability(
for some reason and the str should describe what went wrong and how to
resolve the unavailability, if possible.
"""
- reason: Optional[str]
+ reason: str | None
try:
# Some streams need a stream slice to read records (e.g. if they have a SubstreamPartitionRouter)
# Streams that don't need a stream slice will return `None` as their first stream slice.
@@ -46,7 +50,7 @@ def check_availability(
try:
self.get_first_record_for_slice(stream, stream_slice)
- return True, None
+ return True, None # noqa: TRY300
except StopIteration:
logger.info(f"Successfully connected to stream {stream.name}, but got 0 records.")
return True, None
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/__init__.py b/airbyte_cdk/sources/streams/http/error_handlers/__init__.py
index 1f97d5cc7..9e65db1d6 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/__init__.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/__init__.py
@@ -10,6 +10,7 @@
from .json_error_message_parser import JsonErrorMessageParser
from .response_models import ErrorResolution, ResponseAction
+
__all__ = [
"BackoffStrategy",
"DefaultBackoffStrategy",
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py b/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py
index 6ed821791..c30992627 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/backoff_strategy.py
@@ -3,7 +3,6 @@
#
from abc import ABC, abstractmethod
-from typing import Optional, Union
import requests
@@ -12,9 +11,9 @@ class BackoffStrategy(ABC):
@abstractmethod
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
+ response_or_exception: requests.Response | requests.RequestException | None,
attempt_count: int,
- ) -> Optional[float]:
+ ) -> float | None:
"""
Override this method to dynamically determine backoff time e.g: by reading the X-Retry-After header.
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py b/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py
index 2c3e10ad7..d8ae1f3ee 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/default_backoff_strategy.py
@@ -1,8 +1,6 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
-from typing import Optional, Union
-
import requests
from .backoff_strategy import BackoffStrategy
@@ -11,7 +9,7 @@
class DefaultBackoffStrategy(BackoffStrategy):
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
- attempt_count: int,
- ) -> Optional[float]:
+ response_or_exception: requests.Response | requests.RequestException | None, # noqa: ARG002
+ attempt_count: int, # noqa: ARG002
+ ) -> float | None:
return None
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py
index 74840e2d2..7720f3eff 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/default_error_mapping.py
@@ -2,7 +2,7 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
-from typing import Mapping, Type, Union
+from collections.abc import Mapping
from requests.exceptions import InvalidSchema, InvalidURL, RequestException
@@ -12,7 +12,8 @@
ResponseAction,
)
-DEFAULT_ERROR_MAPPING: Mapping[Union[int, str, Type[Exception]], ErrorResolution] = {
+
+DEFAULT_ERROR_MAPPING: Mapping[int | str | type[Exception], ErrorResolution] = {
InvalidSchema: ErrorResolution(
response_action=ResponseAction.FAIL,
failure_type=FailureType.config_error,
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py
index b231e72e0..7af046202 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/error_handler.py
@@ -1,7 +1,6 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from abc import ABC, abstractmethod
-from typing import Optional, Union
import requests
@@ -15,7 +14,7 @@ class ErrorHandler(ABC):
@property
@abstractmethod
- def max_retries(self) -> Optional[int]:
+ def max_retries(self) -> int | None:
"""
The maximum number of retries to attempt before giving up.
"""
@@ -23,16 +22,14 @@ def max_retries(self) -> Optional[int]:
@property
@abstractmethod
- def max_time(self) -> Optional[int]:
+ def max_time(self) -> int | None:
"""
The maximum amount of time in seconds to retry before giving up.
"""
pass
@abstractmethod
- def interpret_response(
- self, response: Optional[Union[requests.Response, Exception]]
- ) -> ErrorResolution:
+ def interpret_response(self, response: requests.Response | Exception | None) -> ErrorResolution:
"""
Interpret the response or exception and return the corresponding response action, failure type, and error message.
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py b/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py
index 966fe93a1..d5d413f3d 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/error_message_parser.py
@@ -3,14 +3,13 @@
#
from abc import ABC, abstractmethod
-from typing import Optional
import requests
class ErrorMessageParser(ABC):
@abstractmethod
- def parse_response_error_message(self, response: requests.Response) -> Optional[str]:
+ def parse_response_error_message(self, response: requests.Response) -> str | None:
"""
Parse error message from response.
:param response: response received for the request
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py
index 18daca3de..9bd47df66 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/http_status_error_handler.py
@@ -3,8 +3,8 @@
#
import logging
+from collections.abc import Mapping
from datetime import timedelta
-from typing import Mapping, Optional, Union
import requests
@@ -23,7 +23,7 @@ class HttpStatusErrorHandler(ErrorHandler):
def __init__(
self,
logger: logging.Logger,
- error_mapping: Optional[Mapping[Union[int, str, type[Exception]], ErrorResolution]] = None,
+ error_mapping: Mapping[int | str | type[Exception], ErrorResolution] | None = None,
max_retries: int = 5,
max_time: timedelta = timedelta(seconds=600),
) -> None:
@@ -38,15 +38,15 @@ def __init__(
self._max_time = int(max_time.total_seconds())
@property
- def max_retries(self) -> Optional[int]:
+ def max_retries(self) -> int | None:
return self._max_retries
@property
- def max_time(self) -> Optional[int]:
+ def max_time(self) -> int | None:
return self._max_time
- def interpret_response(
- self, response_or_exception: Optional[Union[requests.Response, Exception]] = None
+ def interpret_response( # noqa: PLR0911
+ self, response_or_exception: requests.Response | Exception | None = None
) -> ErrorResolution:
"""
Interpret the response and return the corresponding response action, failure type, and error message.
@@ -56,23 +56,20 @@ def interpret_response(
"""
if isinstance(response_or_exception, Exception):
- mapped_error: Optional[ErrorResolution] = self._error_mapping.get(
+ mapped_error: ErrorResolution | None = self._error_mapping.get(
response_or_exception.__class__
)
if mapped_error is not None:
return mapped_error
- else:
- self._logger.error(
- f"Unexpected exception in error handler: {response_or_exception}"
- )
- return ErrorResolution(
- response_action=ResponseAction.RETRY,
- failure_type=FailureType.system_error,
- error_message=f"Unexpected exception in error handler: {response_or_exception}",
- )
+ self._logger.error(f"Unexpected exception in error handler: {response_or_exception}")
+ return ErrorResolution(
+ response_action=ResponseAction.RETRY,
+ failure_type=FailureType.system_error,
+ error_message=f"Unexpected exception in error handler: {response_or_exception}",
+ )
- elif isinstance(response_or_exception, requests.Response):
+ if isinstance(response_or_exception, requests.Response):
if response_or_exception.status_code is None:
self._logger.error("Response does not include an HTTP status code.")
return ErrorResolution(
@@ -94,17 +91,15 @@ def interpret_response(
if mapped_error is not None:
return mapped_error
- else:
- self._logger.warning(f"Unexpected HTTP Status Code in error handler: '{error_key}'")
- return ErrorResolution(
- response_action=ResponseAction.RETRY,
- failure_type=FailureType.system_error,
- error_message=f"Unexpected HTTP Status Code in error handler: {error_key}",
- )
- else:
- self._logger.error(f"Received unexpected response type: {type(response_or_exception)}")
+ self._logger.warning(f"Unexpected HTTP Status Code in error handler: '{error_key}'")
return ErrorResolution(
- response_action=ResponseAction.FAIL,
+ response_action=ResponseAction.RETRY,
failure_type=FailureType.system_error,
- error_message=f"Received unexpected response type: {type(response_or_exception)}",
+ error_message=f"Unexpected HTTP Status Code in error handler: {error_key}",
)
+ self._logger.error(f"Received unexpected response type: {type(response_or_exception)}")
+ return ErrorResolution(
+ response_action=ResponseAction.FAIL,
+ failure_type=FailureType.system_error,
+ error_message=f"Received unexpected response type: {type(response_or_exception)}",
+ )
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py b/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py
index 7c58280c7..d41e3e70d 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/json_error_message_parser.py
@@ -2,7 +2,6 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Optional
import requests
@@ -11,13 +10,13 @@
class JsonErrorMessageParser(ErrorMessageParser):
- def _try_get_error(self, value: Optional[JsonType]) -> Optional[str]:
+ def _try_get_error(self, value: JsonType | None) -> str | None:
if isinstance(value, str):
return value
- elif isinstance(value, list):
+ if isinstance(value, list):
errors_in_value = [self._try_get_error(v) for v in value]
return ", ".join(v for v in errors_in_value if v is not None)
- elif isinstance(value, dict):
+ if isinstance(value, dict):
new_value = (
value.get("message")
or value.get("messages")
@@ -35,7 +34,7 @@ def _try_get_error(self, value: Optional[JsonType]) -> Optional[str]:
return self._try_get_error(new_value)
return None
- def parse_response_error_message(self, response: requests.Response) -> Optional[str]:
+ def parse_response_error_message(self, response: requests.Response) -> str | None:
"""
Parses the raw response object from a failed request into a user-friendly error message.
diff --git a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py
index e882b89bd..8a06c230a 100644
--- a/airbyte_cdk/sources/streams/http/error_handlers/response_models.py
+++ b/airbyte_cdk/sources/streams/http/error_handlers/response_models.py
@@ -2,7 +2,6 @@
from dataclasses import dataclass
from enum import Enum
-from typing import Optional, Union
import requests
from requests import HTTPError
@@ -21,13 +20,13 @@ class ResponseAction(Enum):
@dataclass
class ErrorResolution:
- response_action: Optional[ResponseAction] = None
- failure_type: Optional[FailureType] = None
- error_message: Optional[str] = None
+ response_action: ResponseAction | None = None
+ failure_type: FailureType | None = None
+ error_message: str | None = None
def _format_exception_error_message(exception: Exception) -> str:
- return f"{type(exception).__name__}: {str(exception)}"
+ return f"{type(exception).__name__}: {exception!s}"
def _format_response_error_message(response: requests.Response) -> str:
@@ -35,7 +34,7 @@ def _format_response_error_message(response: requests.Response) -> str:
response.raise_for_status()
except HTTPError as exception:
return filter_secrets(
- f"Response was not ok: `{str(exception)}`. Response content is: {response.text}"
+ f"Response was not ok: `{exception!s}`. Response content is: {response.text}"
)
# We purposefully do not add the response.content because the response is "ok" so there might be sensitive information in the payload.
# Feel free the
@@ -43,7 +42,7 @@ def _format_response_error_message(response: requests.Response) -> str:
def create_fallback_error_resolution(
- response_or_exception: Optional[Union[requests.Response, Exception]],
+ response_or_exception: requests.Response | Exception | None,
) -> ErrorResolution:
if response_or_exception is None:
# We do not expect this case to happen but if it does, it would be good to understand the cause and improve the error message
diff --git a/airbyte_cdk/sources/streams/http/exceptions.py b/airbyte_cdk/sources/streams/http/exceptions.py
index ee4687626..a3cf57689 100644
--- a/airbyte_cdk/sources/streams/http/exceptions.py
+++ b/airbyte_cdk/sources/streams/http/exceptions.py
@@ -3,16 +3,14 @@
#
-from typing import Optional, Union
-
import requests
class BaseBackoffException(requests.exceptions.HTTPError):
- def __init__(
+ def __init__( # noqa: ANN204
self,
request: requests.PreparedRequest,
- response: Optional[Union[requests.Response, Exception]],
+ response: requests.Response | Exception | None,
error_message: str = "",
):
if isinstance(response, requests.Response):
@@ -37,11 +35,11 @@ class UserDefinedBackoffException(BaseBackoffException):
An exception that exposes how long it attempted to backoff
"""
- def __init__(
+ def __init__( # noqa: ANN204
self,
- backoff: Union[int, float],
+ backoff: int | float, # noqa: PYI041
request: requests.PreparedRequest,
- response: Optional[Union[requests.Response, Exception]],
+ response: requests.Response | Exception | None,
error_message: str = "",
):
"""
diff --git a/airbyte_cdk/sources/streams/http/http.py b/airbyte_cdk/sources/streams/http/http.py
index 40eab27a3..a3522e26f 100644
--- a/airbyte_cdk/sources/streams/http/http.py
+++ b/airbyte_cdk/sources/streams/http/http.py
@@ -1,11 +1,12 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from abc import ABC, abstractmethod
+from collections.abc import Callable, Iterable, Mapping, MutableMapping
from datetime import timedelta
-from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import Any
from urllib.parse import urljoin
import requests
@@ -37,22 +38,23 @@
from airbyte_cdk.sources.types import Record, StreamSlice
from airbyte_cdk.sources.utils.types import JsonType
+
# list of all possible HTTP methods which can be used for sending of request bodies
BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH")
-class HttpStream(Stream, CheckpointMixin, ABC):
+class HttpStream(Stream, CheckpointMixin, ABC): # noqa: PLR0904
"""
Base abstract class for an Airbyte Stream using the HTTP protocol. Basic building block for users building an Airbyte source for a HTTP API.
"""
source_defined_cursor = True # Most HTTP streams use a source defined cursor (i.e: the user can't configure it like on a SQL table)
- page_size: Optional[int] = (
+ page_size: int | None = (
None # Use this variable to define page size for API http requests with pagination support
)
- def __init__(
- self, authenticator: Optional[AuthBase] = None, api_budget: Optional[APIBudget] = None
+ def __init__( # noqa: ANN204
+ self, authenticator: AuthBase | None = None, api_budget: APIBudget | None = None
):
self._exit_on_rate_limit: bool = False
self._http_client = HttpClient(
@@ -135,7 +137,7 @@ def raise_on_http_errors(self) -> bool:
"Deprecated as of CDK version 3.0.0. "
"You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead."
)
- def max_retries(self) -> Union[int, None]:
+ def max_retries(self) -> int | None:
"""
Override if needed. Specifies maximum amount of retries for backoff policy. Return None for no limit.
"""
@@ -146,7 +148,7 @@ def max_retries(self) -> Union[int, None]:
"Deprecated as of CDK version 3.0.0. "
"You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead."
)
- def max_time(self) -> Union[int, None]:
+ def max_time(self) -> int | None:
"""
Override if needed. Specifies maximum total waiting time (in seconds) for backoff policy. Return None for no limit.
"""
@@ -164,7 +166,7 @@ def retry_factor(self) -> float:
return 5
@abstractmethod
- def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
+ def next_page_token(self, response: requests.Response) -> Mapping[str, Any] | None:
"""
Override this method to define a pagination strategy.
@@ -177,9 +179,9 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str,
def path(
self,
*,
- stream_state: Optional[Mapping[str, Any]] = None,
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: Mapping[str, Any] | None = None,
+ stream_slice: Mapping[str, Any] | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> str:
"""
Returns the URL path for the API endpoint e.g: if you wanted to hit https://myapi.com/v1/some_entity then this should return "some_entity"
@@ -187,9 +189,9 @@ def path(
def request_params(
self,
- stream_state: Optional[Mapping[str, Any]],
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: Mapping[str, Any] | None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> MutableMapping[str, Any]:
"""
Override this method to define the query parameters that should be set on an outgoing HTTP request given the inputs.
@@ -200,9 +202,9 @@ def request_params(
def request_headers(
self,
- stream_state: Optional[Mapping[str, Any]],
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: Mapping[str, Any] | None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
"""
Override to return any non-auth headers. Authentication headers will overwrite any overlapping headers returned from this method.
@@ -211,10 +213,10 @@ def request_headers(
def request_body_data(
self,
- stream_state: Optional[Mapping[str, Any]],
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Optional[Union[Mapping[str, Any], str]]:
+ stream_state: Mapping[str, Any] | None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Mapping[str, Any] | str | None:
"""
Override when creating POST/PUT/PATCH requests to populate the body of the request with a non-JSON payload.
@@ -228,10 +230,10 @@ def request_body_data(
def request_body_json(
self,
- stream_state: Optional[Mapping[str, Any]],
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Optional[Mapping[str, Any]]:
+ stream_state: Mapping[str, Any] | None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
+ ) -> Mapping[str, Any] | None:
"""
Override when creating POST/PUT/PATCH requests to populate the body of the request with a JSON payload.
@@ -241,9 +243,9 @@ def request_body_json(
def request_kwargs(
self,
- stream_state: Optional[Mapping[str, Any]],
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_state: Mapping[str, Any] | None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None, # noqa: ARG002
+ next_page_token: Mapping[str, Any] | None = None, # noqa: ARG002
) -> Mapping[str, Any]:
"""
Override to return a mapping of keyword arguments to be used when creating the HTTP request.
@@ -258,8 +260,8 @@ def parse_response(
response: requests.Response,
*,
stream_state: Mapping[str, Any],
- stream_slice: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
+ stream_slice: Mapping[str, Any] | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
) -> Iterable[Mapping[str, Any]]:
"""
Parses the raw response object into a list of records.
@@ -271,7 +273,7 @@ def parse_response(
:return: An iterable containing the parsed response
"""
- def get_backoff_strategy(self) -> Optional[Union[BackoffStrategy, List[BackoffStrategy]]]:
+ def get_backoff_strategy(self) -> BackoffStrategy | list[BackoffStrategy] | None:
"""
Used to initialize Adapter to avoid breaking changes.
If Stream has a `backoff_time` method implementation, we know this stream uses old (pre-HTTPClient) backoff handlers and thus an adapter is needed.
@@ -281,10 +283,9 @@ def get_backoff_strategy(self) -> Optional[Union[BackoffStrategy, List[BackoffSt
"""
if hasattr(self, "backoff_time"):
return HttpStreamAdapterBackoffStrategy(self)
- else:
- return None
+ return None
- def get_error_handler(self) -> Optional[ErrorHandler]:
+ def get_error_handler(self) -> ErrorHandler | None:
"""
Used to initialize Adapter to avoid breaking changes.
If Stream has a `should_retry` method implementation, we know this stream uses old (pre-HTTPClient) error handlers and thus an adapter is needed.
@@ -300,15 +301,14 @@ def get_error_handler(self) -> Optional[ErrorHandler]:
max_time=timedelta(seconds=self.max_time or 0),
)
return error_handler
- else:
- return None
+ return None
@classmethod
def _join_url(cls, url_base: str, path: str) -> str:
return urljoin(url_base, path)
@classmethod
- def parse_response_error_message(cls, response: requests.Response) -> Optional[str]:
+ def parse_response_error_message(cls, response: requests.Response) -> str | None:
"""
Parses the raw response object from a failed request into a user-friendly error message.
By default, this method tries to grab the error message from JSON responses by following common API patterns. Override to parse differently.
@@ -318,13 +318,13 @@ def parse_response_error_message(cls, response: requests.Response) -> Optional[s
"""
# default logic to grab error from common fields
- def _try_get_error(value: Optional[JsonType]) -> Optional[str]:
+ def _try_get_error(value: JsonType | None) -> str | None:
if isinstance(value, str):
return value
- elif isinstance(value, list):
+ if isinstance(value, list):
errors_in_value = [_try_get_error(v) for v in value]
return ", ".join(v for v in errors_in_value if v is not None)
- elif isinstance(value, dict):
+ if isinstance(value, dict):
new_value = (
value.get("message")
or value.get("messages")
@@ -343,7 +343,7 @@ def _try_get_error(value: Optional[JsonType]) -> Optional[str]:
except requests.exceptions.JSONDecodeError:
return None
- def get_error_display_message(self, exception: BaseException) -> Optional[str]:
+ def get_error_display_message(self, exception: BaseException) -> str | None:
"""
Retrieves the user-friendly display message that corresponds to an exception.
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
@@ -360,15 +360,15 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]:
def read_records(
self,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_slice: Mapping[str, Any] | None = None,
+ stream_state: Mapping[str, Any] | None = None,
) -> Iterable[StreamData]:
# A cursor_field indicates this is an incremental stream which offers better checkpointing than RFR enabled via the cursor
if self.cursor_field or not isinstance(self.get_cursor(), ResumableFullRefreshCursor):
yield from self._read_pages(
- lambda req, res, state, _slice: self.parse_response(
+ lambda req, res, state, _slice: self.parse_response( # noqa: ARG005
res, stream_slice=_slice, stream_state=state
),
stream_slice,
@@ -376,7 +376,7 @@ def read_records(
)
else:
yield from self._read_single_page(
- lambda req, res, state, _slice: self.parse_response(
+ lambda req, res, state, _slice: self.parse_response( # noqa: ARG005
res, stream_slice=_slice, stream_state=state
),
stream_slice,
@@ -397,7 +397,7 @@ def state(self, value: MutableMapping[str, Any]) -> None:
cursor.set_initial_state(value)
self._state = value
- def get_cursor(self) -> Optional[Cursor]:
+ def get_cursor(self) -> Cursor | None:
# I don't love that this is semi-stateful but not sure what else to do. We don't know exactly what type of cursor to
# instantiate when creating the class. We can make a few assumptions like if there is a cursor_field which implies
# incremental, but we don't know until runtime if this is a substream. Ideally, a stream should explicitly define
@@ -406,8 +406,7 @@ def get_cursor(self) -> Optional[Cursor]:
if self.has_multiple_slices and isinstance(self.cursor, ResumableFullRefreshCursor):
self.cursor = SubstreamResumableFullRefreshCursor()
return self.cursor
- else:
- return self.cursor
+ return self.cursor
def _read_pages(
self,
@@ -416,12 +415,12 @@ def _read_pages(
requests.PreparedRequest,
requests.Response,
Mapping[str, Any],
- Optional[Mapping[str, Any]],
+ Mapping[str, Any] | None,
],
Iterable[StreamData],
],
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ stream_slice: Mapping[str, Any] | None = None,
+ stream_state: Mapping[str, Any] | None = None,
) -> Iterable[StreamData]:
partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice)
@@ -452,12 +451,12 @@ def _read_single_page(
requests.PreparedRequest,
requests.Response,
Mapping[str, Any],
- Optional[Mapping[str, Any]],
+ Mapping[str, Any] | None,
],
Iterable[StreamData],
],
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
+ stream_slice: Mapping[str, Any] | None = None,
+ stream_state: Mapping[str, Any] | None = None,
) -> Iterable[StreamData]:
partition, cursor_slice, remaining_slice = self._extract_slice_fields(
stream_slice=stream_slice
@@ -481,7 +480,7 @@ def _read_single_page(
@staticmethod
def _extract_slice_fields(
- stream_slice: Optional[Mapping[str, Any]],
+ stream_slice: Mapping[str, Any] | None,
) -> tuple[Mapping[str, Any], Mapping[str, Any], Mapping[str, Any]]:
if not stream_slice:
return {}, {}, {}
@@ -499,16 +498,16 @@ def _extract_slice_fields(
remaining = {
key: val
for key, val in stream_slice.items()
- if key != "partition" and key != "cursor_slice"
+ if key != "partition" and key != "cursor_slice" # noqa: PLR1714
}
return partition, cursor_slice, remaining
def _fetch_next_page(
self,
- stream_slice: Optional[Mapping[str, Any]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
- next_page_token: Optional[Mapping[str, Any]] = None,
- ) -> Tuple[requests.PreparedRequest, requests.Response]:
+ stream_slice: Mapping[str, Any] | None = None,
+ stream_state: Mapping[str, Any] | None = None,
+ next_page_token: Mapping[str, Any] | None = None,
+ ) -> tuple[requests.PreparedRequest, requests.Response]:
request, response = self._http_client.send_request(
http_method=self.http_method,
url=self._join_url(
@@ -551,7 +550,7 @@ def _fetch_next_page(
return request, response
- def get_log_formatter(self) -> Optional[Callable[[requests.Response], Any]]:
+ def get_log_formatter(self) -> Callable[[requests.Response], Any] | None:
"""
:return Optional[Callable[[requests.Response], Any]]: Function that will be used in logging inside HttpClient
@@ -560,7 +559,7 @@ def get_log_formatter(self) -> Optional[Callable[[requests.Response], Any]]:
class HttpSubStream(HttpStream, ABC):
- def __init__(self, parent: HttpStream, **kwargs: Any):
+ def __init__(self, parent: HttpStream, **kwargs: Any): # noqa: ANN204, ANN401
"""
:param parent: should be the instance of HttpStream class
"""
@@ -584,21 +583,21 @@ def __init__(self, parent: HttpStream, **kwargs: Any):
def stream_slices(
self,
- sync_mode: SyncMode,
- cursor_field: Optional[List[str]] = None,
- stream_state: Optional[Mapping[str, Any]] = None,
- ) -> Iterable[Optional[Mapping[str, Any]]]:
+ sync_mode: SyncMode, # noqa: ARG002
+ cursor_field: list[str] | None = None, # noqa: ARG002
+ stream_state: Mapping[str, Any] | None = None,
+ ) -> Iterable[Mapping[str, Any] | None]:
# read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does
# not support either substreams or RFR, but something that needs to be considered once we do
for parent_record in self.parent.read_only_records(stream_state):
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
if parent_record.type == MessageType.RECORD:
- parent_record = parent_record.record.data # type: ignore [assignment, union-attr] # Incorrect type for assignment
+ parent_record = parent_record.record.data # type: ignore [assignment, union-attr] # Incorrect type for assignment # noqa: PLW2901
else:
continue
elif isinstance(parent_record, Record):
- parent_record = parent_record.data
+ parent_record = parent_record.data # noqa: PLW2901
yield {"parent": parent_record}
@@ -607,15 +606,15 @@ def stream_slices(
"You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead."
)
class HttpStreamAdapterBackoffStrategy(BackoffStrategy):
- def __init__(self, stream: HttpStream):
+ def __init__(self, stream: HttpStream): # noqa: ANN204
self.stream = stream
def backoff_time(
self,
- response_or_exception: Optional[Union[requests.Response, requests.RequestException]],
- attempt_count: int,
- ) -> Optional[float]:
- return self.stream.backoff_time(response_or_exception) # type: ignore # noqa # HttpStream.backoff_time has been deprecated
+ response_or_exception: requests.Response | requests.RequestException | None,
+ attempt_count: int, # noqa: ARG002
+ ) -> float | None:
+ return self.stream.backoff_time(response_or_exception) # type: ignore # HttpStream.backoff_time has been deprecated
@deprecated(
@@ -623,19 +622,19 @@ def backoff_time(
"You should set error_handler explicitly in HttpStream.get_error_handler() instead."
)
class HttpStreamAdapterHttpStatusErrorHandler(HttpStatusErrorHandler):
- def __init__(self, stream: HttpStream, **kwargs): # type: ignore # noqa
+ def __init__(self, stream: HttpStream, **kwargs) -> None: # type: ignore
self.stream = stream
super().__init__(**kwargs)
- def interpret_response(
- self, response_or_exception: Optional[Union[requests.Response, Exception]] = None
+ def interpret_response( # noqa: PLR0911
+ self, response_or_exception: requests.Response | Exception | None = None
) -> ErrorResolution:
if isinstance(response_or_exception, Exception):
return super().interpret_response(response_or_exception)
- elif isinstance(response_or_exception, requests.Response):
- should_retry = self.stream.should_retry(response_or_exception) # type: ignore # noqa
+ if isinstance(response_or_exception, requests.Response):
+ should_retry = self.stream.should_retry(response_or_exception) # type: ignore
if should_retry:
- if response_or_exception.status_code == 429:
+ if response_or_exception.status_code == 429: # noqa: PLR2004
return ErrorResolution(
response_action=ResponseAction.RATE_LIMITED,
failure_type=FailureType.transient_error,
@@ -646,29 +645,26 @@ def interpret_response(
failure_type=FailureType.transient_error,
error_message=f"Response status code: {response_or_exception.status_code}. Retrying...",
)
- else:
- if response_or_exception.ok:
- return ErrorResolution(
- response_action=ResponseAction.SUCCESS,
- failure_type=None,
- error_message=None,
- )
- if self.stream.raise_on_http_errors:
- return ErrorResolution(
- response_action=ResponseAction.FAIL,
- failure_type=FailureType.transient_error,
- error_message=f"Response status code: {response_or_exception.status_code}. Unexpected error. Failed.",
- )
- else:
- return ErrorResolution(
- response_action=ResponseAction.IGNORE,
- failure_type=FailureType.transient_error,
- error_message=f"Response status code: {response_or_exception.status_code}. Ignoring...",
- )
- else:
- self._logger.error(f"Received unexpected response type: {type(response_or_exception)}")
+ if response_or_exception.ok:
+ return ErrorResolution(
+ response_action=ResponseAction.SUCCESS,
+ failure_type=None,
+ error_message=None,
+ )
+ if self.stream.raise_on_http_errors:
+ return ErrorResolution(
+ response_action=ResponseAction.FAIL,
+ failure_type=FailureType.transient_error,
+ error_message=f"Response status code: {response_or_exception.status_code}. Unexpected error. Failed.",
+ )
return ErrorResolution(
- response_action=ResponseAction.FAIL,
- failure_type=FailureType.system_error,
- error_message=f"Received unexpected response type: {type(response_or_exception)}",
+ response_action=ResponseAction.IGNORE,
+ failure_type=FailureType.transient_error,
+ error_message=f"Response status code: {response_or_exception.status_code}. Ignoring...",
)
+ self._logger.error(f"Received unexpected response type: {type(response_or_exception)}")
+ return ErrorResolution(
+ response_action=ResponseAction.FAIL,
+ failure_type=FailureType.system_error,
+ error_message=f"Received unexpected response type: {type(response_or_exception)}",
+ )
diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py
index c4fa86866..bfca42a71 100644
--- a/airbyte_cdk/sources/streams/http/http_client.py
+++ b/airbyte_cdk/sources/streams/http/http_client.py
@@ -5,8 +5,9 @@
import logging
import os
import urllib
+from collections.abc import Callable, Mapping
from pathlib import Path
-from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
+from typing import Any
import orjson
import requests
@@ -53,6 +54,7 @@
)
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
+
BODY_REQUEST_METHODS = ("GET", "POST", "PUT", "PATCH")
@@ -68,7 +70,7 @@ class MessageRepresentationAirbyteTracedErrors(AirbyteTracedException):
def __str__(self) -> str:
if self.message:
return self.message
- elif self.internal_message:
+ if self.internal_message:
return self.internal_message
return ""
@@ -76,21 +78,21 @@ def __str__(self) -> str:
class HttpClient:
_DEFAULT_MAX_RETRY: int = 5
_DEFAULT_MAX_TIME: int = 60 * 10
- _ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED}
+ _ACTIONS_TO_RETRY_ON = {ResponseAction.RETRY, ResponseAction.RATE_LIMITED} # noqa: RUF012
- def __init__(
+ def __init__( # noqa: ANN204, PLR0913, PLR0917
self,
name: str,
logger: logging.Logger,
- error_handler: Optional[ErrorHandler] = None,
- api_budget: Optional[APIBudget] = None,
- session: Optional[Union[requests.Session, requests_cache.CachedSession]] = None,
- authenticator: Optional[AuthBase] = None,
- use_cache: bool = False,
- backoff_strategy: Optional[Union[BackoffStrategy, List[BackoffStrategy]]] = None,
- error_message_parser: Optional[ErrorMessageParser] = None,
- disable_retries: bool = False,
- message_repository: Optional[MessageRepository] = None,
+ error_handler: ErrorHandler | None = None,
+ api_budget: APIBudget | None = None,
+ session: requests.Session | requests_cache.CachedSession | None = None,
+ authenticator: AuthBase | None = None,
+ use_cache: bool = False, # noqa: FBT001, FBT002
+ backoff_strategy: BackoffStrategy | list[BackoffStrategy] | None = None,
+ error_message_parser: ErrorMessageParser | None = None,
+ disable_retries: bool = False, # noqa: FBT001, FBT002
+ message_repository: MessageRepository | None = None,
):
self._name = name
self._api_budget: APIBudget = api_budget or APIBudget(policies=[])
@@ -117,7 +119,7 @@ def __init__(
else:
self._backoff_strategies = [DefaultBackoffStrategy()]
self._error_message_parser = error_message_parser or JsonErrorMessageParser()
- self._request_attempt_count: Dict[requests.PreparedRequest, int] = {}
+ self._request_attempt_count: dict[requests.PreparedRequest, int] = {}
self._disable_retries = disable_retries
self._message_repository = message_repository
@@ -155,8 +157,7 @@ def _request_session(self) -> requests.Session:
return CachedLimiterSession(
sqlite_path, backend=backend, api_budget=self._api_budget, match_headers=True
)
- else:
- return LimiterSession(api_budget=self._api_budget)
+ return LimiterSession(api_budget=self._api_budget)
def clear_cache(self) -> None:
"""
@@ -165,9 +166,7 @@ def clear_cache(self) -> None:
if isinstance(self._session, requests_cache.CachedSession):
self._session.cache.clear() # type: ignore # cache.clear is not typed
- def _dedupe_query_params(
- self, url: str, params: Optional[Mapping[str, str]]
- ) -> Mapping[str, str]:
+ def _dedupe_query_params(self, url: str, params: Mapping[str, str] | None) -> Mapping[str, str]:
"""
Remove query parameters from params mapping if they are already encoded in the URL.
:param url: URL with
@@ -180,7 +179,7 @@ def _dedupe_query_params(
query_dict = {k: v[0] for k, v in urllib.parse.parse_qs(query_string).items()}
duplicate_keys_with_same_value = {
- k for k in query_dict.keys() if str(params.get(k)) == str(query_dict[k])
+ k for k in query_dict if str(params.get(k)) == str(query_dict[k])
}
return {k: v for k, v in params.items() if k not in duplicate_keys_with_same_value}
@@ -188,11 +187,11 @@ def _create_prepared_request(
self,
http_method: str,
url: str,
- dedupe_query_params: bool = False,
- headers: Optional[Mapping[str, str]] = None,
- params: Optional[Mapping[str, str]] = None,
- json: Optional[Mapping[str, Any]] = None,
- data: Optional[Union[str, Mapping[str, Any]]] = None,
+ dedupe_query_params: bool = False, # noqa: FBT001, FBT002
+ headers: Mapping[str, str] | None = None,
+ params: Mapping[str, str] | None = None,
+ json: Mapping[str, Any] | None = None,
+ data: str | Mapping[str, Any] | None = None,
) -> requests.PreparedRequest:
if dedupe_query_params:
query_params = self._dedupe_query_params(url, params)
@@ -204,7 +203,7 @@ def _create_prepared_request(
raise RequestBodyException(
"At the same time only one of the 'request_body_data' and 'request_body_json' functions can return data"
)
- elif json:
+ if json:
args["json"] = json
elif data:
args["data"] = data
@@ -220,7 +219,7 @@ def _max_retries(self) -> int:
Determines the max retries based on the provided error handler.
"""
max_retries = None
- if self._disable_retries:
+ if self._disable_retries: # noqa: SIM108
max_retries = 0
else:
max_retries = self._error_handler.max_retries
@@ -241,8 +240,8 @@ def _send_with_retry(
self,
request: requests.PreparedRequest,
request_kwargs: Mapping[str, Any],
- log_formatter: Optional[Callable[[requests.Response], Any]] = None,
- exit_on_rate_limit: Optional[bool] = False,
+ log_formatter: Callable[[requests.Response], Any] | None = None,
+ exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002
) -> requests.Response:
"""
Sends a request with retry logic.
@@ -280,8 +279,8 @@ def _send(
self,
request: requests.PreparedRequest,
request_kwargs: Mapping[str, Any],
- log_formatter: Optional[Callable[[requests.Response], Any]] = None,
- exit_on_rate_limit: Optional[bool] = False,
+ log_formatter: Callable[[requests.Response], Any] | None = None,
+ exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002
) -> requests.Response:
if request not in self._request_attempt_count:
self._request_attempt_count[request] = 1
@@ -295,8 +294,8 @@ def _send(
extra={"headers": request.headers, "url": request.url, "request_body": request.body},
)
- response: Optional[requests.Response] = None
- exc: Optional[requests.RequestException] = None
+ response: requests.Response | None = None
+ exc: requests.RequestException | None = None
try:
response = self._session.send(request, **request_kwargs)
@@ -347,7 +346,7 @@ def _send(
return response # type: ignore # will either return a valid response of type requests.Response or raise an exception
- def _get_response_body(self, response: requests.Response) -> Optional[JsonType]:
+ def _get_response_body(self, response: requests.Response) -> JsonType | None:
"""
Extracts and returns the body of an HTTP response.
@@ -383,11 +382,11 @@ def _evict_key(self, prepared_request: requests.PreparedRequest) -> None:
def _handle_error_resolution(
self,
- response: Optional[requests.Response],
- exc: Optional[requests.RequestException],
+ response: requests.Response | None,
+ exc: requests.RequestException | None,
request: requests.PreparedRequest,
error_resolution: ErrorResolution,
- exit_on_rate_limit: Optional[bool] = False,
+ exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002
) -> None:
if error_resolution.response_action not in self._ACTIONS_TO_RETRY_ON:
self._evict_key(request)
@@ -413,7 +412,7 @@ def _handle_error_resolution(
if error_resolution.response_action == ResponseAction.FAIL:
if response is not None:
filtered_response_message = filter_secrets(
- f"Request (body): '{str(request.body)}'. Response (body): '{self._get_response_body(response)}'. Response (headers): '{response.headers}'."
+ f"Request (body): '{request.body!s}'. Response (body): '{self._get_response_body(response)}'. Response (headers): '{response.headers}'."
)
error_message = f"'{request.method}' request to '{request.url}' failed with status code '{response.status_code}' and error message: '{self._error_message_parser.parse_response_error_message(response)}'. {filtered_response_message}"
else:
@@ -430,7 +429,7 @@ def _handle_error_resolution(
failure_type=error_resolution.failure_type,
)
- elif error_resolution.response_action == ResponseAction.IGNORE:
+ if error_resolution.response_action == ResponseAction.IGNORE:
if response is not None:
log_message = f"Ignoring response for '{request.method}' request to '{request.url}' with response code '{response.status_code}'"
else:
@@ -440,7 +439,7 @@ def _handle_error_resolution(
# TODO: Consider dynamic retry count depending on subsequent error codes
elif (
- error_resolution.response_action == ResponseAction.RETRY
+ error_resolution.response_action == ResponseAction.RETRY # noqa: PLR1714
or error_resolution.response_action == ResponseAction.RATE_LIMITED
):
user_defined_backoff_time = None
@@ -470,7 +469,7 @@ def _handle_error_resolution(
error_message=error_message,
)
- elif retry_endlessly:
+ if retry_endlessly:
raise RateLimitBackoffException(
request=request,
response=(response if response is not None else exc),
@@ -488,25 +487,25 @@ def _handle_error_resolution(
response.raise_for_status()
except requests.HTTPError as e:
self._logger.error(response.text)
- raise e
+ raise e # noqa: TRY201
@property
def name(self) -> str:
return self._name
- def send_request(
+ def send_request( # noqa: PLR0913, PLR0917
self,
http_method: str,
url: str,
request_kwargs: Mapping[str, Any],
- headers: Optional[Mapping[str, str]] = None,
- params: Optional[Mapping[str, str]] = None,
- json: Optional[Mapping[str, Any]] = None,
- data: Optional[Union[str, Mapping[str, Any]]] = None,
- dedupe_query_params: bool = False,
- log_formatter: Optional[Callable[[requests.Response], Any]] = None,
- exit_on_rate_limit: Optional[bool] = False,
- ) -> Tuple[requests.PreparedRequest, requests.Response]:
+ headers: Mapping[str, str] | None = None,
+ params: Mapping[str, str] | None = None,
+ json: Mapping[str, Any] | None = None,
+ data: str | Mapping[str, Any] | None = None,
+ dedupe_query_params: bool = False, # noqa: FBT001, FBT002
+ log_formatter: Callable[[requests.Response], Any] | None = None,
+ exit_on_rate_limit: bool | None = False, # noqa: FBT001, FBT002
+ ) -> tuple[requests.PreparedRequest, requests.Response]:
"""
Prepares and sends request and return request and response objects.
"""
diff --git a/airbyte_cdk/sources/streams/http/rate_limiting.py b/airbyte_cdk/sources/streams/http/rate_limiting.py
index 926a7ad56..33a55d9e1 100644
--- a/airbyte_cdk/sources/streams/http/rate_limiting.py
+++ b/airbyte_cdk/sources/streams/http/rate_limiting.py
@@ -5,7 +5,8 @@
import logging
import sys
import time
-from typing import Any, Callable, Mapping, Optional
+from collections.abc import Callable, Mapping
+from typing import Any
import backoff
from requests import PreparedRequest, RequestException, Response, codes, exceptions
@@ -16,6 +17,7 @@
UserDefinedBackoffException,
)
+
TRANSIENT_EXCEPTIONS = (
DefaultBackoffException,
exceptions.ConnectTimeout,
@@ -31,7 +33,10 @@
def default_backoff_handler(
- max_tries: Optional[int], factor: float, max_time: Optional[int] = None, **kwargs: Any
+ max_tries: int | None,
+ factor: float,
+ max_time: int | None = None,
+ **kwargs: Any, # noqa: ANN401
) -> Callable[[SendRequestCallableType], SendRequestCallableType]:
def log_retry_attempt(details: Mapping[str, Any]) -> None:
_, exc, _ = sys.exc_info()
@@ -40,17 +45,17 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None:
f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}"
)
logger.info(
- f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
+ f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
)
def should_give_up(exc: Exception) -> bool:
# If a non-rate-limiting related 4XX error makes it this far, it means it was unexpected and probably consistent, so we shouldn't back off
- if isinstance(exc, RequestException):
+ if isinstance(exc, RequestException): # noqa: SIM102
if exc.response is not None:
give_up: bool = (
exc.response is not None
and exc.response.status_code != codes.too_many_requests
- and 400 <= exc.response.status_code < 500
+ and 400 <= exc.response.status_code < 500 # noqa: PLR2004
)
if give_up:
logger.info(f"Giving up for returned HTTP status: {exc.response.status_code!r}")
@@ -72,7 +77,9 @@ def should_give_up(exc: Exception) -> bool:
def http_client_default_backoff_handler(
- max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any
+ max_tries: int | None,
+ max_time: int | None = None,
+ **kwargs: Any, # noqa: ANN401
) -> Callable[[SendRequestCallableType], SendRequestCallableType]:
def log_retry_attempt(details: Mapping[str, Any]) -> None:
_, exc, _ = sys.exc_info()
@@ -81,10 +88,10 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None:
f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}"
)
logger.info(
- f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
+ f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
)
- def should_give_up(exc: Exception) -> bool:
+ def should_give_up(exc: Exception) -> bool: # noqa: ARG001
# If made it here, the ResponseAction was RETRY and therefore should not give up
return False
@@ -101,9 +108,11 @@ def should_give_up(exc: Exception) -> bool:
def user_defined_backoff_handler(
- max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any
+ max_tries: int | None,
+ max_time: int | None = None,
+ **kwargs: Any, # noqa: ANN401
) -> Callable[[SendRequestCallableType], SendRequestCallableType]:
- def sleep_on_ratelimit(details: Mapping[str, Any]) -> None:
+ def sleep_on_ratelimit(details: Mapping[str, Any]) -> None: # noqa: ARG001
_, exc, _ = sys.exc_info()
if isinstance(exc, UserDefinedBackoffException):
if exc.response:
@@ -137,7 +146,7 @@ def log_give_up(details: Mapping[str, Any]) -> None:
def rate_limit_default_backoff_handler(
- **kwargs: Any,
+ **kwargs: Any, # noqa: ANN401
) -> Callable[[SendRequestCallableType], SendRequestCallableType]:
def log_retry_attempt(details: Mapping[str, Any]) -> None:
_, exc, _ = sys.exc_info()
@@ -146,7 +155,7 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None:
f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}"
)
logger.info(
- f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
+ f"Caught retryable error '{exc!s}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
)
return backoff.on_exception( # type: ignore # Decorator function returns a function with a different signature than the input function, so mypy can't infer the type of the returned function
diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py b/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py
index 307f91f40..e792ba0b4 100644
--- a/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py
+++ b/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py
@@ -5,6 +5,7 @@
from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator
from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator
+
__all__ = [
"Oauth2Authenticator",
"SingleUseRefreshTokenOauth2Authenticator",
diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py
index 753d79269..44b86940c 100644
--- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py
+++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py
@@ -4,27 +4,28 @@
import logging
from abc import abstractmethod
+from collections.abc import Mapping, MutableMapping
from json import JSONDecodeError
-from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import Any
import backoff
import pendulum
import requests
from requests.auth import AuthBase
+from ..exceptions import DefaultBackoffException # noqa: TID252
from airbyte_cdk.models import FailureType, Level
from airbyte_cdk.sources.http_logger import format_http_message
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
-from ..exceptions import DefaultBackoffException
logger = logging.getLogger("airbyte")
_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
-class AbstractOauth2Authenticator(AuthBase):
+class AbstractOauth2Authenticator(AuthBase): # noqa: PLR0904
"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
is designed to generically perform the refresh flow without regard to how config fields are get/set by
@@ -35,9 +36,9 @@ class AbstractOauth2Authenticator(AuthBase):
def __init__(
self,
- refresh_token_error_status_codes: Tuple[int, ...] = (),
+ refresh_token_error_status_codes: tuple[int, ...] = (),
refresh_token_error_key: str = "",
- refresh_token_error_values: Tuple[str, ...] = (),
+ refresh_token_error_values: tuple[str, ...] = (),
) -> None:
"""
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
@@ -104,7 +105,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
"""
headers = self.get_refresh_request_headers()
- return headers if headers else None
+ return headers if headers else None # noqa: FURB110
def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
@@ -130,7 +131,7 @@ def _wrap_refresh_token_exception(
),
max_time=300,
)
- def _get_refresh_access_token_response(self) -> Any:
+ def _get_refresh_access_token_response(self) -> Any: # noqa: ANN401
try:
response = requests.request(
method="POST",
@@ -144,30 +145,29 @@ def _get_refresh_access_token_response(self) -> Any:
# An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
access_key = response_json.get(self.get_access_token_name())
if not access_key:
- raise Exception(
+ raise Exception( # noqa: TRY002, TRY301
"Token refresh API response was missing access token {self.get_access_token_name()}"
)
add_to_secrets(access_key)
self._log_response(response)
return response_json
- else:
- # log the response even if the request failed for troubleshooting purposes
- self._log_response(response)
- response.raise_for_status()
+ # log the response even if the request failed for troubleshooting purposes
+ self._log_response(response)
+ response.raise_for_status()
except requests.exceptions.RequestException as e:
- if e.response is not None:
- if e.response.status_code == 429 or e.response.status_code >= 500:
- raise DefaultBackoffException(request=e.response.request, response=e.response)
+ if e.response is not None: # noqa: SIM102
+ if e.response.status_code == 429 or e.response.status_code >= 500: # noqa: PLR2004
+ raise DefaultBackoffException(request=e.response.request, response=e.response) # noqa: B904
if self._wrap_refresh_token_exception(e):
message = "Refresh token is invalid or expired. Please re-authenticate from Sources//Settings."
- raise AirbyteTracedException(
+ raise AirbyteTracedException( # noqa: B904
internal_message=message, message=message, failure_type=FailureType.config_error
)
raise
except Exception as e:
- raise Exception(f"Error while refreshing access token: {e}") from e
+ raise Exception(f"Error while refreshing access token: {e}") from e # noqa: TRY002
- def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
+ def refresh_access_token(self) -> tuple[str, str | int]:
"""
Returns the refresh token and its expiration datetime
@@ -179,7 +179,7 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
self.get_expires_in_name()
]
- def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateTime:
+ def _parse_token_expiration_date(self, value: str | int) -> pendulum.DateTime:
"""
Return the expiration datetime of the refresh token
@@ -192,8 +192,7 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateT
f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
)
return pendulum.from_format(str(value), self.token_expiry_date_format)
- else:
- return pendulum.now().add(seconds=int(float(value)))
+ return pendulum.now().add(seconds=int(float(value)))
@property
def token_expiry_is_time_of_expiration(self) -> bool:
@@ -204,7 +203,7 @@ def token_expiry_is_time_of_expiration(self) -> bool:
return False
@property
- def token_expiry_date_format(self) -> Optional[str]:
+ def token_expiry_date_format(self) -> str | None:
"""
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
"""
@@ -212,7 +211,7 @@ def token_expiry_date_format(self) -> Optional[str]:
return None
@abstractmethod
- def get_token_refresh_endpoint(self) -> Optional[str]:
+ def get_token_refresh_endpoint(self) -> str | None:
"""Returns the endpoint to refresh the access token"""
@abstractmethod
@@ -236,11 +235,11 @@ def get_refresh_token_name(self) -> str:
"""The refresh token name to authenticate"""
@abstractmethod
- def get_refresh_token(self) -> Optional[str]:
+ def get_refresh_token(self) -> str | None:
"""The token used to refresh the access token when it expires"""
@abstractmethod
- def get_scopes(self) -> List[str]:
+ def get_scopes(self) -> list[str]:
"""List of requested scopes"""
@abstractmethod
@@ -248,7 +247,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
"""Expiration date of the access token"""
@abstractmethod
- def set_token_expiry_date(self, value: Union[str, int]) -> None:
+ def set_token_expiry_date(self, value: str | int) -> None:
"""Setter for access token expiration date"""
@abstractmethod
@@ -286,7 +285,7 @@ def access_token(self, value: str) -> str:
"""Setter for the access token"""
@property
- def _message_repository(self) -> Optional[MessageRepository]:
+ def _message_repository(self) -> MessageRepository | None:
"""
The implementation can define a message_repository if it wants debugging logs for HTTP requests
"""
diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py
index ffcc8e851..9afab38b4 100644
--- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py
+++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_token.py
@@ -3,7 +3,8 @@
#
from abc import abstractmethod
-from typing import Any, Mapping
+from collections.abc import Mapping
+from typing import Any
import requests
from requests.auth import AuthBase
@@ -12,7 +13,7 @@
class AbstractHeaderAuthenticator(AuthBase):
"""Abstract class for an header-based authenticators that add a header to outgoing HTTP requests."""
- def __call__(self, request: requests.PreparedRequest) -> Any:
+ def __call__(self, request: requests.PreparedRequest) -> Any: # noqa: ANN401
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
index f244e6508..8333c56e8 100644
--- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
+++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
@@ -2,7 +2,8 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
+from collections.abc import Mapping, Sequence
+from typing import Any
import dpath
import pendulum
@@ -24,7 +25,7 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
If a connector_config is provided any mutation of it's value in the scope of this class will emit AirbyteControlConnectorConfigMessage.
"""
- def __init__(
+ def __init__( # noqa: ANN204, PLR0913, PLR0917
self,
token_refresh_endpoint: str,
client_id: str,
@@ -33,7 +34,7 @@ def __init__(
client_id_name: str = "client_id",
client_secret_name: str = "client_secret",
refresh_token_name: str = "refresh_token",
- scopes: List[str] | None = None,
+ scopes: list[str] | None = None,
token_expiry_date: pendulum.DateTime | None = None,
token_expiry_date_format: str | None = None,
access_token_name: str = "access_token",
@@ -42,10 +43,10 @@ def __init__(
refresh_request_headers: Mapping[str, Any] | None = None,
grant_type_name: str = "grant_type",
grant_type: str = "refresh_token",
- token_expiry_is_time_of_expiration: bool = False,
- refresh_token_error_status_codes: Tuple[int, ...] = (),
+ token_expiry_is_time_of_expiration: bool = False, # noqa: FBT001, FBT002
+ refresh_token_error_status_codes: tuple[int, ...] = (),
refresh_token_error_key: str = "",
- refresh_token_error_values: Tuple[str, ...] = (),
+ refresh_token_error_values: tuple[str, ...] = (),
):
self._token_refresh_endpoint = token_refresh_endpoint
self._client_secret_name = client_secret_name
@@ -115,7 +116,7 @@ def get_grant_type(self) -> str:
def get_token_expiry_date(self) -> pendulum.DateTime:
return self._token_expiry_date
- def set_token_expiry_date(self, value: Union[str, int]) -> None:
+ def set_token_expiry_date(self, value: str | int) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)
@property
@@ -123,7 +124,7 @@ def token_expiry_is_time_of_expiration(self) -> bool:
return self._token_expiry_is_time_of_expiration
@property
- def token_expiry_date_format(self) -> Optional[str]:
+ def token_expiry_date_format(self) -> str | None:
return self._token_expiry_date_format
@property
@@ -145,11 +146,11 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
client_secret_config_path, refresh_token_config_path constructor arguments.
"""
- def __init__(
+ def __init__( # noqa: ANN204, PLR0913, PLR0917
self,
connector_config: Mapping[str, Any],
token_refresh_endpoint: str,
- scopes: List[str] | None = None,
+ scopes: list[str] | None = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
@@ -158,18 +159,18 @@ def __init__(
grant_type_name: str = "grant_type",
grant_type: str = "refresh_token",
client_id_name: str = "client_id",
- client_id: Optional[str] = None,
+ client_id: str | None = None,
client_secret_name: str = "client_secret",
- client_secret: Optional[str] = None,
+ client_secret: str | None = None,
access_token_config_path: Sequence[str] = ("credentials", "access_token"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"),
- token_expiry_date_format: Optional[str] = None,
- message_repository: MessageRepository = NoopMessageRepository(),
- token_expiry_is_time_of_expiration: bool = False,
- refresh_token_error_status_codes: Tuple[int, ...] = (),
+ token_expiry_date_format: str | None = None,
+ message_repository: MessageRepository = NoopMessageRepository(), # noqa: B008
+ token_expiry_is_time_of_expiration: bool = False, # noqa: FBT001, FBT002
+ refresh_token_error_status_codes: tuple[int, ...] = (),
refresh_token_error_key: str = "",
- refresh_token_error_values: Tuple[str, ...] = (),
+ refresh_token_error_values: tuple[str, ...] = (),
):
"""
Args:
@@ -282,7 +283,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
self._token_expiry_date_config_path,
default="",
)
- return pendulum.now().subtract(days=1) if expiry_date == "" else pendulum.parse(expiry_date) # type: ignore [arg-type, return-value, no-untyped-call]
+ return pendulum.now().subtract(days=1) if expiry_date == "" else pendulum.parse(expiry_date) # type: ignore [arg-type, return-value, no-untyped-call] # noqa: PLC1901
def set_token_expiry_date( # type: ignore[override]
self,
@@ -305,8 +306,7 @@ def get_new_token_expiry_date(
) -> pendulum.DateTime:
if token_expiry_date_format:
return pendulum.from_format(access_token_expires_in, token_expiry_date_format)
- else:
- return pendulum.now("UTC").add(seconds=int(access_token_expires_in))
+ return pendulum.now("UTC").add(seconds=int(access_token_expires_in))
def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
@@ -324,7 +324,7 @@ def get_access_token(self) -> str:
self.access_token = new_access_token
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(new_token_expiry_date)
- # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message
+ # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message # noqa: FIX001, TD001, TD004
# Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the
# message directly in the console, this is needed
if not isinstance(self._message_repository, NoopMessageRepository):
@@ -337,7 +337,7 @@ def get_access_token(self) -> str:
def refresh_access_token( # type: ignore[override] # Signature doesn't match base class
self,
- ) -> Tuple[str, str, str]:
+ ) -> tuple[str, str, str]:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py
index eec7fd0c5..bb7a79e60 100644
--- a/airbyte_cdk/sources/streams/http/requests_native_auth/token.py
+++ b/airbyte_cdk/sources/streams/http/requests_native_auth/token.py
@@ -1,10 +1,9 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import base64
from itertools import cycle
-from typing import List
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import (
AbstractHeaderAuthenticator,
@@ -26,8 +25,8 @@ def auth_header(self) -> str:
def token(self) -> str:
return f"{self._auth_method} {next(self._tokens_iter)}"
- def __init__(
- self, tokens: List[str], auth_method: str = "Bearer", auth_header: str = "Authorization"
+ def __init__( # noqa: ANN204
+ self, tokens: list[str], auth_method: str = "Bearer", auth_header: str = "Authorization"
):
self._auth_method = auth_method
self._auth_header = auth_header
@@ -49,7 +48,7 @@ def auth_header(self) -> str:
def token(self) -> str:
return f"{self._auth_method} {self._token}"
- def __init__(self, token: str, auth_method: str = "Bearer", auth_header: str = "Authorization"):
+ def __init__(self, token: str, auth_method: str = "Bearer", auth_header: str = "Authorization"): # noqa: ANN204
self._auth_header = auth_header
self._auth_method = auth_method
self._token = token
@@ -69,14 +68,14 @@ def auth_header(self) -> str:
def token(self) -> str:
return f"{self._auth_method} {self._token}"
- def __init__(
+ def __init__( # noqa: ANN204
self,
username: str,
password: str = "",
auth_method: str = "Basic",
auth_header: str = "Authorization",
):
- auth_string = f"{username}:{password}".encode("utf8")
+ auth_string = f"{username}:{password}".encode()
b64_encoded = base64.b64encode(auth_string).decode("utf8")
self._auth_header = auth_header
self._auth_method = auth_method
diff --git a/airbyte_cdk/sources/types.py b/airbyte_cdk/sources/types.py
index 3c466ccd8..21bf39b17 100644
--- a/airbyte_cdk/sources/types.py
+++ b/airbyte_cdk/sources/types.py
@@ -1,28 +1,30 @@
-#
+# # noqa: A005
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from __future__ import annotations
-from typing import Any, ItemsView, Iterator, KeysView, List, Mapping, Optional, ValuesView
+from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView
+from typing import Any
import orjson
+
# A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2":
# "hello"}] returns "hello"
-FieldPointer = List[str]
+FieldPointer = list[str]
Config = Mapping[str, Any]
ConnectionDefinition = Mapping[str, Any]
StreamState = Mapping[str, Any]
-class Record(Mapping[str, Any]):
- def __init__(
+class Record(Mapping[str, Any]): # noqa: PLW1641
+ def __init__( # noqa: ANN204
self,
data: Mapping[str, Any],
stream_name: str,
- associated_slice: Optional[StreamSlice] = None,
- is_file_transfer_message: bool = False,
+ associated_slice: StreamSlice | None = None,
+ is_file_transfer_message: bool = False, # noqa: FBT001, FBT002
):
self._data = data
self._associated_slice = associated_slice
@@ -34,19 +36,19 @@ def data(self) -> Mapping[str, Any]:
return self._data
@property
- def associated_slice(self) -> Optional[StreamSlice]:
+ def associated_slice(self) -> StreamSlice | None:
return self._associated_slice
def __repr__(self) -> str:
return repr(self._data)
- def __getitem__(self, key: str) -> Any:
+ def __getitem__(self, key: str) -> Any: # noqa: ANN401
return self._data[key]
def __len__(self) -> int:
return len(self._data)
- def __iter__(self) -> Any:
+ def __iter__(self) -> Any: # noqa: ANN401
return iter(self._data)
def __contains__(self, item: object) -> bool:
@@ -68,7 +70,7 @@ def __init__(
*,
partition: Mapping[str, Any],
cursor_slice: Mapping[str, Any],
- extra_fields: Optional[Mapping[str, Any]] = None,
+ extra_fields: Mapping[str, Any] | None = None,
) -> None:
"""
:param partition: The partition keys representing a unique partition in the stream.
@@ -109,10 +111,10 @@ def extra_fields(self) -> Mapping[str, Any]:
def __repr__(self) -> str:
return repr(self._stream_slice)
- def __setitem__(self, key: str, value: Any) -> None:
+ def __setitem__(self, key: str, value: Any) -> None: # noqa: ANN401
raise ValueError("StreamSlice is immutable")
- def __getitem__(self, key: str) -> Any:
+ def __getitem__(self, key: str) -> Any: # noqa: ANN401
return self._stream_slice[key]
def __len__(self) -> int:
@@ -121,7 +123,7 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[str]:
return iter(self._stream_slice)
- def __contains__(self, item: Any) -> bool:
+ def __contains__(self, item: Any) -> bool: # noqa: ANN401
return item in self._stream_slice
def keys(self) -> KeysView[str]:
@@ -133,10 +135,10 @@ def items(self) -> ItemsView[str, Any]:
def values(self) -> ValuesView[Any]:
return self._stream_slice.values()
- def get(self, key: str, default: Any = None) -> Optional[Any]:
+ def get(self, key: str, default: Any = None) -> Any | None: # noqa: ANN401
return self._stream_slice.get(key, default)
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other: object) -> bool:
if isinstance(other, dict):
return self._stream_slice == other
if isinstance(other, StreamSlice):
@@ -144,10 +146,10 @@ def __eq__(self, other: Any) -> bool:
return self._partition == other._partition and self._cursor_slice == other._cursor_slice
return False
- def __ne__(self, other: Any) -> bool:
+ def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
- def __json_serializable__(self) -> Any:
+ def __json_serializable__(self) -> Any: # noqa: ANN401, PLW3201
return self._stream_slice
def __hash__(self) -> int:
diff --git a/airbyte_cdk/sources/utils/__init__.py b/airbyte_cdk/sources/utils/__init__.py
index b609a6c7a..c941b3045 100644
--- a/airbyte_cdk/sources/utils/__init__.py
+++ b/airbyte_cdk/sources/utils/__init__.py
@@ -1,7 +1,3 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
-
-# Initialize Utils Package
-
-__all__ = ["record_helper"]
diff --git a/airbyte_cdk/sources/utils/casing.py b/airbyte_cdk/sources/utils/casing.py
index 806e077ae..c78d5e91e 100644
--- a/airbyte_cdk/sources/utils/casing.py
+++ b/airbyte_cdk/sources/utils/casing.py
@@ -8,5 +8,5 @@
# https://stackoverflow.com/a/1176023
def camel_to_snake(s: str) -> str:
- s = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s)
- return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s).lower()
+ s = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s) # noqa: RUF039
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s).lower() # noqa: RUF039
diff --git a/airbyte_cdk/sources/utils/record_helper.py b/airbyte_cdk/sources/utils/record_helper.py
index 3d2cbcecf..1a56a18cd 100644
--- a/airbyte_cdk/sources/utils/record_helper.py
+++ b/airbyte_cdk/sources/utils/record_helper.py
@@ -2,8 +2,9 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import time
+from collections.abc import Mapping
from collections.abc import Mapping as ABCMapping
-from typing import Any, Mapping, Optional
+from typing import Any
from airbyte_cdk.models import (
AirbyteLogMessage,
@@ -20,9 +21,9 @@
def stream_data_to_airbyte_message(
stream_name: str,
data_or_message: StreamData,
- transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform),
- schema: Optional[Mapping[str, Any]] = None,
- is_file_transfer_message: bool = False,
+ transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform), # noqa: B008
+ schema: Mapping[str, Any] | None = None,
+ is_file_transfer_message: bool = False, # noqa: FBT001, FBT002
) -> AirbyteMessage:
if schema is None:
schema = {}
diff --git a/airbyte_cdk/sources/utils/schema_helpers.py b/airbyte_cdk/sources/utils/schema_helpers.py
index f15578238..241e008d6 100644
--- a/airbyte_cdk/sources/utils/schema_helpers.py
+++ b/airbyte_cdk/sources/utils/schema_helpers.py
@@ -7,7 +7,8 @@
import json
import os
import pkgutil
-from typing import Any, ClassVar, Dict, List, Mapping, MutableMapping, Optional, Tuple
+from collections.abc import Mapping, MutableMapping
+from typing import Any, ClassVar
import jsonref
from jsonschema import RefResolver, validate
@@ -25,21 +26,20 @@ class JsonFileLoader:
pointing to shared_schema.json file instead of shared/shared_schema.json
"""
- def __init__(self, uri_base: str, shared: str):
+ def __init__(self, uri_base: str, shared: str): # noqa: ANN204
self.shared = shared
self.uri_base = uri_base
- def __call__(self, uri: str) -> Dict[str, Any]:
+ def __call__(self, uri: str) -> dict[str, Any]:
uri = uri.replace(self.uri_base, f"{self.uri_base}/{self.shared}/")
- with open(uri) as f:
+ with open(uri) as f: # noqa: PLW1514, PTH123
data = json.load(f)
if isinstance(data, dict):
return data
- else:
- raise ValueError(f"Expected to read a dictionary from {uri}. Got: {data}")
+ raise ValueError(f"Expected to read a dictionary from {uri}. Got: {data}")
-def resolve_ref_links(obj: Any) -> Any:
+def resolve_ref_links(obj: Any) -> Any: # noqa: ANN401
"""
Scan resolved schema and convert jsonref.JsonRef object to JSON serializable dict.
@@ -53,17 +53,15 @@ def resolve_ref_links(obj: Any) -> Any:
if isinstance(obj, dict):
obj.pop("definitions", None)
return obj
- else:
- raise ValueError(f"Expected obj to be a dict. Got {obj}")
- elif isinstance(obj, dict):
+ raise ValueError(f"Expected obj to be a dict. Got {obj}")
+ if isinstance(obj, dict):
return {k: resolve_ref_links(v) for k, v in obj.items()}
- elif isinstance(obj, list):
+ if isinstance(obj, list):
return [resolve_ref_links(item) for item in obj]
- else:
- return obj
+ return obj
-def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> None:
+def _expand_refs(schema: Any, ref_resolver: RefResolver | None = None) -> None: # noqa: ANN401
"""Internal function to iterate over schema and replace all occurrences of $ref with their definitions. Recursive.
:param schema: schema that will be patched
@@ -80,14 +78,14 @@ def _expand_refs(schema: Any, ref_resolver: Optional[RefResolver] = None) -> Non
) # expand refs in definitions as well
schema.update(definition)
else:
- for key, value in schema.items():
+ for key, value in schema.items(): # noqa: B007, PERF102
_expand_refs(value, ref_resolver=ref_resolver)
- elif isinstance(schema, List):
+ elif isinstance(schema, list):
for value in schema:
_expand_refs(value, ref_resolver=ref_resolver)
-def expand_refs(schema: Any) -> None:
+def expand_refs(schema: Any) -> None: # noqa: ANN401
"""Iterate over schema and replace all occurrences of $ref with their definitions.
:param schema: schema that will be patched
@@ -96,7 +94,7 @@ def expand_refs(schema: Any) -> None:
schema.pop("definitions", None) # remove definitions created by $ref
-def rename_key(schema: Any, old_key: str, new_key: str) -> None:
+def rename_key(schema: Any, old_key: str, new_key: str) -> None: # noqa: ANN401
"""Iterate over nested dictionary and replace one key with another. Used to replace anyOf with oneOf. Recursive."
:param schema: schema that will be patched
@@ -106,7 +104,7 @@ def rename_key(schema: Any, old_key: str, new_key: str) -> None:
if not isinstance(schema, MutableMapping):
return
- for key, value in schema.items():
+ for key, value in schema.items(): # noqa: B007, PERF102
rename_key(value, old_key, new_key)
if old_key in schema:
schema[new_key] = schema.pop(old_key)
@@ -115,7 +113,7 @@ def rename_key(schema: Any, old_key: str, new_key: str) -> None:
class ResourceSchemaLoader:
"""JSONSchema loader from package resources"""
- def __init__(self, package_name: str):
+ def __init__(self, package_name: str): # noqa: ANN204
self.package_name = package_name
def get_schema(self, name: str) -> dict[str, Any]:
@@ -134,7 +132,7 @@ def get_schema(self, name: str) -> dict[str, Any]:
schema_filename = f"schemas/{name}.json"
raw_file = pkgutil.get_data(self.package_name, schema_filename)
if not raw_file:
- raise IOError(f"Cannot find file {schema_filename}")
+ raise OSError(f"Cannot find file {schema_filename}")
try:
raw_schema = json.loads(raw_file)
except ValueError as err:
@@ -152,7 +150,7 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An
package = importlib.import_module(self.package_name)
if package.__file__:
- base = os.path.dirname(package.__file__) + "/"
+ base = os.path.dirname(package.__file__) + "/" # noqa: PTH120
else:
raise ValueError(f"Package {package} does not have a valid __file__ field")
resolved = jsonref.JsonRef.replace_refs(
@@ -161,8 +159,7 @@ def _resolve_schema_references(self, raw_schema: dict[str, Any]) -> dict[str, An
resolved = resolve_ref_links(resolved)
if isinstance(resolved, dict):
return resolved
- else:
- raise ValueError(f"Expected resolved to be a dict. Got {resolved}")
+ raise ValueError(f"Expected resolved to be a dict. Got {resolved}")
def check_config_against_spec_or_exit(
@@ -191,7 +188,7 @@ class InternalConfig(BaseModel):
limit: int = Field(None, alias="_limit")
page_size: int = Field(None, alias="_page_size")
- def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
+ def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
kwargs["by_alias"] = True
kwargs["exclude_unset"] = True
return super().dict(*args, **kwargs)
@@ -202,13 +199,13 @@ def is_limit_reached(self, records_counter: int) -> bool:
:param records_counter - number of records already red
:return True if limit reached, False otherwise
"""
- if self.limit:
+ if self.limit: # noqa: SIM102
if records_counter >= self.limit:
return True
return False
-def split_config(config: Mapping[str, Any]) -> Tuple[dict[str, Any], InternalConfig]:
+def split_config(config: Mapping[str, Any]) -> tuple[dict[str, Any], InternalConfig]:
"""
Break config map object into 2 instances: first is a dict with user defined
configuration and second is internal config that contains private keys for
diff --git a/airbyte_cdk/sources/utils/slice_logger.py b/airbyte_cdk/sources/utils/slice_logger.py
index ee802a7a6..f394ef72f 100644
--- a/airbyte_cdk/sources/utils/slice_logger.py
+++ b/airbyte_cdk/sources/utils/slice_logger.py
@@ -5,7 +5,8 @@
import json
import logging
from abc import ABC, abstractmethod
-from typing import Any, Mapping, Optional
+from collections.abc import Mapping
+from typing import Any
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level
from airbyte_cdk.models import Type as MessageType
@@ -19,7 +20,7 @@ class SliceLogger(ABC):
SLICE_LOG_PREFIX = "slice:"
- def create_slice_log_message(self, _slice: Optional[Mapping[str, Any]]) -> AirbyteMessage:
+ def create_slice_log_message(self, _slice: Mapping[str, Any] | None) -> AirbyteMessage:
"""
Mapping is an interface that can be implemented in various ways. However, json.dumps will just do a `str(