Skip to content

Commit a38cd7a

Browse files
committed
fix: fix mistakes in parsing responses + add types
1 parent 9d5a58f commit a38cd7a

File tree

4 files changed

+87
-35
lines changed

4 files changed

+87
-35
lines changed

src/jaqpot_python_sdk/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
# SPDX-FileCopyrightText: 2025-present Alex Arvanitidis <[email protected]>
22
#
33
# SPDX-License-Identifier: MIT
4+
5+
from .types import ModelSummary, SearchResult, PredictionResult
6+
from .jaqpot_api_client import JaqpotApiClient
7+
8+
__all__ = ["JaqpotApiClient", "ModelSummary", "SearchResult", "PredictionResult"]

src/jaqpot_python_sdk/jaqpot_api_client.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33

44
from .patches.patched_dataset import PatchedDataset as Dataset
5+
from .types import ModelSummary, SearchResult, PredictionResult
6+
from typing import Optional, Union, List, Dict, Any
57
from jaqpot_api_client import Model
68
from jaqpot_api_client import (
79
ModelApi,
@@ -43,7 +45,7 @@ class JaqpotApiClient:
4345
The logger object for logging messages.
4446
"""
4547

46-
def __init__(self, base_url=None, api_url=None, create_logs=False, api_key=None, api_secret=None):
48+
def __init__(self, base_url: Optional[str] = None, api_url: Optional[str] = None, create_logs: bool = False, api_key: Optional[str] = None, api_secret: Optional[str] = None) -> None:
4749
"""Initialize the JaqpotApiClient.
4850
4951
Parameters
@@ -72,7 +74,7 @@ def __init__(self, base_url=None, api_url=None, create_logs=False, api_key=None,
7274
.build()
7375
)
7476

75-
def get_model_by_id(self, model_id) -> Model:
77+
def get_model_by_id(self, model_id: int) -> Model:
7678
"""Get a model from Jaqpot by its ID.
7779
7880
Parameters
@@ -93,13 +95,13 @@ def get_model_by_id(self, model_id) -> Model:
9395
model_api = ModelApi(self.http_client)
9496
response = model_api.get_model_by_id_with_http_info(id=model_id)
9597
if response.status_code < 300:
96-
return response.data.to_dict()
98+
return response.data
9799
raise JaqpotApiException(
98-
message=response.data.to_dict().message,
99-
status_code=response.status_code.value,
100+
message=response.data.to_dict().get('message', 'Unknown error'),
101+
status_code=response.status_code,
100102
)
101103

102-
def get_model_summary(self, model_id):
104+
def get_model_summary(self, model_id: int) -> ModelSummary:
103105
"""Get a summary of a model from Jaqpot by its ID.
104106
105107
Parameters
@@ -109,8 +111,8 @@ def get_model_summary(self, model_id):
109111
110112
Returns
111113
-------
112-
dict
113-
A dictionary containing the model summary.
114+
ModelSummary
115+
A typed dictionary containing the model summary with standardized fields.
114116
115117
Raises
116118
------
@@ -128,7 +130,7 @@ def get_model_summary(self, model_id):
128130
}
129131
return model_summary
130132

131-
def get_shared_models(self, page=None, size=None, sort=None, organization_id=None):
133+
def get_shared_models(self, page: Optional[int] = None, size: Optional[int] = None, sort: Optional[str] = None, organization_id: Optional[int] = None) -> Any:
132134
"""Get shared models from Jaqpot.
133135
134136
Parameters
@@ -157,13 +159,13 @@ def get_shared_models(self, page=None, size=None, sort=None, organization_id=Non
157159
page=page, size=size, sort=sort, organization_id=organization_id
158160
)
159161
if response.status_code < 300:
160-
return response
162+
return response.data
161163
raise JaqpotApiException(
162-
message=response.data.to_dict().message,
163-
status_code=response.status_code.value,
164+
message=response.data.to_dict().get('message', 'Unknown error'),
165+
status_code=response.status_code,
164166
)
165167

166-
def search_models(self, query, page=None, size=None):
168+
def search_models(self, query: str, page: Optional[int] = None, size: Optional[int] = None) -> SearchResult:
167169
"""Search for models on Jaqpot based on keywords.
168170
169171
Parameters
@@ -201,7 +203,7 @@ def search_models(self, query, page=None, size=None):
201203
status_code=response.status_code,
202204
)
203205

204-
def get_dataset_by_id(self, dataset_id) -> Dataset:
206+
def get_dataset_by_id(self, dataset_id: int) -> Dataset:
205207
"""Get a dataset from Jaqpot by its ID.
206208
207209
Parameters
@@ -224,11 +226,11 @@ def get_dataset_by_id(self, dataset_id) -> Dataset:
224226
if response.status_code < 300:
225227
return response.data
226228
raise JaqpotApiException(
227-
message=response.data.to_dict().message,
228-
status_code=response.status_code.value,
229+
message=response.data.to_dict().get('message', 'Unknown error'),
230+
status_code=response.status_code,
229231
)
230232

231-
def predict_sync(self, model_id, dataset):
233+
def predict_sync(self, model_id: int, dataset: Union[List[Dict[str, Any]], Dict[str, Any]]) -> Any:
232234
"""Make a synchronous prediction with a model on Jaqpot.
233235
234236
Parameters
@@ -267,11 +269,11 @@ def predict_sync(self, model_id, dataset):
267269
elif dataset.status == "FAILURE":
268270
raise JaqpotPredictionFailureException(dataset.failure_reason)
269271
raise JaqpotApiException(
270-
message=response.data.to_dict().message,
271-
status_code=response.status_code.value,
272+
message=response.data.to_dict().get('message', 'Unknown error'),
273+
status_code=response.status_code,
272274
)
273275

274-
def predict_async(self, model_id, dataset):
276+
def predict_async(self, model_id: int, dataset: Union[List[Dict[str, Any]], Dict[str, Any]]) -> int:
275277
"""Make an asynchronous prediction with a model on Jaqpot.
276278
277279
Parameters
@@ -306,11 +308,11 @@ def predict_async(self, model_id, dataset):
306308
dataset_id = int(dataset_location.split("/")[-1])
307309
return dataset_id
308310
raise JaqpotApiException(
309-
message=response.data.to_dict().message,
310-
status_code=response.status_code.value,
311+
message=response.data.to_dict().get('message', 'Unknown error'),
312+
status_code=response.status_code,
311313
)
312314

313-
def predict_with_csv_sync(self, model_id, csv_path):
315+
def predict_with_csv_sync(self, model_id: int, csv_path: str) -> Any:
314316
"""Make a synchronous prediction with a model on Jaqpot using a CSV file.
315317
316318
Parameters
@@ -347,8 +349,8 @@ def predict_with_csv_sync(self, model_id, csv_path):
347349
elif dataset.status == "FAILURE":
348350
raise JaqpotPredictionFailureException(message=dataset.failure_reason)
349351
raise JaqpotApiException(
350-
message=response.data.to_dict().message,
351-
status_code=response.status_code.value,
352+
message=response.data.to_dict().get('message', 'Unknown error'),
353+
status_code=response.status_code,
352354
)
353355

354356
def _get_dataset_with_polling(self, response):
@@ -385,7 +387,7 @@ def _get_dataset_with_polling(self, response):
385387
dataset = self.get_dataset_by_id(dataset_id)
386388
return dataset
387389

388-
def qsartoolbox_calculator_predict_sync(self, smiles, calculator_guid):
390+
def qsartoolbox_calculator_predict_sync(self, smiles: str, calculator_guid: str) -> Any:
389391
"""Synchronously predict using the QSAR Toolbox calculator.
390392
391393
Parameters
@@ -404,7 +406,7 @@ def qsartoolbox_calculator_predict_sync(self, smiles, calculator_guid):
404406
prediction = self.predict_sync(QSARTOOLBOX_CALCULATOR_MODEL_ID, dataset)
405407
return prediction
406408

407-
def qsartoolbox_qsar_model_predict_sync(self, smiles, qsar_guid):
409+
def qsartoolbox_qsar_model_predict_sync(self, smiles: str, qsar_guid: str) -> Any:
408410
"""Synchronously predict QSAR model results using the QSAR Toolbox.
409411
410412
Parameters
@@ -423,7 +425,7 @@ def qsartoolbox_qsar_model_predict_sync(self, smiles, qsar_guid):
423425
prediction = self.predict_sync(QSARTOOLBOX_MODEL_MODEL_ID, dataset)
424426
return prediction
425427

426-
def qsartoolbox_profiler_predict_sync(self, smiles, profiler_guid):
428+
def qsartoolbox_profiler_predict_sync(self, smiles: str, profiler_guid: str) -> Any:
427429
"""Synchronously predict using the QSAR Toolbox profiler.
428430
429431
Parameters

src/jaqpot_python_sdk/types.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Type definitions for the Jaqpot Python SDK."""
2+
3+
from typing import List, Any, Optional, TypedDict, Union, Dict
4+
5+
6+
class ModelSummary(TypedDict, total=False):
7+
"""Summary representation of a Jaqpot model.
8+
9+
Attributes:
10+
name: The name of the model
11+
modelId: The unique identifier of the model
12+
description: A description of what the model does
13+
type: The type/category of the model
14+
independentFeatures: List of input features the model expects
15+
dependentFeatures: List of output features the model produces
16+
"""
17+
name: Optional[str]
18+
modelId: Optional[int]
19+
description: Optional[str]
20+
type: Optional[Any]
21+
independentFeatures: Optional[List[Any]]
22+
dependentFeatures: Optional[List[Any]]
23+
24+
25+
class SearchResult(TypedDict, total=False):
26+
"""Result from model search operations.
27+
28+
Attributes:
29+
content: List of models matching the search query
30+
totalElements: Total number of matching models
31+
totalPages: Total number of pages
32+
pageSize: Number of models per page
33+
pageNumber: Current page number
34+
"""
35+
content: Optional[List[Dict[str, Any]]]
36+
totalElements: Optional[int]
37+
totalPages: Optional[int]
38+
pageSize: Optional[int]
39+
pageNumber: Optional[int]
40+
41+
42+
class PredictionResult(TypedDict, total=False):
43+
"""Result from prediction operations.
44+
45+
Attributes:
46+
predictions: The prediction results
47+
status: Status of the prediction
48+
message: Any message associated with the prediction
49+
"""
50+
predictions: Optional[List[Dict[str, Any]]]
51+
status: Optional[str]
52+
message: Optional[str]

test.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)