Skip to content

Commit 63681af

Browse files
authored
Merge pull request #80 from Paperspace/PS-11043-add-machines-api-support-to-the-SDK
Ps 11043 add machines api support to the sdk
2 parents b8d0377 + af0b2d8 commit 63681af

File tree

20 files changed

+987
-428
lines changed

20 files changed

+987
-428
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .deployment_client import DeploymentsClient
22
from .experiment_client import ExperimentsClient
33
from .hyperparameter_client import HyperparameterJobsClient
4+
from .job_client import JobsClient
5+
from .machines_client import MachinesClient
46
from .model_client import ModelsClient
57
from .project_client import ProjectsClient
6-
from .job_client import JobsClient
78
from .sdk_client import SdkClient

gradient/api_sdk/clients/hyperparameter_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,28 +222,28 @@ def run(
222222
handle = repository.create(hyperparameter)
223223
return handle
224224

225-
def get(self, id_):
225+
def get(self, id):
226226
"""Get Hyperparameter tuning job's instance
227227
228-
:param str id_: Hyperparameter job id
228+
:param str id: Hyperparameter job id
229229
230230
:returns: instance of Hyperparameter
231231
:rtype: models.Hyperparameter
232232
"""
233233

234234
repository = repositories.GetHyperparameterTuningJob(api_key=self.api_key, logger=self.logger)
235-
job = repository.get(id=id_)
235+
job = repository.get(id=id)
236236
return job
237237

238-
def start(self, id_):
238+
def start(self, id):
239239
"""Start existing hyperparameter tuning job
240240
241-
:param str id_: Hyperparameter job id
241+
:param str id: Hyperparameter job id
242242
:raises: exceptions.GradientSdkError
243243
"""
244244

245245
repository = repositories.StartHyperparameterTuningJob(api_key=self.api_key, logger=self.logger)
246-
repository.start(id_=id_)
246+
repository.start(id_=id)
247247

248248
def list(self):
249249
"""Get a list of hyperparameter tuning jobs
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
from gradient.api_sdk import repositories, models
2+
from gradient.api_sdk.repositories.machines import CheckMachineAvailability, DeleteMachine, ListMachines, WaitForState
3+
from .base_client import BaseClient
4+
5+
6+
class MachinesClient(BaseClient):
7+
def create(
8+
self,
9+
name,
10+
machine_type,
11+
region,
12+
size,
13+
billing_type,
14+
template_id,
15+
assign_public_ip=None,
16+
dynamic_public_ip=None,
17+
network_id=None,
18+
team_id=None,
19+
user_id=None,
20+
email=None,
21+
password=None,
22+
first_name=None,
23+
last_name=None,
24+
notification_email=None,
25+
script_id=None,
26+
):
27+
"""Create new machine
28+
29+
:param str name: A memorable name for this machine [required]
30+
:param str machine_type: Machine type [required]
31+
:param str region: Name of the region [required]
32+
:param str size: Storage size for the machine in GB [required]
33+
:param str billing_type: Either 'monthly' or 'hourly' billing [required]
34+
:param str template_id: Template id of the template to use for creating this machine [required]
35+
:param bool assign_public_ip: Assign a new public ip address. Cannot be used with dynamic_public_ip
36+
:param bool dynamic_public_ip: Temporarily assign a new public ip address on machine.
37+
Cannot be used with assign_public_ip
38+
:param str network_id: If creating on a specific network, specify its id
39+
:param str team_id: If creating the machine for a team, specify the team id
40+
:param str user_id: If assigning to an existing user other than yourself, specify the user id
41+
(mutually exclusive with email, password, first_name, last_name)
42+
:param str email: If creating a new user for this machine, specify their email address
43+
(mutually exclusive with user_id)
44+
:param str password: If creating a new user, specify their password (mutually exclusive with user_id)
45+
:param str first_name: If creating a new user, specify their first name (mutually exclusive with user_id)
46+
:param str last_name: If creating a new user, specify their last name (mutually exclusive with user_id)
47+
:param str notification_email: Send a notification to this email address when complete
48+
:param str script_id: The script id of a script to be run on startup
49+
50+
:returns: ID of created machine
51+
:rtype: str
52+
"""
53+
54+
instance = models.Machine(
55+
name=name,
56+
machine_type=machine_type,
57+
region=region,
58+
size=size,
59+
billing_type=billing_type,
60+
template_id=template_id,
61+
assign_public_ip=assign_public_ip,
62+
dynamic_public_ip=dynamic_public_ip,
63+
network_id=network_id,
64+
team_id=team_id,
65+
user_id=user_id,
66+
email=email,
67+
password=password,
68+
first_name=first_name,
69+
last_name=last_name,
70+
notification_email=notification_email,
71+
script_id=script_id,
72+
)
73+
74+
repository = repositories.CreateMachine(api_key=self.api_key, logger=self.logger)
75+
handle = repository.create(instance)
76+
return handle
77+
78+
def get(self, id):
79+
"""Get machine instance
80+
81+
:param str id: ID of a machine [required]
82+
83+
:return: Machine instance
84+
:rtype: models.Machine
85+
"""
86+
repository = repositories.GetMachine(api_key=self.api_key, logger=self.logger)
87+
instance = repository.get(id=id)
88+
return instance
89+
90+
def is_available(self, machine_type, region):
91+
"""Check if specified machine is available in certain region
92+
93+
:param str machine_type: Machine type [required]
94+
:param str region: Name of the region [required]
95+
96+
:return: If specified machine is available in the region
97+
:rtype: bool
98+
s"""
99+
100+
repository = CheckMachineAvailability(api_key=self.api_key, logger=self.logger)
101+
handle = repository.get(machine_type=machine_type, region=region)
102+
return handle
103+
104+
def restart(self, id):
105+
"""Restart machine
106+
107+
:param str id: ID of a machine [required]
108+
"""
109+
110+
repository = repositories.RestartMachine(api_key=self.api_key, logger=self.logger)
111+
repository.restart(id)
112+
113+
def start(self, id):
114+
"""Start machine
115+
116+
:param str id: ID of a machine [required]
117+
"""
118+
119+
repository = repositories.StartMachine(api_key=self.api_key, logger=self.logger)
120+
repository.start(id)
121+
122+
def stop(self, id):
123+
"""Stop machine
124+
125+
:param str id: ID of a machine [required]
126+
"""
127+
128+
repository = repositories.StopMachine(api_key=self.api_key, logger=self.logger)
129+
repository.stop(id)
130+
131+
def update(
132+
self,
133+
id,
134+
name=None,
135+
shutdown_timeout_in_hours=None,
136+
shutdown_timeout_forces=None,
137+
perform_auto_snapshot=None,
138+
auto_snapshot_frequency=None,
139+
auto_snapshot_save_count=None,
140+
dynamic_public_ip=None,
141+
):
142+
"""Update machine instance
143+
144+
:param str id: Id of the machine to update [required]
145+
:param str name: New name for the machine
146+
:param int shutdown_timeout_in_hours: Number of hours before machine is shutdown if no one is logged in
147+
via the Paperspace client
148+
:param bool shutdown_timeout_forces: Force shutdown at shutdown timeout, even if there is
149+
a Paperspace client connection
150+
:param bool perform_auto_snapshot: Perform auto snapshots
151+
:param str auto_snapshot_frequency: One of 'hour', 'day', 'week', or None
152+
:param int auto_snapshot_save_count: Number of snapshots to save
153+
:param str dynamic_public_ip: If true, assigns a new public ip address on machine start and releases it
154+
from the account on machine stop
155+
"""
156+
instance = models.Machine(
157+
name=name,
158+
dynamic_public_ip=dynamic_public_ip,
159+
shutdown_timeout_in_hours=shutdown_timeout_in_hours,
160+
shutdown_timeout_forces=shutdown_timeout_forces,
161+
perform_auto_snapshot=perform_auto_snapshot,
162+
auto_snapshot_frequency=auto_snapshot_frequency,
163+
auto_snapshot_save_count=auto_snapshot_save_count,
164+
)
165+
166+
repository = repositories.UpdateMachine(api_key=self.api_key, logger=self.logger)
167+
repository.update(id, instance)
168+
169+
def get_utilization(self, id, billing_month):
170+
"""
171+
172+
:param id: ID of the machine
173+
:param billing_month: Billing month in "YYYY-MM" format
174+
175+
:return: Machine utilization info
176+
:rtype: models.MachineUtilization
177+
"""
178+
repository = repositories.GetMachineUtilization(api_key=self.api_key, logger=self.logger)
179+
usage = repository.get(id=id, billing_month=billing_month)
180+
return usage
181+
182+
def delete(self, machine_id, release_public_ip=False):
183+
"""Destroy machine with given ID
184+
185+
:param str machine_id: ID of the machine
186+
:param bool release_public_ip: If the assigned public IP should be released
187+
"""
188+
189+
repository = DeleteMachine(api_key=self.api_key, logger=self.logger)
190+
repository.delete(machine_id, release_public_ip=release_public_ip)
191+
192+
def wait_for_state(self, machine_id, state, interval=5):
193+
"""Wait for defined machine state
194+
195+
:param str machine_id: ID of the machine
196+
:param str state: State of machine to wait for
197+
:param int interval: interval between polls
198+
"""
199+
200+
repository = WaitForState(api_key=self.api_key, logger=self.logger)
201+
repository.wait_for_state(machine_id, state, interval)
202+
203+
def list(
204+
self,
205+
id=None,
206+
name=None,
207+
os=None,
208+
ram=None,
209+
cpus=None,
210+
gpu=None,
211+
storage_total=None,
212+
storage_used=None,
213+
usage_rate=None,
214+
shutdown_timeout_in_hours=None,
215+
perform_auto_snapshot=None,
216+
auto_snapshot_frequency=None,
217+
auto_snapshot_save_count=None,
218+
agent_type=None,
219+
created_timestamp=None,
220+
state=None,
221+
updates_pending=None,
222+
network_id=None,
223+
private_ip_address=None,
224+
public_ip_address=None,
225+
region=None,
226+
user_id=None,
227+
team_id=None,
228+
last_run_timestamp=None,
229+
):
230+
"""
231+
232+
:param str id: Optional machine id to match on
233+
:param str name: Filter by machine name
234+
:param str os: Filter by os used
235+
:param int ram: Filter by machine RAM (in bytes)
236+
:param int cpus: Filter by CPU count
237+
:param str gpu: Filter by GPU type
238+
:param str storage_total: Filter by total storage
239+
:param str storage_used: Filter by storage used
240+
:param str usage_rate: Filter by usage rate
241+
:param int shutdown_timeout_in_hours: Filter by shutdown timeout
242+
:param bool perform_auto_snapshot: Filter by performAutoSnapshot flag
243+
:param str auto_snapshot_frequency: Filter by autoSnapshotFrequency flag
244+
:param int auto_snapshot_save_count: Filter by auto shapshots count
245+
:param str agent_type: Filter by agent type
246+
:param datetime created_timestamp: Filter by date created
247+
:param str state: Filter by state
248+
:param str updates_pending: Filter by updates pending
249+
:param str network_id: Filter by network ID
250+
:param str private_ip_address: Filter by private IP address
251+
:param str public_ip_address: Filter by public IP address
252+
:param str region: Filter by region. One of {CA, NY2, AMS1}
253+
:param str user_id: Filter by user ID
254+
:param str team_id: Filter by team ID
255+
:param str last_run_timestamp: Filter by last run date
256+
257+
:return: List of machines
258+
:rtype: list[models.Machine]
259+
"""
260+
261+
repository = ListMachines(api_key=self.api_key, logger=self.logger)
262+
machines = repository.list(
263+
id=id,
264+
name=name,
265+
os=os,
266+
ram=ram,
267+
cpus=cpus,
268+
gpu=gpu,
269+
storage_total=storage_total,
270+
storage_used=storage_used,
271+
usage_rate=usage_rate,
272+
shutdown_timeout_in_hours=shutdown_timeout_in_hours,
273+
perform_auto_snapshot=perform_auto_snapshot,
274+
auto_snapshot_frequency=auto_snapshot_frequency,
275+
auto_snapshot_save_count=auto_snapshot_save_count,
276+
agent_type=agent_type,
277+
created_timestamp=created_timestamp,
278+
state=state,
279+
updates_pending=updates_pending,
280+
network_id=network_id,
281+
private_ip_address=private_ip_address,
282+
public_ip_address=public_ip_address,
283+
region=region,
284+
user_id=user_id,
285+
team_id=team_id,
286+
last_run_timestamp=last_run_timestamp,
287+
)
288+
return machines

gradient/api_sdk/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .deployment import Deployment
22
from .experiment import BaseExperiment, MultiNodeExperiment, SingleNodeExperiment
33
from .hyperparameter import Hyperparameter
4+
from .job import Job
45
from .log import LogRow
6+
from .machine import Machine, MachineEvent, MachineUtilization
57
from .model import Model
68
from .project import Project
7-
from .job import Job

0 commit comments

Comments
 (0)