diff --git a/model_signing/_manifest.py b/model_signing/_manifest.py new file mode 100644 index 00000000..16d2f78b --- /dev/null +++ b/model_signing/_manifest.py @@ -0,0 +1,101 @@ +# Copyright Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing perepo_managerissions and +# limitations under the License. + +from enum import Enum +from typing import List + +from sigstore import dsse + +class DigestAlgorithm(Enum): + SHA256_P1 = 1 + + def __str__(self): + return str(self.name).replace("_", "-").lower() + @staticmethod + def from_string(s: str): + return DigestAlgorithm[s.replace("-", "_")] + +class Hashed: + algorithm: DigestAlgorithm + digest: bytes + def __init__(self, algorithm_: DigestAlgorithm, digest_: bytes): + self.algorithm = algorithm_ + self.digest = digest_ + +class PathMetadata: + hashed: Hashed + path: str + def __init__(self, path_: str, hashed_: Hashed): + self.path = path_ + self.hashed = hashed_ + +class Manifest: + paths: PathMetadata + predicate_type: str + def __init__(self, paths: [PathMetadata]): + self.paths = paths + self.predicate_type = "sigstore.dev/model-transparency/manifest/v1" + + def verify(self, verified_manifest: any) -> None: + # The manifest is the one constructed from disk and is untrusted. + # The statement is from the verified bundle and is trusted. + # Verify the type and version. + predicateType = verified_manifest["predicateType"] + if predicateType != self.predicate_type: + raise ValueError(f"invalid predicate type: {predicateType}") + files = verified_manifest["predicate"]["files"] + if len(self.paths) != len(files): + raise ValueError(f"mismatch number of files: expected {len(files)}, got {len(self.paths)}") + for i in range(len(self.paths)): + actual_path = self.paths[i] + verified_path = files[i] + # Verify the path. + if actual_path.path != verified_path["path"]: + raise ValueError(f"mismatch path name: expected '{verified_path['path']}'. Got '{actual_path.path}'") + # Verify the hash name in verified manifest. + if str(DigestAlgorithm.SHA256_P1) not in verified_path["digest"]: + raise ValueError(f"unrecognized hash algorithm: {set(verified_path['digest'].keys())}") + # Verify the hash name in actual path. + if actual_path.hashed.algorithm != DigestAlgorithm.SHA256_P1: + raise ValueError(f"internal error: algorithm {str(actual_path.hashed.algorithm)}") + # Verify the hash value. + verified_digest = verified_path["digest"][str(actual_path.hashed.algorithm)] + if actual_path.hashed.digest.hex() != verified_digest: + raise ValueError(f"mismatch hash for file '{actual_path.path}': expected '{verified_digest}'. Got '{actual_path.hashed.digest.hex()}'") + + + def to_intoto_statement(self) -> dsse.Statement: + # See example at https://github.com/in-toto/attestation/blob/main/python/tests/test_statement.py. + files: [any] = [] + for _, p in enumerate(self.paths): + f = { + "path": p.path, + "digest": { + str(p.hashed.algorithm): p.hashed.digest.hex(), + }, + } + files += [f] + stmt = ( + dsse._StatementBuilder() + .subjects( + [dsse._Subject(name="-", digest={"sha256": "-"})] + ) + .predicate_type(self.predicate_type) + .predicate( + { + "files": files, + } + ) + ).build() + return stmt diff --git a/model_signing/model.py b/model_signing/model.py index efc55063..4e29d57b 100644 --- a/model_signing/model.py +++ b/model_signing/model.py @@ -20,26 +20,27 @@ Issuer, detect_credential, ) -from sigstore_protobuf_specs.dev.sigstore.bundle.v1 import Bundle + +from sigstore_protobuf_specs.dev.sigstore.bundle import v1 as bundle_v1 + from sigstore.verify import ( policy, Verifier, -) -from sigstore.verify.models import ( - VerificationMaterials, + Bundle, ) from sigstore._internal.fulcio.client import ( ExpiredCertificate, ) -import io +import json from pathlib import Path from typing import Optional from serialize import Serializer import psutil import sys +from _manifest import Manifest def chunk_size() -> int: return int(psutil.virtual_memory().available // 2) @@ -83,7 +84,6 @@ def get_identity_token(self) -> Optional[IdentityToken]: # so we can return early. if token: return IdentityToken(token) - # TODO(): Support staging for testing. if self.oidc_issuer is not None: issuer = Issuer(self.oidc_issuer) @@ -101,16 +101,19 @@ def sign(self, inputfn: Path, signaturefn: Path, oidc_token = self.get_identity_token() if not oidc_token: raise ValueError("No identity token supplied or detected!") - print(f"identity-provider: {oidc_token.issuer}", - file=sys.stderr) + #print(f"identity-provider: {oidc_token.issuer}", + # file=sys.stderr) print(f"identity: {oidc_token.identity}", file=sys.stderr) - contentio = io.BytesIO(Serializer.serialize_v1( - inputfn, chunk_size(), signaturefn, ignorepaths)) + serialized_paths = Serializer.serialize_v2( + inputfn, chunk_size(), signaturefn, ignorepaths) with self.signing_ctx.signer(oidc_token) as signer: - result = signer.sign(input_=contentio) - with signaturefn.open(mode="w") as b: - print(result.to_bundle().to_json(), file=b) + manifest = Manifest(serialized_paths) + bundle = signer.sign_intoto(input_=manifest.to_intoto_statement()) + signaturefn.write_bytes(bundle.to_json().encode('utf-8')) + ## TODO: Check that sign() does verify the signature. + verifier = Verifier.production() + _, _ = verifier.verify_dsse(bundle, policy.UnsafeNoOp()) return SignatureResult() except ExpiredIdentity: return SignatureResult(success=False, @@ -140,22 +143,19 @@ def verify(self, inputfn: Path, signaturefn: Path, ignorepaths: [Path], offline: bool) -> VerificationResult: try: bundle_bytes = signaturefn.read_bytes() - bundle = Bundle().from_json(bundle_bytes) - - material: tuple[Path, VerificationMaterials] - contentio = io.BytesIO(Serializer.serialize_v1( - inputfn, chunk_size(), signaturefn, ignorepaths)) - material = VerificationMaterials.from_bundle(input_=contentio, - bundle=bundle, - offline=offline) + bundle = Bundle.from_json(bundle_bytes) policy_ = policy.Identity( identity=self.identity, issuer=self.oidc_provider, ) - result = self.verifier.verify(materials=material, policy=policy_) - if result: - return VerificationResult() - return VerificationResult(success=False, reason=result.reason) + payload_type, payload = self.verifier.verify_dsse(bundle, policy_) + if payload_type != "application/vnd.in-toto+json": + raise ValueError(f"invalid payload type {payload_type}") + serialized_paths = Serializer.serialize_v2( + inputfn, chunk_size(), signaturefn, ignorepaths) + manifest = Manifest(serialized_paths) + manifest.verify(json.loads(payload)) + return VerificationResult() except Exception as e: return VerificationResult(success=False, reason=f"exception caught: {str(e)}") diff --git a/model_signing/serialize.py b/model_signing/serialize.py index 2a745d1f..a0e58ccc 100644 --- a/model_signing/serialize.py +++ b/model_signing/serialize.py @@ -19,6 +19,9 @@ from multiprocessing import get_start_method, set_start_method from pathlib import Path import platform +from typing import Callable + +from _manifest import PathMetadata, DigestAlgorithm, Hashed # Use for testing while keeping disk size low. allow_symlinks = False @@ -111,7 +114,7 @@ def remove_prefix(text, prefix): return text -def validate_signature_path(model_path: Path, sig_path: Path): +def _validate_signature_path(model_path: Path, sig_path: Path): if model_path.is_file(): return # Note: Only allow top-level folder to have the signature for simplicity. @@ -131,7 +134,7 @@ def is_relative_to(p: Path, path_list: [Path]) -> bool: class Serializer: @staticmethod # TODO: type of returned value. - def _ordered_files(path: Path, ignorepaths: [Path]) -> []: + def _ordered_files(path: Path, ignorepaths: [Path], ignore_folder: bool = False) -> []: children: [Path] if path.is_file(): children = [path] @@ -158,6 +161,9 @@ def _ordered_files(path: Path, ignorepaths: [Path]) -> []: if not child.is_file() and not child.is_dir(): raise ValueError(f"{str(child)} is not a dir or file") + + if ignore_folder and child.is_dir(): + continue # The recorded path must *not* contains the folder name, # since users may rename it. @@ -226,7 +232,7 @@ def _create_tasks(children: [], shard_size: int) -> [[]]: @staticmethod # TODO: type of tasks - def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes: + def _run_tasks(path: Path, chunk: int, tasks: [], fn: Callable[[], bytes]) -> bytes: # See https://superfastpython.com/processpoolexecutor-in-python/ # NOTE: 32 = length of sha256 digest. digest_len = 32 @@ -237,7 +243,7 @@ def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes: if platform.system() == "Linux" and get_start_method() != "fork": set_start_method('fork') with ProcessPoolExecutor() as ppe: - futures = [ppe.submit(Serializer.task, (path, chunk, task)) + futures = [ppe.submit(fn, (path, chunk, task)) for task in tasks] results = [f.result() for f in futures] for i, result in enumerate(results): @@ -249,7 +255,7 @@ def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes: @staticmethod # TODO: type of task_info. - def task(task_info: []): + def _task_v1(task_info: any) -> bytes: # NOTE: we can get process info using: # from multiprocessing import current_process # worker = current_process() @@ -303,7 +309,7 @@ def _serialize_v1(path: Path, chunk: int, shard: int, signature_path: Path, raise ValueError(f"{str(path)} is not a dir or file") # Validate the signature path. - validate_signature_path(path, signature_path) + _validate_signature_path(path, signature_path) # Children to hash. children = Serializer._ordered_files(path, @@ -317,11 +323,100 @@ def _serialize_v1(path: Path, chunk: int, shard: int, signature_path: Path, # Share the computation of hashes. # For simplicity, we pre-allocate the entire array that will hold # the concatenation of all hashes. - all_hashes = Serializer._run_tasks(path, chunk, tasks) + all_hashes = Serializer._run_tasks(path, chunk, tasks, Serializer._task_v1) # Finally, we hash everything. return hashlib.sha256(bytes(all_hashes)).digest() + @staticmethod + # TODO: type of task_info. + def _task_v2(task_info: any) -> bytes: + # NOTE: we can get process info using: + # from multiprocessing import current_process + # worker = current_process() + # print(f'Task {task_info}, + # worker name={worker.name}, pid={worker.pid}', flush=True) + model_path, chunk, (name, ty, start_pos, end_pos) = task_info + # Only files are recorded. + if ty != "file": + raise ValueError(f"internal: got a non-file path {name}") + + return Hasher._node_file_compute_v1(model_path.joinpath(name), + b'', start_pos, end_pos, chunk) + + @staticmethod + def _to_path_metadata(tasks_info: [any], all_hashes: bytes) -> [PathMetadata]: + if not tasks_info: + raise ValueError("internal: tasks_info is empty") + paths: [PathMetadata] = [] + # Iterate over all tasks. + prev_task = tasks_info[0] + prev_name, _, _, _ = prev_task + h = hashlib.sha256(bytes(all_hashes[0: 32])) + for curr_i, curr_task in enumerate(tasks_info[1:]): + curr_name, _, _, _ = curr_task + if prev_name == curr_name: + h.update(bytes(all_hashes[curr_i*32: (curr_i + 1)*32])) + continue + # End of a group of sharded digests for the same file. + # NOTE: each digest is 32-byte long. + paths += [PathMetadata(prev_name, Hashed(DigestAlgorithm.SHA256_P1, h.digest()))] + # Compute the hash for the next group. + h.update(bytes(all_hashes[curr_i*32: (curr_i + 1)*32])) + prev_name = curr_name + + # Compute the digest for the last (unfinished) group. + paths += [PathMetadata(prev_name, Hashed(DigestAlgorithm.SHA256_P1, h.digest()))] + # TODO: Test this function properly. + # paths += [PathMetadata("path/to/file1", Hashed(DigestAlgorithm.SHA256_P1, b'\abcdef1'))] + # paths += [PathMetadata("path/to/file2", Hashed(DigestAlgorithm.SHA256_P1, b'\abcdef2'))] + return paths + + @staticmethod + def _serialize_v2(path: Path, chunk: int, shard: int, signature_path: Path, + ignorepaths: [Path] = []) -> bytes: + if not path.exists(): + raise ValueError(f"{str(path)} does not exist") + + if not allow_symlinks and path.is_symlink(): + raise ValueError(f"{str(path)} is a symlink") + + if chunk < 0: + raise ValueError(f"{str(chunk)} is invalid") + + if not path.is_file() and not path.is_dir(): + raise ValueError(f"{str(path)} is not a dir or file") + + # Validate the signature path. + _validate_signature_path(path, signature_path) + + # Children to hash. + children = Serializer._ordered_files(path, + [signature_path] + ignorepaths, + True) + + # We shard the computation by creating independent "tasks". + if shard < 0: + raise ValueError(f"{str(shard)} is invalid") + tasks = Serializer._create_tasks(children, shard) + + # Share the computation of hashes. + # For simplicity, we pre-allocate the entire array that will hold + # the concatenation of all hashes. + all_hashes = Serializer._run_tasks(path, chunk, tasks, Serializer._task_v2) + + # Turn hashes into PathMedata + return Serializer._to_path_metadata(tasks, all_hashes) + + def serialize_v2(path: Path, chunk: int, signature_path: Path, + ignorepaths: [Path] = []) -> [PathMetadata]: + # NOTE: The shard size must be the same for all clients for + # compatibility. We could make it configurable; but in this + # case the signature file must contain the value used by the signer. + shard_size = 1000000000 # 1GB + return Serializer._serialize_v2(path, chunk, shard_size, + signature_path, ignorepaths) + def serialize_v1(path: Path, chunk: int, signature_path: Path, ignorepaths: [Path] = []) -> bytes: # NOTE: The shard size must be the same for all clients for @@ -350,7 +445,7 @@ def serialize_v0(path: Path, chunk: int, signature_path: Path, raise ValueError(f"{str(path)} is not a dir") # Validate the signature path. - validate_signature_path(path, signature_path) + _validate_signature_path(path, signature_path) children = sorted([x for x in path.iterdir() if x != signature_path and x not in ignorepaths])