Skip to content

Commit 6aa7b6d

Browse files
committed
Merge branch 'master' of github.com:Paperspace/gradient-cli into PS-12720-Make_workspace_optional_for_v2
2 parents 5f371ca + 2743c5f commit 6aa7b6d

File tree

147 files changed

+5111
-1909
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

147 files changed

+5111
-1909
lines changed
Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,124 @@
1+
import copy
2+
3+
from gradient.api_sdk.repositories.common import BaseRepository
14
from .. import logger as sdk_logger
5+
from ..repositories.tags import ListTagRepository, UpdateTagRepository
6+
from ...exceptions import ReceivingDataFailedError
27

38

49
class BaseClient(object):
510
def __init__(
6-
self, api_key,
11+
self,
12+
api_key,
13+
ps_client_name=None,
714
logger=sdk_logger.MuteLogger()
815
):
916
"""
1017
Base class. All client classes inherit from it.
1118
12-
An API key can be created at paperspace.com after you sign in to your account. After obtaining it, you can set
13-
it in the CLI using the command::
14-
15-
gradient apiKey XXXXXXXXXXXXXXXXXXX
16-
17-
or you can provide your API key in any command, for example::
18-
19-
gradient experiments run ... --apiKey XXXXXXXXXXXXXXXXXXX
20-
2119
:param str api_key: your API key
20+
:param str ps_client_name:
2221
:param sdk_logger.Logger logger:
2322
"""
2423
self.api_key = api_key
24+
self.ps_client_name = ps_client_name
2525
self.logger = logger
26+
27+
KNOWN_TAGS_ENTITIES = [
28+
"project", "job", "notebook", "experiment", "deployment", "mlModel", "machine",
29+
]
30+
entity = ""
31+
32+
def _validate_entities(self, entity):
33+
"""
34+
Method to validate if passed entity is correct
35+
:param entity:
36+
:return:
37+
"""
38+
if entity not in self.KNOWN_TAGS_ENTITIES:
39+
raise ReceivingDataFailedError("Not known entity type provided")
40+
41+
@staticmethod
42+
def merge_tags(entity_id, entity_tags, new_tags):
43+
result_tags = []
44+
if entity_tags:
45+
entity_tags = entity_tags[0].get(entity_id, [])
46+
47+
result_tags = entity_tags + new_tags
48+
else:
49+
result_tags += new_tags
50+
return sorted(list(set(result_tags)))
51+
52+
@staticmethod
53+
def diff_tags(entity_id, entity_tags, tags_to_remove):
54+
result_tags = []
55+
if entity_tags:
56+
entity_tags = entity_tags[0].get(entity_id, [])
57+
entity_tags = set(entity_tags) - set(tags_to_remove)
58+
result_tags = sorted(list(entity_tags))
59+
60+
return result_tags
61+
62+
def add_tags(self, entity_id, entity, tags):
63+
"""
64+
Add tags to entity.
65+
:param entity_id:
66+
:param entity:
67+
:param tags:
68+
:return:
69+
"""
70+
self._validate_entities(entity)
71+
72+
list_tag_repository = self.build_repository(ListTagRepository)
73+
entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id])
74+
75+
if entity_tags:
76+
tags = self.merge_tags(entity_id, entity_tags, tags)
77+
78+
update_tag_repository = self.build_repository(UpdateTagRepository)
79+
update_tag_repository.update(entity=entity, entity_id=entity_id, tags=tags)
80+
81+
def remove_tags(self, entity_id, entity, tags):
82+
"""
83+
Remove tags from entity.
84+
:param str entity_id:
85+
:param str entity:
86+
:param list[str] tags: list of tags to remove from entity
87+
:return:
88+
"""
89+
self._validate_entities(entity)
90+
91+
list_tag_repository = self.build_repository(ListTagRepository)
92+
entity_tags = list_tag_repository.list(entity=entity, entity_ids=[entity_id])
93+
94+
if entity_tags:
95+
entity_tags = self.diff_tags(entity_id, entity_tags, tags)
96+
97+
update_tag_repository = self.build_repository(UpdateTagRepository)
98+
update_tag_repository.update(entity=entity, entity_id=entity_id, tags=entity_tags)
99+
100+
def list_tags(self, entity_ids, entity):
101+
"""
102+
List tags for entity
103+
:param list[str] entity_ids:
104+
:param str entity:
105+
:return:
106+
"""
107+
self._validate_entities(entity)
108+
109+
list_tag_repository = self.build_repository(ListTagRepository)
110+
entity_tags = list_tag_repository.list(entity=entity, entity_ids=entity_ids)
111+
return entity_tags
112+
113+
def build_repository(self, repository_class, *args, **kwargs):
114+
"""
115+
:param type[BaseRepository] repository_class:
116+
:rtype: BaseRepository
117+
"""
118+
119+
if self.ps_client_name is not None and kwargs.get("ps_client_name") is None:
120+
kwargs = copy.deepcopy(kwargs)
121+
kwargs["ps_client_name"] = self.ps_client_name
122+
123+
repository = repository_class(*args, api_key=self.api_key, logger=self.logger, **kwargs)
124+
return repository
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from gradient.api_sdk import repositories
2+
3+
from gradient.api_sdk.clients.base_client import BaseClient
4+
5+
6+
class ClustersClient(BaseClient):
7+
def list(self, limit=11, offset=0, **kwargs):
8+
"""
9+
Get a list of clusters for your team
10+
11+
:param int limit: how many element to return on request
12+
:param int offset: from what position we should return clusters
13+
14+
:return: clusters
15+
:rtype: list
16+
"""
17+
repository = repositories.ListClusters(api_key=self.api_key, logger=self.logger)
18+
clusters = repository.list(limit=limit, offset=offset)
19+
return clusters

gradient/api_sdk/clients/deployment_client.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class DeploymentsClient(BaseClient):
2424
)
2525
"""
2626
HOST_URL = config.config.CONFIG_HOST
27+
entity = "deployment"
2728

2829
def create(
2930
self,
@@ -46,7 +47,7 @@ def create(
4647
cluster_id=None,
4748
auth_username=None,
4849
auth_password=None,
49-
use_vpc=False,
50+
tags=None,
5051
):
5152
"""
5253
Method to create a Deployment instance.
@@ -89,7 +90,7 @@ def create(
8990
:param str cluster_id: cluster ID
9091
:param str auth_username: Username
9192
:param str auth_password: Password
92-
:param bool use_vpc:
93+
:param list[str] tags: List of tags
9394
9495
:returns: Created deployment id
9596
:rtype: str
@@ -117,8 +118,10 @@ def create(
117118
auth_password=auth_password,
118119
)
119120

120-
repository = repositories.CreateDeployment(api_key=self.api_key, logger=self.logger)
121-
deployment_id = repository.create(deployment, use_vpc=use_vpc)
121+
repository = self.build_repository(repositories.CreateDeployment)
122+
deployment_id = repository.create(deployment)
123+
if tags:
124+
self.add_tags(entity_id=deployment_id, entity=self.entity, tags=tags)
122125
return deployment_id
123126

124127
def get(self, deployment_id):
@@ -129,11 +132,11 @@ def get(self, deployment_id):
129132
:return: Deployment instance
130133
:rtype: models.Deployment
131134
"""
132-
repository = repositories.GetDeployment(self.api_key, logger=self.logger)
135+
repository = self.build_repository(repositories.GetDeployment)
133136
deployment = repository.get(deployment_id=deployment_id)
134137
return deployment
135138

136-
def start(self, deployment_id, use_vpc=False):
139+
def start(self, deployment_id):
137140
"""
138141
Start deployment
139142
@@ -142,13 +145,12 @@ def start(self, deployment_id, use_vpc=False):
142145
gradient deployments start --id <your-deployment-id>
143146
144147
:param str deployment_id: Deployment ID
145-
:param bool use_vpc:
146148
"""
147149

148-
repository = repositories.StartDeployment(api_key=self.api_key, logger=self.logger)
149-
repository.start(deployment_id, use_vpc=use_vpc)
150+
repository = self.build_repository(repositories.StartDeployment)
151+
repository.start(deployment_id)
150152

151-
def stop(self, deployment_id, use_vpc=False):
153+
def stop(self, deployment_id):
152154
"""
153155
Stop deployment
154156
@@ -157,29 +159,31 @@ def stop(self, deployment_id, use_vpc=False):
157159
gradient deployments stop --id <your-deployment-id>
158160
159161
:param deployment_id: Deployment ID
160-
:param bool use_vpc:
161162
"""
162163

163-
repository = repositories.StopDeployment(api_key=self.api_key, logger=self.logger)
164-
repository.stop(deployment_id, use_vpc=use_vpc)
164+
repository = self.build_repository(repositories.StopDeployment)
165+
repository.stop(deployment_id)
165166

166-
def list(self, state=None, project_id=None, model_id=None, use_vpc=False):
167+
def list(self, state=None, project_id=None, model_id=None, tags=None):
167168
"""
168169
List deployments with optional filtering
169170
170-
:param str state:
171-
:param str project_id:
172-
:param str model_id:
173-
:param bool use_vpc:
171+
:param str state: state to filter deployments
172+
:param str project_id: project ID to filter deployments
173+
:param str model_id: model ID to filter deployments
174+
:param list[str]|tuple[str] tags: tags to filter deployments with OR
175+
176+
:returns: List of Deployment model instances
177+
:rtype: list[models.Deployment]
174178
"""
175179

176-
repository = repositories.ListDeployments(api_key=self.api_key, logger=self.logger)
177-
deployments = repository.list(state=state, project_id=project_id, model_id=model_id, use_vpc=use_vpc)
180+
repository = self.build_repository(repositories.ListDeployments)
181+
deployments = repository.list(state=state, project_id=project_id, model_id=model_id, tags=tags)
178182
return deployments
179183

180-
def delete(self, deployment_id, use_vpc=False):
181-
repository = repositories.DeleteDeployment(api_key=self.api_key, logger=self.logger)
182-
repository.delete(deployment_id, use_vpc=use_vpc)
184+
def delete(self, deployment_id):
185+
repository = self.build_repository(repositories.DeleteDeployment)
186+
repository.delete(deployment_id)
183187

184188
def update(
185189
self,
@@ -204,7 +208,6 @@ def update(
204208
cluster_id=None,
205209
auth_username=None,
206210
auth_password=None,
207-
use_vpc=False,
208211
):
209212
deployment = models.Deployment(
210213
deployment_type=deployment_type,
@@ -229,5 +232,5 @@ def update(
229232
auth_password=auth_password,
230233
)
231234

232-
repository = repositories.UpdateDeployment(api_key=self.api_key, logger=self.logger)
233-
repository.update(deployment_id, deployment, use_vpc=use_vpc)
235+
repository = self.build_repository(repositories.UpdateDeployment)
236+
repository.update(deployment_id, deployment)

0 commit comments

Comments
 (0)