diff --git a/clvm/SExp.py b/clvm/SExp.py index 48e90316..2f197c7a 100644 --- a/clvm/SExp.py +++ b/clvm/SExp.py @@ -179,9 +179,9 @@ def as_int(self) -> int: raise TypeError("Unable to convert a pair to an int") return int_from_bytes(self.atom) - def as_bin(self) -> bytes: + def as_bin(self, *, allow_backrefs: bool = False) -> bytes: f = io.BytesIO() - sexp_to_stream(self, f) + sexp_to_stream(self, f, allow_backrefs=allow_backrefs) return f.getvalue() # TODO: should be `v: CastableType` diff --git a/clvm/object_cache.py b/clvm/object_cache.py new file mode 100644 index 00000000..335e9357 --- /dev/null +++ b/clvm/object_cache.py @@ -0,0 +1,106 @@ +from typing import Callable, Dict, Generic, Optional, Tuple, TypeVar + +import hashlib + +from .CLVMObject import CLVMStorage + +T = TypeVar("T") + + +class ObjectCache(Generic[T]): + """ + `ObjectCache` provides a way to calculate and cache values for each node + in a clvm object tree. It can be used to calculate the sha256 tree hash + for an object and save the hash for all the child objects for building + usage tables, for example. + + It also allows a function that's defined recursively on a clvm tree to + have a non-recursive implementation (as it keeps a stack of uncached + objects locally). + """ + + def __init__(self, f: Callable[["ObjectCache[T]", CLVMStorage], Optional[T]]): + """ + `f`: Callable[ObjectCache, CLVMObject] -> Union[None, T] + + The function `f` is expected to calculate its T value recursively based + on the T values for the left and right child for a pair. For an atom, the + function f must calculate the T value directly. + + If a pair is passed and one of the children does not have its T value cached + in `ObjectCache` yet, return `None` and f will be called with each child in turn. + Don't recurse in f; that's part of the point of this function. + """ + self.f = f + self.lookup: Dict[int, Tuple[T, CLVMStorage]] = dict() + + def get(self, obj: CLVMStorage) -> T: + obj_id = id(obj) + if obj_id not in self.lookup: + obj_list = [obj] + while obj_list: + node = obj_list.pop() + node_id = id(node) + if node_id not in self.lookup: + v = self.f(self, node) + if v is None: + if node.pair is None: + raise ValueError("f returned None for atom", node) + obj_list.append(node) + obj_list.append(node.pair[0]) + obj_list.append(node.pair[1]) + else: + self.lookup[node_id] = (v, node) + return self.lookup[obj_id][0] + + def contains(self, obj: CLVMStorage) -> bool: + return id(obj) in self.lookup + + +def treehash(cache: ObjectCache[bytes], obj: CLVMStorage) -> Optional[bytes]: + """ + This function can be fed to `ObjectCache` to calculate the sha256 tree + hash for all objects in a tree. + """ + if obj.pair: + left, right = obj.pair + + # ensure both `left` and `right` have cached values + if cache.contains(left) and cache.contains(right): + left_hash = cache.get(left) + right_hash = cache.get(right) + return hashlib.sha256(b"\2" + left_hash + right_hash).digest() + return None + assert obj.atom is not None + return hashlib.sha256(b"\1" + obj.atom).digest() + + +def serialized_length(cache: ObjectCache[int], obj: CLVMStorage) -> Optional[int]: + """ + This function can be fed to `ObjectCache` to calculate the serialized + length for all objects in a tree. + """ + if obj.pair: + left, right = obj.pair + + # ensure both `left` and `right` have cached values + if cache.contains(left) and cache.contains(right): + left_length = cache.get(left) + right_length = cache.get(right) + return 1 + left_length + right_length + return None + assert obj.atom is not None + lb = len(obj.atom) + if lb == 0 or (lb == 1 and obj.atom[0] < 128): + return 1 + if lb < 0x40: + return 1 + lb + if lb < 0x2000: + return 2 + lb + if lb < 0x100000: + return 3 + lb + if lb < 0x8000000: + return 4 + lb + if lb < 0x400000000: + return 5 + lb + raise ValueError("atom of size %d too long" % lb) diff --git a/clvm/read_cache_lookup.py b/clvm/read_cache_lookup.py new file mode 100644 index 00000000..b0d16f46 --- /dev/null +++ b/clvm/read_cache_lookup.py @@ -0,0 +1,178 @@ +from collections import Counter +from typing import Dict, Optional, List, Set, Tuple + +import hashlib + + +LEFT = 0 +RIGHT = 1 + + +class ReadCacheLookup: + """ + When deserializing a clvm object, a stack of deserialized child objects + is created, which can be used with back-references. A `ReadCacheLookup` keeps + track of the state of this stack and all child objects under each root + node in the stack so that we can quickly determine if a relevant + back-reference is available. + + In other words, if we've already serialized an object with tree hash T, + and we encounter another object with that tree hash, we don't re-serialize + it, but rather include a back-reference to it. This data structure lets + us quickly determine which back-reference has the shortest path. + + Note that there is a counter. This is because the stack contains some + child objects that are transient, and no longer appear in the stack + at later times in the parsing. We don't want to waste time looking for + these objects that no longer exist, so we reference-count them. + + All hashes correspond to sha256 tree hashes. + """ + + def __init__(self) -> None: + """ + Create a new `ReadCacheLookup` object with just the null terminator + (ie. an empty list of objects). + """ + self.root_hash = hashlib.sha256(b"\1").digest() + self.read_stack: List[Tuple[bytes, bytes]] = [] + self.count: Counter[bytes] = Counter() + self.parent_paths_for_child: Dict[bytes, List[Tuple[bytes, int]]] = {} + + def push(self, obj_hash: bytes) -> None: + """ + This function is used to note that an object with the given hash has just + been pushed to the read stack, and update the lookups as appropriate. + """ + # we add two new entries: the new root of the tree, and this object (by id) + # new_root: (obj_hash, old_root) + new_root_hash = hashlib.sha256(b"\2" + obj_hash + self.root_hash).digest() + + self.read_stack.append((obj_hash, self.root_hash)) + + self.count.update([obj_hash, new_root_hash]) + + new_parent_to_old_root = (new_root_hash, LEFT) + self.parent_paths_for_child.setdefault(obj_hash, list()).append( + new_parent_to_old_root + ) + + new_parent_to_id = (new_root_hash, RIGHT) + self.parent_paths_for_child.setdefault(self.root_hash, list()).append( + new_parent_to_id + ) + self.root_hash = new_root_hash + + def pop(self) -> Tuple[bytes, bytes]: + """ + This function is used to note that the top object has just been popped + from the read stack. Return the 2-tuple of the child hashes. + """ + item = self.read_stack.pop() + self.count[item[0]] -= 1 + self.count[self.root_hash] -= 1 + self.root_hash = item[1] + return item + + def pop2_and_cons(self) -> None: + """ + This function is used to note that a "pop-and-cons" operation has just + happened. We remove two objects, cons them together, and push the cons, + updating the internal look-ups as necessary. + """ + # we remove two items: the right side of each left/right pair + right = self.pop() + left = self.pop() + + self.count.update([left[0], right[0]]) + + new_root_hash = hashlib.sha256(b"\2" + left[0] + right[0]).digest() + + self.parent_paths_for_child.setdefault(left[0], list()).append( + (new_root_hash, LEFT) + ) + self.parent_paths_for_child.setdefault(right[0], list()).append( + (new_root_hash, RIGHT) + ) + self.push(new_root_hash) + + def find_paths(self, obj_hash: bytes, serialized_length: int) -> Set[bytes]: + """ + This function looks for a path from the root to a child node with a given hash + by using the read cache. + """ + valid_paths: Set[bytes] = set() + if serialized_length < 3: + return valid_paths + + seen_ids: Set[bytes] = set() + + max_bytes_for_path_encoding = serialized_length - 2 + # 1 byte for 0xfe, 1 min byte for savings + + max_path_length = max_bytes_for_path_encoding * 8 - 1 + seen_ids.add(obj_hash) + + partial_paths: List[Tuple[bytes, List[int]]] = [(obj_hash, [])] + + while partial_paths: + new_seen_ids = set(seen_ids) + new_partial_paths = [] + for node, path in partial_paths: + if node == self.root_hash: + valid_paths.add(reversed_path_to_bytes(path)) + continue + + parent_paths = self.parent_paths_for_child.get(node) + + if parent_paths: + for parent, direction in parent_paths: + if self.count[parent] > 0 and parent not in seen_ids: + new_path = list(path) + new_path.append(direction) + if len(new_path) > max_path_length: + return set() + new_partial_paths.append((parent, new_path)) + new_seen_ids.add(parent) + partial_paths = new_partial_paths + if valid_paths: + return valid_paths + seen_ids = set(new_seen_ids) + return valid_paths + + def find_path(self, obj_hash: bytes, serialized_length: int) -> Optional[bytes]: + r = self.find_paths(obj_hash, serialized_length) + return min(r) if len(r) > 0 else None + + +def reversed_path_to_bytes(path: List[int]) -> bytes: + """ + Convert a list of 0/1 (for left/right) values to a path expected by clvm. + + Reverse the list; convert to a binary number; prepend a 1; break into bytes. + + [] => bytes([0b1]) + [0] => bytes([0b10]) + [1] => bytes([0b11]) + [0, 0] => bytes([0b100]) + [0, 1] => bytes([0b101]) + [1, 0] => bytes([0b110]) + [1, 1] => bytes([0b111]) + [0, 0, 1] => bytes([0b1001]) + [1, 1, 1, 1, 0, 0, 0, 0, 1] => bytes([0b11, 0b11100001]) + """ + + byte_count = (len(path) + 1 + 7) >> 3 + v = bytearray(byte_count) + index = byte_count - 1 + mask = 1 + for p in reversed(path): + if p: + v[index] |= mask + if mask == 0x80: + index -= 1 + mask = 1 + else: + mask <<= 1 + v[index] |= mask + return bytes(v) diff --git a/clvm/serialize.py b/clvm/serialize.py index ad7c3016..9d0eb265 100644 --- a/clvm/serialize.py +++ b/clvm/serialize.py @@ -1,6 +1,7 @@ # decoding: # read a byte # if it's 0x80, it's nil (which might be same as 0) +# if it's 0xfe, it's a back-reference. Read an atom, and treat it as a path in the cache tree. # if it's 0xff, it's a cons box. Read two items, build cons # otherwise, number of leading set bits is length in bytes to read size # For example, if the bit fields of the first byte read are: @@ -12,37 +13,45 @@ # If the first byte read is one of the following: # 1000 0000 -> 0 bytes : nil # 0000 0000 -> 1 byte : zero (b'\x00') -from __future__ import annotations import io import typing -from .CLVMObject import CLVMObject, CLVMStorage - - -if typing.TYPE_CHECKING: - from .SExp import CastableType, SExp +from .read_cache_lookup import ReadCacheLookup +from .object_cache import ObjectCache, treehash, serialized_length +from .CLVMObject import CLVMStorage MAX_SINGLE_BYTE = 0x7F +BACK_REFERENCE = 0xFE CONS_BOX_MARKER = 0xFF - T = typing.TypeVar("T") +_T_CLVMStorage = typing.TypeVar("_T_CLVMStorage", bound=CLVMStorage) + +CS = typing.TypeVar("CS", bound=CLVMStorage) ToCLVMStorage = typing.Callable[ - [typing.Union[bytes, typing.Tuple[CLVMStorage, CLVMStorage]]], CLVMStorage + [typing.Union[CLVMStorage, bytes, typing.Tuple[CLVMStorage, CLVMStorage]]], + _T_CLVMStorage, ] +ValStackType = CLVMStorage + OpCallable = typing.Callable[ - ["OpStackType", "ValStackType", typing.BinaryIO, ToCLVMStorage], None + ["OpStackType[T]", ValStackType, typing.BinaryIO, ToCLVMStorage[T]], ValStackType ] -ValStackType = typing.List[CLVMStorage] -OpStackType = typing.List[OpCallable] +OpStackType = typing.List[OpCallable[T]] + +def sexp_to_byte_iterator( + sexp: CLVMStorage, *, allow_backrefs: bool = False +) -> typing.Iterator[bytes]: + if allow_backrefs: + yield from sexp_to_byte_iterator_with_backrefs(sexp) + return -def sexp_to_byte_iterator(sexp: CLVMStorage) -> typing.Iterator[bytes]: todo_stack = [sexp] while todo_stack: sexp = todo_stack.pop() @@ -56,6 +65,52 @@ def sexp_to_byte_iterator(sexp: CLVMStorage) -> typing.Iterator[bytes]: yield from atom_to_byte_iterator(sexp.atom) +def sexp_to_byte_iterator_with_backrefs(sexp: CLVMStorage) -> typing.Iterator[bytes]: + # in `read_op_stack`: + # "P" = "push" + # "C" = "pop two objects, create and push a new cons with them" + + read_op_stack = ["P"] + + write_stack = [sexp] + + read_cache_lookup = ReadCacheLookup() + + thc = ObjectCache(treehash) + slc = ObjectCache(serialized_length) + + while write_stack: + node_to_write = write_stack.pop() + op = read_op_stack.pop() + assert op == "P" + + node_serialized_length = slc.get(node_to_write) + + node_tree_hash = thc.get(node_to_write) + path = read_cache_lookup.find_path(node_tree_hash, node_serialized_length) + if path: + yield bytes([BACK_REFERENCE]) + yield from atom_to_byte_iterator(path) + read_cache_lookup.push(node_tree_hash) + elif node_to_write.pair: + left, right = node_to_write.pair + yield bytes([CONS_BOX_MARKER]) + write_stack.append(right) + write_stack.append(left) + read_op_stack.append("C") + read_op_stack.append("P") + read_op_stack.append("P") + else: + atom = node_to_write.atom + assert atom is not None + yield from atom_to_byte_iterator(atom) + read_cache_lookup.push(node_tree_hash) + + while read_op_stack[-1:] == ["C"]: + read_op_stack.pop() + read_cache_lookup.pop2_and_cons() + + def atom_to_byte_iterator(as_atom: bytes) -> typing.Iterator[bytes]: size = len(as_atom) if size == 0: @@ -97,17 +152,56 @@ def atom_to_byte_iterator(as_atom: bytes) -> typing.Iterator[bytes]: yield as_atom -def sexp_to_stream(sexp: SExp, f: typing.BinaryIO) -> None: - for b in sexp_to_byte_iterator(sexp): +def sexp_to_stream( + sexp: CLVMStorage, f: typing.BinaryIO, *, allow_backrefs: bool = False +) -> None: + for b in sexp_to_byte_iterator(sexp, allow_backrefs=allow_backrefs): f.write(b) +def msb_mask(byte: int) -> int: + byte |= byte >> 1 + byte |= byte >> 2 + byte |= byte >> 4 + return (byte + 1) >> 1 + + +def traverse_path( + obj: CLVMStorage, path: bytes, to_sexp: ToCLVMStorage[CS] +) -> CLVMStorage: + path_as_int = int.from_bytes(path, "big") + if path_as_int == 0: + return to_sexp(b"") + + while path_as_int > 1: + if obj.pair is None: + raise ValueError("path into atom", obj) + obj = obj.pair[path_as_int & 1] + path_as_int >>= 1 + + return obj + + +def _op_cons( + op_stack: OpStackType[CS], + val_stack: ValStackType, + f: typing.BinaryIO, + to_sexp: ToCLVMStorage[CS], +) -> ValStackType: + assert val_stack.pair is not None + right, val_stack = val_stack.pair + assert val_stack.pair is not None + left, val_stack = val_stack.pair + new_cons = to_sexp((left, right)) + return to_sexp((new_cons, val_stack)) + + def _op_read_sexp( - op_stack: OpStackType, + op_stack: OpStackType[CS], val_stack: ValStackType, f: typing.BinaryIO, - to_sexp: ToCLVMStorage, -) -> None: + to_sexp: ToCLVMStorage[CS], +) -> ValStackType: blob = f.read(1) if len(blob) == 0: raise ValueError("bad encoding") @@ -116,29 +210,50 @@ def _op_read_sexp( op_stack.append(_op_cons) op_stack.append(_op_read_sexp) op_stack.append(_op_read_sexp) - return - val_stack.append(_atom_from_stream(f, b, to_sexp)) + return val_stack + atom_as_sexp = to_sexp(_atom_from_stream(f, b)) + return to_sexp((atom_as_sexp, val_stack)) -def _op_cons( - op_stack: OpStackType, +def _op_read_sexp_allow_backrefs( + op_stack: OpStackType[CS], val_stack: ValStackType, f: typing.BinaryIO, - to_sexp: ToCLVMStorage, -) -> None: - right = val_stack.pop() - left = val_stack.pop() - val_stack.append(to_sexp((left, right))) + to_sexp: ToCLVMStorage[CS], +) -> CLVMStorage: + blob = f.read(1) + if len(blob) == 0: + raise ValueError("bad encoding") + b = blob[0] + if b == CONS_BOX_MARKER: + op_stack.append(_op_cons) + op_stack.append(_op_read_sexp_allow_backrefs) + op_stack.append(_op_read_sexp_allow_backrefs) + return val_stack + if b == BACK_REFERENCE: + blob = f.read(1) + if len(blob) == 0: + raise ValueError("bad encoding") + path = _atom_from_stream(f, blob[0]) + backref = traverse_path(val_stack, path, to_sexp) + return to_sexp((backref, val_stack)) + atom_as_sexp = to_sexp(_atom_from_stream(f, b)) + return to_sexp((atom_as_sexp, val_stack)) -def sexp_from_stream(f: typing.BinaryIO, to_sexp: typing.Callable[["CastableType"], T]) -> T: - op_stack: OpStackType = [_op_read_sexp] - val_stack: ValStackType = [] +def sexp_from_stream( + f: typing.BinaryIO, to_sexp: ToCLVMStorage[CS], *, allow_backrefs: bool = False +) -> CS: + op_stack: OpStackType[CS] = [ + _op_read_sexp_allow_backrefs if allow_backrefs else _op_read_sexp + ] + val_stack: ValStackType = to_sexp(b"") while op_stack: func = op_stack.pop() - func(op_stack, val_stack, f, CLVMObject) - return to_sexp(val_stack.pop()) + val_stack = func(op_stack, val_stack, f, to_sexp) + assert val_stack.pair is not None + return to_sexp(val_stack.pair[0]) def _op_consume_sexp(f: typing.BinaryIO) -> typing.Tuple[bytes, int]: @@ -193,13 +308,11 @@ def sexp_buffer_from_stream(f: typing.BinaryIO) -> bytes: return ret.getvalue() -def _atom_from_stream( - f: typing.BinaryIO, b: int, to_sexp: ToCLVMStorage -) -> CLVMStorage: +def _atom_from_stream(f: typing.BinaryIO, b: int) -> bytes: if b == 0x80: - return to_sexp(b"") + return b"" if b <= MAX_SINGLE_BYTE: - return to_sexp(bytes([b])) + return bytes([b]) bit_count = 0 bit_mask = 0x80 while b & bit_mask: @@ -218,4 +331,4 @@ def _atom_from_stream( blob = f.read(size) if len(blob) != size: raise ValueError("bad encoding") - return to_sexp(blob) + return blob diff --git a/tests/generator.bin.gz b/tests/generator.bin.gz new file mode 100644 index 00000000..b120ff3e Binary files /dev/null and b/tests/generator.bin.gz differ diff --git a/tests/object_cache_test.py b/tests/object_cache_test.py new file mode 100644 index 00000000..8dbd373c --- /dev/null +++ b/tests/object_cache_test.py @@ -0,0 +1,41 @@ +import unittest + +from clvm.object_cache import ObjectCache, treehash, serialized_length + +from clvm_tools.binutils import assemble + + +class ObjectCacheTest(unittest.TestCase): + def check(self, obj_text: str, expected_hash: str, expected_length: int) -> None: + obj = assemble(obj_text) + th = ObjectCache(treehash) + self.assertEqual(th.get(obj).hex(), expected_hash) + sl = ObjectCache(serialized_length) + self.assertEqual(sl.get(obj), expected_length) + + def test_various(self) -> None: + self.check( + "0x00", + "47dc540c94ceb704a23875c11273e16bb0b8a87aed84de911f2133568115f254", + 1, + ) + + self.check( + "0", "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", 1 + ) + + self.check( + "foo", "0080b50a51ecd0ccfaaa4d49dba866fe58724f18445d30202bafb03e21eef6cb", 4 + ) + + self.check( + "(foo . bar)", + "c518e45ae6a7b4146017b7a1d81639051b132f1f5572ce3088a3898a9ed1280b", + 9, + ) + + self.check( + "(this is a longer test of a deeper tree)", + "0a072d7d860d77d8e290ced0fdb29a271198ca3db54d701c45d831e3aae6422c", + 47, + ) diff --git a/tests/read_cache_lookup_test.py b/tests/read_cache_lookup_test.py new file mode 100644 index 00000000..9cd0d1cd --- /dev/null +++ b/tests/read_cache_lookup_test.py @@ -0,0 +1,105 @@ +import unittest + +from clvm import to_sexp_f +from clvm.read_cache_lookup import ReadCacheLookup +from clvm.object_cache import ObjectCache, treehash + + +class ReadCacheLookupTest(unittest.TestCase): + def test_various(self) -> None: + rcl = ReadCacheLookup() + treehasher = ObjectCache(treehash) + + # rcl = () + nil = to_sexp_f(b"") + nil_hash = treehasher.get(nil) + self.assertEqual(rcl.root_hash, nil_hash) + + foo = to_sexp_f(b"foo") + foo_hash = treehasher.get(foo) + rcl.push(foo_hash) + + # rcl = (foo . 0) + + current_stack = to_sexp_f([foo]) + current_stack_hash = treehasher.get(current_stack) + + self.assertEqual(rcl.root_hash, current_stack_hash) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([2])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([3])) + self.assertEqual( + rcl.find_path(current_stack_hash, serialized_length=20), bytes([1]) + ) + + bar = to_sexp_f(b"bar") + bar_hash = treehasher.get(bar) + rcl.push(bar_hash) + + # rcl = (bar foo) + + current_stack = to_sexp_f([bar, foo]) + current_stack_hash = treehasher.get(current_stack) + foo_list_hash = treehasher.get(to_sexp_f([b"foo"])) + self.assertEqual(rcl.root_hash, current_stack_hash) + self.assertEqual(rcl.find_path(bar_hash, serialized_length=20), bytes([2])) + self.assertEqual(rcl.find_path(foo_list_hash, serialized_length=20), bytes([3])) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([5])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([7])) + self.assertEqual( + rcl.find_path(current_stack_hash, serialized_length=20), bytes([1]) + ) + self.assertEqual(rcl.count[foo_list_hash], 1) + + rcl.pop2_and_cons() + # rcl = ((foo . bar) . 0) + + current_stack = to_sexp_f([(foo, bar)]) + current_stack_hash = treehasher.get(current_stack) + self.assertEqual(rcl.root_hash, current_stack_hash) + + # we no longer have `(foo . 0)` in the read stack + # check that its count is zero + self.assertEqual(rcl.count[foo_list_hash], 0) + + self.assertEqual(rcl.find_path(bar_hash, serialized_length=20), bytes([6])) + self.assertEqual(rcl.find_path(foo_list_hash, serialized_length=20), None) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([4])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([3])) + self.assertEqual( + rcl.find_path(current_stack_hash, serialized_length=20), bytes([1]) + ) + + rcl.push(foo_hash) + rcl.push(foo_hash) + rcl.pop2_and_cons() + + # rcl = ((foo . foo) (foo . bar)) + + current_stack = to_sexp_f([(foo, foo), (foo, bar)]) + current_stack_hash = treehasher.get(current_stack) + self.assertEqual(rcl.root_hash, current_stack_hash) + self.assertEqual(rcl.find_path(bar_hash, serialized_length=20), bytes([13])) + self.assertEqual(rcl.find_path(foo_list_hash, serialized_length=20), None) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([4])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([7])) + + # find BOTH minimal paths to `foo` + self.assertEqual( + rcl.find_paths(foo_hash, serialized_length=20), + set([bytes([4]), bytes([6])]), + ) + + rcl = ReadCacheLookup() + rcl.push(foo_hash) + rcl.push(foo_hash) + rcl.pop2_and_cons() + rcl.push(foo_hash) + rcl.push(foo_hash) + rcl.pop2_and_cons() + rcl.pop2_and_cons() + # rcl = ((foo . foo) . (foo . foo)) + # find ALL minimal paths to `foo` + self.assertEqual( + rcl.find_paths(foo_hash, serialized_length=20), + set([bytes([8]), bytes([10]), bytes([12]), bytes([14])]), + ) diff --git a/tests/serialize_test.py b/tests/serialize_test.py index b84deb1c..d40b77e8 100644 --- a/tests/serialize_test.py +++ b/tests/serialize_test.py @@ -1,10 +1,16 @@ +import gzip import io import unittest from typing import Optional from clvm import to_sexp_f -from clvm.SExp import CastableType -from clvm.serialize import (sexp_from_stream, sexp_buffer_from_stream, atom_to_byte_iterator) +from clvm.SExp import CastableType, SExp +from clvm.serialize import ( + _atom_from_stream, + sexp_from_stream, + sexp_buffer_from_stream, + atom_to_byte_iterator, +) TEXT = b"the quick brown fox jumps over the lazy dogs" @@ -26,8 +32,26 @@ def __len__(self) -> int: return 0x400000001 +def has_backrefs(blob: bytes) -> bool: + """ + Return `True` iff blob has a backref in it. + """ + f = io.BytesIO(blob) + obj_count = 1 + while obj_count > 0: + b = f.read(1)[0] + if b == 0xfe: + return True + if b == 0xff: + obj_count += 1 + else: + _atom_from_stream(f, b) + obj_count -= 1 + return False + + class SerializeTest(unittest.TestCase): - def check_serde(self, s: CastableType) -> None: + def check_serde(self, s: CastableType) -> bytes: v = to_sexp_f(s) b = v.as_bin() v1 = sexp_from_stream(io.BytesIO(b), to_sexp_f) @@ -43,6 +67,23 @@ def check_serde(self, s: CastableType) -> None: buf = sexp_buffer_from_stream(io.BytesIO(b)) self.assertEqual(buf, b) + # now turn on backrefs and make sure everything still works + + b2 = v.as_bin(allow_backrefs=True) + self.assertTrue(len(b2) <= len(b)) + if has_backrefs(b2) or len(b2) < len(b): + # if we have any backrefs, ensure they actually save space + self.assertTrue(len(b2) < len(b)) + print("%d bytes before %d after %d saved" % (len(b), len(b2), len(b) - len(b2))) + io_b2 = io.BytesIO(b2) + self.assertRaises(ValueError, lambda: sexp_from_stream(io_b2, to_sexp_f)) + io_b2 = io.BytesIO(b2) + v2 = sexp_from_stream(io_b2, to_sexp_f, allow_backrefs=True) + self.assertEqual(v2, s) + b3 = v2.as_bin() + self.assertEqual(b, b3) + return b2 + def test_zero(self) -> None: v = to_sexp_f(b"\x00") self.assertEqual(v.as_bin(), b"\x00") @@ -145,3 +186,50 @@ def test_deserialize_large_blob(self) -> None: with self.assertRaises(ValueError): sexp_buffer_from_stream(InfiniteStream(bytes_in)) + + def test_deserialize_generator(self) -> None: + blob = gzip.GzipFile("tests/generator.bin.gz").read() + s = sexp_from_stream(io.BytesIO(blob), to_sexp_f) + b = self.check_serde(s) + assert len(b) == 19124 + + def test_deserialize_bomb(self) -> None: + def make_bomb(depth: int) -> SExp: + bomb = to_sexp_f(TEXT) + for _ in range(depth): + bomb = to_sexp_f((bomb, bomb)) + return bomb + + bomb_10 = make_bomb(10) + b10_1 = bomb_10.as_bin(allow_backrefs=False) + b10_2 = bomb_10.as_bin(allow_backrefs=True) + self.assertEqual(len(b10_1), 47103) + self.assertEqual(len(b10_2), 75) + + bomb_20 = make_bomb(20) + b20_1 = bomb_20.as_bin(allow_backrefs=False) + b20_2 = bomb_20.as_bin(allow_backrefs=True) + self.assertEqual(len(b20_1), 48234495) + self.assertEqual(len(b20_2), 105) + + bomb_30 = make_bomb(30) + # do not uncomment the next line unless you want to run out of memory + # b30_1 = bomb_30.as_bin(allow_backrefs=False) + b30_2 = bomb_30.as_bin(allow_backrefs=True) + + # self.assertEqual(len(b30_1), 1) + self.assertEqual(len(b30_2), 135) + + def test_specific_tree(self) -> None: + sexp1 = to_sexp_f((("AAA", "BBB"), ("CCC", "AAA"))) + serialized_sexp1_v1 = sexp1.as_bin(allow_backrefs=False) + serialized_sexp1_v2 = sexp1.as_bin(allow_backrefs=True) + self.assertEqual(len(serialized_sexp1_v1), 19) + self.assertEqual(len(serialized_sexp1_v2), 17) + deserialized_sexp1_v1 = sexp_from_stream( + io.BytesIO(serialized_sexp1_v1), to_sexp_f, allow_backrefs=False + ) + deserialized_sexp1_v2 = sexp_from_stream( + io.BytesIO(serialized_sexp1_v2), to_sexp_f, allow_backrefs=True + ) + self.assertTrue(deserialized_sexp1_v1 == deserialized_sexp1_v2)