Skip to content

Commit 806d932

Browse files
committed
base64-encode commands sent to the APIs so they are not blocked by Cloudflare's filter
(cherry picked from commit 8b6857e)
1 parent 4fd9928 commit 806d932

File tree

6 files changed

+207
-131
lines changed

6 files changed

+207
-131
lines changed

Pipfile.lock

Lines changed: 133 additions & 108 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gradient/api_sdk/serializers/experiment.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import copy
2+
13
import marshmallow
24

3-
from . import BaseSchema
4-
from .. import models
5+
from .base import BaseSchema
6+
from .. import models, utils
57

68

79
class BaseExperimentSchema(BaseSchema):
@@ -22,7 +24,7 @@ class BaseExperimentSchema(BaseSchema):
2224
id = marshmallow.fields.Str(load_from="handle")
2325
state = marshmallow.fields.Int()
2426

25-
def get_instance(self, obj_dict):
27+
def get_instance(self, obj_dict, many=False):
2628
# without popping these marshmallow wouldn't use load_from
2729
obj_dict.pop("id", None)
2830
obj_dict.pop("project_id", None)
@@ -31,7 +33,7 @@ def get_instance(self, obj_dict):
3133
if isinstance(ports, int):
3234
obj_dict["ports"] = str(ports)
3335

34-
instance = super(BaseExperimentSchema, self).get_instance(obj_dict)
36+
instance = super(BaseExperimentSchema, self).get_instance(obj_dict, many=many)
3537
return instance
3638

3739

@@ -46,6 +48,13 @@ class SingleNodeExperimentSchema(BaseExperimentSchema):
4648
registry_password = marshmallow.fields.Str(dump_to="registryPassword", load_from="registryPassword")
4749
registry_url = marshmallow.fields.Str(dump_to="registryUrl", load_from="registryUrl")
4850

51+
@marshmallow.pre_dump
52+
def preprocess(self, data, **kwargs):
53+
data = copy.copy(data)
54+
55+
utils.base64_encode_attribute(data, "command")
56+
return data
57+
4958

5059
class MultiNodeExperimentSchema(BaseExperimentSchema):
5160
MODEL = models.MultiNodeExperiment
@@ -79,6 +88,14 @@ class MultiNodeExperimentSchema(BaseExperimentSchema):
7988
parameter_server_registry_url = marshmallow.fields.Str(dump_to="parameterServerRegistryUrl",
8089
load_from="parameterServerRegistryUrl")
8190

91+
@marshmallow.pre_dump
92+
def preprocess(self, data, **kwargs):
93+
data = copy.copy(data)
94+
95+
utils.base64_encode_attribute(data, "worker_command")
96+
utils.base64_encode_attribute(data, "parameter_server_command")
97+
return data
98+
8299

83100
class MpiMultiNodeExperimentSchema(BaseExperimentSchema):
84101
MODEL = models.MpiMultiNodeExperiment
@@ -92,10 +109,8 @@ class MpiMultiNodeExperimentSchema(BaseExperimentSchema):
92109
load_from="masterContainer")
93110
master_machine_type = marshmallow.fields.Str(required=True, dump_to="masterMachineType",
94111
load_from="masterMachineType")
95-
master_command = marshmallow.fields.Str(required=True, dump_to="masterCommand",
96-
load_from="masterCommand")
97-
master_count = marshmallow.fields.Int(required=True, dump_to="masterCount",
98-
load_from="masterCount")
112+
master_command = marshmallow.fields.Str(required=True, dump_to="masterCommand", load_from="masterCommand")
113+
master_count = marshmallow.fields.Int(required=True, dump_to="masterCount", load_from="masterCount")
99114
worker_container_user = marshmallow.fields.Str(dump_to="workerContainerUser", load_from="workerContainerUser")
100115
worker_registry_username = marshmallow.fields.Str(dump_to="workerRegistryUsername",
101116
load_from="workerRegistryUsername")
@@ -111,3 +126,11 @@ class MpiMultiNodeExperimentSchema(BaseExperimentSchema):
111126
load_from="masterRegistryPassword")
112127
master_registry_url = marshmallow.fields.Str(dump_to="masterRegistryUrl",
113128
load_from="masterRegistryUrl")
129+
130+
@marshmallow.pre_dump
131+
def preprocess(self, data, **kwargs):
132+
data = copy.copy(data)
133+
134+
utils.base64_encode_attribute(data, "worker_command")
135+
utils.base64_encode_attribute(data, "master_command")
136+
return data

gradient/api_sdk/serializers/hyperparameter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import copy
2+
13
import marshmallow
24

35
from .experiment import BaseExperimentSchema
4-
from .. import models
6+
from .. import models, utils
57

68

79
class HyperparameterSchema(BaseExperimentSchema):
@@ -28,3 +30,11 @@ class HyperparameterSchema(BaseExperimentSchema):
2830
load_from="hyperparameterServerContainer")
2931
hyperparameter_server_container_user = marshmallow.fields.Str(dump_to="hyperparameterServerContainerUser",
3032
load_from="hyperparameterServerContainerUser")
33+
34+
@marshmallow.pre_dump
35+
def preprocess(self, data, **kwargs):
36+
data = copy.copy(data)
37+
38+
utils.base64_encode_attribute(data, "worker_command")
39+
utils.base64_encode_attribute(data, "tuning_command")
40+
return data

gradient/api_sdk/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
from collections import OrderedDict
23

34
import six
@@ -48,3 +49,20 @@ def print_dict_recursive(input_dict, logger, indent=0, tabulator=" "):
4849
print_dict_recursive(OrderedDict(val), logger, indent + 1)
4950
else:
5051
logger.log("%s%s" % (tabulator * (indent + 1), val))
52+
53+
54+
def base64_encode(s):
55+
if six.PY3:
56+
s = bytes(s, encoding="utf8")
57+
58+
encoded_str = base64.b64encode(s)
59+
60+
if six.PY3: # Python3's base64.b64encode returns a bytes instance so it should be converted back to unicode
61+
encoded_str = encoded_str.decode("utf-8")
62+
63+
return encoded_str
64+
65+
66+
def base64_encode_attribute(data, name):
67+
encoded_value = base64_encode(getattr(data, name))
68+
setattr(data, name, encoded_value)

tests/functional/test_experiments.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class TestExperimentsCreateSingleNode(object):
5656
"projectHandle": u"testHandle",
5757
"container": u"testContainer",
5858
"machineType": u"testType",
59-
"command": u"testCommand",
59+
"command": u"dGVzdENvbW1hbmQ=",
6060
"experimentTypeId": constants.ExperimentType.SINGLE_NODE,
6161
"workspaceUrl": u"some-workspace",
6262
}
@@ -73,7 +73,7 @@ class TestExperimentsCreateSingleNode(object):
7373
"projectHandle": u"testHandle",
7474
"container": u"testContainer",
7575
"machineType": u"testType",
76-
"command": u"testCommand",
76+
"command": u"dGVzdENvbW1hbmQ=",
7777
"containerUser": u"conUser",
7878
"registryUsername": u"userName",
7979
"registryPassword": u"passwd",
@@ -333,11 +333,11 @@ class TestExperimentsCreateMultiNode(object):
333333
u"experimentTypeId": 2,
334334
u"workerContainer": u"wcon",
335335
u"workerMachineType": u"mty",
336-
u"workerCommand": u"wcom",
336+
u"workerCommand": u"d2NvbQ==",
337337
u"workerCount": 2,
338338
u"parameterServerContainer": u"pscon",
339339
u"parameterServerMachineType": u"psmtype",
340-
u"parameterServerCommand": u"ls",
340+
u"parameterServerCommand": u"bHM=",
341341
u"parameterServerCount": 2,
342342
u"workerContainerUser": u"usr",
343343
u"workspaceUrl": u"https://github.com/Paperspace/gradient-cli.git",
@@ -356,11 +356,11 @@ class TestExperimentsCreateMultiNode(object):
356356
"experimentTypeId": 3,
357357
"workerContainer": u"wcon",
358358
"workerMachineType": u"mty",
359-
"workerCommand": u"wcom",
359+
"workerCommand": u"d2NvbQ==",
360360
"workerCount": 2,
361361
"masterContainer": u"pscon",
362362
"masterMachineType": u"psmtype",
363-
"masterCommand": u"ls",
363+
"masterCommand": u"bHM=",
364364
"masterCount": 2,
365365
"workerContainerUser": u"usr",
366366
"workerRegistryUsername": u"rusr",

tests/functional/test_hyperparameters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ class TestCreateHyperparameters(object):
2424
"workerContainer": "some_container",
2525
"workerMachineType": "k80",
2626
"name": "some_name",
27-
"tuningCommand": "some command",
27+
"tuningCommand": "c29tZSBjb21tYW5k",
2828
"workerCount": 1,
29-
"workerCommand": "some worker command",
29+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
3030
"experimentTypeId": constants.ExperimentType.HYPERPARAMETER_TUNING,
3131
"projectHandle": "pr4yxj956",
3232
}
@@ -66,9 +66,9 @@ class TestCreateHyperparameters(object):
6666
"workerContainer": "some_worker_container",
6767
"workerMachineType": "k80",
6868
"name": "some_name",
69-
"tuningCommand": "some command",
69+
"tuningCommand": "c29tZSBjb21tYW5k",
7070
"workerCount": 666,
71-
"workerCommand": "some worker command",
71+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
7272
"workerRegistryUsername": "some_registry_username",
7373
"workerRegistryPassword": "some_registry_password",
7474
"workerContainerUser": "some_worker_container_user",
@@ -273,9 +273,9 @@ class TestCreateAndStartHyperparameters(object):
273273
"workerContainer": "some_container",
274274
"workerMachineType": "k80",
275275
"name": "some_name",
276-
"tuningCommand": "some command",
276+
"tuningCommand": "c29tZSBjb21tYW5k",
277277
"workerCount": 1,
278-
"workerCommand": "some worker command",
278+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
279279
"projectHandle": "pr4yxj956",
280280
"experimentTypeId": constants.ExperimentType.HYPERPARAMETER_TUNING,
281281
}
@@ -315,9 +315,9 @@ class TestCreateAndStartHyperparameters(object):
315315
"workerContainer": "some_worker_container",
316316
"workerMachineType": "k80",
317317
"name": "some_name",
318-
"tuningCommand": "some command",
318+
"tuningCommand": "c29tZSBjb21tYW5k",
319319
"workerCount": 666,
320-
"workerCommand": "some worker command",
320+
"workerCommand": "c29tZSB3b3JrZXIgY29tbWFuZA==",
321321
"workerRegistryUsername": "some_registry_username",
322322
"workerRegistryPassword": "some_registry_password",
323323
"workerContainerUser": "some_worker_container_user",

0 commit comments

Comments
 (0)