Skip to content

Commit

Permalink
volume put: concurrent sha256 hash for multipart uploads
Browse files Browse the repository at this point in the history
  • Loading branch information
danielnorberg committed Jan 4, 2025
1 parent 4ca850c commit 844c96f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 20 deletions.
22 changes: 17 additions & 5 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..exception import ExecutionError
from .async_utils import TaskContext, retry
from .grpc_utils import retry_transient_errors
from .hash_utils import UploadHashes, get_upload_hashes
from .hash_utils import DUMMY_HASH_HEX, UploadHashes, get_upload_hashes
from .http_utils import ClientSessionRegistry
from .logger import logger

Expand Down Expand Up @@ -307,6 +307,7 @@ def _get_file_upload_spec(
source_description: Any,
mount_filename: PurePosixPath,
mode: int,
multipart_hash: bool = True,
) -> FileUploadSpec:
with source() as fp:
# Current position is ignored - we always upload from position 0
Expand All @@ -316,10 +317,11 @@ def _get_file_upload_spec(

if size >= LARGE_FILE_LIMIT:
# TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
md5_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD else None
sha256_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD and not multipart_hash else None
use_blob = True
content = None
hashes = get_upload_hashes(fp, md5_hex=md5_hex)
hashes = get_upload_hashes(fp, md5_hex=md5_hex, sha256_hex=sha256_hex)
else:
use_blob = False
content = fp.read()
Expand All @@ -339,7 +341,10 @@ def _get_file_upload_spec(


def get_file_upload_spec_from_path(
filename: Path, mount_filename: PurePosixPath, mode: Optional[int] = None
filename: Path,
mount_filename: PurePosixPath,
mode: Optional[int] = None,
multipart_hash: bool = True,
) -> FileUploadSpec:
# Python appears to give files 0o666 bits on Windows (equal for user, group, and global),
# so we mask those out to 0o755 for compatibility with POSIX-based permissions.
Expand All @@ -349,10 +354,16 @@ def get_file_upload_spec_from_path(
filename,
mount_filename,
mode,
multipart_hash=multipart_hash,
)


def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec:
def get_file_upload_spec_from_fileobj(
fp: BinaryIO,
mount_filename: PurePosixPath,
mode: int,
multipart_hash: bool = True,
) -> FileUploadSpec:
@contextmanager
def source():
# We ignore position in stream and always upload from position 0
Expand All @@ -364,6 +375,7 @@ def source():
str(fp),
mount_filename,
mode,
multipart_hash=multipart_hash,
)


Expand Down
1 change: 1 addition & 0 deletions modal/_utils/hash_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modal.config import logger

HASH_CHUNK_SIZE = 65536
DUMMY_HASH_HEX = "baadbaadbaadbaadbaadbaadbaadbaad"


def _update(hashers: Sequence[Callable[[bytes], None]], data: Union[bytes, BinaryIO]) -> None:
Expand Down
68 changes: 53 additions & 15 deletions modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import concurrent.futures
import enum
import functools
import hashlib
import io
import os
import platform
import re
Expand Down Expand Up @@ -38,6 +40,7 @@
)
from ._utils.deprecation import deprecation_error, deprecation_warning, renamed_parameter
from ._utils.grpc_utils import retry_transient_errors
from ._utils.hash_utils import DUMMY_HASH_HEX
from ._utils.name_utils import check_object_name
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -515,6 +518,7 @@ class _VolumeUploadContextManager:
_force: bool
progress_cb: Callable[..., Any]
_upload_generators: list[Generator[Callable[[], FileUploadSpec], None, None]]
_executor: concurrent.futures.ThreadPoolExecutor

def __init__(
self, volume_id: str, client: _Client, progress_cb: Optional[Callable[..., Any]] = None, force: bool = False
Expand All @@ -525,25 +529,25 @@ def __init__(
self._upload_generators = []
self._progress_cb = progress_cb or (lambda *_, **__: None)
self._force = force
self._executor = concurrent.futures.ThreadPoolExecutor()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if not exc_val:
async def _upload():
# Flatten all the uploads yielded by the upload generators in the batch
def gen_upload_providers():
for gen in self._upload_generators:
yield from gen

async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]:
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as exe:
# TODO: avoid eagerly expanding
futs = [loop.run_in_executor(exe, f) for f in gen_upload_providers()]
logger.debug(f"Computing checksums for {len(futs)} files using {exe._max_workers} workers")
for fut in asyncio.as_completed(futs):
yield await fut
# TODO: avoid eagerly expanding
futs = [loop.run_in_executor(self._executor, f) for f in gen_upload_providers()]
logger.debug(f"Computing checksums for {len(futs)} files using {self._executor._max_workers} workers")
for fut in asyncio.as_completed(futs):
yield await fut

# Compute checksums & Upload files
files: list[api_pb2.MountFile] = []
Expand All @@ -563,6 +567,10 @@ async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]:
except GRPCError as exc:
raise FileExistsError(exc.message) if exc.status == Status.ALREADY_EXISTS else exc

with self._executor:
if not exc_val:
await _upload()

def put_file(
self,
local_file: Union[Path, str, BinaryIO],
Expand All @@ -581,9 +589,13 @@ def put_file(

def gen():
if isinstance(local_file, str) or isinstance(local_file, Path):
yield lambda: get_file_upload_spec_from_path(local_file, PurePosixPath(remote_path), mode)
yield lambda: get_file_upload_spec_from_path(
local_file, PurePosixPath(remote_path), mode, multipart_hash=False
)
else:
yield lambda: get_file_upload_spec_from_fileobj(local_file, PurePosixPath(remote_path), mode or 0o644)
yield lambda: get_file_upload_spec_from_fileobj(
local_file, PurePosixPath(remote_path), mode or 0o644, multipart_hash=False
)

self._upload_generators.append(gen())

Expand All @@ -604,7 +616,7 @@ def put_directory(

def create_file_spec_provider(subpath):
relpath_str = subpath.relative_to(local_path)
return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str)
return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str, multipart_hash=False)

def gen():
glob = local_path.rglob("*") if recursive else local_path.glob("*")
Expand All @@ -618,23 +630,48 @@ def gen():
async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
remote_filename = file_spec.mount_filename
progress_task_id = self._progress_cb(name=remote_filename, size=file_spec.size)
request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex)
response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1)

exists = False
resulting_sha256_hex = None
if file_spec.sha256_hex != DUMMY_HASH_HEX:
resulting_sha256_hex = file_spec.sha256_hex
request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex)
response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1)
exists = response.exists

def get_sha256(fp: BinaryIO):
if isinstance(fp, io.BytesIO):
data = fp
else:
data = open(fp.name, "rb")
return hashlib.file_digest(data, "sha256").hexdigest()

start_time = time.monotonic()
if not response.exists:
if not exists:
if file_spec.use_blob:
logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)")
with file_spec.source() as fp:
sha256_fut = None
if file_spec.sha256_hex == DUMMY_HASH_HEX:
sha256_fut = asyncio.get_event_loop().run_in_executor(self._executor, get_sha256, fp)

blob_id = await blob_upload_file(
fp,
self._client.stub,
functools.partial(self._progress_cb, progress_task_id),
sha256_hex=file_spec.sha256_hex,
md5_hex=file_spec.md5_hex,
)
if sha256_fut:
t0 = time.monotonic()
sha256_hex = await sha256_fut
logger.debug(
f"Awaited concurrent sha256 of {file_spec.source_description} for {time.monotonic() - t0:.3}s"
)
else:
sha256_hex = file_spec.sha256_hex
logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}")
request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex)
request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=sha256_hex)
else:
logger.debug(
f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)"
Expand All @@ -645,6 +682,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT:
response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1)
if response.exists:
resulting_sha256_hex = response.sha256_hex
break

if not response.exists:
Expand All @@ -653,7 +691,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
self._progress_cb(task_id=progress_task_id, complete=True)
return api_pb2.MountFile(
filename=remote_filename,
sha256_hex=file_spec.sha256_hex,
sha256_hex=resulting_sha256_hex,
mode=file_spec.mode,
)

Expand Down

0 comments on commit 844c96f

Please sign in to comment.