diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index a619a05a00..2178a20e94 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -100,6 +100,7 @@ def _create_request( path_or_url: str, method: HTTPMethod, params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, hooks: Optional[Hooks] = None, @@ -110,10 +111,12 @@ def _create_request( else: url = join_url(self.base_url, path_or_url) + request_headers = (self.headers or {}) | (headers or {}) + return Request( method=method, url=url, - headers=self.headers, + headers=request_headers, params=params, json=json, auth=auth or self.auth, @@ -144,6 +147,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> path_or_url=path, method=method, params=kwargs.pop("params", None), + headers=kwargs.pop("headers", None), json=kwargs.pop("json", None), auth=kwargs.pop("auth", None), hooks=kwargs.pop("hooks", None), @@ -161,6 +165,7 @@ def paginate( path: str = "", method: HTTPMethodBasic = "GET", params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, paginator: Optional[BasePaginator] = None, @@ -213,7 +218,13 @@ def paginate( hooks["response"] = [raise_for_status] request = self._create_request( - path_or_url=path, method=method, params=params, json=json, auth=auth, hooks=hooks + path_or_url=path, + headers=headers, + method=method, + params=params, + json=json, + auth=auth, + hooks=hooks, ) if paginator: diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index 3bf2a1b3d2..824d3878ab 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -1,4 +1,5 @@ """Generic API Source""" + from copy import deepcopy from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union import graphlib @@ -70,7 +71,11 @@ def rest_api( ) -> List[DltResource]: """Creates and configures a REST API source with default settings""" return rest_api_resources( - {"client": client, "resources": resources, "resource_defaults": resource_defaults} + { + "client": client, + "resources": resources, + "resource_defaults": resource_defaults, + } ) @@ -242,6 +247,7 @@ def create_resources( endpoint_config = cast(Endpoint, endpoint_resource["endpoint"]) request_params = endpoint_config.get("params", {}) request_json = endpoint_config.get("json", None) + request_headers = endpoint_config.get("headers") paginator = create_paginator(endpoint_config.get("paginator")) processing_steps = endpoint_resource.pop("processing_steps", []) @@ -288,6 +294,7 @@ def process( def paginate_resource( method: HTTPMethodBasic, path: str, + headers: Dict[str, Any], params: Dict[str, Any], json: Optional[Dict[str, Any]], paginator: Optional[BasePaginator], @@ -323,6 +330,7 @@ def paginate_resource( yield from client.paginate( method=method, path=path, + headers=headers, params=params, json=json, paginator=paginator, @@ -336,6 +344,7 @@ def paginate_resource( )( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), + headers=request_headers, params=request_params, json=request_json, paginator=paginator, @@ -355,6 +364,7 @@ def paginate_dependent_resource( items: List[Dict[str, Any]], method: HTTPMethodBasic, path: str, + request_headers: Optional[Dict[str, Any]], params: Dict[str, Any], json: Optional[Dict[str, Any]], paginator: Optional[BasePaginator], @@ -378,23 +388,29 @@ def paginate_dependent_resource( ) for item in items: - formatted_path, expanded_params, updated_json, parent_record = ( - process_parent_data_item( - path=path, - item=item, - params=params, - request_json=json, - resolved_params=resolved_params, - include_from_parent=include_from_parent, - incremental=incremental_object, - incremental_value_convert=incremental_cursor_transform, - ) + ( + formatted_path, + expanded_params, + updated_json, + updated_headers, + parent_record, + ) = process_parent_data_item( + path=path, + item=item, + params=params, + request_headers=request_headers, + request_json=json, + resolved_params=resolved_params, + include_from_parent=include_from_parent, + incremental=incremental_object, + incremental_value_convert=incremental_cursor_transform, ) for child_page in client.paginate( method=method, path=formatted_path, params=expanded_params, + headers=updated_headers, json=updated_json, paginator=paginator, data_selector=data_selector, @@ -413,6 +429,7 @@ def paginate_dependent_resource( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), params=base_params, + request_headers=request_headers, json=request_json, paginator=paginator, data_selector=endpoint_config.get("data_selector"), @@ -456,7 +473,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig: has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS) if ( isinstance( - auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials) + auth_config, + (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials), ) or has_sensitive_key ): @@ -503,7 +521,7 @@ def identity_func(x: Any) -> Any: def _validate_param_type( - request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]] + request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]], ) -> None: for _, value in request_params.items(): if isinstance(value, dict) and value.get("type") not in PARAM_TYPES: diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 50e1b85bd8..9f5702cb3f 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -264,7 +264,7 @@ def setup_incremental_object( def parse_convert_or_deprecated_transform( - config: Union[IncrementalConfig, Dict[str, Any]] + config: Union[IncrementalConfig, Dict[str, Any]], ) -> Optional[Callable[..., Any]]: convert = config.get("convert", None) deprecated_transform = config.get("transform", None) @@ -317,15 +317,20 @@ def build_resource_dependency_graph( endpoint_resource["endpoint"]["path"], available_contexts ) - # Find all expressions in params and json, but error if any of them is not in available_contexts + # Find all expressions in params, json, or header, but error if any of them is not in available_contexts params_expressions = _find_expressions(endpoint_resource["endpoint"].get("params", {})) _raise_if_any_not_in(params_expressions, available_contexts, message="params") json_expressions = _find_expressions(endpoint_resource["endpoint"].get("json", {})) _raise_if_any_not_in(json_expressions, available_contexts, message="json") + headers_expressions = _find_expressions(endpoint_resource["endpoint"].get("headers", {})) + _raise_if_any_not_in(headers_expressions, available_contexts, message="headers") + resolved_params += _expressions_to_resolved_params( - _filter_resource_expressions(path_expressions | params_expressions | json_expressions) + _filter_resource_expressions( + path_expressions | params_expressions | json_expressions | headers_expressions + ) ) # set of resources in resolved params @@ -723,11 +728,12 @@ def process_parent_data_item( item: Dict[str, Any], resolved_params: List[ResolvedParam], params: Optional[Dict[str, Any]] = None, + request_headers: Optional[Dict[str, Any]] = None, request_json: Optional[Dict[str, Any]] = None, include_from_parent: Optional[List[str]] = None, incremental: Optional[Incremental[Any]] = None, incremental_value_convert: Optional[Callable[..., Any]] = None, -) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: params_values = collect_resolved_values( item, resolved_params, incremental, incremental_value_convert ) @@ -737,10 +743,20 @@ def process_parent_data_item( None if request_json is None else expand_placeholders(request_json, params_values) ) + expanded_headers = ( + None if request_headers is None else expand_placeholders(request_headers, params_values) + ) + parent_resource_name = resolved_params[0].resolve_config["resource"] parent_record = build_parent_record(item, parent_resource_name, include_from_parent) - return expanded_path, expanded_params, expanded_json, parent_record + return ( + expanded_path, + expanded_params, + expanded_json, + expanded_headers, + parent_record, + ) def convert_incremental_values( @@ -819,7 +835,9 @@ def expand_placeholders(obj: Any, placeholders: Dict[str, Any]) -> Any: def build_parent_record( - item: Dict[str, Any], parent_resource_name: str, include_from_parent: Optional[List[str]] + item: Dict[str, Any], + parent_resource_name: str, + include_from_parent: Optional[List[str]], ) -> Dict[str, Any]: """ Builds a dictionary of the `include_from_parent` fields from the parent, diff --git a/dlt/sources/rest_api/typing.py b/dlt/sources/rest_api/typing.py index 2d3b23c737..076f338e8d 100644 --- a/dlt/sources/rest_api/typing.py +++ b/dlt/sources/rest_api/typing.py @@ -249,6 +249,7 @@ class Endpoint(TypedDict, total=False): response_actions: Optional[List[ResponseAction]] incremental: Optional[IncrementalConfig] auth: Optional[AuthConfig] + headers: Optional[Dict[str, Any]] class ProcessingSteps(TypedDict): diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index 3af3196c5e..a37316ee25 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -92,11 +92,13 @@ def test_process_parent_data_item() -> None: ResolvedParam("id", {"field": "obj_id", "resource": "issues", "type": "resolve"}) ] - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{id}/comments", - item={"obj_id": 12345}, - resolved_params=resolved_params, - include_from_parent=None, + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345}, + resolved_params=resolved_params, + include_from_parent=None, + ) ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" assert expanded_params == {} # defaults to empty dict @@ -104,66 +106,79 @@ def test_process_parent_data_item() -> None: assert parent_record == {} # same but with empty params and json - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{id}/comments", - item={"obj_id": 12345}, - params={}, - request_json={}, - resolved_params=resolved_params, + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345}, + params={}, + request_json={}, + resolved_params=resolved_params, + ) ) # those got propagated assert expanded_params == {} assert request_json == {} # generates empty body! # also test params and json - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/comments", - item={"obj_id": 12345}, - params={"orig_id": "{id}"}, - request_json={"orig_id": "{id}"}, - resolved_params=resolved_params, + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/comments", + item={"obj_id": 12345}, + params={"orig_id": "{id}"}, + request_json={"orig_id": "{id}"}, + resolved_params=resolved_params, + ) ) assert expanded_params == {"orig_id": "12345"} assert request_json == {"orig_id": "12345"} - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{id}/comments", - item={"obj_id": 12345}, - resolved_params=resolved_params, - include_from_parent=["obj_id"], + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345}, + resolved_params=resolved_params, + include_from_parent=["obj_id"], + ) ) assert parent_record == {"_issues_obj_id": 12345} - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{id}/comments", - item={"obj_id": 12345, "obj_node": "node_1"}, - resolved_params=resolved_params, - include_from_parent=["obj_id", "obj_node"], + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345, "obj_node": "node_1"}, + resolved_params=resolved_params, + include_from_parent=["obj_id", "obj_node"], + ) ) assert parent_record == {"_issues_obj_id": 12345, "_issues_obj_node": "node_1"} # Test resource field reference in path resolved_params_reference = [ ResolvedParam( - "resources.issues.obj_id", {"field": "obj_id", "resource": "issues", "type": "resolve"} + "resources.issues.obj_id", + {"field": "obj_id", "resource": "issues", "type": "resolve"}, ) ] - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{resources.issues.obj_id}/comments", - item={"obj_id": 12345, "obj_node": "node_1"}, - resolved_params=resolved_params_reference, - include_from_parent=["obj_id", "obj_node"], + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{resources.issues.obj_id}/comments", + item={"obj_id": 12345, "obj_node": "node_1"}, + resolved_params=resolved_params_reference, + include_from_parent=["obj_id", "obj_node"], + ) ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" # Test resource field reference in params - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/comments", - item={"obj_id": 12345, "obj_node": "node_1"}, - params={"id": "{resources.issues.obj_id}"}, - request_json={"id": "{resources.issues.obj_id}"}, - resolved_params=resolved_params_reference, - include_from_parent=["obj_id", "obj_node"], + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/comments", + item={"obj_id": 12345, "obj_node": "node_1"}, + params={"id": "{resources.issues.obj_id}"}, + request_json={"id": "{resources.issues.obj_id}"}, + resolved_params=resolved_params_reference, + include_from_parent=["obj_id", "obj_node"], + ) ) assert bound_path == "dlt-hub/dlt/issues/comments" assert expanded_params == {"id": "12345"} @@ -172,16 +187,19 @@ def test_process_parent_data_item() -> None: # Test nested data resolved_param_nested = [ ResolvedParam( - "id", {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"} + "id", + {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"}, ) ] item = {"some_results": {"obj_id": 12345}} - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{id}/comments", - item=item, - params={}, - resolved_params=resolved_param_nested, - include_from_parent=None, + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item=item, + params={}, + resolved_params=resolved_param_nested, + include_from_parent=None, + ) ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" @@ -217,12 +235,14 @@ def test_process_parent_data_item() -> None: ResolvedParam("id", {"field": "id", "resource": "comments", "type": "resolve"}), ] - bound_path, expanded_params, request_json, parent_record = process_parent_data_item( - path="dlt-hub/dlt/issues/{issue_id}/comments/{id}", - item={"issue": 12345, "id": 56789}, - params={}, - resolved_params=multi_resolve_params, - include_from_parent=None, + bound_path, expanded_params, request_json, request_headers, parent_record = ( + process_parent_data_item( + path="dlt-hub/dlt/issues/{issue_id}/comments/{id}", + item={"issue": 12345, "id": 56789}, + params={}, + resolved_params=multi_resolve_params, + include_from_parent=None, + ) ) assert bound_path == "dlt-hub/dlt/issues/12345/comments/56789" assert parent_record == {} diff --git a/tests/sources/rest_api/conftest.py b/tests/sources/rest_api/conftest.py index eb89a58a4a..8c9083580e 100644 --- a/tests/sources/rest_api/conftest.py +++ b/tests/sources/rest_api/conftest.py @@ -103,7 +103,7 @@ def posts_header_cursor(request, context): response = paginator.page_records if paginator.next_page_url_params: - context.headers["cursor"] = f"{page_number+1}" + context.headers["cursor"] = f"{page_number + 1}" return response @@ -144,6 +144,18 @@ def post_comments_via_query_param(request, context): post_id = int(request.qs.get("post_id", [0])[0]) return paginate_by_page_number(request, generate_comments(post_id)) + @router.get(r"/post_comments_via_headers(\?.*)?$") + def post_comments_via_header(request, context): + # raise ValueError(f"{request.headers}") + post_id = int(request.headers.get("post_id")) + return paginate_by_page_number(request, generate_comments(post_id)) + + @router.post(r"/post_comments(\?.*)?$") + def post_comments_via_json_param(request, context): + body = request.json() + post_id = int(body.get("post_id", 0)) + return paginate_by_page_number(request, generate_comments(post_id)) + @router.get(r"/posts/\d+$") def post_detail(request, context): post_id = request.url.split("/")[-1] @@ -154,6 +166,18 @@ def post_detail_via_query_param(request, context): post_id = int(request.qs.get("post_id", [0])[0]) return {"id": int(post_id), "body": f"Post body {post_id}"} + @router.get(r"/post_detail_via_headers(\?.*)?$") + def post_detail_via_header(request, context): + # raise ValueError(f"{request.headers}") + post_id = int(request.headers.get("post_id")) + return {"id": int(post_id), "body": f"Post body {post_id}"} + + @router.post(r"/post_detail(\?.*)?$") + def post_detail_via_json_param(request, context): + body = request.json() + post_id = int(body.get("post_id", 0)) + return {"id": int(post_id), "body": f"Post body {post_id}"} + @router.get(r"/posts/\d+/some_details_404") def post_detail_404(request, context): """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" diff --git a/tests/sources/rest_api/integration/test_offline.py b/tests/sources/rest_api/integration/test_offline.py index bfaf91d9eb..fec11ed66a 100644 --- a/tests/sources/rest_api/integration/test_offline.py +++ b/tests/sources/rest_api/integration/test_offline.py @@ -105,6 +105,64 @@ }, id="interpolated_params_in_query_string", ), + # Using interpolated params in json object + pytest.param( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_comments", + "endpoint": { + "path": "post_comments", + "method": "POST", + "json": { + "post_id": "{resources.posts.id}", + }, + }, + }, + { + "name": "post_details", + "endpoint": { + "path": "post_detail", + "method": "POST", + "json": { + "post_id": "{resources.posts.id}", + }, + }, + }, + ], + }, + id="interpolated_params_in_json_string", + ), + # Using interpolated params in the headers + pytest.param( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_comments", + "endpoint": { + "path": "post_comments_via_headers", + "headers": { + "post_id": "{resources.posts.id}", + }, + }, + }, + { + "name": "post_details", + "endpoint": { + "path": "post_detail_via_headers", + "headers": { + "post_id": "{resources.posts.id}", + }, + }, + }, + ], + }, + id="interpolated_params_in_the_headers", + ), ], ) def test_load_mock_api(mock_api_server, config): @@ -471,7 +529,11 @@ def test_dependent_resource_query_string_params( pytest.param( { "path": "post_detail", - "params": {"post_id": "{resources.posts.id}", "sort": "desc", "locale": ""}, + "params": { + "post_id": "{resources.posts.id}", + "sort": "desc", + "locale": "", + }, }, {"sort": ["desc"], "locale": [""]}, id="one_static_param_is_empty", @@ -657,7 +719,11 @@ def update_request(self, request: Request) -> None: "endpoint": { "path": "/posts/search", "method": "POST", - "json": {"ids_greater_than": 50, "page_size": 25, "page_count": 4}, + "json": { + "ids_greater_than": 50, + "page_size": 25, + "page_count": 4, + }, "paginator": JSONBodyPageCursorPaginator(), }, }