diff --git a/modal/_utils/blob_utils.py b/modal/_utils/blob_utils.py index 53a4e36c5..38f8c6d2e 100644 --- a/modal/_utils/blob_utils.py +++ b/modal/_utils/blob_utils.py @@ -18,7 +18,7 @@ from ..exception import ExecutionError from .async_utils import TaskContext, retry from .grpc_utils import retry_transient_errors -from .hash_utils import UploadHashes, get_upload_hashes +from .hash_utils import DUMMY_HASH_HEX, UploadHashes, get_upload_hashes from .http_utils import ClientSessionRegistry from .logger import logger @@ -307,6 +307,7 @@ def _get_file_upload_spec( source_description: Any, mount_filename: PurePosixPath, mode: int, + multipart_hash: bool = True, ) -> FileUploadSpec: with source() as fp: # Current position is ignored - we always upload from position 0 @@ -316,10 +317,11 @@ def _get_file_upload_spec( if size >= LARGE_FILE_LIMIT: # TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs - md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None + md5_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD else None + sha256_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD and not multipart_hash else None use_blob = True content = None - hashes = get_upload_hashes(fp, md5_hex=md5_hex) + hashes = get_upload_hashes(fp, md5_hex=md5_hex, sha256_hex=sha256_hex) else: use_blob = False content = fp.read() @@ -339,7 +341,10 @@ def _get_file_upload_spec( def get_file_upload_spec_from_path( - filename: Path, mount_filename: PurePosixPath, mode: Optional[int] = None + filename: Path, + mount_filename: PurePosixPath, + mode: Optional[int] = None, + multipart_hash: bool = True, ) -> FileUploadSpec: # Python appears to give files 0o666 bits on Windows (equal for user, group, and global), # so we mask those out to 0o755 for compatibility with POSIX-based permissions. @@ -349,10 +354,16 @@ def get_file_upload_spec_from_path( filename, mount_filename, mode, + multipart_hash=multipart_hash, ) -def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec: +def get_file_upload_spec_from_fileobj( + fp: BinaryIO, + mount_filename: PurePosixPath, + mode: int, + multipart_hash: bool = True, +) -> FileUploadSpec: @contextmanager def source(): # We ignore position in stream and always upload from position 0 @@ -364,6 +375,7 @@ def source(): str(fp), mount_filename, mode, + multipart_hash=multipart_hash, ) diff --git a/modal/_utils/hash_utils.py b/modal/_utils/hash_utils.py index 7f48beda3..931bd51dc 100644 --- a/modal/_utils/hash_utils.py +++ b/modal/_utils/hash_utils.py @@ -64,6 +64,13 @@ def sha256_hex(self) -> str: return base64.b64decode(self.sha256_base64).hex() +DUMMY_HASH_HEX = "baadbaadbaadbaadbaadbaadbaadbaad" +DUMMY_HASHES = UploadHashes( + base64.b64encode(bytes.fromhex(DUMMY_HASH_HEX)).decode("ascii"), + base64.b64encode(bytes.fromhex(DUMMY_HASH_HEX)).decode("ascii"), +) + + def get_upload_hashes( data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None ) -> UploadHashes: diff --git a/modal/volume.py b/modal/volume.py index 3f2a295c1..439a7eab8 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -38,6 +38,7 @@ ) from ._utils.deprecation import deprecation_error, deprecation_warning, renamed_parameter from ._utils.grpc_utils import retry_transient_errors +from ._utils.hash_utils import DUMMY_HASH_HEX from ._utils.name_utils import check_object_name from .client import _Client from .config import logger @@ -581,9 +582,13 @@ def put_file( def gen(): if isinstance(local_file, str) or isinstance(local_file, Path): - yield lambda: get_file_upload_spec_from_path(local_file, PurePosixPath(remote_path), mode) + yield lambda: get_file_upload_spec_from_path( + local_file, PurePosixPath(remote_path), mode, multipart_hash=False + ) else: - yield lambda: get_file_upload_spec_from_fileobj(local_file, PurePosixPath(remote_path), mode or 0o644) + yield lambda: get_file_upload_spec_from_fileobj( + local_file, PurePosixPath(remote_path), mode or 0o644, multipart_hash=False + ) self._upload_generators.append(gen()) @@ -604,7 +609,7 @@ def put_directory( def create_file_spec_provider(subpath): relpath_str = subpath.relative_to(local_path) - return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str) + return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str, multipart_hash=False) def gen(): glob = local_path.rglob("*") if recursive else local_path.glob("*") @@ -618,11 +623,17 @@ def gen(): async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: remote_filename = file_spec.mount_filename progress_task_id = self._progress_cb(name=remote_filename, size=file_spec.size) - request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex) - response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1) + + exists = False + resulting_sha256_hex = None + if file_spec.sha256_hex != DUMMY_HASH_HEX: + resulting_sha256_hex = file_spec.sha256_hex + request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex) + response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1) + exists = response.exists start_time = time.monotonic() - if not response.exists: + if not exists: if file_spec.use_blob: logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)") with file_spec.source() as fp: @@ -631,10 +642,11 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: self._client.stub, functools.partial(self._progress_cb, progress_task_id), sha256_hex=file_spec.sha256_hex, - md5_hex=file_spec.md5_hex, + md5_hex=file_spe.md5_hex(), ) logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}") - request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex) + sha256_hex = file_spec.sha256_hex if file_spec.sha256_hex != DUMMY_HASH_HEX else "" + request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=sha256_hex) else: logger.debug( f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)" @@ -645,6 +657,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT: response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1) if response.exists: + resulting_sha256_hex = response.sha256_hex break if not response.exists: @@ -653,7 +666,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: self._progress_cb(task_id=progress_task_id, complete=True) return api_pb2.MountFile( filename=remote_filename, - sha256_hex=file_spec.sha256_hex, + sha256_hex=resulting_sha256_hex, mode=file_spec.mode, )