Skip to content

Commit b5f356b

Browse files
authored
Merge pull request #270 from Paperspace/model-upload-fix
model upload fix
2 parents f423974 + e0bf46e commit b5f356b

File tree

2 files changed

+47
-29
lines changed

2 files changed

+47
-29
lines changed

gradient/api_sdk/s3_uploader.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ def get_bucket_url(bucket_name, s3_fields):
7979
return url
8080

8181

82+
class S3PutFileUploader(S3FileUploader):
83+
def _upload(self, url, data):
84+
"""Send data to S3 and raise exception if it was not a success
85+
86+
:param str url:
87+
:param encoder.MultipartEncoderMonitor data:
88+
"""
89+
file_path = data.encoder.fields['file'][0]
90+
client = self._get_client(url)
91+
client.headers = {"Content-Type": mimetypes.guess_type(file_path)[0] or ""}
92+
93+
response = client.put("", data=data)
94+
if not response.ok:
95+
raise sdk_exceptions.S3UploadFailedError(response)
96+
97+
8298
class S3ProjectFileUploader(object):
8399
def __init__(self, api_key, s3uploader=None, logger=None, ps_client_name=None):
84100
"""
@@ -164,7 +180,7 @@ def __init__(self, api_key, multipart_encoder_cls=None, logger=None, ps_client_n
164180
api_key=api_key,
165181
ps_client_name=ps_client_name,
166182
)
167-
self.s3uploader = s3uploader or S3FileUploader(
183+
self.s3uploader = s3uploader or S3PutFileUploader(
168184
logger=self.logger,
169185
ps_client_name=ps_client_name,
170186
multipart_encoder_cls=self.multipart_encoder_cls

tests/functional/test_models.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,12 @@ class TestModelUpload(object):
303303
UPDATE_TAGS_RESPONSE_JSON_200 = example_responses.UPDATE_TAGS_RESPONSE
304304

305305
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
306+
@mock.patch("gradient.api_sdk.clients.http_client.requests.put")
306307
@mock.patch("gradient.api_sdk.clients.http_client.requests.post")
307308
def test_should_send_post_request_when_models_update_command_was_used_with_basic_options(
308-
self, post_patched, get_patched):
309+
self, post_patched, put_patched, get_patched):
309310
post_patched.return_value = MockResponse(self.CREATE_MODEL_V2_REPONSE)
311+
put_patched.return_value = MockResponse()
310312
get_patched.return_value = MockResponse(self.GET_PRESIGNED_URL_RESPONSE)
311313

312314
runner = CliRunner()
@@ -324,10 +326,11 @@ def test_should_send_post_request_when_models_update_command_was_used_with_basic
324326
files=None,
325327
data=None,
326328
params=self.BASE_PARAMS),
329+
])
330+
put_patched.assert_has_calls([
327331
mock.call(self.GET_PRESIGNED_URL_RESPONSE,
328332
headers={"Content-Type": mock.ANY},
329333
json=None,
330-
files=None,
331334
params=None,
332335
data=mock.ANY)
333336
])
@@ -336,15 +339,17 @@ def test_should_send_post_request_when_models_update_command_was_used_with_basic
336339
params=self.GET_PRESIGNED_URL_PARAMS,
337340
json=None,
338341
)
339-
assert post_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
342+
assert put_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
340343

341344
assert EXPECTED_HEADERS["X-API-Key"] != "some_key"
342345

343346
@mock.patch("gradient.api_sdk.clients.http_client.requests.get")
347+
@mock.patch("gradient.api_sdk.clients.http_client.requests.put")
344348
@mock.patch("gradient.api_sdk.clients.http_client.requests.post")
345349
def test_should_send_post_request_when_models_update_command_was_used_with_all_options(
346-
self, post_patched, get_patched):
350+
self, post_patched, put_patched, get_patched):
347351
post_patched.return_value = MockResponse(self.CREATE_MODEL_V2_REPONSE)
352+
put_patched.return_value = MockResponse()
348353
get_patched.return_value = MockResponse(self.GET_PRESIGNED_URL_RESPONSE)
349354

350355
runner = CliRunner()
@@ -362,19 +367,19 @@ def test_should_send_post_request_when_models_update_command_was_used_with_all_o
362367
files=None,
363368
data=None,
364369
params=self.ALL_OPTIONS_PARAMS),
370+
])
371+
put_patched.assert_has_calls([
365372
mock.call(self.GET_PRESIGNED_URL_RESPONSE,
366373
headers={"Content-Type": mock.ANY},
367374
json=None,
368-
files=None,
369375
params=None,
370-
data=mock.ANY)
371-
])
376+
data=mock.ANY)])
372377
get_patched.assert_called_once_with(self.GET_PRESIGNED_URL,
373378
headers=EXPECTED_HEADERS,
374379
params=self.GET_PRESIGNED_URL_PARAMS,
375380
json=None,
376381
)
377-
assert post_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
382+
assert put_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
378383

379384
assert EXPECTED_HEADERS["X-API-Key"] != "some_key"
380385

@@ -403,20 +408,20 @@ def test_should_replace_api_key_in_headers_when_api_key_parameter_was_used(
403408
data=None,
404409
params=self.ALL_OPTIONS_PARAMS
405410
),
411+
])
412+
put_patched.assert_has_calls([
406413
mock.call(self.GET_PRESIGNED_URL_RESPONSE,
407414
headers={"Content-Type": mock.ANY},
408-
files=None,
409415
json=None,
410416
params=None,
411417
data=mock.ANY)
412418
])
413-
414419
get_patched.assert_called_once_with(self.GET_PRESIGNED_URL,
415420
headers=EXPECTED_HEADERS_WITH_CHANGED_API_KEY,
416421
params=self.GET_PRESIGNED_URL_PARAMS,
417422
json=None,
418423
)
419-
assert post_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
424+
assert put_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
420425

421426
assert EXPECTED_HEADERS["X-API-Key"] != "some_key"
422427

@@ -447,21 +452,18 @@ def test_should_read_options_from_yaml_file(
447452
data=None,
448453
params=self.ALL_OPTIONS_PARAMS
449454
),
450-
mock.call(
451-
self.GET_PRESIGNED_URL_RESPONSE,
452-
headers={'Content-Type': mock.ANY},
453-
json=None,
454-
files=None,
455-
params=None,
456-
data=mock.ANY
457-
)
458455
])
456+
put_patched.assert_called_once_with(self.GET_PRESIGNED_URL_RESPONSE,
457+
headers={"Content-Type": ""},
458+
json=None,
459+
params=None,
460+
data=mock.ANY)
459461
get_patched.assert_called_once_with(self.GET_PRESIGNED_URL,
460462
headers=EXPECTED_HEADERS_WITH_CHANGED_API_KEY,
461463
params=self.GET_PRESIGNED_URL_PARAMS,
462464
json=None,
463465
)
464-
assert post_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
466+
assert put_patched.call_args.kwargs["data"].encoder.fields["file"][0] == self.MODEL_FILE
465467

466468
assert EXPECTED_HEADERS["X-API-Key"] != "some_key"
467469

@@ -511,15 +513,8 @@ def test_should_send_proper_data_and_tag_machine(
511513
data=None,
512514
params=self.BASE_PARAMS
513515
),
514-
mock.call(
515-
self.GET_PRESIGNED_URL_RESPONSE,
516-
headers={"Content-Type": mock.ANY},
517-
files=None,
518-
json=None,
519-
params=None,
520-
data=mock.ANY,
521-
),
522516
])
517+
523518
get_patched.assert_has_calls(
524519
[
525520
mock.call(
@@ -532,6 +527,13 @@ def test_should_send_proper_data_and_tag_machine(
532527
)
533528
put_patched.assert_has_calls(
534529
[
530+
mock.call(
531+
self.GET_PRESIGNED_URL_RESPONSE,
532+
headers={"Content-Type": mock.ANY},
533+
json=None,
534+
params=None,
535+
data=mock.ANY,
536+
),
535537
mock.call(
536538
self.TAGS_URL,
537539
headers=EXPECTED_HEADERS,

0 commit comments

Comments
 (0)