Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable resolve for query and json parameters #2208

Open
wants to merge 9 commits into
base: devel
Choose a base branch
from
48 changes: 37 additions & 11 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic API Source"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
import graphlib
Expand Down Expand Up @@ -68,7 +69,11 @@ def rest_api(
) -> List[DltResource]:
"""Creates and configures a REST API source with default settings"""
return rest_api_resources(
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
{
"client": client,
"resources": resources,
"resource_defaults": resource_defaults,
}
)


Expand Down Expand Up @@ -250,7 +255,9 @@ def create_resources(

resolved_params: List[ResolvedParam] = resolved_param_map[resource_name]

include_from_parent: List[str] = endpoint_resource.get("include_from_parent", [])
include_from_parent: List[str] = endpoint_resource.get(
"include_from_parent", []
)
if not resolved_params and include_from_parent:
raise ValueError(
f"Resource {resource_name} has include_from_parent but is not "
Expand All @@ -273,7 +280,9 @@ def create_resources(

hooks = create_response_hooks(endpoint_config.get("response_actions"))

resource_kwargs = exclude_keys(endpoint_resource, {"endpoint", "include_from_parent"})
resource_kwargs = exclude_keys(
endpoint_resource, {"endpoint", "include_from_parent"}
)

def process(
resource: DltResource,
Expand Down Expand Up @@ -334,18 +343,23 @@ def paginate_resource(
hooks=hooks,
)

resources[resource_name] = process(resources[resource_name], processing_steps)
resources[resource_name] = process(
resources[resource_name], processing_steps
)

else:
first_param = resolved_params[0]
predecessor = resources[first_param.resolve_config["resource"]]

base_params = exclude_keys(request_params, {x.param_name for x in resolved_params})
base_params = exclude_keys(
request_params, {x.param_name for x in resolved_params}
)

def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
request_json: Optional[Dict[str, Any]],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,14 +382,22 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, parent_record, updated_params, updated_json = (
process_parent_data_item(
path=path,
item=item,
params=params,
request_json=request_json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
)
)

for child_page in client.paginate(
method=method,
path=formatted_path,
params=params,
params=updated_params,
json=updated_json,
paginator=paginator,
data_selector=data_selector,
hooks=hooks,
Expand All @@ -393,12 +415,15 @@ def paginate_dependent_resource(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
params=base_params,
request_json=request_json,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
hooks=hooks,
)

resources[resource_name] = process(resources[resource_name], processing_steps)
resources[resource_name] = process(
resources[resource_name], processing_steps
)

return resources

Expand Down Expand Up @@ -435,7 +460,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
if (
isinstance(
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
auth_config,
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
)
or has_sensitive_key
):
Expand Down Expand Up @@ -481,7 +507,7 @@ def identity_func(x: Any) -> Any:


def _validate_param_type(
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
) -> None:
for _, value in request_params.items():
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:
Expand Down
Loading
Loading