Skip to content

Commit 86a2952

Browse files
committed
backup
Signed-off-by: laurentsimon <[email protected]>
1 parent 00aed55 commit 86a2952

File tree

2 files changed

+116
-13
lines changed

2 files changed

+116
-13
lines changed

model_signing/model.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import psutil
4141
import sys
4242

43+
from _manifest import Manifest
4344

4445
def chunk_size() -> int:
4546
return int(psutil.virtual_memory().available // 2)
@@ -82,7 +83,7 @@ def get_identity_token(self) -> Optional[IdentityToken]:
8283
# Happy path: we've detected an ambient credential,
8384
# so we can return early.
8485
if token:
85-
return token
86+
return IdentityToken(token)
8687

8788
# TODO(): Support staging for testing.
8889
if self.oidc_issuer is not None:
@@ -105,12 +106,13 @@ def sign(self, inputfn: Path, signaturefn: Path,
105106
file=sys.stderr)
106107
print(f"identity: {oidc_token.identity}", file=sys.stderr)
107108

108-
contentio = io.BytesIO(Serializer.serialize_v1(
109-
inputfn, chunk_size(), signaturefn, ignorepaths))
109+
serialized_paths = Serializer.serialize_v2(
110+
inputfn, chunk_size(), signaturefn, ignorepaths)
110111
with self.signing_ctx.signer(oidc_token) as signer:
111-
result = signer.sign(input_=contentio)
112+
manifest = Manifest(serialized_paths)
113+
result = signer.sign(input_=manifest.to_intoto_statement())
112114
with signaturefn.open(mode="w") as b:
113-
print(result.to_bundle().to_json(), file=b)
115+
print(result.to_json(), file=b)
114116
return SignatureResult()
115117
except ExpiredIdentity:
116118
return SignatureResult(success=False,
@@ -143,6 +145,11 @@ def verify(self, inputfn: Path, signaturefn: Path,
143145
bundle = Bundle().from_json(bundle_bytes)
144146

145147
material: tuple[Path, VerificationMaterials]
148+
# TODO: verification
149+
# serialized_paths = Serializer.serialize_v2(
150+
# inputfn, chunk_size(), signaturefn, ignorepaths)
151+
# manifest = Manifest(serialized_paths)
152+
# result = signer.sign(input_=manifest.to_intoto_statement())
146153
contentio = io.BytesIO(Serializer.serialize_v1(
147154
inputfn, chunk_size(), signaturefn, ignorepaths))
148155
material = VerificationMaterials.from_bundle(input_=contentio,

model_signing/serialize.py

+104-8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from multiprocessing import get_start_method, set_start_method
2020
from pathlib import Path
2121
import platform
22+
from typing import Callable
23+
24+
from _manifest import PathMetadata, DigestAlgorithm, Hashed
2225

2326
# Use for testing while keeping disk size low.
2427
allow_symlinks = False
@@ -111,7 +114,7 @@ def remove_prefix(text, prefix):
111114
return text
112115

113116

114-
def validate_signature_path(model_path: Path, sig_path: Path):
117+
def _validate_signature_path(model_path: Path, sig_path: Path):
115118
if model_path.is_file():
116119
return
117120
# Note: Only allow top-level folder to have the signature for simplicity.
@@ -131,7 +134,7 @@ def is_relative_to(p: Path, path_list: [Path]) -> bool:
131134
class Serializer:
132135
@staticmethod
133136
# TODO: type of returned value.
134-
def _ordered_files(path: Path, ignorepaths: [Path]) -> []:
137+
def _ordered_files(path: Path, ignorepaths: [Path], ignore_folder: bool = False) -> []:
135138
children: [Path]
136139
if path.is_file():
137140
children = [path]
@@ -158,6 +161,9 @@ def _ordered_files(path: Path, ignorepaths: [Path]) -> []:
158161

159162
if not child.is_file() and not child.is_dir():
160163
raise ValueError(f"{str(child)} is not a dir or file")
164+
165+
if ignore_folder and child.is_dir():
166+
continue
161167

162168
# The recorded path must *not* contains the folder name,
163169
# since users may rename it.
@@ -226,7 +232,7 @@ def _create_tasks(children: [], shard_size: int) -> [[]]:
226232

227233
@staticmethod
228234
# TODO: type of tasks
229-
def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes:
235+
def _run_tasks(path: Path, chunk: int, tasks: [], fn: Callable[[], bytes]) -> bytes:
230236
# See https://superfastpython.com/processpoolexecutor-in-python/
231237
# NOTE: 32 = length of sha256 digest.
232238
digest_len = 32
@@ -237,7 +243,7 @@ def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes:
237243
if platform.system() == "Linux" and get_start_method() != "fork":
238244
set_start_method('fork')
239245
with ProcessPoolExecutor() as ppe:
240-
futures = [ppe.submit(Serializer.task, (path, chunk, task))
246+
futures = [ppe.submit(fn, (path, chunk, task))
241247
for task in tasks]
242248
results = [f.result() for f in futures]
243249
for i, result in enumerate(results):
@@ -249,7 +255,7 @@ def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes:
249255

250256
@staticmethod
251257
# TODO: type of task_info.
252-
def task(task_info: []):
258+
def _task_v1(task_info: any) -> bytes:
253259
# NOTE: we can get process info using:
254260
# from multiprocessing import current_process
255261
# worker = current_process()
@@ -303,7 +309,7 @@ def _serialize_v1(path: Path, chunk: int, shard: int, signature_path: Path,
303309
raise ValueError(f"{str(path)} is not a dir or file")
304310

305311
# Validate the signature path.
306-
validate_signature_path(path, signature_path)
312+
_validate_signature_path(path, signature_path)
307313

308314
# Children to hash.
309315
children = Serializer._ordered_files(path,
@@ -317,11 +323,101 @@ def _serialize_v1(path: Path, chunk: int, shard: int, signature_path: Path,
317323
# Share the computation of hashes.
318324
# For simplicity, we pre-allocate the entire array that will hold
319325
# the concatenation of all hashes.
320-
all_hashes = Serializer._run_tasks(path, chunk, tasks)
326+
all_hashes = Serializer._run_tasks(path, chunk, tasks, Serializer._task_v1)
321327

322328
# Finally, we hash everything.
323329
return hashlib.sha256(bytes(all_hashes)).digest()
324330

331+
@staticmethod
332+
# TODO: type of task_info.
333+
def _task_v2(task_info: any) -> bytes:
334+
# NOTE: we can get process info using:
335+
# from multiprocessing import current_process
336+
# worker = current_process()
337+
# print(f'Task {task_info},
338+
# worker name={worker.name}, pid={worker.pid}', flush=True)
339+
_, chunk, (name, ty, start_pos, end_pos) = task_info
340+
# Only files are recorded.
341+
if ty != "file":
342+
raise ValueError(f"internal: got a non-file path {name}")
343+
344+
return Hasher._node_file_compute_v1(name,
345+
b'', start_pos, end_pos, chunk)
346+
347+
@staticmethod
348+
def _to_path_metadata(task_info: [any], all_hashes: bytes) -> [PathMetadata]:
349+
if not task_info:
350+
raise ValueError("internal: task_info is empty")
351+
352+
paths: [PathMetadata] = []
353+
# Iterate over all tasks.
354+
prev_task = task_info[0]
355+
prev_i = 0
356+
prev_name, _, _, _ = prev_task
357+
for curr_i, curr_task in enumerate(task_info[1:]):
358+
curr_name, _, _, _ = curr_task
359+
if prev_name == curr_name:
360+
continue
361+
# End of a group of sharded digests for the same file.
362+
# NOTE: each digest is 32-byte long.
363+
h = hashlib.sha256(bytes(all_hashes[prev_i: curr_i+32])).digest()
364+
paths += [PathMetadata(prev_name, Hashed(DigestAlgorithm.SHA256_P1, h))]
365+
prev_i = curr_i
366+
prev_name = curr_name
367+
368+
# Compute the digest for the last (unfinished) task.
369+
if prev_i < len(task_info):
370+
h = hashlib.sha256(bytes(all_hashes[prev_i:])).digest()
371+
paths += [PathMetadata(prev_name, Hashed(DigestAlgorithm.SHA256_P1, h))]
372+
# paths += [PathMetadata("path/to/file1", Hashed(DigestAlgorithm.SHA256_P1, b'\abcdef1'))]
373+
# paths += [PathMetadata("path/to/file2", Hashed(DigestAlgorithm.SHA256_P1, b'\abcdef2'))]
374+
return paths
375+
376+
@staticmethod
377+
def _serialize_v2(path: Path, chunk: int, shard: int, signature_path: Path,
378+
ignorepaths: [Path] = []) -> bytes:
379+
if not path.exists():
380+
raise ValueError(f"{str(path)} does not exist")
381+
382+
if not allow_symlinks and path.is_symlink():
383+
raise ValueError(f"{str(path)} is a symlink")
384+
385+
if chunk < 0:
386+
raise ValueError(f"{str(chunk)} is invalid")
387+
388+
if not path.is_file() and not path.is_dir():
389+
raise ValueError(f"{str(path)} is not a dir or file")
390+
391+
# Validate the signature path.
392+
_validate_signature_path(path, signature_path)
393+
394+
# Children to hash.
395+
children = Serializer._ordered_files(path,
396+
[signature_path] + ignorepaths,
397+
True)
398+
399+
# We shard the computation by creating independent "tasks".
400+
if shard < 0:
401+
raise ValueError(f"{str(shard)} is invalid")
402+
tasks = Serializer._create_tasks(children, shard)
403+
404+
# Share the computation of hashes.
405+
# For simplicity, we pre-allocate the entire array that will hold
406+
# the concatenation of all hashes.
407+
all_hashes = Serializer._run_tasks(path, chunk, tasks, Serializer._task_v2)
408+
409+
# Turn hashes into PathMedata
410+
return Serializer._to_path_metadata(tasks, all_hashes)
411+
412+
def serialize_v2(path: Path, chunk: int, signature_path: Path,
413+
ignorepaths: [Path] = []) -> [PathMetadata]:
414+
# NOTE: The shard size must be the same for all clients for
415+
# compatibility. We could make it configurable; but in this
416+
# case the signature file must contain the value used by the signer.
417+
shard_size = 1000000000 # 1GB
418+
return Serializer._serialize_v2(path, chunk, shard_size,
419+
signature_path, ignorepaths)
420+
325421
def serialize_v1(path: Path, chunk: int, signature_path: Path,
326422
ignorepaths: [Path] = []) -> bytes:
327423
# NOTE: The shard size must be the same for all clients for
@@ -350,7 +446,7 @@ def serialize_v0(path: Path, chunk: int, signature_path: Path,
350446
raise ValueError(f"{str(path)} is not a dir")
351447

352448
# Validate the signature path.
353-
validate_signature_path(path, signature_path)
449+
_validate_signature_path(path, signature_path)
354450

355451
children = sorted([x for x in path.iterdir()
356452
if x != signature_path and x not in ignorepaths])

0 commit comments

Comments
 (0)