Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP feat: Use manifest #112

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions model_signing/_manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing perepo_managerissions and
# limitations under the License.

from enum import Enum
from typing import List

from sigstore import dsse

class DigestAlgorithm(Enum):
SHA256_P1 = 1

def __str__(self):
return str(self.name).replace("_", "-").lower()
@staticmethod
def from_string(s: str):
return DigestAlgorithm[s.replace("-", "_")]

class Hashed:
algorithm: DigestAlgorithm
digest: bytes
def __init__(self, algorithm_: DigestAlgorithm, digest_: bytes):
self.algorithm = algorithm_
self.digest = digest_

class PathMetadata:
hashed: Hashed
path: str
def __init__(self, path_: str, hashed_: Hashed):
self.path = path_
self.hashed = hashed_

class Manifest:
paths: PathMetadata
predicate_type: str
def __init__(self, paths: [PathMetadata]):
self.paths = paths
self.predicate_type = "sigstore.dev/model-transparency/manifest/v1"

def verify(self, verified_manifest: any) -> None:
# The manifest is the one constructed from disk and is untrusted.
# The statement is from the verified bundle and is trusted.
# Verify the type and version.
predicateType = verified_manifest["predicateType"]
if predicateType != self.predicate_type:
raise ValueError(f"invalid predicate type: {predicateType}")
files = verified_manifest["predicate"]["files"]
if len(self.paths) != len(files):
raise ValueError(f"mismatch number of files: expected {len(files)}, got {len(self.paths)}")
for i in range(len(self.paths)):
actual_path = self.paths[i]
verified_path = files[i]
# Verify the path.
if actual_path.path != verified_path["path"]:
raise ValueError(f"mismatch path name: expected '{verified_path['path']}'. Got '{actual_path.path}'")
# Verify the hash name in verified manifest.
if str(DigestAlgorithm.SHA256_P1) not in verified_path["digest"]:
raise ValueError(f"unrecognized hash algorithm: {set(verified_path['digest'].keys())}")
# Verify the hash name in actual path.
if actual_path.hashed.algorithm != DigestAlgorithm.SHA256_P1:
raise ValueError(f"internal error: algorithm {str(actual_path.hashed.algorithm)}")
# Verify the hash value.
verified_digest = verified_path["digest"][str(actual_path.hashed.algorithm)]
if actual_path.hashed.digest.hex() != verified_digest:
raise ValueError(f"mismatch hash for file '{actual_path.path}': expected '{verified_digest}'. Got '{actual_path.hashed.digest.hex()}'")


def to_intoto_statement(self) -> dsse.Statement:
# See example at https://github.com/in-toto/attestation/blob/main/python/tests/test_statement.py.
files: [any] = []
for _, p in enumerate(self.paths):
f = {
"path": p.path,
"digest": {
str(p.hashed.algorithm): p.hashed.digest.hex(),
},
}
files += [f]
stmt = (
dsse._StatementBuilder()
.subjects(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the subjects be the files?

My mental model had the statement subjects as the artifacts we want to sign so you get subject[n].name = path and subject[n].digest = file_hash. So we would only set the predicate_type field as sigstore.dev/model-transparency/manifest/v1 and that describes everything we need to know about the subjects. Furthermore, the predicate could then hold metadata like {"hash_algorithm": "abc"}.

PLMK if there is something I don't get here.

See https://github.com/in-toto/attestation/blob/main/spec/v1/statement.md

Copy link
Collaborator Author

@laurentsimon laurentsimon Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct that's how intoto works, but there is some discussion here about whether we want to use intoto or not #111

For this PR, I could not pack files in the subject due to some current limitation on the hash type accepted by sigstore-python library (they only accept known hashes, but we use a parallel hash). And I could not use a non-intoto payload for this PoC, because sigstore-python does not support it yet (WIP).

[dsse._Subject(name="-", digest={"sha256": "-"})]
)
.predicate_type(self.predicate_type)
.predicate(
{
"files": files,
}
)
).build()
return stmt
50 changes: 25 additions & 25 deletions model_signing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@
Issuer,
detect_credential,
)
from sigstore_protobuf_specs.dev.sigstore.bundle.v1 import Bundle

from sigstore_protobuf_specs.dev.sigstore.bundle import v1 as bundle_v1

from sigstore.verify import (
policy,
Verifier,
)
from sigstore.verify.models import (
VerificationMaterials,
Bundle,
)

from sigstore._internal.fulcio.client import (
ExpiredCertificate,
)

import io
import json
from pathlib import Path
from typing import Optional
from serialize import Serializer
import psutil
import sys

from _manifest import Manifest

def chunk_size() -> int:
return int(psutil.virtual_memory().available // 2)
Expand Down Expand Up @@ -83,7 +84,6 @@ def get_identity_token(self) -> Optional[IdentityToken]:
# so we can return early.
if token:
return IdentityToken(token)

# TODO(): Support staging for testing.
if self.oidc_issuer is not None:
issuer = Issuer(self.oidc_issuer)
Expand All @@ -101,16 +101,19 @@ def sign(self, inputfn: Path, signaturefn: Path,
oidc_token = self.get_identity_token()
if not oidc_token:
raise ValueError("No identity token supplied or detected!")
print(f"identity-provider: {oidc_token.issuer}",
file=sys.stderr)
#print(f"identity-provider: {oidc_token.issuer}",
# file=sys.stderr)
print(f"identity: {oidc_token.identity}", file=sys.stderr)

contentio = io.BytesIO(Serializer.serialize_v1(
inputfn, chunk_size(), signaturefn, ignorepaths))
serialized_paths = Serializer.serialize_v2(
inputfn, chunk_size(), signaturefn, ignorepaths)
with self.signing_ctx.signer(oidc_token) as signer:
result = signer.sign(input_=contentio)
with signaturefn.open(mode="w") as b:
print(result.to_bundle().to_json(), file=b)
manifest = Manifest(serialized_paths)
bundle = signer.sign_intoto(input_=manifest.to_intoto_statement())
signaturefn.write_bytes(bundle.to_json().encode('utf-8'))
## TODO: Check that sign() does verify the signature.
verifier = Verifier.production()
_, _ = verifier.verify_dsse(bundle, policy.UnsafeNoOp())
return SignatureResult()
except ExpiredIdentity:
return SignatureResult(success=False,
Expand Down Expand Up @@ -140,22 +143,19 @@ def verify(self, inputfn: Path, signaturefn: Path,
ignorepaths: [Path], offline: bool) -> VerificationResult:
try:
bundle_bytes = signaturefn.read_bytes()
bundle = Bundle().from_json(bundle_bytes)

material: tuple[Path, VerificationMaterials]
contentio = io.BytesIO(Serializer.serialize_v1(
inputfn, chunk_size(), signaturefn, ignorepaths))
material = VerificationMaterials.from_bundle(input_=contentio,
bundle=bundle,
offline=offline)
bundle = Bundle.from_json(bundle_bytes)
policy_ = policy.Identity(
identity=self.identity,
issuer=self.oidc_provider,
)
result = self.verifier.verify(materials=material, policy=policy_)
if result:
return VerificationResult()
return VerificationResult(success=False, reason=result.reason)
payload_type, payload = self.verifier.verify_dsse(bundle, policy_)
if payload_type != "application/vnd.in-toto+json":
raise ValueError(f"invalid payload type {payload_type}")
serialized_paths = Serializer.serialize_v2(
inputfn, chunk_size(), signaturefn, ignorepaths)
manifest = Manifest(serialized_paths)
manifest.verify(json.loads(payload))
return VerificationResult()
except Exception as e:
return VerificationResult(success=False,
reason=f"exception caught: {str(e)}")
Expand Down
111 changes: 103 additions & 8 deletions model_signing/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from multiprocessing import get_start_method, set_start_method
from pathlib import Path
import platform
from typing import Callable

from _manifest import PathMetadata, DigestAlgorithm, Hashed

# Use for testing while keeping disk size low.
allow_symlinks = False
Expand Down Expand Up @@ -111,7 +114,7 @@ def remove_prefix(text, prefix):
return text


def validate_signature_path(model_path: Path, sig_path: Path):
def _validate_signature_path(model_path: Path, sig_path: Path):
if model_path.is_file():
return
# Note: Only allow top-level folder to have the signature for simplicity.
Expand All @@ -131,7 +134,7 @@ def is_relative_to(p: Path, path_list: [Path]) -> bool:
class Serializer:
@staticmethod
# TODO: type of returned value.
def _ordered_files(path: Path, ignorepaths: [Path]) -> []:
def _ordered_files(path: Path, ignorepaths: [Path], ignore_folder: bool = False) -> []:
children: [Path]
if path.is_file():
children = [path]
Expand All @@ -158,6 +161,9 @@ def _ordered_files(path: Path, ignorepaths: [Path]) -> []:

if not child.is_file() and not child.is_dir():
raise ValueError(f"{str(child)} is not a dir or file")

if ignore_folder and child.is_dir():
continue

# The recorded path must *not* contains the folder name,
# since users may rename it.
Expand Down Expand Up @@ -226,7 +232,7 @@ def _create_tasks(children: [], shard_size: int) -> [[]]:

@staticmethod
# TODO: type of tasks
def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes:
def _run_tasks(path: Path, chunk: int, tasks: [], fn: Callable[[], bytes]) -> bytes:
# See https://superfastpython.com/processpoolexecutor-in-python/
# NOTE: 32 = length of sha256 digest.
digest_len = 32
Expand All @@ -237,7 +243,7 @@ def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes:
if platform.system() == "Linux" and get_start_method() != "fork":
set_start_method('fork')
with ProcessPoolExecutor() as ppe:
futures = [ppe.submit(Serializer.task, (path, chunk, task))
futures = [ppe.submit(fn, (path, chunk, task))
for task in tasks]
results = [f.result() for f in futures]
for i, result in enumerate(results):
Expand All @@ -249,7 +255,7 @@ def _run_tasks(path: Path, chunk: int, tasks: []) -> bytes:

@staticmethod
# TODO: type of task_info.
def task(task_info: []):
def _task_v1(task_info: any) -> bytes:
# NOTE: we can get process info using:
# from multiprocessing import current_process
# worker = current_process()
Expand Down Expand Up @@ -303,7 +309,7 @@ def _serialize_v1(path: Path, chunk: int, shard: int, signature_path: Path,
raise ValueError(f"{str(path)} is not a dir or file")

# Validate the signature path.
validate_signature_path(path, signature_path)
_validate_signature_path(path, signature_path)

# Children to hash.
children = Serializer._ordered_files(path,
Expand All @@ -317,11 +323,100 @@ def _serialize_v1(path: Path, chunk: int, shard: int, signature_path: Path,
# Share the computation of hashes.
# For simplicity, we pre-allocate the entire array that will hold
# the concatenation of all hashes.
all_hashes = Serializer._run_tasks(path, chunk, tasks)
all_hashes = Serializer._run_tasks(path, chunk, tasks, Serializer._task_v1)

# Finally, we hash everything.
return hashlib.sha256(bytes(all_hashes)).digest()

@staticmethod
# TODO: type of task_info.
def _task_v2(task_info: any) -> bytes:
# NOTE: we can get process info using:
# from multiprocessing import current_process
# worker = current_process()
# print(f'Task {task_info},
# worker name={worker.name}, pid={worker.pid}', flush=True)
model_path, chunk, (name, ty, start_pos, end_pos) = task_info
# Only files are recorded.
if ty != "file":
raise ValueError(f"internal: got a non-file path {name}")

return Hasher._node_file_compute_v1(model_path.joinpath(name),
b'', start_pos, end_pos, chunk)

@staticmethod
def _to_path_metadata(tasks_info: [any], all_hashes: bytes) -> [PathMetadata]:
if not tasks_info:
raise ValueError("internal: tasks_info is empty")
paths: [PathMetadata] = []
# Iterate over all tasks.
prev_task = tasks_info[0]
prev_name, _, _, _ = prev_task
h = hashlib.sha256(bytes(all_hashes[0: 32]))
for curr_i, curr_task in enumerate(tasks_info[1:]):
curr_name, _, _, _ = curr_task
if prev_name == curr_name:
h.update(bytes(all_hashes[curr_i*32: (curr_i + 1)*32]))
continue
# End of a group of sharded digests for the same file.
# NOTE: each digest is 32-byte long.
paths += [PathMetadata(prev_name, Hashed(DigestAlgorithm.SHA256_P1, h.digest()))]
# Compute the hash for the next group.
h.update(bytes(all_hashes[curr_i*32: (curr_i + 1)*32]))
prev_name = curr_name

# Compute the digest for the last (unfinished) group.
paths += [PathMetadata(prev_name, Hashed(DigestAlgorithm.SHA256_P1, h.digest()))]
# TODO: Test this function properly.
# paths += [PathMetadata("path/to/file1", Hashed(DigestAlgorithm.SHA256_P1, b'\abcdef1'))]
# paths += [PathMetadata("path/to/file2", Hashed(DigestAlgorithm.SHA256_P1, b'\abcdef2'))]
return paths

@staticmethod
def _serialize_v2(path: Path, chunk: int, shard: int, signature_path: Path,
ignorepaths: [Path] = []) -> bytes:
if not path.exists():
raise ValueError(f"{str(path)} does not exist")

if not allow_symlinks and path.is_symlink():
raise ValueError(f"{str(path)} is a symlink")

if chunk < 0:
raise ValueError(f"{str(chunk)} is invalid")

if not path.is_file() and not path.is_dir():
raise ValueError(f"{str(path)} is not a dir or file")

# Validate the signature path.
_validate_signature_path(path, signature_path)

# Children to hash.
children = Serializer._ordered_files(path,
[signature_path] + ignorepaths,
True)

# We shard the computation by creating independent "tasks".
if shard < 0:
raise ValueError(f"{str(shard)} is invalid")
tasks = Serializer._create_tasks(children, shard)

# Share the computation of hashes.
# For simplicity, we pre-allocate the entire array that will hold
# the concatenation of all hashes.
all_hashes = Serializer._run_tasks(path, chunk, tasks, Serializer._task_v2)

# Turn hashes into PathMedata
return Serializer._to_path_metadata(tasks, all_hashes)

def serialize_v2(path: Path, chunk: int, signature_path: Path,
ignorepaths: [Path] = []) -> [PathMetadata]:
# NOTE: The shard size must be the same for all clients for
# compatibility. We could make it configurable; but in this
# case the signature file must contain the value used by the signer.
shard_size = 1000000000 # 1GB
return Serializer._serialize_v2(path, chunk, shard_size,
signature_path, ignorepaths)

def serialize_v1(path: Path, chunk: int, signature_path: Path,
ignorepaths: [Path] = []) -> bytes:
# NOTE: The shard size must be the same for all clients for
Expand Down Expand Up @@ -350,7 +445,7 @@ def serialize_v0(path: Path, chunk: int, signature_path: Path,
raise ValueError(f"{str(path)} is not a dir")

# Validate the signature path.
validate_signature_path(path, signature_path)
_validate_signature_path(path, signature_path)

children = sorted([x for x in path.iterdir()
if x != signature_path and x not in ignorepaths])
Expand Down
Loading