Skip to content

Commit

Permalink
volume put: upload without sha256 hash
Browse files Browse the repository at this point in the history
Let the backend return the resulting sha256 hash instead of computing it
up front.

Requires #2722 and
corresponding internal change to make sha256_hex optional in the
MountPutFile handler.
  • Loading branch information
danielnorberg committed Jan 4, 2025
1 parent bad590b commit be07828
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 14 deletions.
22 changes: 17 additions & 5 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..exception import ExecutionError
from .async_utils import TaskContext, retry
from .grpc_utils import retry_transient_errors
from .hash_utils import UploadHashes, get_upload_hashes
from .hash_utils import DUMMY_HASH_HEX, UploadHashes, get_upload_hashes
from .http_utils import ClientSessionRegistry
from .logger import logger

Expand Down Expand Up @@ -307,6 +307,7 @@ def _get_file_upload_spec(
source_description: Any,
mount_filename: PurePosixPath,
mode: int,
multipart_hash: bool = True,
) -> FileUploadSpec:
with source() as fp:
# Current position is ignored - we always upload from position 0
Expand All @@ -316,10 +317,11 @@ def _get_file_upload_spec(

if size >= LARGE_FILE_LIMIT:
# TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
md5_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD else None
sha256_hex = DUMMY_HASH_HEX if size > MULTIPART_UPLOAD_THRESHOLD and not multipart_hash else None
use_blob = True
content = None
hashes = get_upload_hashes(fp, md5_hex=md5_hex)
hashes = get_upload_hashes(fp, md5_hex=md5_hex, sha256_hex=sha256_hex)
else:
use_blob = False
content = fp.read()
Expand All @@ -339,7 +341,10 @@ def _get_file_upload_spec(


def get_file_upload_spec_from_path(
filename: Path, mount_filename: PurePosixPath, mode: Optional[int] = None
filename: Path,
mount_filename: PurePosixPath,
mode: Optional[int] = None,
multipart_hash: bool = True,
) -> FileUploadSpec:
# Python appears to give files 0o666 bits on Windows (equal for user, group, and global),
# so we mask those out to 0o755 for compatibility with POSIX-based permissions.
Expand All @@ -349,10 +354,16 @@ def get_file_upload_spec_from_path(
filename,
mount_filename,
mode,
multipart_hash=multipart_hash,
)


def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPath, mode: int) -> FileUploadSpec:
def get_file_upload_spec_from_fileobj(
fp: BinaryIO,
mount_filename: PurePosixPath,
mode: int,
multipart_hash: bool = True,
) -> FileUploadSpec:
@contextmanager
def source():
# We ignore position in stream and always upload from position 0
Expand All @@ -364,6 +375,7 @@ def source():
str(fp),
mount_filename,
mode,
multipart_hash=multipart_hash,
)


Expand Down
7 changes: 7 additions & 0 deletions modal/_utils/hash_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def sha256_hex(self) -> str:
return base64.b64decode(self.sha256_base64).hex()


DUMMY_HASH_HEX = "baadbaadbaadbaadbaadbaadbaadbaad"
DUMMY_HASHES = UploadHashes(
base64.b64encode(bytes.fromhex(DUMMY_HASH_HEX)).decode("ascii"),
base64.b64encode(bytes.fromhex(DUMMY_HASH_HEX)).decode("ascii"),
)


def get_upload_hashes(
data: Union[bytes, BinaryIO], sha256_hex: Optional[str] = None, md5_hex: Optional[str] = None
) -> UploadHashes:
Expand Down
31 changes: 22 additions & 9 deletions modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ._utils.deprecation import deprecation_error, deprecation_warning, renamed_parameter
from ._utils.grpc_utils import retry_transient_errors
from ._utils.hash_utils import DUMMY_HASH_HEX
from ._utils.name_utils import check_object_name
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -581,9 +582,13 @@ def put_file(

def gen():
if isinstance(local_file, str) or isinstance(local_file, Path):
yield lambda: get_file_upload_spec_from_path(local_file, PurePosixPath(remote_path), mode)
yield lambda: get_file_upload_spec_from_path(
local_file, PurePosixPath(remote_path), mode, multipart_hash=False
)
else:
yield lambda: get_file_upload_spec_from_fileobj(local_file, PurePosixPath(remote_path), mode or 0o644)
yield lambda: get_file_upload_spec_from_fileobj(
local_file, PurePosixPath(remote_path), mode or 0o644, multipart_hash=False
)

self._upload_generators.append(gen())

Expand All @@ -604,7 +609,7 @@ def put_directory(

def create_file_spec_provider(subpath):
relpath_str = subpath.relative_to(local_path)
return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str)
return lambda: get_file_upload_spec_from_path(subpath, remote_path / relpath_str, multipart_hash=False)

def gen():
glob = local_path.rglob("*") if recursive else local_path.glob("*")
Expand All @@ -618,11 +623,17 @@ def gen():
async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
remote_filename = file_spec.mount_filename
progress_task_id = self._progress_cb(name=remote_filename, size=file_spec.size)
request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex)
response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1)

exists = False
resulting_sha256_hex = None
if file_spec.sha256_hex != DUMMY_HASH_HEX:
resulting_sha256_hex = file_spec.sha256_hex
request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex)
response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1)
exists = response.exists

start_time = time.monotonic()
if not response.exists:
if not exists:
if file_spec.use_blob:
logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)")
with file_spec.source() as fp:
Expand All @@ -631,10 +642,11 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
self._client.stub,
functools.partial(self._progress_cb, progress_task_id),
sha256_hex=file_spec.sha256_hex,
md5_hex=file_spec.md5_hex,
md5_hex=file_spe.md5_hex(),
)
logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}")
request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex)
sha256_hex = file_spec.sha256_hex if file_spec.sha256_hex != DUMMY_HASH_HEX else ""
request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=sha256_hex)
else:
logger.debug(
f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)"
Expand All @@ -645,6 +657,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT:
response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1)
if response.exists:
resulting_sha256_hex = response.sha256_hex
break

if not response.exists:
Expand All @@ -653,7 +666,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
self._progress_cb(task_id=progress_task_id, complete=True)
return api_pb2.MountFile(
filename=remote_filename,
sha256_hex=file_spec.sha256_hex,
sha256_hex=resulting_sha256_hex,
mode=file_spec.mode,
)

Expand Down

0 comments on commit be07828

Please sign in to comment.