diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 75ec52f..b44b287 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.3.0" + ".": "2.4.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index c0cf1c6..7f8b2d8 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 6 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/channel3%2Fpublic-sdk-a86114ce098255360b65356eedfe9c93f9db44aa99cb90d8c36756d39c2c2de0.yml -openapi_spec_hash: 113158785b160e8b67d66e2820137df8 +configured_endpoints: 5 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/channel3%2Fpublic-sdk-3366fbfe5ea0c833c184c33d00d301e23e23c0cfa7398b0ebc34a90ab03f65fd.yml +openapi_spec_hash: e428021f51d697d779a5ddd3ee7109b7 config_hash: 0ec132fef7cbcef12aebece85f2ef2b1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 683c5ae..388dfc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## 2.4.0 (2025-11-10) + +Full Changelog: [v2.3.0...v2.4.0](https://github.com/channel3-ai/sdk-python/compare/v2.3.0...v2.4.0) + +### Features + +* **api:** api update ([f50613e](https://github.com/channel3-ai/sdk-python/commit/f50613e9cf5d2067b282896dfafcb73ef99ee0ae)) + + +### Bug Fixes + +* **client:** close streams without requiring full consumption ([4d44bb1](https://github.com/channel3-ai/sdk-python/commit/4d44bb1d045d7d2447bc305dd8463e27dd170175)) + + +### Chores + +* **internal/tests:** avoid race condition with implicit client cleanup ([ef7fe91](https://github.com/channel3-ai/sdk-python/commit/ef7fe91e5e5c14e15172642362aee07b96f19b3d)) +* **internal:** grammar fix (it's -> its) ([ae7ab14](https://github.com/channel3-ai/sdk-python/commit/ae7ab1494bc21b8b5aeefb91af3353b696949bcb)) + ## 2.3.0 (2025-10-28) Full Changelog: [v2.2.1...v2.3.0](https://github.com/channel3-ai/sdk-python/compare/v2.2.1...v2.3.0) diff --git a/api.md b/api.md index 046c050..4359e79 100644 --- a/api.md +++ b/api.md @@ -33,13 +33,12 @@ Methods: Types: ```python -from channel3_sdk.types import Brand, BrandListResponse +from channel3_sdk.types import Brand ``` Methods: -- client.brands.retrieve(brand_id) -> Brand -- client.brands.list(\*\*params) -> BrandListResponse +- client.brands.list(\*\*params) -> Brand # Enrich diff --git a/pyproject.toml b/pyproject.toml index 1ebbba5..3961e1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "channel3_sdk" -version = "2.3.0" +version = "2.4.0" description = "The official Python library for the channel3 API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/channel3_sdk/_streaming.py b/src/channel3_sdk/_streaming.py index eefa638..e5f4e09 100644 --- a/src/channel3_sdk/_streaming.py +++ b/src/channel3_sdk/_streaming.py @@ -57,9 +57,8 @@ def __stream__(self) -> Iterator[_T]: for sse in iterator: yield process_data(data=sse.json(), cast_to=cast_to, response=response) - # Ensure the entire stream is consumed - for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + response.close() def __enter__(self) -> Self: return self @@ -121,9 +120,8 @@ async def __stream__(self) -> AsyncIterator[_T]: async for sse in iterator: yield process_data(data=sse.json(), cast_to=cast_to, response=response) - # Ensure the entire stream is consumed - async for _sse in iterator: - ... + # As we might not fully consume the response stream, we need to close it explicitly + await response.aclose() async def __aenter__(self) -> Self: return self diff --git a/src/channel3_sdk/_utils/_utils.py b/src/channel3_sdk/_utils/_utils.py index 50d5926..eec7f4a 100644 --- a/src/channel3_sdk/_utils/_utils.py +++ b/src/channel3_sdk/_utils/_utils.py @@ -133,7 +133,7 @@ def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: # Type safe methods for narrowing types with TypeVars. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], # however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. +# care about the contained types we can safely use `object` in its place. # # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. # `is_*` is for when you're dealing with an unknown input diff --git a/src/channel3_sdk/_version.py b/src/channel3_sdk/_version.py index 1fac827..66f444f 100644 --- a/src/channel3_sdk/_version.py +++ b/src/channel3_sdk/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "channel3_sdk" -__version__ = "2.3.0" # x-release-please-version +__version__ = "2.4.0" # x-release-please-version diff --git a/src/channel3_sdk/resources/brands.py b/src/channel3_sdk/resources/brands.py index 306db20..57aa10f 100644 --- a/src/channel3_sdk/resources/brands.py +++ b/src/channel3_sdk/resources/brands.py @@ -2,12 +2,10 @@ from __future__ import annotations -from typing import Optional - import httpx from ..types import brand_list_params -from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given +from .._types import Body, Query, Headers, NotGiven, not_given from .._utils import maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -19,7 +17,6 @@ ) from ..types.brand import Brand from .._base_client import make_request_options -from ..types.brand_list_response import BrandListResponse __all__ = ["BrandsResource", "AsyncBrandsResource"] @@ -44,54 +41,19 @@ def with_streaming_response(self) -> BrandsResourceWithStreamingResponse: """ return BrandsResourceWithStreamingResponse(self) - def retrieve( - self, - brand_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> Brand: - """ - Get detailed information for a specific brand by its ID. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not brand_id: - raise ValueError(f"Expected a non-empty value for `brand_id` but received {brand_id!r}") - return self._get( - f"/v0/brands/{brand_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Brand, - ) - def list( self, *, - page: int | Omit = omit, - query: Optional[str] | Omit = omit, - size: int | Omit = omit, + query: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> BrandListResponse: + ) -> Brand: """ - Get all brands that the vendor currently sells. + Find a brand by name. Args: extra_headers: Send extra headers @@ -109,16 +71,9 @@ def list( extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=maybe_transform( - { - "page": page, - "query": query, - "size": size, - }, - brand_list_params.BrandListParams, - ), + query=maybe_transform({"query": query}, brand_list_params.BrandListParams), ), - cast_to=BrandListResponse, + cast_to=Brand, ) @@ -142,54 +97,19 @@ def with_streaming_response(self) -> AsyncBrandsResourceWithStreamingResponse: """ return AsyncBrandsResourceWithStreamingResponse(self) - async def retrieve( - self, - brand_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> Brand: - """ - Get detailed information for a specific brand by its ID. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not brand_id: - raise ValueError(f"Expected a non-empty value for `brand_id` but received {brand_id!r}") - return await self._get( - f"/v0/brands/{brand_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Brand, - ) - async def list( self, *, - page: int | Omit = omit, - query: Optional[str] | Omit = omit, - size: int | Omit = omit, + query: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, - ) -> BrandListResponse: + ) -> Brand: """ - Get all brands that the vendor currently sells. + Find a brand by name. Args: extra_headers: Send extra headers @@ -207,16 +127,9 @@ async def list( extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=await async_maybe_transform( - { - "page": page, - "query": query, - "size": size, - }, - brand_list_params.BrandListParams, - ), + query=await async_maybe_transform({"query": query}, brand_list_params.BrandListParams), ), - cast_to=BrandListResponse, + cast_to=Brand, ) @@ -224,9 +137,6 @@ class BrandsResourceWithRawResponse: def __init__(self, brands: BrandsResource) -> None: self._brands = brands - self.retrieve = to_raw_response_wrapper( - brands.retrieve, - ) self.list = to_raw_response_wrapper( brands.list, ) @@ -236,9 +146,6 @@ class AsyncBrandsResourceWithRawResponse: def __init__(self, brands: AsyncBrandsResource) -> None: self._brands = brands - self.retrieve = async_to_raw_response_wrapper( - brands.retrieve, - ) self.list = async_to_raw_response_wrapper( brands.list, ) @@ -248,9 +155,6 @@ class BrandsResourceWithStreamingResponse: def __init__(self, brands: BrandsResource) -> None: self._brands = brands - self.retrieve = to_streamed_response_wrapper( - brands.retrieve, - ) self.list = to_streamed_response_wrapper( brands.list, ) @@ -260,9 +164,6 @@ class AsyncBrandsResourceWithStreamingResponse: def __init__(self, brands: AsyncBrandsResource) -> None: self._brands = brands - self.retrieve = async_to_streamed_response_wrapper( - brands.retrieve, - ) self.list = async_to_streamed_response_wrapper( brands.list, ) diff --git a/src/channel3_sdk/resources/enrich.py b/src/channel3_sdk/resources/enrich.py index 518cf9a..c8c46af 100644 --- a/src/channel3_sdk/resources/enrich.py +++ b/src/channel3_sdk/resources/enrich.py @@ -53,7 +53,8 @@ def enrich_url( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> EnrichEnrichURLResponse: """ - Enrich a product URL with additional information. + Search by product URL, get back full product information from Channel3’s product + database. Args: url: The URL of the product to enrich @@ -108,7 +109,8 @@ async def enrich_url( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> EnrichEnrichURLResponse: """ - Enrich a product URL with additional information. + Search by product URL, get back full product information from Channel3’s product + database. Args: url: The URL of the product to enrich diff --git a/src/channel3_sdk/resources/search.py b/src/channel3_sdk/resources/search.py index ca0bf09..8d2f327 100644 --- a/src/channel3_sdk/resources/search.py +++ b/src/channel3_sdk/resources/search.py @@ -70,7 +70,8 @@ def perform( context: Optional customer information to personalize search results - filters: Optional filters + filters: Optional filters. Search will only consider products that match all of the + filters. image_url: Image URL @@ -154,7 +155,8 @@ async def perform( context: Optional customer information to personalize search results - filters: Optional filters + filters: Optional filters. Search will only consider products that match all of the + filters. image_url: Image URL diff --git a/src/channel3_sdk/types/__init__.py b/src/channel3_sdk/types/__init__.py index cf95e97..a520e33 100644 --- a/src/channel3_sdk/types/__init__.py +++ b/src/channel3_sdk/types/__init__.py @@ -7,7 +7,6 @@ from .variant import Variant as Variant from .brand_list_params import BrandListParams as BrandListParams from .availability_status import AvailabilityStatus as AvailabilityStatus -from .brand_list_response import BrandListResponse as BrandListResponse from .search_perform_params import SearchPerformParams as SearchPerformParams from .search_perform_response import SearchPerformResponse as SearchPerformResponse from .enrich_enrich_url_params import EnrichEnrichURLParams as EnrichEnrichURLParams diff --git a/src/channel3_sdk/types/brand.py b/src/channel3_sdk/types/brand.py index ac624c5..b2c1a33 100644 --- a/src/channel3_sdk/types/brand.py +++ b/src/channel3_sdk/types/brand.py @@ -12,6 +12,9 @@ class Brand(BaseModel): name: str + best_commission_rate: Optional[float] = None + """The maximum commission rate for the brand, as a percentage""" + description: Optional[str] = None logo_url: Optional[str] = None diff --git a/src/channel3_sdk/types/brand_list_params.py b/src/channel3_sdk/types/brand_list_params.py index fdbb5fe..9475030 100644 --- a/src/channel3_sdk/types/brand_list_params.py +++ b/src/channel3_sdk/types/brand_list_params.py @@ -2,15 +2,10 @@ from __future__ import annotations -from typing import Optional -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict __all__ = ["BrandListParams"] class BrandListParams(TypedDict, total=False): - page: int - - query: Optional[str] - - size: int + query: Required[str] diff --git a/src/channel3_sdk/types/brand_list_response.py b/src/channel3_sdk/types/brand_list_response.py deleted file mode 100644 index 8caf935..0000000 --- a/src/channel3_sdk/types/brand_list_response.py +++ /dev/null @@ -1,25 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import List - -from .brand import Brand -from .._models import BaseModel - -__all__ = ["BrandListResponse", "Pagination"] - - -class Pagination(BaseModel): - current_page: int - - page_size: int - - total_count: int - - total_pages: int - - -class BrandListResponse(BaseModel): - items: List[Brand] - - pagination: Pagination - """Pagination metadata for responses""" diff --git a/src/channel3_sdk/types/product_retrieve_response.py b/src/channel3_sdk/types/product_retrieve_response.py index 393efa8..79f5244 100644 --- a/src/channel3_sdk/types/product_retrieve_response.py +++ b/src/channel3_sdk/types/product_retrieve_response.py @@ -26,6 +26,8 @@ class ProductRetrieveResponse(BaseModel): brand_name: Optional[str] = None + categories: Optional[List[str]] = None + description: Optional[str] = None gender: Optional[Literal["male", "female", "unisex"]] = None diff --git a/src/channel3_sdk/types/search_perform_params.py b/src/channel3_sdk/types/search_perform_params.py index 32794ad..fe34c78 100644 --- a/src/channel3_sdk/types/search_perform_params.py +++ b/src/channel3_sdk/types/search_perform_params.py @@ -22,7 +22,10 @@ class SearchPerformParams(TypedDict, total=False): """Optional customer information to personalize search results""" filters: Filters - """Optional filters""" + """Optional filters. + + Search will only consider products that match all of the filters. + """ image_url: Optional[str] """Image URL""" @@ -36,6 +39,13 @@ class SearchPerformParams(TypedDict, total=False): class Config(TypedDict, total=False): enrich_query: bool + """ + If True, search will use AI to enrich the query, for example pulling the gender, + brand, and price range from the query. + """ + + monetizable_only: bool + """If True, search will only consider products that offer commission.""" redirect_mode: Optional[Literal["brand", "price", "commission"]] """ @@ -44,8 +54,6 @@ class Config(TypedDict, total=False): to the brand's product page """ - semantic_search: bool - class FiltersPrice(TypedDict, total=False): max_price: Optional[float] @@ -57,15 +65,28 @@ class FiltersPrice(TypedDict, total=False): class Filters(TypedDict, total=False): availability: Optional[List[AvailabilityStatus]] - """List of availability statuses""" + """If provided, only products with these availability statuses will be returned""" brand_ids: Optional[SequenceNotStr[str]] - """List of brand IDs""" + """If provided, only products from these brands will be returned""" + + category_ids: Optional[SequenceNotStr[str]] + """If provided, only products from these categories will be returned""" + + condition: Optional[Literal["new", "refurbished", "used"]] + """Filter by product condition. + + Incubating: condition data is currently incomplete; products without condition + data will be included in all condition filter results. + """ exclude_product_ids: Optional[SequenceNotStr[str]] - """List of product IDs to exclude""" + """If provided, products with these IDs will be excluded from the results""" gender: Optional[Literal["male", "female", "unisex"]] price: Optional[FiltersPrice] """Price filter. Values are inclusive.""" + + website_ids: Optional[SequenceNotStr[str]] + """If provided, only products from these websites will be returned""" diff --git a/src/channel3_sdk/types/search_perform_response.py b/src/channel3_sdk/types/search_perform_response.py index 45c5620..3e257fe 100644 --- a/src/channel3_sdk/types/search_perform_response.py +++ b/src/channel3_sdk/types/search_perform_response.py @@ -28,6 +28,8 @@ class SearchPerformResponseItem(BaseModel): url: str + categories: Optional[List[str]] = None + description: Optional[str] = None variants: Optional[List[Variant]] = None diff --git a/tests/api_resources/test_brands.py b/tests/api_resources/test_brands.py index 7cc8fb5..bb6b914 100644 --- a/tests/api_resources/test_brands.py +++ b/tests/api_resources/test_brands.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from channel3_sdk import Channel3, AsyncChannel3 -from channel3_sdk.types import Brand, BrandListResponse +from channel3_sdk.types import Brand base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,83 +17,37 @@ class TestBrands: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - def test_method_retrieve(self, client: Channel3) -> None: - brand = client.brands.retrieve( - "brand_id", - ) - assert_matches_type(Brand, brand, path=["response"]) - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - def test_raw_response_retrieve(self, client: Channel3) -> None: - response = client.brands.with_raw_response.retrieve( - "brand_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - brand = response.parse() - assert_matches_type(Brand, brand, path=["response"]) - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - def test_streaming_response_retrieve(self, client: Channel3) -> None: - with client.brands.with_streaming_response.retrieve( - "brand_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - brand = response.parse() - assert_matches_type(Brand, brand, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - def test_path_params_retrieve(self, client: Channel3) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `brand_id` but received ''"): - client.brands.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_list(self, client: Channel3) -> None: - brand = client.brands.list() - assert_matches_type(BrandListResponse, brand, path=["response"]) - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - def test_method_list_with_all_params(self, client: Channel3) -> None: brand = client.brands.list( - page=0, query="query", - size=0, ) - assert_matches_type(BrandListResponse, brand, path=["response"]) + assert_matches_type(Brand, brand, path=["response"]) @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_raw_response_list(self, client: Channel3) -> None: - response = client.brands.with_raw_response.list() + response = client.brands.with_raw_response.list( + query="query", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" brand = response.parse() - assert_matches_type(BrandListResponse, brand, path=["response"]) + assert_matches_type(Brand, brand, path=["response"]) @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_streaming_response_list(self, client: Channel3) -> None: - with client.brands.with_streaming_response.list() as response: + with client.brands.with_streaming_response.list( + query="query", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" brand = response.parse() - assert_matches_type(BrandListResponse, brand, path=["response"]) + assert_matches_type(Brand, brand, path=["response"]) assert cast(Any, response.is_closed) is True @@ -103,82 +57,36 @@ class TestAsyncBrands: "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] ) - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - async def test_method_retrieve(self, async_client: AsyncChannel3) -> None: - brand = await async_client.brands.retrieve( - "brand_id", - ) - assert_matches_type(Brand, brand, path=["response"]) - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncChannel3) -> None: - response = await async_client.brands.with_raw_response.retrieve( - "brand_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - brand = await response.parse() - assert_matches_type(Brand, brand, path=["response"]) - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncChannel3) -> None: - async with async_client.brands.with_streaming_response.retrieve( - "brand_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - brand = await response.parse() - assert_matches_type(Brand, brand, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncChannel3) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `brand_id` but received ''"): - await async_client.brands.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_list(self, async_client: AsyncChannel3) -> None: - brand = await async_client.brands.list() - assert_matches_type(BrandListResponse, brand, path=["response"]) - - @pytest.mark.skip(reason="Prism tests are disabled") - @parametrize - async def test_method_list_with_all_params(self, async_client: AsyncChannel3) -> None: brand = await async_client.brands.list( - page=0, query="query", - size=0, ) - assert_matches_type(BrandListResponse, brand, path=["response"]) + assert_matches_type(Brand, brand, path=["response"]) @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_raw_response_list(self, async_client: AsyncChannel3) -> None: - response = await async_client.brands.with_raw_response.list() + response = await async_client.brands.with_raw_response.list( + query="query", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" brand = await response.parse() - assert_matches_type(BrandListResponse, brand, path=["response"]) + assert_matches_type(Brand, brand, path=["response"]) @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_streaming_response_list(self, async_client: AsyncChannel3) -> None: - async with async_client.brands.with_streaming_response.list() as response: + async with async_client.brands.with_streaming_response.list( + query="query", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" brand = await response.parse() - assert_matches_type(BrandListResponse, brand, path=["response"]) + assert_matches_type(Brand, brand, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_search.py b/tests/api_resources/test_search.py index 7c9d28d..fa4236f 100644 --- a/tests/api_resources/test_search.py +++ b/tests/api_resources/test_search.py @@ -30,19 +30,22 @@ def test_method_perform_with_all_params(self, client: Channel3) -> None: base64_image="base64_image", config={ "enrich_query": True, + "monetizable_only": True, "redirect_mode": "brand", - "semantic_search": True, }, context="context", filters={ "availability": ["InStock"], "brand_ids": ["string"], + "category_ids": ["string"], + "condition": "new", "exclude_product_ids": ["string"], "gender": "male", "price": { "max_price": 0, "min_price": 0, }, + "website_ids": ["string"], }, image_url="image_url", limit=0, @@ -91,19 +94,22 @@ async def test_method_perform_with_all_params(self, async_client: AsyncChannel3) base64_image="base64_image", config={ "enrich_query": True, + "monetizable_only": True, "redirect_mode": "brand", - "semantic_search": True, }, context="context", filters={ "availability": ["InStock"], "brand_ids": ["string"], + "category_ids": ["string"], + "condition": "new", "exclude_product_ids": ["string"], "gender": "male", "price": { "max_price": 0, "min_price": 0, }, + "website_ids": ["string"], }, image_url="image_url", limit=0, diff --git a/tests/test_client.py b/tests/test_client.py index 3227185..400fdc1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -59,51 +59,49 @@ def _get_open_connections(client: Channel3 | AsyncChannel3) -> int: class TestChannel3: - client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - def test_raw_response(self, respx_mock: MockRouter) -> None: + def test_raw_response(self, respx_mock: MockRouter, client: Channel3) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Channel3) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, client: Channel3) -> None: + copied = client.copy() + assert id(copied) != id(client) - copied = self.client.copy(api_key="another My API Key") + copied = client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, client: Channel3) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: client = Channel3( @@ -138,6 +136,7 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() def test_copy_default_query(self) -> None: client = Channel3( @@ -175,13 +174,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + client.close() + + def test_copy_signature(self, client: Channel3) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -192,12 +193,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, client: Channel3) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -254,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + def test_request_timeout(self, client: Channel3) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(100.0) @@ -274,6 +273,8 @@ def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + client.close() + def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: @@ -285,6 +286,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + client.close() + # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: client = Channel3( @@ -295,6 +298,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + client.close() + # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = Channel3( @@ -305,6 +310,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + client.close() + async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: @@ -316,14 +323,14 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = Channel3( + test_client = Channel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = Channel3( + test_client2 = Channel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -332,10 +339,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + test_client.close() + test_client2.close() + def test_validate_headers(self) -> None: client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -364,8 +374,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + client.close() + + def test_request_extra_json(self, client: Channel3) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -376,7 +388,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -387,7 +399,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -398,8 +410,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Channel3) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -409,7 +421,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -420,8 +432,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Channel3) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -434,7 +446,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -448,7 +460,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -491,7 +503,7 @@ def test_multipart_repeating_array(self, client: Channel3) -> None: ] @pytest.mark.respx(base_url=base_url) - def test_basic_union_response(self, respx_mock: MockRouter) -> None: + def test_basic_union_response(self, respx_mock: MockRouter, client: Channel3) -> None: class Model1(BaseModel): name: str @@ -500,12 +512,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + def test_union_response_different_types(self, respx_mock: MockRouter, client: Channel3) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -516,18 +528,18 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Channel3) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -543,7 +555,7 @@ class Model(BaseModel): ) ) - response = self.client.get("/foo", cast_to=Model) + response = client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 @@ -555,6 +567,8 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" + client.close() + def test_base_url_env(self) -> None: with update_env(CHANNEL3_BASE_URL="http://localhost:5000/from/env"): client = Channel3(api_key=api_key, _strict_response_validation=True) @@ -582,6 +596,7 @@ def test_base_url_trailing_slash(self, client: Channel3) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -605,6 +620,7 @@ def test_base_url_no_trailing_slash(self, client: Channel3) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -628,35 +644,36 @@ def test_absolute_request_url(self, client: Channel3) -> None: ), ) assert request.url == "https://myapi.com/foo" + client.close() def test_copied_client_does_not_close_http(self) -> None: - client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied - assert not client.is_closed() + assert not test_client.is_closed() def test_client_context_manager(self) -> None: - client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - with client as c2: - assert c2 is client + test_client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + def test_client_response_validation_error(self, respx_mock: MockRouter, client: Channel3) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - self.client.get("/foo", cast_to=Model) + client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -676,11 +693,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = client.get("/foo", cast_to=Model) + response = non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + strict_client.close() + non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -703,9 +723,9 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: Channel3 + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) @@ -719,7 +739,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien with pytest.raises(APITimeoutError): client.search.with_streaming_response.perform().__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -728,7 +748,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client with pytest.raises(APIStatusError): client.search.with_streaming_response.perform().__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -830,83 +850,77 @@ def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - def test_follow_redirects(self, respx_mock: MockRouter) -> None: + def test_follow_redirects(self, respx_mock: MockRouter, client: Channel3) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Channel3) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - self.client.post( - "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response - ) + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" class TestAsyncChannel3: - client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response(self, respx_mock: MockRouter) -> None: + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, async_client: AsyncChannel3) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) - copied = self.client.copy(api_key="another My API Key") + copied = async_client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert async_client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, async_client: AsyncChannel3) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = async_client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert async_client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(async_client.timeout, httpx.Timeout) - def test_copy_default_headers(self) -> None: + async def test_copy_default_headers(self) -> None: client = AsyncChannel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) @@ -939,8 +953,9 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() - def test_copy_default_query(self) -> None: + async def test_copy_default_query(self) -> None: client = AsyncChannel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} ) @@ -976,13 +991,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + await client.close() + + def test_copy_signature(self, async_client: AsyncChannel3) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + async_client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(async_client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -993,12 +1010,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, async_client: AsyncChannel3) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = async_client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -1055,12 +1072,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - async def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + async def test_request_timeout(self, async_client: AsyncChannel3) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( + request = async_client._build_request( FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) ) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1075,6 +1092,8 @@ async def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + await client.close() + async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: @@ -1086,6 +1105,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + await client.close() + # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: client = AsyncChannel3( @@ -1096,6 +1117,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + await client.close() + # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = AsyncChannel3( @@ -1106,6 +1129,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + await client.close() + def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: @@ -1116,15 +1141,15 @@ def test_invalid_http_client(self) -> None: http_client=cast(Any, http_client), ) - def test_default_headers_option(self) -> None: - client = AsyncChannel3( + async def test_default_headers_option(self) -> None: + test_client = AsyncChannel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncChannel3( + test_client2 = AsyncChannel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -1133,10 +1158,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + await test_client.close() + await test_client2.close() + def test_validate_headers(self) -> None: client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -1147,7 +1175,7 @@ def test_validate_headers(self) -> None: client2 = AsyncChannel3(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 - def test_default_query_option(self) -> None: + async def test_default_query_option(self) -> None: client = AsyncChannel3( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} ) @@ -1165,8 +1193,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + await client.close() + + def test_request_extra_json(self, client: Channel3) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1177,7 +1207,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1188,7 +1218,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1199,8 +1229,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: Channel3) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1210,7 +1240,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1221,8 +1251,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: Channel3) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1235,7 +1265,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1249,7 +1279,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1292,7 +1322,7 @@ def test_multipart_repeating_array(self, async_client: AsyncChannel3) -> None: ] @pytest.mark.respx(base_url=base_url) - async def test_basic_union_response(self, respx_mock: MockRouter) -> None: + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: class Model1(BaseModel): name: str @@ -1301,12 +1331,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - async def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -1317,18 +1347,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncChannel3 + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1344,11 +1376,11 @@ class Model(BaseModel): ) ) - response = await self.client.get("/foo", cast_to=Model) + response = await async_client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 - def test_base_url_setter(self) -> None: + async def test_base_url_setter(self) -> None: client = AsyncChannel3( base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True ) @@ -1358,7 +1390,9 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" - def test_base_url_env(self) -> None: + await client.close() + + async def test_base_url_env(self) -> None: with update_env(CHANNEL3_BASE_URL="http://localhost:5000/from/env"): client = AsyncChannel3(api_key=api_key, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1378,7 +1412,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None: + async def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1387,6 +1421,7 @@ def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1403,7 +1438,7 @@ def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None: + async def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1412,6 +1447,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1428,7 +1464,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncChannel3) -> None: + async def test_absolute_request_url(self, client: AsyncChannel3) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1437,37 +1473,37 @@ def test_absolute_request_url(self, client: AsyncChannel3) -> None: ), ) assert request.url == "https://myapi.com/foo" + await client.close() async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied await asyncio.sleep(0.2) - assert not client.is_closed() + assert not test_client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - async with client as c2: - assert c2 is client + test_client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - await self.client.get("/foo", cast_to=Model) + await async_client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -1478,7 +1514,6 @@ async def test_client_max_retries_validation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str @@ -1490,11 +1525,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = await client.get("/foo", cast_to=Model) + response = await non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + await strict_client.close() + await non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -1517,13 +1555,12 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - @pytest.mark.asyncio - async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncChannel3 + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -1536,7 +1573,7 @@ async def test_retrying_timeout_errors_doesnt_leak( with pytest.raises(APITimeoutError): await async_client.search.with_streaming_response.perform().__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -1547,12 +1584,11 @@ async def test_retrying_status_errors_doesnt_leak( with pytest.raises(APIStatusError): await async_client.search.with_streaming_response.perform().__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, @@ -1584,7 +1620,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_omit_retry_count_header( self, async_client: AsyncChannel3, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1608,7 +1643,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_overwrite_retry_count_header( self, async_client: AsyncChannel3, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1656,26 +1690,26 @@ async def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - await self.client.post( + await async_client.post( "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response )