2
2
import os
3
3
4
4
from .patches .patched_dataset import PatchedDataset as Dataset
5
+ from .types import ModelSummary , SearchResult , PredictionResult
6
+ from typing import Optional , Union , List , Dict , Any
5
7
from jaqpot_api_client import Model
6
8
from jaqpot_api_client import (
7
9
ModelApi ,
@@ -43,7 +45,7 @@ class JaqpotApiClient:
43
45
The logger object for logging messages.
44
46
"""
45
47
46
- def __init__ (self , base_url = None , api_url = None , create_logs = False , api_key = None , api_secret = None ):
48
+ def __init__ (self , base_url : Optional [ str ] = None , api_url : Optional [ str ] = None , create_logs : bool = False , api_key : Optional [ str ] = None , api_secret : Optional [ str ] = None ) -> None :
47
49
"""Initialize the JaqpotApiClient.
48
50
49
51
Parameters
@@ -72,7 +74,7 @@ def __init__(self, base_url=None, api_url=None, create_logs=False, api_key=None,
72
74
.build ()
73
75
)
74
76
75
- def get_model_by_id (self , model_id ) -> Model :
77
+ def get_model_by_id (self , model_id : int ) -> Model :
76
78
"""Get a model from Jaqpot by its ID.
77
79
78
80
Parameters
@@ -93,13 +95,13 @@ def get_model_by_id(self, model_id) -> Model:
93
95
model_api = ModelApi (self .http_client )
94
96
response = model_api .get_model_by_id_with_http_info (id = model_id )
95
97
if response .status_code < 300 :
96
- return response .data . to_dict ()
98
+ return response .data
97
99
raise JaqpotApiException (
98
- message = response .data .to_dict ().message ,
99
- status_code = response .status_code . value ,
100
+ message = response .data .to_dict ().get ( ' message' , 'Unknown error' ) ,
101
+ status_code = response .status_code ,
100
102
)
101
103
102
- def get_model_summary (self , model_id ) :
104
+ def get_model_summary (self , model_id : int ) -> ModelSummary :
103
105
"""Get a summary of a model from Jaqpot by its ID.
104
106
105
107
Parameters
@@ -109,8 +111,8 @@ def get_model_summary(self, model_id):
109
111
110
112
Returns
111
113
-------
112
- dict
113
- A dictionary containing the model summary.
114
+ ModelSummary
115
+ A typed dictionary containing the model summary with standardized fields .
114
116
115
117
Raises
116
118
------
@@ -128,7 +130,7 @@ def get_model_summary(self, model_id):
128
130
}
129
131
return model_summary
130
132
131
- def get_shared_models (self , page = None , size = None , sort = None , organization_id = None ):
133
+ def get_shared_models (self , page : Optional [ int ] = None , size : Optional [ int ] = None , sort : Optional [ str ] = None , organization_id : Optional [ int ] = None ) -> Any :
132
134
"""Get shared models from Jaqpot.
133
135
134
136
Parameters
@@ -157,13 +159,13 @@ def get_shared_models(self, page=None, size=None, sort=None, organization_id=Non
157
159
page = page , size = size , sort = sort , organization_id = organization_id
158
160
)
159
161
if response .status_code < 300 :
160
- return response
162
+ return response . data
161
163
raise JaqpotApiException (
162
- message = response .data .to_dict ().message ,
163
- status_code = response .status_code . value ,
164
+ message = response .data .to_dict ().get ( ' message' , 'Unknown error' ) ,
165
+ status_code = response .status_code ,
164
166
)
165
167
166
- def search_models (self , query , page = None , size = None ):
168
+ def search_models (self , query : str , page : Optional [ int ] = None , size : Optional [ int ] = None ) -> SearchResult :
167
169
"""Search for models on Jaqpot based on keywords.
168
170
169
171
Parameters
@@ -201,7 +203,7 @@ def search_models(self, query, page=None, size=None):
201
203
status_code = response .status_code ,
202
204
)
203
205
204
- def get_dataset_by_id (self , dataset_id ) -> Dataset :
206
+ def get_dataset_by_id (self , dataset_id : int ) -> Dataset :
205
207
"""Get a dataset from Jaqpot by its ID.
206
208
207
209
Parameters
@@ -224,11 +226,11 @@ def get_dataset_by_id(self, dataset_id) -> Dataset:
224
226
if response .status_code < 300 :
225
227
return response .data
226
228
raise JaqpotApiException (
227
- message = response .data .to_dict ().message ,
228
- status_code = response .status_code . value ,
229
+ message = response .data .to_dict ().get ( ' message' , 'Unknown error' ) ,
230
+ status_code = response .status_code ,
229
231
)
230
232
231
- def predict_sync (self , model_id , dataset ) :
233
+ def predict_sync (self , model_id : int , dataset : Union [ List [ Dict [ str , Any ]], Dict [ str , Any ]]) -> Any :
232
234
"""Make a synchronous prediction with a model on Jaqpot.
233
235
234
236
Parameters
@@ -267,11 +269,11 @@ def predict_sync(self, model_id, dataset):
267
269
elif dataset .status == "FAILURE" :
268
270
raise JaqpotPredictionFailureException (dataset .failure_reason )
269
271
raise JaqpotApiException (
270
- message = response .data .to_dict ().message ,
271
- status_code = response .status_code . value ,
272
+ message = response .data .to_dict ().get ( ' message' , 'Unknown error' ) ,
273
+ status_code = response .status_code ,
272
274
)
273
275
274
- def predict_async (self , model_id , dataset ) :
276
+ def predict_async (self , model_id : int , dataset : Union [ List [ Dict [ str , Any ]], Dict [ str , Any ]]) -> int :
275
277
"""Make an asynchronous prediction with a model on Jaqpot.
276
278
277
279
Parameters
@@ -306,11 +308,11 @@ def predict_async(self, model_id, dataset):
306
308
dataset_id = int (dataset_location .split ("/" )[- 1 ])
307
309
return dataset_id
308
310
raise JaqpotApiException (
309
- message = response .data .to_dict ().message ,
310
- status_code = response .status_code . value ,
311
+ message = response .data .to_dict ().get ( ' message' , 'Unknown error' ) ,
312
+ status_code = response .status_code ,
311
313
)
312
314
313
- def predict_with_csv_sync (self , model_id , csv_path ) :
315
+ def predict_with_csv_sync (self , model_id : int , csv_path : str ) -> Any :
314
316
"""Make a synchronous prediction with a model on Jaqpot using a CSV file.
315
317
316
318
Parameters
@@ -347,8 +349,8 @@ def predict_with_csv_sync(self, model_id, csv_path):
347
349
elif dataset .status == "FAILURE" :
348
350
raise JaqpotPredictionFailureException (message = dataset .failure_reason )
349
351
raise JaqpotApiException (
350
- message = response .data .to_dict ().message ,
351
- status_code = response .status_code . value ,
352
+ message = response .data .to_dict ().get ( ' message' , 'Unknown error' ) ,
353
+ status_code = response .status_code ,
352
354
)
353
355
354
356
def _get_dataset_with_polling (self , response ):
@@ -385,7 +387,7 @@ def _get_dataset_with_polling(self, response):
385
387
dataset = self .get_dataset_by_id (dataset_id )
386
388
return dataset
387
389
388
- def qsartoolbox_calculator_predict_sync (self , smiles , calculator_guid ) :
390
+ def qsartoolbox_calculator_predict_sync (self , smiles : str , calculator_guid : str ) -> Any :
389
391
"""Synchronously predict using the QSAR Toolbox calculator.
390
392
391
393
Parameters
@@ -404,7 +406,7 @@ def qsartoolbox_calculator_predict_sync(self, smiles, calculator_guid):
404
406
prediction = self .predict_sync (QSARTOOLBOX_CALCULATOR_MODEL_ID , dataset )
405
407
return prediction
406
408
407
- def qsartoolbox_qsar_model_predict_sync (self , smiles , qsar_guid ) :
409
+ def qsartoolbox_qsar_model_predict_sync (self , smiles : str , qsar_guid : str ) -> Any :
408
410
"""Synchronously predict QSAR model results using the QSAR Toolbox.
409
411
410
412
Parameters
@@ -423,7 +425,7 @@ def qsartoolbox_qsar_model_predict_sync(self, smiles, qsar_guid):
423
425
prediction = self .predict_sync (QSARTOOLBOX_MODEL_MODEL_ID , dataset )
424
426
return prediction
425
427
426
- def qsartoolbox_profiler_predict_sync (self , smiles , profiler_guid ) :
428
+ def qsartoolbox_profiler_predict_sync (self , smiles : str , profiler_guid : str ) -> Any :
427
429
"""Synchronously predict using the QSAR Toolbox profiler.
428
430
429
431
Parameters
0 commit comments