Skip to content

Commit b509bc0

Browse files
committed
Merge remote-tracking branch 'origin/master' into PS-12943
# Conflicts: # gradient/api_sdk/serializers/cluster.py
2 parents 683f53e + 2b57657 commit b509bc0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+870
-849
lines changed

gradient/api_sdk/clients/base_client.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,27 @@
1+
import copy
2+
3+
from gradient.api_sdk.repositories.common import BaseRepository
14
from .. import logger as sdk_logger
2-
from ..repositories.clusters import ValidateClusterRepository
3-
from gradient.api_sdk.sdk_exceptions import GradientSdkError
45
from ..repositories.tags import ListTagRepository, UpdateTagRepository
56
from ...exceptions import ReceivingDataFailedError
67

78

89
class BaseClient(object):
910
def __init__(
10-
self, api_key,
11+
self,
12+
api_key,
13+
ps_client_name=None,
1114
logger=sdk_logger.MuteLogger()
1215
):
1316
"""
1417
Base class. All client classes inherit from it.
1518
16-
An API key can be created at paperspace.com after you sign in to your account. After obtaining it, you can set
17-
it in the CLI using the command::
18-
19-
gradient apiKey XXXXXXXXXXXXXXXXXXX
20-
21-
or you can provide your API key in any command, for example::
22-
23-
gradient experiments run ... --apiKey XXXXXXXXXXXXXXXXXXX
24-
2519
:param str api_key: your API key
20+
:param str ps_client_name:
2621
:param sdk_logger.Logger logger:
2722
"""
2823
self.api_key = api_key
24+
self.ps_client_name = ps_client_name
2925
self.logger = logger
3026

3127
KNOWN_TAGS_ENTITIES = [
@@ -73,13 +69,13 @@ def add_tags(self, entity_id, entity, tags):
7369
"""
7470
self._validate_entities(entity)
7571

76-
list_tag_repository = ListTagRepository(api_key=self.api_key, logger=self.logger)
72+
list_tag_repository = self.build_repository(ListTagRepository)
7773
entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id])
7874

7975
if entity_tags:
8076
tags = self.merge_tags(entity_id, entity_tags, tags)
8177

82-
update_tag_repository = UpdateTagRepository(api_key=self.api_key, logger=self.logger)
78+
update_tag_repository = self.build_repository(UpdateTagRepository)
8379
update_tag_repository.update(entity=entity, entity_id=entity_id, tags=tags)
8480

8581
def remove_tags(self, entity_id, entity, tags):
@@ -92,13 +88,13 @@ def remove_tags(self, entity_id, entity, tags):
9288
"""
9389
self._validate_entities(entity)
9490

95-
list_tag_repository = ListTagRepository(api_key=self.api_key, logger=self.logger)
91+
list_tag_repository = self.build_repository(ListTagRepository)
9692
entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id])
9793

9894
if entity_tags:
9995
entity_tags = self.diff_tags(entity_id, entity_tags, tags)
10096

101-
update_tag_repository = UpdateTagRepository(api_key=self.api_key, logger=self.logger)
97+
update_tag_repository = self.build_repository(UpdateTagRepository)
10298
update_tag_repository.update(entity=entity, entity_id=entity_id, tags=entity_tags)
10399

104100
def list_tags(self, entity_ids, entity):
@@ -110,6 +106,19 @@ def list_tags(self, entity_ids, entity):
110106
"""
111107
self._validate_entities(entity)
112108

113-
list_tag_repository = ListTagRepository(api_key=self.api_key, logger=self.logger)
109+
list_tag_repository = self.build_repository(ListTagRepository)
114110
entity_tags = list_tag_repository.list(entity=entity, entity_ids=entity_ids)
115111
return entity_tags
112+
113+
def build_repository(self, repository_class, *args, **kwargs):
114+
"""
115+
:param type[BaseRepository] repository_class:
116+
:rtype: BaseRepository
117+
"""
118+
119+
if self.ps_client_name is not None and kwargs.get("ps_client_name") is None:
120+
kwargs = copy.deepcopy(kwargs)
121+
kwargs["ps_client_name"] = self.ps_client_name
122+
123+
repository = repository_class(*args, api_key=self.api_key, logger=self.logger, **kwargs)
124+
return repository

gradient/api_sdk/clients/deployment_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create(
118118
auth_password=auth_password,
119119
)
120120

121-
repository = repositories.CreateDeployment(api_key=self.api_key, logger=self.logger)
121+
repository = self.build_repository(repositories.CreateDeployment)
122122
deployment_id = repository.create(deployment)
123123
if tags:
124124
self.add_tags(entity_id=deployment_id, entity=self.entity, tags=tags)
@@ -132,7 +132,7 @@ def get(self, deployment_id):
132132
:return: Deployment instance
133133
:rtype: models.Deployment
134134
"""
135-
repository = repositories.GetDeployment(self.api_key, logger=self.logger)
135+
repository = self.build_repository(repositories.GetDeployment)
136136
deployment = repository.get(deployment_id=deployment_id)
137137
return deployment
138138

@@ -147,7 +147,7 @@ def start(self, deployment_id):
147147
:param str deployment_id: Deployment ID
148148
"""
149149

150-
repository = repositories.StartDeployment(api_key=self.api_key, logger=self.logger)
150+
repository = self.build_repository(repositories.StartDeployment)
151151
repository.start(deployment_id)
152152

153153
def stop(self, deployment_id):
@@ -161,7 +161,7 @@ def stop(self, deployment_id):
161161
:param deployment_id: Deployment ID
162162
"""
163163

164-
repository = repositories.StopDeployment(api_key=self.api_key, logger=self.logger)
164+
repository = self.build_repository(repositories.StopDeployment)
165165
repository.stop(deployment_id)
166166

167167
def list(self, state=None, project_id=None, model_id=None, tags=None):
@@ -177,12 +177,12 @@ def list(self, state=None, project_id=None, model_id=None, tags=None):
177177
:rtype: list[models.Deployment]
178178
"""
179179

180-
repository = repositories.ListDeployments(api_key=self.api_key, logger=self.logger)
180+
repository = self.build_repository(repositories.ListDeployments)
181181
deployments = repository.list(state=state, project_id=project_id, model_id=model_id, tags=tags)
182182
return deployments
183183

184184
def delete(self, deployment_id):
185-
repository = repositories.DeleteDeployment(api_key=self.api_key, logger=self.logger)
185+
repository = self.build_repository(repositories.DeleteDeployment)
186186
repository.delete(deployment_id)
187187

188188
def update(
@@ -232,5 +232,5 @@ def update(
232232
auth_password=auth_password,
233233
)
234234

235-
repository = repositories.UpdateDeployment(api_key=self.api_key, logger=self.logger)
235+
repository = self.build_repository(repositories.UpdateDeployment)
236236
repository.update(deployment_id, deployment)

gradient/api_sdk/clients/experiment_client.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def create_single_node(
9696
registry_url=registry_url,
9797
)
9898

99-
repository = repositories.CreateSingleNodeExperiment(api_key=self.api_key, logger=self.logger)
99+
repository = self.build_repository(repositories.CreateSingleNodeExperiment)
100100
handle = repository.create(experiment)
101101

102102
if tags:
@@ -224,7 +224,7 @@ def create_multi_node(
224224
parameter_server_registry_url=parameter_server_registry_url,
225225
)
226226

227-
repository = repositories.CreateMultiNodeExperiment(api_key=self.api_key, logger=self.logger)
227+
repository = self.build_repository(repositories.CreateMultiNodeExperiment)
228228
handle = repository.create(experiment)
229229

230230
if tags:
@@ -349,7 +349,7 @@ def create_mpi_multi_node(
349349
master_registry_url=master_registry_url,
350350
)
351351

352-
repository = repositories.CreateMpiMultiNodeExperiment(api_key=self.api_key, logger=self.logger)
352+
repository = self.build_repository(repositories.CreateMpiMultiNodeExperiment)
353353
handle = repository.create(experiment)
354354

355355
if tags:
@@ -445,7 +445,7 @@ def run_single_node(
445445
registry_url=registry_url,
446446
)
447447

448-
repository = repositories.RunSingleNodeExperiment(api_key=self.api_key, logger=self.logger)
448+
repository = self.build_repository(repositories.RunSingleNodeExperiment)
449449
handle = repository.create(experiment)
450450

451451
if tags:
@@ -571,11 +571,12 @@ def run_multi_node(
571571
parameter_server_registry_url=parameter_server_registry_url,
572572
)
573573

574-
repository = repositories.RunMultiNodeExperiment(api_key=self.api_key, logger=self.logger)
574+
repository = self.build_repository(repositories.RunMultiNodeExperiment)
575575
handle = repository.create(experiment)
576576

577577
if tags:
578578
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
579+
579580
return handle
580581

581582
def run_mpi_multi_node(
@@ -695,11 +696,12 @@ def run_mpi_multi_node(
695696
master_registry_url=master_registry_url,
696697
)
697698

698-
repository = repositories.RunMpiMultiNodeExperiment(api_key=self.api_key, logger=self.logger)
699+
repository = self.build_repository(repositories.RunMpiMultiNodeExperiment)
699700
handle = repository.create(experiment)
700701

701702
if tags:
702703
self.add_tags(entity_id=handle, entity=self.entity, tags=tags)
704+
703705
return handle
704706

705707
def start(self, experiment_id):
@@ -710,7 +712,7 @@ def start(self, experiment_id):
710712
:raises: exceptions.GradientSdkError
711713
"""
712714

713-
repository = repositories.StartExperiment(api_key=self.api_key, logger=self.logger)
715+
repository = self.build_repository(repositories.StartExperiment)
714716
repository.start(experiment_id)
715717

716718
def stop(self, experiment_id):
@@ -721,7 +723,7 @@ def stop(self, experiment_id):
721723
:raises: exceptions.GradientSdkError
722724
"""
723725

724-
repository = repositories.StopExperiment(api_key=self.api_key, logger=self.logger)
726+
repository = self.build_repository(repositories.StopExperiment)
725727
repository.stop(experiment_id)
726728

727729
def list(self, project_id=None, offset=None, limit=None, get_meta=False, tags=None):
@@ -737,7 +739,7 @@ def list(self, project_id=None, offset=None, limit=None, get_meta=False, tags=No
737739
:rtype: list[models.SingleNodeExperiment|models.MultiNodeExperiment]|tuple[list[models.SingleNodeExperiment|models.MultiNodeExperiment],dict]
738740
"""
739741

740-
repository = repositories.ListExperiments(api_key=self.api_key, logger=self.logger)
742+
repository = self.build_repository(repositories.ListExperiments)
741743
experiments = repository.list(project_id=project_id, limit=limit, offset=offset, get_meta=get_meta, tags=tags)
742744
return experiments
743745

@@ -747,7 +749,7 @@ def get(self, experiment_id):
747749
:param str experiment_id: Experiment ID
748750
:rtype: models.SingleNodeExperiment|models.MultiNodeExperiment|MpiMultiNodeExperiment
749751
"""
750-
repository = repositories.GetExperiment(api_key=self.api_key, logger=self.logger)
752+
repository = self.build_repository(repositories.GetExperiment)
751753
experiment = repository.get(experiment_id=experiment_id)
752754
return experiment
753755

@@ -762,7 +764,7 @@ def logs(self, experiment_id, line=0, limit=10000):
762764
:rtype: list[models.LogRow]
763765
"""
764766

765-
repository = repositories.ListExperimentLogs(api_key=self.api_key, logger=self.logger)
767+
repository = self.build_repository(repositories.ListExperimentLogs)
766768
logs = repository.list(experiment_id, line, limit)
767769
return logs
768770

@@ -777,12 +779,12 @@ def yield_logs(self, experiment_id, line=0, limit=10000):
777779
:rtype: Iterator[models.LogRow]
778780
"""
779781

780-
repository = repositories.ListExperimentLogs(api_key=self.api_key, logger=self.logger)
782+
repository = self.build_repository(repositories.ListExperimentLogs)
781783
logs_generator = repository.yield_logs(experiment_id, line, limit)
782784
return logs_generator
783785

784786
def delete(self, experiment_id):
785-
repository = repositories.DeleteExperiment(api_key=self.api_key, logger=self.logger)
787+
repository = self.build_repository(repositories.DeleteExperiment)
786788
repository.delete(experiment_id)
787789

788790
def _validate_arguments(self, **kwargs):

gradient/api_sdk/clients/http_client.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,30 @@
77
from ..config import config
88

99
default_headers = {"X-API-Key": config.PAPERSPACE_API_KEY,
10-
"ps_client_name": "gradient-cli",
10+
"ps_client_name": "gradient-cli-sdk",
1111
"ps_client_version": version.version}
1212

1313

1414
class API(object):
15-
def __init__(self, api_url, headers=None, api_key=None, logger=sdk_logger.MuteLogger()):
15+
def __init__(self, api_url, headers=None, api_key=None, ps_client_name=None, logger=sdk_logger.MuteLogger()):
1616
"""
1717
18-
:type str api_url: url you want to connect
19-
:type dict headers: headers
20-
:type str api_key: your API key
21-
:type sdk_logger.Logger logger:
18+
:param str api_url: url you want to connect
19+
:param dict headers: headers
20+
:param str api_key: your API key
21+
:param str ps_client_name: Client name
22+
:param sdk_logger.Logger logger:
2223
"""
2324
self.api_url = api_url
2425
headers = headers or default_headers
2526
self.headers = headers.copy()
27+
2628
if api_key:
2729
self.api_key = api_key
30+
31+
if ps_client_name is not None:
32+
self.ps_client_name = ps_client_name
33+
2834
self.logger = logger
2935

3036
@property
@@ -35,6 +41,14 @@ def api_key(self):
3541
def api_key(self, value):
3642
self.headers["X-API-Key"] = value
3743

44+
@property
45+
def ps_client_name(self):
46+
return self.headers.get("ps_client_name")
47+
48+
@ps_client_name.setter
49+
def ps_client_name(self, value):
50+
self.headers["ps_client_name"] = value
51+
3852
def get_path(self, url):
3953
api_url = self.api_url if not self.api_url.endswith("/") else self.api_url[:-1]
4054
template = "{}{}" if url.startswith("/") else "{}/{}"

gradient/api_sdk/clients/hyperparameter_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def create(
106106
use_dockerfile=use_dockerfile,
107107
)
108108

109-
repository = repositories.CreateHyperparameterJob(api_key=self.api_key, logger=self.logger)
109+
repository = self.build_repository(repositories.CreateHyperparameterJob)
110110
handle = repository.create(hyperparameter)
111111

112112
if tags:
@@ -228,7 +228,7 @@ def run(
228228
use_dockerfile=use_dockerfile,
229229
)
230230

231-
repository = repositories.CreateAndStartHyperparameterJob(api_key=self.api_key, logger=self.logger)
231+
repository = self.build_repository(repositories.CreateAndStartHyperparameterJob)
232232
handle = repository.create(hyperparameter)
233233

234234
if tags:
@@ -245,7 +245,7 @@ def get(self, id):
245245
:rtype: models.Hyperparameter
246246
"""
247247

248-
repository = repositories.GetHyperparameterTuningJob(api_key=self.api_key, logger=self.logger)
248+
repository = self.build_repository(repositories.GetHyperparameterTuningJob)
249249
job = repository.get(id=id)
250250
return job
251251

@@ -256,7 +256,7 @@ def start(self, id):
256256
:raises: exceptions.GradientSdkError
257257
"""
258258

259-
repository = repositories.StartHyperparameterTuningJob(api_key=self.api_key, logger=self.logger)
259+
repository = self.build_repository(repositories.StartHyperparameterTuningJob)
260260
repository.start(id_=id)
261261

262262
def list(self):
@@ -281,6 +281,6 @@ def list(self):
281281
282282
:rtype: list[models.Hyperparameter]
283283
"""
284-
repository = repositories.ListHyperparameterJobs(api_key=self.api_key, logger=self.logger)
284+
repository = self.build_repository(repositories.ListHyperparameterJobs)
285285
experiments = repository.list()
286286
return experiments

0 commit comments

Comments
 (0)