Skip to content

Commit 8a2cc9f

Browse files
feat(api): update pagination configs (#25)
1 parent d31ada3 commit 8a2cc9f

File tree

8 files changed

+90
-58
lines changed

8 files changed

+90
-58
lines changed

.stats.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 27
22
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-b45f922f6a041550870a96f5acec02aa6d8830046fc98b95a275c6486f7586fc.yml
33
openapi_spec_hash: ef7fddfb49b4d9c440b0635d2c86f341
4-
config_hash: a97b3049608e3cfca813a523902d499b
4+
config_hash: 93f687135e6d45a0f0f83fbfdcb1d8c9

api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Methods:
2222
- <code title="post /deployments">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">create</a>(\*\*<a href="src/replicate/types/deployment_create_params.py">params</a>) -> <a href="./src/replicate/types/deployment_create_response.py">DeploymentCreateResponse</a></code>
2323
- <code title="get /deployments/{deployment_owner}/{deployment_name}">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">retrieve</a>(deployment_name, \*, deployment_owner) -> <a href="./src/replicate/types/deployment_retrieve_response.py">DeploymentRetrieveResponse</a></code>
2424
- <code title="patch /deployments/{deployment_owner}/{deployment_name}">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">update</a>(deployment_name, \*, deployment_owner, \*\*<a href="src/replicate/types/deployment_update_params.py">params</a>) -> <a href="./src/replicate/types/deployment_update_response.py">DeploymentUpdateResponse</a></code>
25-
- <code title="get /deployments">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">list</a>() -> <a href="./src/replicate/types/deployment_list_response.py">DeploymentListResponse</a></code>
25+
- <code title="get /deployments">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">list</a>() -> <a href="./src/replicate/types/deployment_list_response.py">SyncCursorURLPage[DeploymentListResponse]</a></code>
2626
- <code title="delete /deployments/{deployment_owner}/{deployment_name}">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">delete</a>(deployment_name, \*, deployment_owner) -> None</code>
2727
- <code title="get /collections">client.deployments.<a href="./src/replicate/resources/deployments/deployments.py">list_em_all</a>() -> None</code>
2828

@@ -88,7 +88,7 @@ Methods:
8888

8989
- <code title="post /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">create</a>(\*\*<a href="src/replicate/types/prediction_create_params.py">params</a>) -> <a href="./src/replicate/types/prediction.py">Prediction</a></code>
9090
- <code title="get /predictions/{prediction_id}">client.predictions.<a href="./src/replicate/resources/predictions.py">retrieve</a>(prediction_id) -> <a href="./src/replicate/types/prediction.py">Prediction</a></code>
91-
- <code title="get /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">list</a>(\*\*<a href="src/replicate/types/prediction_list_params.py">params</a>) -> <a href="./src/replicate/types/prediction.py">SyncCursorURLPage[Prediction]</a></code>
91+
- <code title="get /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">list</a>(\*\*<a href="src/replicate/types/prediction_list_params.py">params</a>) -> <a href="./src/replicate/types/prediction.py">SyncCursorURLPageWithCreatedFilters[Prediction]</a></code>
9292
- <code title="post /predictions/{prediction_id}/cancel">client.predictions.<a href="./src/replicate/resources/predictions.py">cancel</a>(prediction_id) -> None</code>
9393

9494
# Trainings

src/replicate/pagination.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,56 @@
77

88
from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage
99

10-
__all__ = ["SyncCursorURLPage", "AsyncCursorURLPage"]
10+
__all__ = [
11+
"SyncCursorURLPageWithCreatedFilters",
12+
"AsyncCursorURLPageWithCreatedFilters",
13+
"SyncCursorURLPage",
14+
"AsyncCursorURLPage",
15+
]
1116

1217
_T = TypeVar("_T")
1318

1419

20+
class SyncCursorURLPageWithCreatedFilters(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
21+
results: List[_T]
22+
next: Optional[str] = None
23+
24+
@override
25+
def _get_page_items(self) -> List[_T]:
26+
results = self.results
27+
if not results:
28+
return []
29+
return results
30+
31+
@override
32+
def next_page_info(self) -> Optional[PageInfo]:
33+
url = self.next
34+
if url is None:
35+
return None
36+
37+
return PageInfo(url=httpx.URL(url))
38+
39+
40+
class AsyncCursorURLPageWithCreatedFilters(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
41+
results: List[_T]
42+
next: Optional[str] = None
43+
44+
@override
45+
def _get_page_items(self) -> List[_T]:
46+
results = self.results
47+
if not results:
48+
return []
49+
return results
50+
51+
@override
52+
def next_page_info(self) -> Optional[PageInfo]:
53+
url = self.next
54+
if url is None:
55+
return None
56+
57+
return PageInfo(url=httpx.URL(url))
58+
59+
1560
class SyncCursorURLPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
1661
results: List[_T]
1762
next: Optional[str] = None

src/replicate/resources/deployments/deployments.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
PredictionsResourceWithStreamingResponse,
2727
AsyncPredictionsResourceWithStreamingResponse,
2828
)
29-
from ..._base_client import make_request_options
29+
from ...pagination import SyncCursorURLPage, AsyncCursorURLPage
30+
from ..._base_client import AsyncPaginator, make_request_options
3031
from ...types.deployment_list_response import DeploymentListResponse
3132
from ...types.deployment_create_response import DeploymentCreateResponse
3233
from ...types.deployment_update_response import DeploymentUpdateResponse
@@ -345,7 +346,7 @@ def list(
345346
extra_query: Query | None = None,
346347
extra_body: Body | None = None,
347348
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
348-
) -> DeploymentListResponse:
349+
) -> SyncCursorURLPage[DeploymentListResponse]:
349350
"""
350351
Get a list of deployments associated with the current account, including the
351352
latest release configuration for each deployment.
@@ -392,12 +393,13 @@ def list(
392393
}
393394
```
394395
"""
395-
return self._get(
396+
return self._get_api_list(
396397
"/deployments",
398+
page=SyncCursorURLPage[DeploymentListResponse],
397399
options=make_request_options(
398400
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
399401
),
400-
cast_to=DeploymentListResponse,
402+
model=DeploymentListResponse,
401403
)
402404

403405
def delete(
@@ -798,7 +800,7 @@ async def update(
798800
cast_to=DeploymentUpdateResponse,
799801
)
800802

801-
async def list(
803+
def list(
802804
self,
803805
*,
804806
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -807,7 +809,7 @@ async def list(
807809
extra_query: Query | None = None,
808810
extra_body: Body | None = None,
809811
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
810-
) -> DeploymentListResponse:
812+
) -> AsyncPaginator[DeploymentListResponse, AsyncCursorURLPage[DeploymentListResponse]]:
811813
"""
812814
Get a list of deployments associated with the current account, including the
813815
latest release configuration for each deployment.
@@ -854,12 +856,13 @@ async def list(
854856
}
855857
```
856858
"""
857-
return await self._get(
859+
return self._get_api_list(
858860
"/deployments",
861+
page=AsyncCursorURLPage[DeploymentListResponse],
859862
options=make_request_options(
860863
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
861864
),
862-
cast_to=DeploymentListResponse,
865+
model=DeploymentListResponse,
863866
)
864867

865868
async def delete(

src/replicate/resources/predictions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
async_to_raw_response_wrapper,
2424
async_to_streamed_response_wrapper,
2525
)
26-
from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
26+
from ..pagination import SyncCursorURLPageWithCreatedFilters, AsyncCursorURLPageWithCreatedFilters
2727
from .._base_client import AsyncPaginator, make_request_options
2828
from ..types.prediction import Prediction
2929

@@ -302,7 +302,7 @@ def list(
302302
extra_query: Query | None = None,
303303
extra_body: Body | None = None,
304304
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
305-
) -> SyncCursorURLPage[Prediction]:
305+
) -> SyncCursorURLPageWithCreatedFilters[Prediction]:
306306
"""
307307
Get a paginated list of all predictions created by the user or organization
308308
associated with the provided API token.
@@ -389,7 +389,7 @@ def list(
389389
"""
390390
return self._get_api_list(
391391
"/predictions",
392-
page=SyncCursorURLPage[Prediction],
392+
page=SyncCursorURLPageWithCreatedFilters[Prediction],
393393
options=make_request_options(
394394
extra_headers=extra_headers,
395395
extra_query=extra_query,
@@ -713,7 +713,7 @@ def list(
713713
extra_query: Query | None = None,
714714
extra_body: Body | None = None,
715715
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
716-
) -> AsyncPaginator[Prediction, AsyncCursorURLPage[Prediction]]:
716+
) -> AsyncPaginator[Prediction, AsyncCursorURLPageWithCreatedFilters[Prediction]]:
717717
"""
718718
Get a paginated list of all predictions created by the user or organization
719719
associated with the provided API token.
@@ -800,7 +800,7 @@ def list(
800800
"""
801801
return self._get_api_list(
802802
"/predictions",
803-
page=AsyncCursorURLPage[Prediction],
803+
page=AsyncCursorURLPageWithCreatedFilters[Prediction],
804804
options=make_request_options(
805805
extra_headers=extra_headers,
806806
extra_query=extra_query,
Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

3-
from typing import List, Optional
3+
from typing import Optional
44
from datetime import datetime
55
from typing_extensions import Literal
66

77
from .._models import BaseModel
88

9-
__all__ = [
10-
"DeploymentListResponse",
11-
"Result",
12-
"ResultCurrentRelease",
13-
"ResultCurrentReleaseConfiguration",
14-
"ResultCurrentReleaseCreatedBy",
15-
]
9+
__all__ = ["DeploymentListResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"]
1610

1711

18-
class ResultCurrentReleaseConfiguration(BaseModel):
12+
class CurrentReleaseConfiguration(BaseModel):
1913
hardware: Optional[str] = None
2014
"""The SKU for the hardware used to run the model."""
2115

@@ -26,7 +20,7 @@ class ResultCurrentReleaseConfiguration(BaseModel):
2620
"""The minimum number of instances for scaling."""
2721

2822

29-
class ResultCurrentReleaseCreatedBy(BaseModel):
23+
class CurrentReleaseCreatedBy(BaseModel):
3024
type: Literal["organization", "user"]
3125
"""The account type of the creator. Can be a user or an organization."""
3226

@@ -43,13 +37,13 @@ class ResultCurrentReleaseCreatedBy(BaseModel):
4337
"""The name of the account that created the release."""
4438

4539

46-
class ResultCurrentRelease(BaseModel):
47-
configuration: Optional[ResultCurrentReleaseConfiguration] = None
40+
class CurrentRelease(BaseModel):
41+
configuration: Optional[CurrentReleaseConfiguration] = None
4842

4943
created_at: Optional[datetime] = None
5044
"""The time the release was created."""
5145

52-
created_by: Optional[ResultCurrentReleaseCreatedBy] = None
46+
created_by: Optional[CurrentReleaseCreatedBy] = None
5347

5448
model: Optional[str] = None
5549
"""The model identifier string in the format of `{model_owner}/{model_name}`."""
@@ -65,22 +59,11 @@ class ResultCurrentRelease(BaseModel):
6559
"""The ID of the model version used in the release."""
6660

6761

68-
class Result(BaseModel):
69-
current_release: Optional[ResultCurrentRelease] = None
62+
class DeploymentListResponse(BaseModel):
63+
current_release: Optional[CurrentRelease] = None
7064

7165
name: Optional[str] = None
7266
"""The name of the deployment."""
7367

7468
owner: Optional[str] = None
7569
"""The owner of the deployment."""
76-
77-
78-
class DeploymentListResponse(BaseModel):
79-
next: Optional[str] = None
80-
"""A URL pointing to the next page of deployment objects if any"""
81-
82-
previous: Optional[str] = None
83-
"""A URL pointing to the previous page of deployment objects if any"""
84-
85-
results: Optional[List[Result]] = None
86-
"""An array containing a page of deployment objects"""

tests/api_resources/test_deployments.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
DeploymentUpdateResponse,
1616
DeploymentRetrieveResponse,
1717
)
18+
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
1819

1920
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
2021

@@ -192,7 +193,7 @@ def test_path_params_update(self, client: ReplicateClient) -> None:
192193
@parametrize
193194
def test_method_list(self, client: ReplicateClient) -> None:
194195
deployment = client.deployments.list()
195-
assert_matches_type(DeploymentListResponse, deployment, path=["response"])
196+
assert_matches_type(SyncCursorURLPage[DeploymentListResponse], deployment, path=["response"])
196197

197198
@pytest.mark.skip()
198199
@parametrize
@@ -202,7 +203,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None:
202203
assert response.is_closed is True
203204
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
204205
deployment = response.parse()
205-
assert_matches_type(DeploymentListResponse, deployment, path=["response"])
206+
assert_matches_type(SyncCursorURLPage[DeploymentListResponse], deployment, path=["response"])
206207

207208
@pytest.mark.skip()
208209
@parametrize
@@ -212,7 +213,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None:
212213
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
213214

214215
deployment = response.parse()
215-
assert_matches_type(DeploymentListResponse, deployment, path=["response"])
216+
assert_matches_type(SyncCursorURLPage[DeploymentListResponse], deployment, path=["response"])
216217

217218
assert cast(Any, response.is_closed) is True
218219

@@ -470,7 +471,7 @@ async def test_path_params_update(self, async_client: AsyncReplicateClient) -> N
470471
@parametrize
471472
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
472473
deployment = await async_client.deployments.list()
473-
assert_matches_type(DeploymentListResponse, deployment, path=["response"])
474+
assert_matches_type(AsyncCursorURLPage[DeploymentListResponse], deployment, path=["response"])
474475

475476
@pytest.mark.skip()
476477
@parametrize
@@ -480,7 +481,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No
480481
assert response.is_closed is True
481482
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
482483
deployment = await response.parse()
483-
assert_matches_type(DeploymentListResponse, deployment, path=["response"])
484+
assert_matches_type(AsyncCursorURLPage[DeploymentListResponse], deployment, path=["response"])
484485

485486
@pytest.mark.skip()
486487
@parametrize
@@ -490,7 +491,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient)
490491
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
491492

492493
deployment = await response.parse()
493-
assert_matches_type(DeploymentListResponse, deployment, path=["response"])
494+
assert_matches_type(AsyncCursorURLPage[DeploymentListResponse], deployment, path=["response"])
494495

495496
assert cast(Any, response.is_closed) is True
496497

0 commit comments

Comments
 (0)