Skip to content

Commit

Permalink
rebased on 1.8.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Mucio authored and Francesco Mucio committed Mar 10, 2025
1 parent 491465f commit 4ab196e
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 78 deletions.
15 changes: 13 additions & 2 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _create_request(
path_or_url: str,
method: HTTPMethod,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -110,10 +111,12 @@ def _create_request(
else:
url = join_url(self.base_url, path_or_url)

request_headers = (self.headers or {}) | (headers or {})

return Request(
method=method,
url=url,
headers=self.headers,
headers=request_headers,
params=params,
json=json,
auth=auth or self.auth,
Expand Down Expand Up @@ -144,6 +147,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) ->
path_or_url=path,
method=method,
params=kwargs.pop("params", None),
headers=kwargs.pop("headers", None),
json=kwargs.pop("json", None),
auth=kwargs.pop("auth", None),
hooks=kwargs.pop("hooks", None),
Expand All @@ -161,6 +165,7 @@ def paginate(
path: str = "",
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
Expand Down Expand Up @@ -213,7 +218,13 @@ def paginate(
hooks["response"] = [raise_for_status]

request = self._create_request(
path_or_url=path, method=method, params=params, json=json, auth=auth, hooks=hooks
path_or_url=path,
headers=headers,
method=method,
params=params,
json=json,
auth=auth,
hooks=hooks,
)

if paginator:
Expand Down
46 changes: 32 additions & 14 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic API Source"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
import graphlib
Expand Down Expand Up @@ -70,7 +71,11 @@ def rest_api(
) -> List[DltResource]:
"""Creates and configures a REST API source with default settings"""
return rest_api_resources(
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
{
"client": client,
"resources": resources,
"resource_defaults": resource_defaults,
}
)


Expand Down Expand Up @@ -242,6 +247,7 @@ def create_resources(
endpoint_config = cast(Endpoint, endpoint_resource["endpoint"])
request_params = endpoint_config.get("params", {})
request_json = endpoint_config.get("json", None)
request_headers = endpoint_config.get("headers")
paginator = create_paginator(endpoint_config.get("paginator"))
processing_steps = endpoint_resource.pop("processing_steps", [])

Expand Down Expand Up @@ -288,6 +294,7 @@ def process(
def paginate_resource(
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
json: Optional[Dict[str, Any]],
paginator: Optional[BasePaginator],
Expand Down Expand Up @@ -323,6 +330,7 @@ def paginate_resource(
yield from client.paginate(
method=method,
path=path,
headers=headers,
params=params,
json=json,
paginator=paginator,
Expand All @@ -336,6 +344,7 @@ def paginate_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=request_headers,
params=request_params,
json=request_json,
paginator=paginator,
Expand All @@ -355,6 +364,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
request_headers: Optional[Dict[str, Any]],
params: Dict[str, Any],
json: Optional[Dict[str, Any]],
paginator: Optional[BasePaginator],
Expand All @@ -378,23 +388,29 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, expanded_params, updated_json, parent_record = (
process_parent_data_item(
path=path,
item=item,
params=params,
request_json=json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
incremental=incremental_object,
incremental_value_convert=incremental_cursor_transform,
)
(
formatted_path,
expanded_params,
updated_json,
updated_headers,
parent_record,
) = process_parent_data_item(
path=path,
item=item,
params=params,
request_headers=request_headers,
request_json=json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
incremental=incremental_object,
incremental_value_convert=incremental_cursor_transform,
)

for child_page in client.paginate(
method=method,
path=formatted_path,
params=expanded_params,
headers=updated_headers,
json=updated_json,
paginator=paginator,
data_selector=data_selector,
Expand All @@ -413,6 +429,7 @@ def paginate_dependent_resource(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
params=base_params,
request_headers=request_headers,
json=request_json,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
Expand Down Expand Up @@ -456,7 +473,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
if (
isinstance(
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
auth_config,
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
)
or has_sensitive_key
):
Expand Down Expand Up @@ -503,7 +521,7 @@ def identity_func(x: Any) -> Any:


def _validate_param_type(
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
) -> None:
for _, value in request_params.items():
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:
Expand Down
30 changes: 24 additions & 6 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def setup_incremental_object(


def parse_convert_or_deprecated_transform(
config: Union[IncrementalConfig, Dict[str, Any]]
config: Union[IncrementalConfig, Dict[str, Any]],
) -> Optional[Callable[..., Any]]:
convert = config.get("convert", None)
deprecated_transform = config.get("transform", None)
Expand Down Expand Up @@ -317,15 +317,20 @@ def build_resource_dependency_graph(
endpoint_resource["endpoint"]["path"], available_contexts
)

# Find all expressions in params and json, but error if any of them is not in available_contexts
# Find all expressions in params, json, or header, but error if any of them is not in available_contexts
params_expressions = _find_expressions(endpoint_resource["endpoint"].get("params", {}))
_raise_if_any_not_in(params_expressions, available_contexts, message="params")

json_expressions = _find_expressions(endpoint_resource["endpoint"].get("json", {}))
_raise_if_any_not_in(json_expressions, available_contexts, message="json")

headers_expressions = _find_expressions(endpoint_resource["endpoint"].get("headers", {}))
_raise_if_any_not_in(headers_expressions, available_contexts, message="headers")

resolved_params += _expressions_to_resolved_params(
_filter_resource_expressions(path_expressions | params_expressions | json_expressions)
_filter_resource_expressions(
path_expressions | params_expressions | json_expressions | headers_expressions
)
)

# set of resources in resolved params
Expand Down Expand Up @@ -723,11 +728,12 @@ def process_parent_data_item(
item: Dict[str, Any],
resolved_params: List[ResolvedParam],
params: Optional[Dict[str, Any]] = None,
request_headers: Optional[Dict[str, Any]] = None,
request_json: Optional[Dict[str, Any]] = None,
include_from_parent: Optional[List[str]] = None,
incremental: Optional[Incremental[Any]] = None,
incremental_value_convert: Optional[Callable[..., Any]] = None,
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
params_values = collect_resolved_values(
item, resolved_params, incremental, incremental_value_convert
)
Expand All @@ -737,10 +743,20 @@ def process_parent_data_item(
None if request_json is None else expand_placeholders(request_json, params_values)
)

expanded_headers = (
None if request_headers is None else expand_placeholders(request_headers, params_values)
)

parent_resource_name = resolved_params[0].resolve_config["resource"]
parent_record = build_parent_record(item, parent_resource_name, include_from_parent)

return expanded_path, expanded_params, expanded_json, parent_record
return (
expanded_path,
expanded_params,
expanded_json,
expanded_headers,
parent_record,
)


def convert_incremental_values(
Expand Down Expand Up @@ -819,7 +835,9 @@ def expand_placeholders(obj: Any, placeholders: Dict[str, Any]) -> Any:


def build_parent_record(
item: Dict[str, Any], parent_resource_name: str, include_from_parent: Optional[List[str]]
item: Dict[str, Any],
parent_resource_name: str,
include_from_parent: Optional[List[str]],
) -> Dict[str, Any]:
"""
Builds a dictionary of the `include_from_parent` fields from the parent,
Expand Down
1 change: 1 addition & 0 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class Endpoint(TypedDict, total=False):
response_actions: Optional[List[ResponseAction]]
incremental: Optional[IncrementalConfig]
auth: Optional[AuthConfig]
headers: Optional[Dict[str, Any]]


class ProcessingSteps(TypedDict):
Expand Down
Loading

0 comments on commit 4ab196e

Please sign in to comment.