diff --git a/modal/_utils/blob_utils.py b/modal/_utils/blob_utils.py index 53a4e36c5..38f8c6d2e 100644 --- a/modal/_utils/blob_utils.py +++ b/modal/_utils/blob_utils.py @@ -18,7 +18,7 @@ from ..exception import ExecutionError from .async_utils import TaskContext, retry from .grpc_utils import retry_transient_errors -from .hash_utils import UploadHashes, get_upload_hashes +from .hash_utils import DUMMY_HASH_HEX, UploadHashes, get_upload_hashes from .http_utils import ClientSessionRegistry from .logger import logger @@ -307,6 +307,7 @@ def _get_file_upload_spec( source_description: Any, mount_filename: PurePosixPath, mode: int, + multipart_hash: bool = True, ) -> FileUploadSpec: with source() as fp: # Current position is ignored - we always upload from position 0 @@ -316,10 +317,11 @@ def _get_file_upload_spec( if size >= LARGE_FILE_LIMIT: # TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs - md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None + md5_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD else None + sha256_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD and not multipart_hash else None use_blob = True content = None - hashes = get_upload_hashes(fp, md5_hex=md5_hex) + hashes = get_upload_hashes(fp, md5_hex=md5_hex, sha256_hex=sha256_hex) else: use_blob = False content = fp.read() @@ -339,7 +341,10 @@ def _get_file_upload_spec( def get_file_upload_spec_from_path( - filename: Path, mount_filename: PurePosixPath, mode: Optional[int] = None + filename: Path, + mount_filename: PurePosixPath, + mode: Optional[int] = None, + multipart_hash: bool = True, ) -> FileUploadSpec: # Python appears to give files 0o666 bits on Windows (equal for user, group, and global), # so we mask those out to 0o755 for compatibility with POSIX-based permissions. @@ -349,10 +354,16 @@ def get_file_upload_spec_from_path( filename, mount_filename, mode, + multipart_hash=multipart_hash, ) -def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec: +def get_file_upload_spec_from_fileobj( + fp: BinaryIO, + mount_filename: PurePosixPath, + mode: int, + multipart_hash: bool = True, +) -> FileUploadSpec: @contextmanager def source(): # We ignore position in stream and always upload from position 0 @@ -364,6 +375,7 @@ def source(): str(fp), mount_filename, mode, + multipart_hash=multipart_hash, ) diff --git a/modal/_utils/hash_utils.py b/modal/_utils/hash_utils.py index 7f48beda3..bd54c47fb 100644 --- a/modal/_utils/hash_utils.py +++ b/modal/_utils/hash_utils.py @@ -8,6 +8,7 @@ from modal.config import logger HASH_CHUNK_SIZE = 65536 +DUMMY_HASH_HEX = "baadbaadbaadbaadbaadbaadbaadbaad" def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None: diff --git a/modal/volume.py b/modal/volume.py index 3f2a295c1..c00f72c92 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -3,6 +3,8 @@ import concurrent.futures import enum import functools +import hashlib +import io import os import platform import re @@ -38,6 +40,7 @@ ) from ._utils.deprecation import deprecation_error, deprecation_warning, renamed_parameter from ._utils.grpc_utils import retry_transient_errors +from ._utils.hash_utils import DUMMY_HASH_HEX from ._utils.name_utils import check_object_name from .client import _Client from .config import logger @@ -515,6 +518,7 @@ class _VolumeUploadContextManager: _force: bool progress_cb: Callable[..., Any] _upload_generators: list[Generator[Callable[[], FileUploadSpec], None, None]] + _executor: concurrent.futures.ThreadPoolExecutor def __init__( self, volume_id: str, client: _Client, progress_cb: Optional[Callable[..., Any]] = None, force: bool = False @@ -525,12 +529,13 @@ def __init__( self._upload_generators = [] self._progress_cb = progress_cb or (lambda *_, **__: None) self._force = force + self._executor = concurrent.futures.ThreadPoolExecutor() async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - if not exc_val: + async def _upload(): # Flatten all the uploads yielded by the upload generators in the batch def gen_upload_providers(): for gen in self._upload_generators: @@ -538,12 +543,11 @@ def gen_upload_providers(): async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]: loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as exe: - # TODO: avoid eagerly expanding - futs = [loop.run_in_executor(exe, f) for f in gen_upload_providers()] - logger.debug(f"Computing checksums for {len(futs)} files using {exe._max_workers} workers") - for fut in asyncio.as_completed(futs): - yield await fut + # TODO: avoid eagerly expanding + futs = [loop.run_in_executor(self._executor, f) for f in gen_upload_providers()] + logger.debug(f"Computing checksums for {len(futs)} files using {self._executor._max_workers} workers") + for fut in asyncio.as_completed(futs): + yield await fut # Compute checksums & Upload files files: list[api_pb2.MountFile] = [] @@ -563,6 +567,10 @@ async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]: except GRPCError as exc: raise FileExistsError(exc.message) if exc.status == Status.ALREADY_EXISTS else exc + with self._executor: + if not exc_val: + await _upload() + def put_file( self, local_file: Union[Path, str, BinaryIO], @@ -581,9 +589,13 @@ def put_file( def gen(): if isinstance(local_file, str) or isinstance(local_file, Path): - yield lambda: get_file_upload_spec_from_path(local_file, PurePosixPath(remote_path), mode) + yield lambda: get_file_upload_spec_from_path( + local_file, PurePosixPath(remote_path), mode, multipart_hash=False + ) else: - yield lambda: get_file_upload_spec_from_fileobj(local_file, PurePosixPath(remote_path), mode or 0o644) + yield lambda: get_file_upload_spec_from_fileobj( + local_file, PurePosixPath(remote_path), mode or 0o644, multipart_hash=False + ) self._upload_generators.append(gen()) @@ -604,7 +616,7 @@ def put_directory( def create_file_spec_provider(subpath): relpath_str = subpath.relative_to(local_path) - return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str) + return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str, multipart_hash=False) def gen(): glob = local_path.rglob("*") if recursive else local_path.glob("*") @@ -618,14 +630,31 @@ def gen(): async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: remote_filename = file_spec.mount_filename progress_task_id = self._progress_cb(name=remote_filename, size=file_spec.size) - request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex) - response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1) + + exists = False + resulting_sha256_hex = None + if file_spec.sha256_hex != DUMMY_HASH_HEX: + resulting_sha256_hex = file_spec.sha256_hex + request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex) + response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1) + exists = response.exists + + def get_sha256(fp: BinaryIO): + if isinstance(fp, io.BytesIO): + data = fp + else: + data = open(fp.name, "rb") + return hashlib.file_digest(data, "sha256").hexdigest() start_time = time.monotonic() - if not response.exists: + if not exists: if file_spec.use_blob: logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)") with file_spec.source() as fp: + sha256_fut = None + if file_spec.sha256_hex == DUMMY_HASH_HEX: + sha256_fut = asyncio.get_event_loop().run_in_executor(self._executor, get_sha256, fp) + blob_id = await blob_upload_file( fp, self._client.stub, @@ -633,8 +662,16 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: sha256_hex=file_spec.sha256_hex, md5_hex=file_spec.md5_hex, ) + if sha256_fut: + t0 = time.monotonic() + sha256_hex = await sha256_fut + logger.debug( + f"Awaited concurrent sha256 of {file_spec.source_description} for {time.monotonic() - t0:.3}s" + ) + else: + sha256_hex = file_spec.sha256_hex logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}") - request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex) + request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=sha256_hex) else: logger.debug( f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)" @@ -645,6 +682,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT: response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1) if response.exists: + resulting_sha256_hex = response.sha256_hex break if not response.exists: @@ -653,7 +691,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: self._progress_cb(task_id=progress_task_id, complete=True) return api_pb2.MountFile( filename=remote_filename, - sha256_hex=file_spec.sha256_hex, + sha256_hex=resulting_sha256_hex, mode=file_spec.mode, )