Skip to content

Commit 5658ace

Browse files
committed
base64-encode commands sent to the APIs so they are not blocked by Cloudflare's filter
(cherry picked from commit 8b6857e)
1 parent 76b5cf8 commit 5658ace

File tree

6 files changed

+272
-201
lines changed

6 files changed

+272
-201
lines changed

Pipfile.lock

Lines changed: 208 additions & 182 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gradient/api_sdk/serializers/experiment.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import copy
2+
13
import marshmallow
24

3-
from . import BaseSchema
4-
from .. import models
5+
from .base import BaseSchema
6+
from .. import models, utils
57

68

79
class BaseExperimentSchema(BaseSchema):
@@ -20,7 +22,7 @@ class BaseExperimentSchema(BaseSchema):
2022
id = marshmallow.fields.Str(load_from="handle")
2123
state = marshmallow.fields.Int()
2224

23-
def get_instance(self, obj_dict):
25+
def get_instance(self, obj_dict, many=False):
2426
# without popping these marshmallow wouldn't use load_from
2527
obj_dict.pop("id", None)
2628
obj_dict.pop("project_id", None)
@@ -29,7 +31,7 @@ def get_instance(self, obj_dict):
2931
if isinstance(ports, int):
3032
obj_dict["ports"] = str(ports)
3133

32-
instance = super(BaseExperimentSchema, self).get_instance(obj_dict)
34+
instance = super(BaseExperimentSchema, self).get_instance(obj_dict, many=many)
3335
return instance
3436

3537

@@ -44,6 +46,13 @@ class SingleNodeExperimentSchema(BaseExperimentSchema):
4446
registry_password = marshmallow.fields.Str(dump_to="registryPassword", load_from="registryPassword")
4547
registry_url = marshmallow.fields.Str(dump_to="registryUrl", load_from="registryUrl")
4648

49+
@marshmallow.pre_dump
50+
def preprocess(self, data, **kwargs):
51+
data = copy.copy(data)
52+
53+
utils.base64_encode_attribute(data, "command")
54+
return data
55+
4756

4857
class MultiNodeExperimentSchema(BaseExperimentSchema):
4958
MODEL = models.MultiNodeExperiment
@@ -76,3 +85,11 @@ class MultiNodeExperimentSchema(BaseExperimentSchema):
7685
load_from="parameterServerRegistryPassword")
7786
parameter_server_registry_url = marshmallow.fields.Str(dump_to="parameterServerRegistryUrl",
7887
load_from="parameterServerRegistryUrl")
88+
89+
@marshmallow.pre_dump
90+
def preprocess(self, data, **kwargs):
91+
data = copy.copy(data)
92+
93+
utils.base64_encode_attribute(data, "worker_command")
94+
utils.base64_encode_attribute(data, "parameter_server_command")
95+
return data

gradient/api_sdk/serializers/hyperparameter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import copy
2+
13
import marshmallow
24

35
from .experiment import BaseExperimentSchema
4-
from .. import models
6+
from .. import models, utils
57

68

79
class HyperparameterSchema(BaseExperimentSchema):
@@ -28,3 +30,11 @@ class HyperparameterSchema(BaseExperimentSchema):
2830
load_from="hyperparameterServerContainer")
2931
hyperparameter_server_container_user = marshmallow.fields.Str(dump_to="hyperparameterServerContainerUser",
3032
load_from="hyperparameterServerContainerUser")
33+
34+
@marshmallow.pre_dump
35+
def preprocess(self, data, **kwargs):
36+
data = copy.copy(data)
37+
38+
utils.base64_encode_attribute(data, "worker_command")
39+
utils.base64_encode_attribute(data, "tuning_command")
40+
return data

gradient/api_sdk/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
from collections import OrderedDict
23

34
import six
@@ -48,3 +49,20 @@ def print_dict_recursive(input_dict, logger, indent=0, tabulator=" "):
4849
print_dict_recursive(OrderedDict(val), logger, indent + 1)
4950
else:
5051
logger.log("%s%s" % (tabulator * (indent + 1), val))
52+
53+
54+
def base64_encode(s):
55+
if six.PY3:
56+
s = bytes(s, encoding="utf8")
57+
58+
encoded_str = base64.b64encode(s)
59+
60+
if six.PY3: # Python3's base64.b64encode returns a bytes instance so it should be converted back to unicode
61+
encoded_str = encoded_str.decode("utf-8")
62+
63+
return encoded_str
64+
65+
66+
def base64_encode_attribute(data, name):
67+
encoded_value = base64_encode(getattr(data, name))
68+
setattr(data, name, encoded_value)

tests/functional/test_experiments.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class TestExperimentsCreateSingleNode(object):
5454
"projectHandle": u"testHandle",
5555
"container": u"testContainer",
5656
"machineType": u"testType",
57-
"command": u"testCommand",
57+
"command": u"dGVzdENvbW1hbmQ=",
5858
"experimentTypeId": constants.ExperimentType.SINGLE_NODE,
5959
"workspaceUrl": u"some-workspace",
6060
}
@@ -69,7 +69,7 @@ class TestExperimentsCreateSingleNode(object):
6969
"projectHandle": u"testHandle",
7070
"container": u"testContainer",
7171
"machineType": u"testType",
72-
"command": u"testCommand",
72+
"command": u"dGVzdENvbW1hbmQ=",
7373
"containerUser": u"conUser",
7474
"registryUsername": u"userName",
7575
"registryPassword": u"passwd",
@@ -252,11 +252,11 @@ class TestExperimentsCreateMultiNode(object):
252252
u"experimentTypeId": 2,
253253
u"workerContainer": u"wcon",
254254
u"workerMachineType": u"mty",
255-
u"workerCommand": u"wcom",
255+
u"workerCommand": u"d2NvbQ==",
256256
u"workerCount": 2,
257257
u"parameterServerContainer": u"pscon",
258258
u"parameterServerMachineType": u"psmtype",
259-
u"parameterServerCommand": u"ls",
259+
u"parameterServerCommand": u"bHM=",
260260
u"parameterServerCount": 2,
261261
u"workerContainerUser": u"usr",
262262
u"workspaceUrl": u"https://github.com/Paperspace/gradient-cli.git",
@@ -273,11 +273,11 @@ class TestExperimentsCreateMultiNode(object):
273273
"experimentTypeId": 3,
274274
"workerContainer": u"wcon",
275275
"workerMachineType": u"mty",
276-
"workerCommand": u"wcom",
276+
"workerCommand": u"d2NvbQ==",
277277
"workerCount": 2,
278278
"parameterServerContainer": u"pscon",
279279
"parameterServerMachineType": u"psmtype",
280-
"parameterServerCommand": u"ls",
280+
"parameterServerCommand": u"bHM=",
281281
"parameterServerCount": 2,
282282
"workerContainerUser": u"usr",
283283
"workerRegistryUsername": u"rusr",

tests/functional/test_hyperparameters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ class TestCreateHyperparameters(object):
2424
"workerContainer": "some_container",
2525
"workerMachineType": "k80",
2626
"name": "some_name",
27-
"tuningCommand": "some command",
27+
"tuningCommand": "c29tZSBjb21tYW5k",
2828
"workerCount": 1,
29-
"workerCommand": "some worker command",
29+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
3030
"experimentTypeId": constants.ExperimentType.HYPERPARAMETER_TUNING,
3131
"projectHandle": "pr4yxj956",
3232
}
@@ -66,9 +66,9 @@ class TestCreateHyperparameters(object):
6666
"workerContainer": "some_worker_container",
6767
"workerMachineType": "k80",
6868
"name": "some_name",
69-
"tuningCommand": "some command",
69+
"tuningCommand": "c29tZSBjb21tYW5k",
7070
"workerCount": 666,
71-
"workerCommand": "some worker command",
71+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
7272
"workerRegistryUsername": "some_registry_username",
7373
"workerRegistryPassword": "some_registry_password",
7474
"workerContainerUser": "some_worker_container_user",
@@ -273,9 +273,9 @@ class TestCreateAndStartHyperparameters(object):
273273
"workerContainer": "some_container",
274274
"workerMachineType": "k80",
275275
"name": "some_name",
276-
"tuningCommand": "some command",
276+
"tuningCommand": "c29tZSBjb21tYW5k",
277277
"workerCount": 1,
278-
"workerCommand": "some worker command",
278+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
279279
"projectHandle": "pr4yxj956",
280280
"experimentTypeId": constants.ExperimentType.HYPERPARAMETER_TUNING,
281281
}
@@ -315,9 +315,9 @@ class TestCreateAndStartHyperparameters(object):
315315
"workerContainer": "some_worker_container",
316316
"workerMachineType": "k80",
317317
"name": "some_name",
318-
"tuningCommand": "some command",
318+
"tuningCommand": "c29tZSBjb21tYW5k",
319319
"workerCount": 666,
320-
"workerCommand": "some worker command",
320+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
321321
"workerRegistryUsername": "some_registry_username",
322322
"workerRegistryPassword": "some_registry_password",
323323
"workerContainerUser": "some_worker_container_user",

0 commit comments

Comments
 (0)