Skip to content

Commit 7c59a2a

Browse files
authored
Merge pull request #81 from Paperspace/PS-11155-Add-notebooks-support-to-CLI-and-SDK
Ps 11155 add notebooks support to cli and sdk
2 parents 49927dc + 877b899 commit 7c59a2a

27 files changed

+1790
-174
lines changed

gradient/api_sdk/clients/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
from .job_client import JobsClient
55
from .machines_client import MachinesClient
66
from .model_client import ModelsClient
7+
from .notebook_client import NotebooksClient
78
from .project_client import ProjectsClient
89
from .sdk_client import SdkClient
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from .base_client import BaseClient
2+
from .. import repositories, models
3+
4+
5+
class NotebooksClient(BaseClient):
6+
def create(
7+
self,
8+
vm_type_id,
9+
container_id,
10+
cluster_id,
11+
container_name=None,
12+
name=None,
13+
registry_username=None,
14+
registry_password=None,
15+
default_entrypoint=None,
16+
container_user=None,
17+
shutdown_timeout=None,
18+
is_preemptible=None,
19+
):
20+
"""Create new notebook
21+
22+
:param int vm_type_id:
23+
:param int container_id:
24+
:param int cluster_id:
25+
:param str container_name:
26+
:param str name:
27+
:param str registry_username:
28+
:param str registry_password:
29+
:param str default_entrypoint:
30+
:param str container_user:
31+
:param int|float shutdown_timeout:
32+
:param bool is_preemptible:
33+
34+
:return: Notebook ID
35+
:rtype str:
36+
"""
37+
38+
notebook = models.Notebook(
39+
vm_type_id=vm_type_id,
40+
container_id=container_id,
41+
cluster_id=cluster_id,
42+
container_name=container_name,
43+
name=name,
44+
registry_username=registry_username,
45+
registry_password=registry_password,
46+
default_entrypoint=default_entrypoint,
47+
container_user=container_user,
48+
shutdown_timeout=shutdown_timeout,
49+
is_preemptible=is_preemptible,
50+
)
51+
52+
repository = repositories.CreateNotebook(api_key=self.api_key, logger=self.logger)
53+
handle = repository.create(notebook)
54+
return handle
55+
56+
def get(self, id):
57+
"""Get Notebook
58+
59+
:param str id: Notebook ID
60+
:rtype: models.Notebook
61+
"""
62+
repository = repositories.GetNotebook(api_key=self.api_key, logger=self.logger)
63+
notebook = repository.get(id=id)
64+
return notebook
65+
66+
def delete(self, id):
67+
"""Delete existing notebook
68+
69+
:param str id: Notebook ID
70+
"""
71+
repository = repositories.DeleteNotebook(api_key=self.api_key, logger=self.logger)
72+
repository.delete(id)
73+
74+
def list(self):
75+
"""Get list of Notebooks
76+
77+
:rtype: list[models.Notebook]
78+
"""
79+
repository = repositories.ListNotebooks(api_key=self.api_key, logger=self.logger)
80+
notebooks = repository.list()
81+
return notebooks

gradient/api_sdk/clients/sdk_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from . import DeploymentsClient, ExperimentsClient, HyperparameterJobsClient, ModelsClient, ProjectsClient
1+
from . import DeploymentsClient, ExperimentsClient, HyperparameterJobsClient, ModelsClient, ProjectsClient, \
2+
MachinesClient, NotebooksClient
23
from .job_client import JobsClient
34
from .. import logger as sdk_logger
45

@@ -15,3 +16,5 @@ def __init__(self, api_key, logger=sdk_logger.MuteLogger()):
1516
self.models = ModelsClient(api_key=api_key, logger=logger)
1617
self.jobs = JobsClient(api_key=api_key, logger=logger)
1718
self.projects = ProjectsClient(api_key=api_key, logger=logger)
19+
self.machines = MachinesClient(api_key=api_key, logger=logger)
20+
self.notebooks = NotebooksClient(api_key=api_key, logger=logger)

gradient/api_sdk/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .log import LogRow
66
from .machine import Machine, MachineEvent, MachineUtilization
77
from .model import Model
8+
from .notebook import Notebook
89
from .project import Project
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import attr
2+
3+
4+
@attr.s
5+
class Notebook(object):
6+
id = attr.ib(type=str, default=None)
7+
vm_type_id = attr.ib(type=int, default=None)
8+
container_id = attr.ib(type=int, default=None)
9+
container_name = attr.ib(type=str, default=None)
10+
name = attr.ib(type=str, default=None)
11+
cluster_id = attr.ib(type=int, default=None)
12+
registry_username = attr.ib(type=str, default=None)
13+
registry_password = attr.ib(type=str, default=None)
14+
default_entrypoint = attr.ib(type=str, default=None)
15+
container_user = attr.ib(type=str, default=None)
16+
shutdown_timeout = attr.ib(type=int, default=None)
17+
is_preemptible = attr.ib(type=bool, default=None)
18+
project_id = attr.ib(type=bool, default=None)
19+
state = attr.ib(type=bool, default=None)
20+
vm_type = attr.ib(type=bool, default=None)
21+
fqdn = attr.ib(type=bool, default=None)

gradient/api_sdk/repositories/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
from .machines import CheckMachineAvailability, CreateMachine, CreateResource, StartMachine, StopMachine, \
88
RestartMachine, GetMachine, UpdateMachine, GetMachineUtilization
99
from .models import ListModels
10+
from .notebooks import CreateNotebook, DeleteNotebook, GetNotebook, ListNotebooks
1011
from .projects import CreateProject, ListProjects

gradient/api_sdk/repositories/jobs.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,6 @@ def _get_api_url(self, **_):
99
return config.config.CONFIG_HOST
1010

1111

12-
class ParseJobDictMixin(object):
13-
@staticmethod
14-
def _parse_object(job_dict, **kwargs):
15-
"""
16-
17-
:param job_dict:
18-
:param kwargs:
19-
:return:
20-
:rtype: Job
21-
"""
22-
job = JobSchema().get_instance(job_dict)
23-
return job
24-
25-
2612
class ListJobs(GetBaseJobApiUrlMixin, ListResources):
2713

2814
def get_request_url(self, **kwargs):
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from gradient import config
2+
from .common import CreateResource, DeleteResource, ListResources, GetResource
3+
from .. import serializers
4+
5+
6+
class GetNotebookApiUrlMixin(object):
7+
def _get_api_url(self, use_vpc=False):
8+
return config.config.CONFIG_HOST
9+
10+
11+
class CreateNotebook(GetNotebookApiUrlMixin, CreateResource):
12+
SERIALIZER_CLS = serializers.NotebookSchema
13+
14+
def get_request_url(self, **kwargs):
15+
return "notebooks/createNotebook"
16+
17+
def _process_instance_dict(self, instance_dict):
18+
# the API requires this field but marshmallow does not create it if it's value is None
19+
instance_dict.setdefault("containerId")
20+
return instance_dict
21+
22+
23+
class DeleteNotebook(GetNotebookApiUrlMixin, DeleteResource):
24+
def get_request_url(self, **kwargs):
25+
return "notebooks/v2/deleteNotebook"
26+
27+
def _get_request_json(self, kwargs):
28+
notebook_id = kwargs["id"]
29+
d = {"notebookId": notebook_id}
30+
return d
31+
32+
def _send_request(self, client, url, json_data=None):
33+
response = client.post(url, json=json_data)
34+
return response
35+
36+
37+
class GetNotebook(GetNotebookApiUrlMixin, GetResource):
38+
def get_request_url(self, **kwargs):
39+
notebook_id = kwargs["id"]
40+
url = "notebooks/{}/getNotebook".format(notebook_id)
41+
return url
42+
43+
def _parse_object(self, data, **kwargs):
44+
# this ugly hack is here because marshmallow disallows reading value into `id` field
45+
# if JSON's field was named differently (despite using load_from in schema definition)
46+
data["id"] = data["handle"]
47+
48+
serializer = serializers.NotebookSchema()
49+
notebooks = serializer.get_instance(data)
50+
return notebooks
51+
52+
53+
class ListNotebooks(GetNotebookApiUrlMixin, ListResources):
54+
def get_request_url(self, **kwargs):
55+
return "notebooks/getNotebooks"
56+
57+
def _parse_objects(self, data, **kwargs):
58+
notebook_dicts = data["notebookList"]
59+
# this ugly hack is here because marshmallow disallows reading value into `id` field
60+
# if JSON's field was named differently (despite using load_from in schema definition)
61+
for d in notebook_dicts:
62+
d["id"] = d["handle"]
63+
64+
serializer = serializers.NotebookSchema()
65+
notebooks = serializer.get_instance(notebook_dicts, many=True)
66+
return notebooks
67+
68+
def _get_request_json(self, kwargs):
69+
json_ = {
70+
"filter": {
71+
"filter": {
72+
"limit": 11,
73+
"offset": 0,
74+
"where": {
75+
"dtDeleted": None,
76+
},
77+
"order": "jobId desc",
78+
},
79+
},
80+
}
81+
return json_

gradient/api_sdk/serializers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
from .log import LogRowSchema
77
from .machine import MachineSchema, MachineSchemaForListing, MachineEventSchema
88
from .model import Model
9+
from .notebook import NotebookSchema
910
from .project import Project
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import marshmallow
2+
3+
from . import BaseSchema
4+
from .. import models
5+
6+
7+
class NotebookSchema(BaseSchema):
8+
MODEL = models.Notebook
9+
10+
id = marshmallow.fields.Str()
11+
vm_type_id = marshmallow.fields.Int(load_from="vmTypeId", dump_to="vmTypeId")
12+
container_id = marshmallow.fields.Int(load_from="containerId", dump_to="containerId", allow_none=True)
13+
container_name = marshmallow.fields.Str(load_from="containerName", dump_to="containerName", allow_none=True)
14+
name = marshmallow.fields.Str()
15+
cluster_id = marshmallow.fields.Int(load_from="clusterId", dump_to="clusterId")
16+
registry_username = marshmallow.fields.Str(load_from="registryUsername", dump_to="registryUsername")
17+
registry_password = marshmallow.fields.Str(load_from="registryPassword", dump_to="registryPassword")
18+
default_entrypoint = marshmallow.fields.Str(load_from="defaultEntrypoint", dump_to="defaultEntrypoint")
19+
container_user = marshmallow.fields.Str(load_from="containerUser", dump_to="containerUser")
20+
shutdown_timeout = marshmallow.fields.Int(load_from="shutdownTimeout", dump_to="shutdownTimeout")
21+
is_preemptible = marshmallow.fields.Bool(load_from="isPreemptible", dump_to="isPreemptible")
22+
project_id = marshmallow.fields.Str(load_from="projectHandle", dump_to="projectHandle")
23+
state = marshmallow.fields.Str()
24+
vm_type = marshmallow.fields.Str(load_from="vmType", dump_to="vmType")
25+
fqdn = marshmallow.fields.Str()

0 commit comments

Comments
 (0)