Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions docs/specs/code/fs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#
# Runs in sage if the pycryptodome package is installed in the environment.
# `sage --pip install pycryptodome`
# In sage, use load("fs.py") to run this test
#
from Crypto.Cipher import AES

import hashlib
import struct
import math

def hash(data):
assert isinstance(data, bytes), "data not bytes"
return hashlib.sha256(data).digest()

class FSPRF:
"""
Fiat-Shamir Pseudorandom Function object.
Produces an infinite stream of bytes organized in 16-byte blocks.
Block i = AES256(SEED, ID(i))
"""
def __init__(self, seed: bytes):
assert len(seed) == 32, "Seed must be 32 bytes (AES-256 key size)."
self.counter = 0
self.buffer = bytearray()
self.cipher = AES.new(seed, AES.MODE_ECB)

def bytes(self, n: int) -> bytes:
"""Returns the next n bytes in the stream."""
# Fill buffer until we have enough bytes
while len(self.buffer) < n:
# block_id is the 16-byte little-endian representation of integer i
block_id = self.counter.to_bytes(16, 'little')

# Block i = AES256(SEED, ID(i))
block_output = self.cipher.encrypt(block_id)

self.buffer.extend(block_output)
self.counter += 1

# Consume n bytes from the front of the buffer
result = self.buffer[:n]
self.buffer = self.buffer[n:]
return bytes(result)


class Transcript:
def __init__(self):
self.tr = bytearray()
self._is_initialized = False
self._fs = None
self._tr_snapshot_len = 0

def init(self, session_id: bytes):
"""
Initializes the transcript with a session_id.
Must be called exactly once before any other method.
"""
assert not self._is_initialized, "Transcript.init() must be called exactly once."
self._is_initialized = True
self.write_bytes(session_id)

def write_field(self, elt, sz = 32):
assert self._is_initialized, "init not called"
self.tr.append(0x01)
self.tr.extend( int(elt).to_bytes(sz, byteorder="little"))

def write_bytes(self, b):
assert self._is_initialized, "init not called"
self.tr.append(0x00)
# packs an unsigned long long (8 bytes) in Little Endian (<)
length_prefix = struct.pack('<Q', len(b))
self.tr.extend(length_prefix)
self.tr.extend(b)

def write_field_element_array(self, elems, sz=32):
"""
Spec: Append byte designator 0x3, 8-byte LE count, then serialized elements.
"""
assert self._is_initialized, "init not called"
self.tr.append(0x02)
count_prefix = struct.pack('<Q', len(elems))
self.tr.extend(count_prefix)

for elem in elems:
self.tr.extend( int(elem).to_bytes(sz, byteorder="little"))


def _get_fs(self) -> FSPRF:
"""
Retrieves the current FSPRF object.
If 'write' has been called since the last retrieval, a new FSPRF
is seeded using H(tr).
"""
assert self._is_initialized, "init not called"
# If the transcript has changed, create a new FSPRF.
if self._fs is None or len(self.tr) != self._tr_snapshot_len:
# Spec: "Next, a seed is generated by applying the function H to the (entire) string tr."
seed = hash(bytes(self.tr))
self._fs = FSPRF(seed)
self._tr_snapshot_len = len(self.tr)

return self._fs

def generate_nat(self, m):
"""
Generates a random natural number between 0 and m-1 inclusive via rejection sampling.
"""
assert m > 0, "m must be > 0"

l = m.bit_length()
nbytes = math.ceil(l / 8)
mask = (1 << l) - 1 # Bitmask to isolate lower l bits
fs = self._get_fs()
while True:
b = fs.bytes(nbytes)
k = int.from_bytes(b, 'little')
r = k & mask
if r < m:
return r

def generate_field(self, p):
fs = self._get_fs()
sz = math.ceil(p.bit_length() / 8)
while True:
b = fs.bytes(sz)
x = int.from_bytes(b, byteorder='little', signed=False)
if x < p:
return x

def generate_nats_wo_replacement(self, m, n):
assert m > n, "invalid parameter"
A = list(range(0, m))
for i in range(0, n):
j = i + self.generate_nat(m - i)
A[i], A[j] = A[j], A[i]
return A[:n]


# --- Test Example ---

if __name__ == "__main__":
t = Transcript()

p = 115792089210356248762697446949407573530086143415290314195533631308867097853951
session_id = b"test"
t.init(session_id)

arr = bytearray()
for bi in range(0, 100):
arr.append(bi)
t.write_bytes(arr)

tv1 = [t.generate_field(p) for i in range(0,16)]
for ti in tv1:
print(hex(ti))

t.write_field(7)

tv2 = [t.generate_field(p) for i in range(0,16)]
for ti in tv2:
print(hex(ti))

fe_array = [(8), (9)]
t.write_field_element_array(fe_array)

tv3 = [t.generate_field(p) for i in range(0,16)]
for ti in tv3:
print(hex(ti))

t.write_bytes(b'nats')

ns = [1, 1, 1, 2, 2, 2, 7, 7, 7, 7, 32, 32, 32, 32,
256, 256, 256, 256, 1000, 10000, 60000, 65535, 100000, 100000]
nats = [t.generate_nat(n) for n in ns]
print(nats)

t.write_bytes(b'choose')
choose_sizes = [31, 32, 63, 64, 1000, 65535]
for cs in choose_sizes:
gotc = t.generate_nats_wo_replacement(cs, 20)
print(gotc)

150 changes: 150 additions & 0 deletions docs/specs/code/merkle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
## Code from the spec for operating on MerkleTree datastructures.

import hashlib

def hash(data):
assert isinstance(data, bytes), "data not bytes"
return hashlib.sha256(data).digest()

class MerkleTree:
def __init__(self, n):
self.n = n
self.a = [b''] * (2 * n)

def set_leaf(self, pos, leaf):
"""
Sets a leaf at a specific position.
pos: 0-based index relative to the leaves (0 to n-1)
"""
assert 0 <= pos < self.n, f"{pos} is out of bounds"
self.a[pos + self.n] = leaf

def build_tree(self):
"""
Computes the internal nodes from n-1 down to 1.
Returns the root (M.a[1]).
"""
for i in range(self.n - 1, 0, -1):
left = self.a[2 * i]
right = self.a[2 * i + 1]

self.a[i] = hash(left + right)

return self.a[1]

def mark_tree(self, requested_leaves):
marked = [False] * (2 * self.n)

for i in requested_leaves:
assert 0 <= i < self.n, f"invalid requested index {i}"
marked[i + self.n] = True

for i in range(self.n - 1, 0, -1):
marked[i] = marked[2 * i] or marked[2 * i + 1]

return marked

def compressed_proof(self, requested_leaves):
"""
Generates a compressed proof for the requested leaves.
"""
proof = []

marked = self.mark_tree(requested_leaves)

for i in range(self.n - 1, 0, -1):
if marked[i]:
child = 2 * i

# If the left child is marked, we need the right child (sibling).
if marked[child]:
child += 1

# If the identified child/sibling is NOT marked,
# we must provide its hash in the proof so the verifier can calculate the parent.
if not marked[child]:
proof.append(self.a[child])

return proof

def verify_merkle(self, root, n, k, s, indices, proof):
"""
Verifies that the provided leaves (s) at specific positions (indices)
are part of the Merkle tree defined by 'root'.

:param root: The expected Root Hash
:param n: Total number of leaves in the tree
:param k: Number of leaves being verified
:param s: List of leaf data/hashes to verify
:param indices: List of positions for the leaves in 's'
:param proof: List of proof hashes
"""
tmp = [None] * (2 * n)
defined = [False] * (2 * n)

proof_index = 0

if n != self.n: return False

marked = self.mark_tree(indices)

for i in range(n - 1, 0, -1):
if marked[i]:
child = 2 * i
if marked[child]:
child += 1

if not marked[child]:
if proof_index >= len(proof):
return False

tmp[child] = proof[proof_index]
proof_index += 1
defined[child] = True

for i in range(k):
pos = indices[i] + n
tmp[pos] = s[i]
defined[pos] = True

for i in range(n - 1, 0, -1):
if defined[2 * i] and defined[2 * i + 1]:
left = tmp[2 * i]
right = tmp[2 * i + 1]
tmp[i] = hash(left + right)
defined[i] = True

return defined[1] and (tmp[1] == root)


if __name__ == "__main__":
# Example from the test vector section in the Appendix.
n = 5
mt = MerkleTree(n)

c0 = bytes.fromhex('4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a')
c1 = bytes.fromhex('dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986')
c3 = bytes.fromhex('e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71')
mt.set_leaf(0, c0)
mt.set_leaf(1, c1)
mt.set_leaf(2,bytes.fromhex('084fed08b978af4d7d196a7446a86b58009e636b611db16211b65a9aadff29c5'))
mt.set_leaf(3, c3)
mt.set_leaf(4,bytes.fromhex('e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db'))

root_hash = mt.build_tree()

print(f"Merkle Root: {root_hash.hex()}")

print(f"Requesting [0,1]:")
req_leaves = [0, 1]
proof = mt.compressed_proof(req_leaves)
for p in proof:
print(p.hex())
assert mt.verify_merkle(root_hash, n, 2, [c0, c1], [0, 1], proof), "Bad proof"

print(f"Requesting [1,3]:")
req_leaves = [1, 3]
proof = mt.compressed_proof(req_leaves)
for p in proof:
print(p.hex())
assert mt.verify_merkle(root_hash, n, 2, [c1, c3], [1, 3], proof), "Bad proof"
Loading
Loading