diff --git a/core/state.go b/core/state.go index 378ba65bec..27c20f0572 100644 --- a/core/state.go +++ b/core/state.go @@ -139,10 +139,11 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr // fetch root key rootKeyDBKey := dbPrefix - var rootKey *trie.Key + var rootKey *trie.BitArray // TODO: use value instead of pointer err := s.txn.Get(rootKeyDBKey, func(val []byte) error { - rootKey = new(trie.Key) - return rootKey.UnmarshalBinary(val) + rootKey = new(trie.BitArray) + rootKey.UnmarshalBinary(val) + return nil }) // if some error other than "not found" @@ -169,7 +170,7 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr if resultingRootKey != nil { var rootKeyBytes bytes.Buffer - _, marshalErr := resultingRootKey.WriteTo(&rootKeyBytes) + _, marshalErr := resultingRootKey.Write(&rootKeyBytes) if marshalErr != nil { return marshalErr } diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go new file mode 100644 index 0000000000..75d1ebb3ee --- /dev/null +++ b/core/trie/bitarray.go @@ -0,0 +1,655 @@ +package trie + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "math/bits" + + "github.com/NethermindEth/juno/core/felt" +) + +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) + +var emptyBitArray = new(BitArray) + +// Represents a bit array with length representing the number of used bits. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. +type BitArray struct { + len uint8 // number of used bits + words [4]uint64 // little endian (i.e. words[0] is the least significant) +} + +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b +} + +// Returns the felt representation of the bit array. +func (b *BitArray) Felt() felt.Felt { + var f felt.Felt + f.SetBytes(b.Bytes()) + return f +} + +func (b *BitArray) Len() uint8 { + return b.len +} + +// Returns the bytes representation of the bit array in big endian format +func (b *BitArray) Bytes() []byte { + var res [32]byte + + b.truncateToLength() + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + + return res[:] +} + +// Sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. +// If length >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// LSBsFromLSB(x, 4) = 1011 (len=4) +// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) +// LSBsFromLSB(x, 0) = 0 (len=0) +// +//nolint:mnd +func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + b.Set(x) + b.len = n + + // Clear all words beyond what's needed + switch { + case n == 0: + b.words = [4]uint64{0, 0, 0, 0} + case n <= 64: + mask := maxUint64 >> (64 - n) + b.words[0] &= mask + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + case n <= 128: + mask := maxUint64 >> (128 - n) + b.words[1] &= mask + b.words[2] = 0 + b.words[3] = 0 + case n <= 192: + mask := maxUint64 >> (192 - n) + b.words[2] &= mask + b.words[3] = 0 + default: + mask := maxUint64 >> (256 - uint16(n)) + b.words[3] &= mask + } + + return b +} + +// Returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// For example: +// +// x = 11001011 (len=8) +// LSBs(x, 1) = 1001011 (len=7) +// LSBs(x, 10) = 0 (len=0) +// LSBs(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n == 0 { + return b.Set(x) + } + + if n > x.Len() { + return b.clear() + } + + return b.LSBsFromLSB(x, x.Len()-n) +} + +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). +// If n >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + +// Sets the bit array to the longest sequence of matching most significant bits between two bit arrays. +// For example: +// +// x = 1101 0111 (len=8) +// y = 1101 0000 (len=8) +// CommonMSBs(x,y) = 1101 (len=4) +func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { + if x.len == 0 || y.len == 0 { + return b.clear() + } + + long, short := x, y + if x.len < y.len { + long, short = y, x + } + + // Align arrays by right-shifting longer array and then XOR to find differences + // Example: + // short = 1100 (len=4) + // long = 1101 0111 (len=8) + // + // Step 1: Right shift longer array by 4 + // short = 1100 + // long = 1101 + // + // Step 2: XOR shows difference at last bit + // 1100 (short) + // 1101 (aligned long) + // ---- XOR + // 0001 (difference at last position) + // We can then use the position of the first set bit and right-shift to get the common MSBs + diff := long.len - short.len + b.Rsh(long, diff).Xor(b, short) + divergentBit := findFirstSetBit(b) + + return b.Rsh(short, divergentBit) +} + +// Sets the bit array to x >> n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 { + return b.Set(x) + } + + if n >= x.len { + return b.clear() + } + + switch { + case n == 0: + return b.Set(x) + case n >= 192: + b.rsh192(x) + b.len = x.len - n + n -= 192 + b.words[0] >>= n + case n >= 128: + b.rsh128(x) + b.len = x.len - n + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + case n >= 64: + b.rsh64(x) + b.len = x.len - n + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] >>= n + default: + b.Set(x) + b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + } + + b.truncateToLength() + return b +} + +// Lsh sets the bit array to x << n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { + b.Set(x) + + if x.len == 0 || n == 0 { + return b + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n == 0: + return b + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +// Sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + + // First copy x + b.Set(x) + + // Then shift left by y's length and OR with y + return b.Lsh(b, y.len).Or(b, y) +} + +// Sets the bit array to x | y and returns the bit array. +func (b *BitArray) Or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len + return b +} + +// Sets the bit array to x ^ y and returns the bit array. +func (b *BitArray) Xor(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] ^ y.words[0] + b.words[1] = x.words[1] ^ y.words[1] + b.words[2] = x.words[2] ^ y.words[2] + b.words[3] = x.words[3] ^ y.words[3] + return b +} + +// Checks if two bit arrays are equal +func (b *BitArray) Equal(x *BitArray) bool { + // TODO(weiihann): this is really not a good thing to do... + if b == nil && x == nil { + return true + } else if b == nil || x == nil { + return false + } + + return b.len == x.len && + b.words[0] == x.words[0] && + b.words[1] == x.words[1] && + b.words[2] == x.words[2] && + b.words[3] == x.words[3] +} + +// Returns true if bit n-th is set, where n = 0 is LSB. +func (b *BitArray) IsBitSetFromLSB(n uint8) bool { + return b.BitFromLSB(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitFromLSB(n uint8) uint8 { + if n >= b.len { + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 +} + +func (b *BitArray) IsBitSet(n uint8) bool { + return b.Bit(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) Bit(n uint8) uint8 { + if n >= b.Len() { + return 0 + } + + return b.BitFromLSB(b.Len() - n - 1) +} + +// Returns the bit value at the most significant bit +func (b *BitArray) MSB() uint8 { + return b.Bit(0) +} + +func (b *BitArray) LSB() uint8 { + return b.BitFromLSB(0) +} + +func (b *BitArray) IsEmpty() bool { + return b.len == 0 +} + +// Serialises the BitArray into a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { + if err := buf.WriteByte(b.len); err != nil { + return 0, err + } + + n, err := buf.Write(b.activeBytes()) + return n + 1, err +} + +// Deserialises the BitArray from a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *BitArray) UnmarshalBinary(data []byte) { + b.len = data[0] + + var bs [32]byte + copy(bs[32-b.byteCount():], data[1:]) + b.setBytes32(bs[:]) +} + +// Sets the bit array to the same value as x. +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +// Sets the bit array to the bytes representation of a felt. +func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { + b.len = length + b.setFelt(f) + b.truncateToLength() + return b +} + +// Sets the bit array to the bytes representation of a felt with length 251. +func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { + b.len = 251 + b.setFelt(f) + b.truncateToLength() + return b +} + +// Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. +// If the data is larger than 32 bytes, only the first 32 bytes are used. +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + b.setBytes32(data) + b.len = length + b.truncateToLength() + return b +} + +// Sets the bit array to the uint64 representation of a bit array. +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { + b.words[0] = data + b.len = length + b.truncateToLength() + return b +} + +// Sets the bit array to a single bit. +func (b *BitArray) SetBit(bit uint8) *BitArray { + b.len = 1 + b.words[0] = uint64(bit & 1) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + return b +} + +// Returns the length of the encoded bit array in bytes. +func (b *BitArray) EncodedLen() uint { + return b.byteCount() + 1 +} + +// Returns a deep copy of the bit array. +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +// Returns the encoded string representation of the bit array. +func (b *BitArray) EncodedString() string { + var res []byte + res = append(res, b.len) + res = append(res, b.Bytes()...) + return string(res) +} + +// Returns a string representation of the bit array. +// This is typically used for logging or debugging. +func (b *BitArray) String() string { + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) +} + +func (b *BitArray) setFelt(f *felt.Felt) { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) +} + +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) +} + +// Returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *BitArray) byteCount() uint { + const bits8 = 8 + // Cast to uint16 to avoid overflow + return (uint(b.len) + (bits8 - 1)) / uint(bits8) +} + +// Returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *BitArray) activeBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + +func (b *BitArray) rsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *BitArray) rsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *BitArray) rsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + +func (b *BitArray) clear() *BitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} + +// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +//nolint:mnd +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + +// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 3 // third bit from right is set +func findFirstSetBit(b *BitArray) uint8 { + if b.len == 0 { + return 0 + } + + // Start from the most significant and move towards the least significant + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + // All bits are zero, no set bit found + return 0 +} + +// Cmp compares two bit arrays lexicographically. +// The comparison is first done by length, then by content if lengths are equal. +// Returns: +// +// -1 if b < x +// 0 if b == x +// 1 if b > x +func (b *BitArray) Cmp(x *BitArray) int { + // First compare lengths + if b.len < x.len { + return -1 + } + if b.len > x.len { + return 1 + } + + // Lengths are equal, compare the actual bits + d0, carry := bits.Sub64(b.words[0], x.words[0], 0) + d1, carry := bits.Sub64(b.words[1], x.words[1], carry) + d2, carry := bits.Sub64(b.words[2], x.words[2], carry) + d3, carry := bits.Sub64(b.words[3], x.words[3], carry) + + if carry == 1 { + return -1 + } + + if d0|d1|d2|d3 == 0 { + return 0 + } + + return 1 +} diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go new file mode 100644 index 0000000000..e3d7c795a0 --- /dev/null +++ b/core/trie/bitarray_test.go @@ -0,0 +1,1805 @@ +package trie + +import ( + "bytes" + "encoding/binary" + "math" + "math/bits" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + +const ( + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 +) + +func TestBytes(t *testing.T) { + tests := []struct { + name string + ba BitArray + want [32]byte + }{ + { + name: "length == 0", + ba: BitArray{len: 0, words: maxBits}, + want: [32]byte{}, + }, + { + name: "length < 64", + ba: BitArray{len: 38, words: maxBits}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + ba: BitArray{len: 100, words: maxBits}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "128 <= length < 192", + ba: BitArray{len: 130, words: maxBits}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x3) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "192 <= length < 255", + ba: BitArray{len: 201, words: maxBits}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 254", + ba: BitArray{len: 254, words: maxBits}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 255", + ba: BitArray{len: 255, words: maxBits}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], ones63) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.Bytes() + if !bytes.Equal(got, tt.want[:]) { + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) + } + + // check if the received bytes has the same bit count as the BitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) + } + }) + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *BitArray + shiftBy uint8 + expected *BitArray + }{ + { + name: "zero length array", + initial: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + shiftBy: 0, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 65, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 32, + expected: &BitArray{ + len: 96, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 64, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by 127", + initial: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + shiftBy: 127, + expected: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 128, + expected: &BitArray{ + len: 123, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 192, + expected: &BitArray{ + len: 59, + words: [4]uint64{0x7FFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).Rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 5, + want: emptyBitArray, + }, + { + name: "shift by 0", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 4, + want: &BitArray{ + len: 8, + words: [4]uint64{0xF0, 0, 0, 0}, // 11110000 + }, + }, + { + name: "shift across word boundary", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 62, + want: &BitArray{ + len: 66, + words: [4]uint64{0xC000000000000000, 0x3, 0, 0}, + }, + }, + { + name: "shift by 64 (full word)", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 64, + want: &BitArray{ + len: 72, + words: [4]uint64{0, 0xFF, 0, 0}, + }, + }, + { + name: "shift by 128", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 128, + want: &BitArray{ + len: 136, + words: [4]uint64{0, 0, 0xFF, 0}, + }, + }, + { + name: "shift by 192", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 192, + want: &BitArray{ + len: 200, + words: [4]uint64{0, 0, 0, 0xFF}, + }, + }, + { + name: "shift causing length overflow", + x: &BitArray{ + len: 200, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{ + 0xF000000000000000, + 0xF, + 0, + 0, + }, + }, + }, + { + name: "shift sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + n: 4, + want: &BitArray{ + len: 12, + words: [4]uint64{0xAA0, 0, 0, 0}, // 101010100000 + }, + }, + { + name: "shift partial word across boundary", + x: &BitArray{ + len: 100, + words: [4]uint64{0xFF, 0xFF, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 160, + words: [4]uint64{ + 0xF000000000000000, + 0xF00000000000000F, + 0xF, + 0, + }, + }, + }, + { + name: "near maximum length shift", + x: &BitArray{ + len: 251, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 4, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{0xFF0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Lsh(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("Lsh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "both empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "first array empty", + x: emptyBitArray, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "second array empty", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: emptyBitArray, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "different lengths within word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 2, + words: [4]uint64{0x3, 0, 0, 0}, // 11 + }, + want: &BitArray{ + len: 6, + words: [4]uint64{0x3F, 0, 0, 0}, // 111111 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 62, + words: [4]uint64{0x3FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 66, + words: [4]uint64{maxUint64, 0x3, 0, 0}, + }, + }, + { + name: "across multiple words", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + y: &BitArray{ + len: 8, + words: [4]uint64{0x55, 0, 0, 0}, // 01010101 + }, + want: &BitArray{ + len: 16, + words: [4]uint64{0xAA55, 0, 0, 0}, // 1010101001010101 + }, + }, + { + name: "result exactly at length limit", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + want: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Append(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("Append() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEqualMSBs(t *testing.T) { + tests := []struct { + name string + a *BitArray + b *BitArray + want bool + }{ + { + name: "equal lengths, equal values", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: true, + }, + { + name: "equal lengths, different values", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different lengths, a longer but same prefix", + a: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, b longer but same prefix", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, different prefix", + a: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "zero length arrays", + a: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + b: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "one zero length array", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "max length difference", + a: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + b: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.EqualMSBs(tt.b); got != tt.want { + t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) + } + // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) + if got := tt.b.EqualMSBs(tt.a); got != tt.want { + t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBs(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBsFromLSB(t *testing.T) { + tests := []struct { + name string + initial BitArray + length uint8 + expected BitArray + }{ + { + name: "zero", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 0, + expected: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get 32 LSBs", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 32, + expected: BitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get 1 LSB", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 1, + expected: BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + { + name: "get 100 LSBs across words", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "get 64 LSBs at word boundary", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get 128 LSBs at word boundary", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 128, + expected: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "get 150 LSBs in third word", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 150, + expected: BitArray{ + len: 150, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, + }, + }, + { + name: "get 220 LSBs in fourth word", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 220, + expected: BitArray{ + len: 220, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, + }, + }, + { + name: "get 251 LSBs", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 251, + expected: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "get 100 LSBs from sparse bits", + initial: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 128, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).LSBsFromLSB(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteAndUnmarshalBinary(t *testing.T) { + tests := []struct { + name string + ba BitArray + want []byte // Expected bytes after writing + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: []byte{0}, // Just the length byte + }, + { + name: "8 bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + want: []byte{8, 0xFF}, // length byte + 1 data byte + }, + { + name: "10 bits requiring 2 bytes", + ba: BitArray{ + len: 10, + words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary + }, + want: []byte{10, 0x3, 0xFF}, // length byte + 2 data bytes + }, + { + name: "64 bits", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: append( + []byte{64}, // length byte + []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}..., // 8 data bytes + ), + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + want: func() []byte { + b := make([]byte, 33) // 1 length byte + 32 data bytes + b[0] = 251 // length byte + // First byte is 0x07 (from the most significant bits) + b[1] = 0x07 + // Rest of the bytes are 0xFF + for i := 2; i < 33; i++ { + b[i] = 0xFF + } + return b + }(), + }, + { + name: "sparse bits", + ba: BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary + }, + want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := new(bytes.Buffer) + gotN, err := tt.ba.Write(buf) + assert.NoError(t, err) + + // Check number of bytes written + if gotN != len(tt.want) { + t.Errorf("Write() wrote %d bytes, want %d", gotN, len(tt.want)) + } + + // Check written bytes + got := buf.Bytes() + if !bytes.Equal(got, tt.want) { + t.Errorf("Write() = %v, want %v", got, tt.want) + } + + var gotBitArray BitArray + gotBitArray.UnmarshalBinary(got) + if !gotBitArray.Equal(&tt.ba) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) + } + }) + } +} + +func TestCommonPrefix(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "one empty array", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "identical arrays - single word", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "identical arrays - multiple words", + x: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + y: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "different lengths with common prefix - first word", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different lengths with common prefix - multiple words", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + y: &BitArray{ + len: 127, + words: [4]uint64{maxUint64, ones63, 0, 0}, + }, + want: &BitArray{ + len: 127, + words: [4]uint64{maxUint64, ones63, 0, 0}, + }, + }, + { + name: "different at first bit", + x: &BitArray{ + len: 64, + words: [4]uint64{ones63, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "different in middle of first word", + x: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in second word", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, 0xFFFFFFFF0FFFFFFF, 0, 0}, + }, + y: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in third word", + x: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + y: &BitArray{ + len: 192, + words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, + }, + want: &BitArray{ + len: 56, + words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in last word", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFF0FFFFFFF}, + }, + want: &BitArray{ + len: 27, + words: [4]uint64{0x7FFFFFF}, + }, + }, + { + name: "sparse bits with common prefix", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, + }, + y: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, + }, + want: &BitArray{ + len: 52, + words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, + }, + }, + { + name: "max length difference", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + y: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray) + gotSymmetric := new(BitArray) + + got.CommonMSBs(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("CommonMSBs() = %v, want %v", got, tt.want) + } + + // Test symmetry: x.CommonMSBs(y) should equal y.CommonMSBs(x) + gotSymmetric.CommonMSBs(tt.y, tt.x) + if !gotSymmetric.Equal(tt.want) { + t.Errorf("CommonMSBs() symmetric test = %v, want %v", gotSymmetric, tt.want) + } + }) + } +} + +func TestIsBitSetFromLSB(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit set", + ba: BitArray{ + len: 64, + words: [4]uint64{1, 0, 0, 0}, + }, + pos: 0, + want: true, + }, + { + name: "last bit in first word", + ba: BitArray{ + len: 64, + words: [4]uint64{1 << 63, 0, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "first bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 64, + want: true, + }, + { + name: "bit beyond length", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 65, + want: false, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 1, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 0, + want: false, + }, + { + name: "bit in last word", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 59}, + }, + pos: 251, + want: false, // position 251 is beyond the highest valid bit (250) + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 250, + want: true, + }, + { + name: "highest valid bit (255)", + ba: BitArray{ + len: 255, + words: [4]uint64{0, 0, 0, 1 << 62}, // bit 255 set + }, + pos: 254, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 100, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSetFromLSB(tt.pos) + if got != tt.want { + t.Errorf("IsBitSetFromLSB(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit (MSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x80, 0, 0, 0}, // 10000000 + }, + pos: 0, + want: true, + }, + { + name: "last bit (LSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x01, 0, 0, 0}, // 00000001 + }, + pos: 7, + want: true, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 0, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 1, + want: false, + }, + { + name: "position beyond length", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + pos: 8, + want: false, + }, + { + name: "bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 0, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 99, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSet(tt.pos) + if got != tt.want { + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestFeltConversion(t *testing.T) { + tests := []struct { + name string + ba BitArray + length uint8 + want string // hex representation of felt + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + length: 0, + want: "0x0", + }, + { + name: "single word", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + want: "0xffffffffffffffff", + }, + { + name: "two words", + ba: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 128, + want: "0xffffffffffffffffffffffffffffffff", + }, + { + name: "three words", + ba: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 192, + want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + length: 251, + want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "sparse bits", + ba: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 128, + want: "0x5555555555555555aaaaaaaaaaaaaaaa", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Felt() conversion + gotFelt := tt.ba.Felt() + assert.Equal(t, tt.want, gotFelt.String()) + + // Test SetFelt() conversion (round trip) + var newBA BitArray + newBA.SetFelt(tt.length, &gotFelt) + assert.Equal(t, tt.ba.len, newBA.len) + assert.Equal(t, tt.ba.words, newBA.words) + }) + } +} + +func TestSetFeltValidation(t *testing.T) { + tests := []struct { + name string + feltStr string + length uint8 + shouldMatch bool + }{ + { + name: "valid felt with matching length", + feltStr: "0xf", + length: 4, + shouldMatch: true, + }, + { + name: "felt larger than specified length", + feltStr: "0xff", + length: 4, + shouldMatch: false, + }, + { + name: "zero felt with non-zero length", + feltStr: "0x0", + length: 8, + shouldMatch: true, + }, + { + name: "max felt with max length", + feltStr: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + length: 251, + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f felt.Felt + _, err := f.SetString(tt.feltStr) + require.NoError(t, err) + + var ba BitArray + ba.SetFelt(tt.length, &f) + + // Convert back to felt and compare + roundTrip := ba.Felt() + if tt.shouldMatch { + assert.True(t, roundTrip.Equal(&f), + "expected %s, got %s", f.String(), roundTrip.String()) + } else { + assert.False(t, roundTrip.Equal(&f), + "values should not match: original %s, roundtrip %s", + f.String(), roundTrip.String()) + } + }) + } +} + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + bit uint8 + want BitArray + }{ + { + name: "set bit 0", + bit: 0, + want: BitArray{ + len: 1, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "set bit 1", + bit: 1, + want: BitArray{ + len: 1, + words: [4]uint64{1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBit(tt.bit) + if !got.Equal(&tt.want) { + t.Errorf("SetBit(%v) = %v, want %v", tt.bit, got, tt.want) + } + }) + } +} + +func TestCmp(t *testing.T) { + tests := []struct { + name string + x BitArray + y BitArray + want int + }{ + { + name: "equal empty arrays", + x: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + y: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: 0, + }, + { + name: "equal non-empty arrays", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 0, + }, + { + name: "different lengths - x shorter", + x: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "different lengths - x longer", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, x < y in first word", + x: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "same length, x > y in first word", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, difference in last word", + x: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFFF}}, + y: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFF0}}, + want: 1, + }, + { + name: "same length, sparse bits", + x: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}}, + y: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}}, + want: 1, + }, + { + name: "max length difference", + x: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + y: BitArray{len: 1, words: [4]uint64{0x1, 0, 0, 0}}, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.x.Cmp(&tt.y) + if got != tt.want { + t.Errorf("Cmp() = %v, want %v", got, tt.want) + } + + // Test anti-symmetry: if x.Cmp(y) = z then y.Cmp(x) = -z + gotReverse := tt.y.Cmp(&tt.x) + if gotReverse != -tt.want { + t.Errorf("Reverse Cmp() = %v, want %v", gotReverse, -tt.want) + } + + // Test transitivity with self: x.Cmp(x) should always be 0 + if tt.x.Cmp(&tt.x) != 0 { + t.Error("Self Cmp() != 0") + } + }) + } +} diff --git a/core/trie/key.go b/core/trie/key.go deleted file mode 100644 index 0d0ca7aa88..0000000000 --- a/core/trie/key.go +++ /dev/null @@ -1,187 +0,0 @@ -package trie - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "math/big" - - "github.com/NethermindEth/juno/core/felt" -) - -var NilKey = &Key{len: 0, bitset: [32]byte{}} - -type Key struct { - len uint8 - bitset [32]byte -} - -func NewKey(length uint8, keyBytes []byte) Key { - k := Key{len: length} - if len(keyBytes) > len(k.bitset) { - panic("bytes does not fit in bitset") - } - copy(k.bitset[len(k.bitset)-len(keyBytes):], keyBytes) - return k -} - -func (k *Key) bytesNeeded() uint { - const byteBits = 8 - return (uint(k.len) + (byteBits - 1)) / byteBits -} - -func (k *Key) inUseBytes() []byte { - return k.bitset[len(k.bitset)-int(k.bytesNeeded()):] -} - -func (k *Key) unusedBytes() []byte { - return k.bitset[:len(k.bitset)-int(k.bytesNeeded())] -} - -func (k *Key) WriteTo(buf *bytes.Buffer) (int64, error) { - if err := buf.WriteByte(k.len); err != nil { - return 0, err - } - - n, err := buf.Write(k.inUseBytes()) - return int64(1 + n), err -} - -func (k *Key) UnmarshalBinary(data []byte) error { - k.len = data[0] - k.bitset = [32]byte{} - copy(k.inUseBytes(), data[1:1+k.bytesNeeded()]) - return nil -} - -func (k *Key) EncodedLen() uint { - return k.bytesNeeded() + 1 -} - -func (k *Key) Len() uint8 { - return k.len -} - -func (k *Key) Felt() felt.Felt { - var f felt.Felt - f.SetBytes(k.bitset[:]) - return f -} - -func (k *Key) Equal(other *Key) bool { - if k == nil && other == nil { - return true - } else if k == nil || other == nil { - return false - } - return k.len == other.len && k.bitset == other.bitset -} - -// IsBitSet returns whether the bit at the given position is 1. -// Position 0 represents the least significant (rightmost) bit. -func (k *Key) IsBitSet(position uint8) bool { - const LSB = uint8(0x1) - byteIdx := position / 8 - byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1] - bitIdx := position % 8 - return ((byteAtIdx >> bitIdx) & LSB) != 0 -} - -// shiftRight removes n least significant bits from the key by performing a right shift -// operation and reducing the key length. For example, if the key contains bits -// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4). -// -// The operation is destructive - it modifies the key in place. -func (k *Key) shiftRight(n uint8) { - if k.len < n { - panic("deleting more bits than there are") - } - - if n == 0 { - return - } - - var bigInt big.Int - bigInt.SetBytes(k.bitset[:]) - bigInt.Rsh(&bigInt, uint(n)) - bigInt.FillBytes(k.bitset[:]) - k.len -= n -} - -// MostSignificantBits returns a new key with the most significant n bits of the current key. -func (k *Key) MostSignificantBits(n uint8) (*Key, error) { - if n > k.len { - return nil, errors.New("cannot get more bits than the key length") - } - - keyCopy := k.Copy() - keyCopy.shiftRight(k.len - n) - return &keyCopy, nil -} - -// Truncate truncates key to `length` bits by clearing the remaining upper bits -func (k *Key) Truncate(length uint8) { - k.len = length - - unusedBytes := k.unusedBytes() - clear(unusedBytes) - - // clear upper bits on the last used byte - inUseBytes := k.inUseBytes() - unusedBitsCount := 8 - (k.len % 8) - if unusedBitsCount != 8 && len(inUseBytes) > 0 { - inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount - } -} - -func (k *Key) String() string { - return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) -} - -// Copy returns a deep copy of the key -func (k *Key) Copy() Key { - newKey := Key{len: k.len} - copy(newKey.bitset[:], k.bitset[:]) - return newKey -} - -func (k *Key) Bytes() [32]byte { - var result [32]byte - copy(result[:], k.bitset[:]) - return result -} - -// findCommonKey finds the set of common MSB bits in two key bitsets. -func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { - divergentBit := findDivergentBit(longerKey, shorterKey) - - if divergentBit == 0 { - return *NilKey, false - } - - commonKey := *shorterKey - commonKey.shiftRight(shorterKey.Len() - divergentBit + 1) - return commonKey, divergentBit == shorterKey.Len()+1 -} - -// findDivergentBit finds the first bit that is different between two keys, -// starting from the most significant bit of both keys. -func findDivergentBit(longerKey, shorterKey *Key) uint8 { - divergentBit := uint8(0) - for divergentBit <= shorterKey.Len() && - longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) { - divergentBit++ - } - return divergentBit -} - -func isSubset(longerKey, shorterKey *Key) bool { - divergentBit := findDivergentBit(longerKey, shorterKey) - return divergentBit == shorterKey.Len()+1 -} - -func FeltToKey(length uint8, key *felt.Felt) Key { - keyBytes := key.Bytes() - return NewKey(length, keyBytes[:]) -} diff --git a/core/trie/key_test.go b/core/trie/key_test.go deleted file mode 100644 index 3867678e6e..0000000000 --- a/core/trie/key_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package trie_test - -import ( - "bytes" - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestKeyEncoding(t *testing.T) { - tests := map[string]struct { - Len uint8 - Bytes []byte - }{ - "multiple of 8": { - Len: 4 * 8, - Bytes: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - }, - "0 len": { - Len: 0, - Bytes: []byte{}, - }, - "odd len": { - Len: 3, - Bytes: []byte{0x03}, - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - key := trie.NewKey(test.Len, test.Bytes) - - var keyBuffer bytes.Buffer - n, err := key.WriteTo(&keyBuffer) - require.NoError(t, err) - assert.Equal(t, len(test.Bytes)+1, int(n)) - - keyBytes := keyBuffer.Bytes() - require.Len(t, keyBytes, int(n)) - assert.Equal(t, test.Len, keyBytes[0]) - assert.Equal(t, test.Bytes, keyBytes[1:]) - - var decodedKey trie.Key - require.NoError(t, decodedKey.UnmarshalBinary(keyBytes)) - assert.Equal(t, key, decodedKey) - }) - } -} - -func BenchmarkKeyEncoding(b *testing.B) { - val, err := new(felt.Felt).SetRandom() - require.NoError(b, err) - valBytes := val.Bytes() - - key := trie.NewKey(felt.Bits, valBytes[:]) - buffer := bytes.Buffer{} - buffer.Grow(felt.Bytes + 1) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := key.WriteTo(&buffer) - require.NoError(b, err) - require.NoError(b, key.UnmarshalBinary(buffer.Bytes())) - buffer.Reset() - } -} - -func TestTruncate(t *testing.T) { - tests := map[string]struct { - key trie.Key - newLen uint8 - expectedKey trie.Key - }{ - "truncate to 12 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 12, - expectedKey: trie.NewKey(12, []byte{0x03, 0x14}), - }, - "truncate to 9 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 9, - expectedKey: trie.NewKey(9, []byte{0x01, 0x14}), - }, - "truncate to 3 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 3, - expectedKey: trie.NewKey(3, []byte{0x04}), - }, - "truncate to multiple of 8": { - key: trie.NewKey(251, []uint8{ - 0x7, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab, - 0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0, - 0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f, - }), - newLen: 248, - expectedKey: trie.NewKey(248, []uint8{ - 0x0, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab, - 0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0, - 0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f, - }), - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - copyKey := test.key - copyKey.Truncate(test.newLen) - assert.Equal(t, test.expectedKey, copyKey) - }) - } -} - -func TestKeyTest(t *testing.T) { - key := trie.NewKey(44, []byte{0x10, 0x02}) - for i := 0; i < int(key.Len()); i++ { - assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i) - } -} - -func TestIsBitSet(t *testing.T) { - tests := map[string]struct { - key trie.Key - position uint8 - expected bool - }{ - "single byte, LSB set": { - key: trie.NewKey(8, []byte{0x01}), - position: 0, - expected: true, - }, - "single byte, MSB set": { - key: trie.NewKey(8, []byte{0x80}), - position: 7, - expected: true, - }, - "single byte, middle bit set": { - key: trie.NewKey(8, []byte{0x10}), - position: 4, - expected: true, - }, - "single byte, bit not set": { - key: trie.NewKey(8, []byte{0xFE}), - position: 0, - expected: false, - }, - "multiple bytes, LSB set": { - key: trie.NewKey(16, []byte{0x00, 0x02}), - position: 1, - expected: true, - }, - "multiple bytes, MSB set": { - key: trie.NewKey(16, []byte{0x01, 0x00}), - position: 8, - expected: true, - }, - "multiple bytes, no bits set": { - key: trie.NewKey(16, []byte{0x00, 0x00}), - position: 7, - expected: false, - }, - "check all bits in pattern": { - key: trie.NewKey(8, []byte{0xA5}), // 10100101 - position: 0, - expected: true, - }, - } - - // Additional test for 0xA5 pattern - key := trie.NewKey(8, []byte{0xA5}) // 10100101 - expectedBits := []bool{true, false, true, false, false, true, false, true} - for i, expected := range expectedBits { - assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i) - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - result := tc.key.IsBitSet(tc.position) - assert.Equal(t, tc.expected, result) - }) - } -} - -func TestMostSignificantBits(t *testing.T) { - tests := []struct { - name string - key trie.Key - n uint8 - want trie.Key - expectErr bool - }{ - { - name: "Valid case", - key: trie.NewKey(8, []byte{0b11110000}), - n: 4, - want: trie.NewKey(4, []byte{0b00001111}), - expectErr: false, - }, - { - name: "Request more bits than available", - key: trie.NewKey(8, []byte{0b11110000}), - n: 10, - want: trie.Key{}, - expectErr: true, - }, - { - name: "Zero bits requested", - key: trie.NewKey(8, []byte{0b11110000}), - n: 0, - want: trie.NewKey(0, []byte{}), - expectErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.key.MostSignificantBits(tt.n) - if (err != nil) != tt.expectErr { - t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr) - return - } - if !tt.expectErr && !got.Equal(&tt.want) { - t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/core/trie/node.go b/core/trie/node.go index 172869cb10..51d5b76785 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -13,14 +13,14 @@ import ( // https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#trie_construction type Node struct { Value *felt.Felt - Left *Key - Right *Key + Left *BitArray + Right *BitArray LeftHash *felt.Felt RightHash *felt.Felt } // Hash calculates the hash of a [Node] -func (n *Node) Hash(path *Key, hashFn crypto.HashFn) *felt.Felt { +func (n *Node) Hash(path *BitArray, hashFn crypto.HashFn) *felt.Felt { if path.Len() == 0 { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -35,32 +35,32 @@ func (n *Node) Hash(path *Key, hashFn crypto.HashFn) *felt.Felt { } // Hash calculates the hash of a [Node] -func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFn crypto.HashFn) *felt.Felt { +func (n *Node) HashFromParent(parentKey, nodeKey *BitArray, hashFn crypto.HashFn) *felt.Felt { path := path(nodeKey, parentKey) return n.Hash(&path, hashFn) } -func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { +func (n *Node) WriteTo(buf *bytes.Buffer) (int, error) { if n.Value == nil { return 0, errors.New("cannot marshal node with nil value") } - totalBytes := int64(0) + var totalBytes int valueB := n.Value.Bytes() wrote, err := buf.Write(valueB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } if n.Left != nil { - wrote, errInner := n.Left.WriteTo(buf) + wrote, errInner := n.Left.Write(buf) totalBytes += wrote if errInner != nil { return totalBytes, errInner } - wrote, errInner = n.Right.WriteTo(buf) // n.Right is non-nil by design + wrote, errInner = n.Right.Write(buf) // n.Right is non-nil by design totalBytes += wrote if errInner != nil { return totalBytes, errInner @@ -75,14 +75,14 @@ func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { leftHashB := n.LeftHash.Bytes() wrote, err = buf.Write(leftHashB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } rightHashB := n.RightHash.Bytes() wrote, err = buf.Write(rightHashB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } @@ -110,17 +110,13 @@ func (n *Node) UnmarshalBinary(data []byte) error { } if n.Left == nil { - n.Left = new(Key) - n.Right = new(Key) + n.Left = new(BitArray) + n.Right = new(BitArray) } - if err := n.Left.UnmarshalBinary(data); err != nil { - return err - } + n.Left.UnmarshalBinary(data) data = data[n.Left.EncodedLen():] - if err := n.Right.UnmarshalBinary(data); err != nil { - return err - } + n.Right.UnmarshalBinary(data) data = data[n.Right.EncodedLen():] if n.LeftHash == nil { @@ -157,11 +153,13 @@ func (n *Node) Update(other *Node) error { return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value) } - if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) { + if n.Left != nil && other.Left != nil && !n.Left.Equal(emptyBitArray) && !other.Left.Equal(emptyBitArray) && !n.Left.Equal(other.Left) { return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) } - if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) { + if n.Right != nil && other.Right != nil && + !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && + !n.Right.Equal(other.Right) { return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) } @@ -177,10 +175,10 @@ func (n *Node) Update(other *Node) error { if other.Value != nil { n.Value = other.Value } - if other.Left != nil && !other.Left.Equal(NilKey) { + if other.Left != nil && !other.Left.Equal(emptyBitArray) { n.Left = other.Left } - if other.Right != nil && !other.Right.Equal(NilKey) { + if other.Right != nil && !other.Right.Equal(emptyBitArray) { n.Right = other.Right } if other.LeftHash != nil { diff --git a/core/trie/node_test.go b/core/trie/node_test.go index ccb52b3eac..cc1bb06eda 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -22,7 +22,7 @@ func TestNodeHash(t *testing.T) { node := trie.Node{ Value: new(felt.Felt).SetBytes(valueBytes), } - path := trie.NewKey(6, []byte{42}) + path := trie.NewBitArray(6, 42) assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") } diff --git a/core/trie/proof.go b/core/trie/proof.go index 9f1fd3ab10..4afed36dfc 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -40,14 +40,15 @@ func (b *Binary) String() string { type Edge struct { Child *felt.Felt // child hash - Path *Key // path from parent to child + Path *BitArray // path from parent to child } func (e *Edge) Hash(hash crypto.HashFn) *felt.Felt { - length := make([]byte, len(e.Path.bitset)) - length[len(e.Path.bitset)-1] = e.Path.len + var length [32]byte + length[31] = e.Path.len pathFelt := e.Path.Felt() - lengthFelt := new(felt.Felt).SetBytes(length) + lengthFelt := new(felt.Felt).SetBytes(length[:]) + // TODO: no need to return reference, just return value to avoid heap allocation return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt) } @@ -71,7 +72,7 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { return err } - var parentKey *Key + var parentKey *BitArray for i, sNode := range nodesFromRoot { sNodeEdge, sNodeBinary, err := storageNodeToProofNode(t, parentKey, sNode) @@ -138,9 +139,8 @@ func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSe // - The path bits don't match the key bits // - The proof ends before processing all key bits func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) (*felt.Felt, error) { - key := FeltToKey(globalTrieHeight, keyFelt) + keyBits := new(BitArray).SetFelt(globalTrieHeight, keyFelt) expectedHash := root - keyLen := key.Len() var curPos uint8 for { @@ -156,17 +156,17 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash switch node := proofNode.(type) { case *Binary: // Binary nodes represent left/right choices - if key.Len() <= curPos { - return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", key.Len(), curPos) + if keyBits.Len() <= curPos { + return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", keyBits.Len(), curPos) } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if key.IsBitSet(keyLen - curPos - 1) { + if keyBits.IsBitSet(curPos) { expectedHash = node.RightHash } curPos++ case *Edge: // Edge nodes represent paths between binary nodes - if !verifyEdgePath(&key, node.Path, curPos) { + if !verifyEdgePath(keyBits, node.Path, curPos) { return &felt.Zero, nil } @@ -176,7 +176,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash crypto.Hash } // We've consumed all bits in our path - if curPos >= keyLen { + if curPos >= keyBits.Len() { return expectedHash, nil } } @@ -235,18 +235,18 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * } nodes := NewStorageNodeSet() - firstKey := FeltToKey(globalTrieHeight, first) + firstKey := new(BitArray).SetFelt(globalTrieHeight, first) // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values // Empty range proof with more elements on the right is not accepted in this function. // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. if len(keys) == 0 { - rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + rootKey, val, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - if val != nil || hasRightElement(rootKey, &firstKey, nodes) { + if val != nil || hasRightElement(rootKey, firstKey, nodes) { return false, errors.New("more entries available") } @@ -254,17 +254,17 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * } last := keys[len(keys)-1] - lastKey := FeltToKey(globalTrieHeight, last) + lastKey := new(BitArray).SetFelt(globalTrieHeight, last) // Special case: there is only one element and two edge keys are the same - if len(keys) == 1 && firstKey.Equal(&lastKey) { - rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + if len(keys) == 1 && firstKey.Equal(lastKey) { + rootKey, val, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - elementKey := FeltToKey(globalTrieHeight, keys[0]) - if !firstKey.Equal(&elementKey) { + elementKey := new(BitArray).SetFelt(globalTrieHeight, keys[0]) + if !firstKey.Equal(elementKey) { return false, errors.New("correct proof but invalid key") } @@ -272,7 +272,7 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, errors.New("correct proof but invalid value") } - return hasRightElement(rootKey, &firstKey, nodes), nil + return hasRightElement(rootKey, firstKey, nodes), nil } // In all other cases, we require two edge paths available. @@ -281,12 +281,12 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, errors.New("last key is less than first key") } - rootKey, _, err := proofToPath(root, &firstKey, proof, nodes) + rootKey, _, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - lastRootKey, _, err := proofToPath(root, &lastKey, proof, nodes) + lastRootKey, _, err := proofToPath(root, lastKey, proof, nodes) if err != nil { return false, err } @@ -311,11 +311,11 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) } - return hasRightElement(rootKey, &lastKey, nodes), nil + return hasRightElement(rootKey, lastKey, nodes), nil } // isEdge checks if the storage node is an edge node. -func isEdge(parentKey *Key, sNode StorageNode) bool { +func isEdge(parentKey *BitArray, sNode StorageNode) bool { sNodeLen := sNode.key.len if parentKey == nil { // Root return sNodeLen != 0 @@ -326,7 +326,7 @@ func isEdge(parentKey *Key, sNode StorageNode) bool { // storageNodeToProofNode converts a StorageNode to the ProofNode(s). // Juno's Trie has nodes that are Binary AND Edge, whereas the protocol requires nodes that are Binary XOR Edge. // We need to convert the former to the latter for proof generation. -func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { +func storageNodeToProofNode(tri *Trie, parentKey *BitArray, sNode StorageNode) (*Edge, *Binary, error) { var edge *Edge if isEdge(parentKey, sNode) { edgePath := path(sNode.key, parentKey) @@ -375,8 +375,8 @@ func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge // proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining // as hashes. The given edge proof can be existent or non-existent. -func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageNodeSet) (*Key, *felt.Felt, error) { - rootKey, val, err := buildPath(root, key, 0, nil, proof, nodes) +func proofToPath(root *felt.Felt, keyBits *BitArray, proof *ProofNodeSet, nodes *StorageNodeSet) (*BitArray, *felt.Felt, error) { + rootKey, val, err := buildPath(root, keyBits, 0, nil, proof, nodes) if err != nil { return nil, nil, err } @@ -400,7 +400,7 @@ func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageN sn := NewPartialStorageNode(edge.Path, edge.Child) // Handle leaf edge case (single key trie) - if edge.Path.Len() == key.Len() { + if edge.Path.Len() == keyBits.Len() { if err := nodes.Put(*sn.key, sn); err != nil { return nil, nil, fmt.Errorf("failed to store leaf edge: %w", err) } @@ -433,12 +433,12 @@ func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageN // It returns the current node's key and any leaf value found along this path. func buildPath( nodeHash *felt.Felt, - key *Key, + key *BitArray, curPos uint8, curNode *StorageNode, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // We reached the leaf if curPos == key.Len() { leafKey := key.Copy() @@ -451,7 +451,7 @@ func buildPath( proofNode, ok := proof.Get(*nodeHash) if !ok { // non-existent proof node - return NilKey, nil, nil + return emptyBitArray, nil, nil } switch pn := proofNode.(type) { @@ -470,30 +470,26 @@ func buildPath( func handleBinaryNode( binary *Binary, nodeHash *felt.Felt, - key *Key, + key *BitArray, curPos uint8, curNode *StorageNode, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // If curNode is nil, it means that this current binary node is the root node. // Or, it's an internal binary node and the parent is also a binary node. // A standalone binary proof node always corresponds to a single storage node. // If curNode is not nil, it means that the parent node is an edge node. // In this case, the key of the storage node is based on the parent edge node. if curNode == nil { - nodeKey, err := key.MostSignificantBits(curPos) - if err != nil { - return nil, nil, err - } - curNode = NewPartialStorageNode(nodeKey, nodeHash) + curNode = NewPartialStorageNode(new(BitArray).MSBs(key, curPos), nodeHash) } curNode.node.LeftHash = binary.LeftHash curNode.node.RightHash = binary.RightHash // Calculate next position and determine to take left or right path nextPos := curPos + 1 - isRightPath := key.IsBitSet(key.Len() - nextPos) + isRightPath := key.IsBitSet(curPos) nextHash := binary.LeftHash if isRightPath { nextHash = binary.RightHash @@ -523,23 +519,19 @@ func handleBinaryNode( // the current node's key and any leaf value found along this path. func handleEdgeNode( edge *Edge, - key *Key, + key *BitArray, curPos uint8, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // Verify the edge path matches the key path if !verifyEdgePath(key, edge.Path, curPos) { - return NilKey, nil, nil + return emptyBitArray, nil, nil } // The next node position is the end of the edge path nextPos := curPos + edge.Path.Len() - nodeKey, err := key.MostSignificantBits(nextPos) - if err != nil { - return nil, nil, fmt.Errorf("failed to get MSB for internal edge: %w", err) - } - curNode := NewPartialStorageNode(nodeKey, edge.Child) + curNode := NewPartialStorageNode(new(BitArray).MSBs(key, nextPos), edge.Child) // This is an edge leaf, stop traversing the trie if nextPos == key.Len() { @@ -562,24 +554,12 @@ func handleEdgeNode( } // verifyEdgePath checks if the edge path matches the key path at the current position. -func verifyEdgePath(key, edgePath *Key, curPos uint8) bool { - if key.Len() < curPos+edgePath.Len() { - return false - } - - // Ensure the bits between segment of the key and the node path match - start := key.Len() - curPos - edgePath.Len() - end := key.Len() - curPos - for i := start; i < end; i++ { - if key.IsBitSet(i) != edgePath.IsBitSet(i-start) { - return false // paths diverge - this proves non-membership - } - } - return true +func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { + return new(BitArray).LSBs(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. -func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { +func buildTrie(height uint8, rootKey *BitArray, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { tr, err := NewTriePedersen(newMemStorage(), height) if err != nil { return nil, err @@ -607,9 +587,9 @@ func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values [] // hasRightElement checks if there is a right sibling for the given key in the trie. // This function assumes that the entire path has been resolved. -func hasRightElement(rootKey, key *Key, nodes *StorageNodeSet) bool { +func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { cur := rootKey - for cur != nil && !cur.Equal(NilKey) { + for cur != nil && !cur.Equal(emptyBitArray) { sn, ok := nodes.Get(*cur) if !ok { return false @@ -622,8 +602,7 @@ func hasRightElement(rootKey, key *Key, nodes *StorageNodeSet) bool { // If we're taking a left path and there's a right sibling, // then there are elements with larger values - bitPos := key.Len() - cur.Len() - 1 - isLeft := !key.IsBitSet(bitPos) + isLeft := !key.IsBitSet(cur.Len()) if isLeft && sn.node.RightHash != nil { return true } diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 94eaabc549..046b1b1bca 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -360,7 +360,7 @@ func TestOneElementRangeProof(t *testing.T) { }) } -// TestAllElementsProof tests the range proof with all elements and nil proof. +// TestAllElementsRangeProof tests the range proof with all elements and nil proof. func TestAllElementsRangeProof(t *testing.T) { t.Parallel() diff --git a/core/trie/storage.go b/core/trie/storage.go index c4e5ae0915..6fe994fe3b 100644 --- a/core/trie/storage.go +++ b/core/trie/storage.go @@ -42,17 +42,17 @@ func NewStorage(txn db.Transaction, prefix []byte) *Storage { // dbKey creates a byte array to be used as a key to our KV store // it simply appends the given key to the configured prefix -func (t *Storage) dbKey(key *Key, buffer *bytes.Buffer) (int64, error) { +func (t *Storage) dbKey(key *BitArray, buffer *bytes.Buffer) (int, error) { _, err := buffer.Write(t.prefix) if err != nil { return 0, err } - keyLen, err := key.WriteTo(buffer) - return int64(len(t.prefix)) + keyLen, err + keyLen, err := key.Write(buffer) + return len(t.prefix) + keyLen, err } -func (t *Storage) Put(key *Key, value *Node) error { +func (t *Storage) Put(key *BitArray, value *Node) error { buffer := getBuffer() defer bufferPool.Put(buffer) keyLen, err := t.dbKey(key, buffer) @@ -69,7 +69,7 @@ func (t *Storage) Put(key *Key, value *Node) error { return t.txn.Set(encodedBytes[:keyLen], encodedBytes[keyLen:]) } -func (t *Storage) Get(key *Key) (*Node, error) { +func (t *Storage) Get(key *BitArray) (*Node, error) { buffer := getBuffer() defer bufferPool.Put(buffer) _, err := t.dbKey(key, buffer) @@ -87,7 +87,7 @@ func (t *Storage) Get(key *Key) (*Node, error) { return node, err } -func (t *Storage) Delete(key *Key) error { +func (t *Storage) Delete(key *BitArray) error { buffer := getBuffer() defer bufferPool.Put(buffer) _, err := t.dbKey(key, buffer) @@ -97,21 +97,22 @@ func (t *Storage) Delete(key *Key) error { return t.txn.Delete(buffer.Bytes()) } -func (t *Storage) RootKey() (*Key, error) { - var rootKey *Key +func (t *Storage) RootKey() (*BitArray, error) { + var rootKey *BitArray if err := t.txn.Get(t.prefix, func(val []byte) error { - rootKey = new(Key) - return rootKey.UnmarshalBinary(val) + rootKey = new(BitArray) + rootKey.UnmarshalBinary(val) + return nil }); err != nil { return nil, err } return rootKey, nil } -func (t *Storage) PutRootKey(newRootKey *Key) error { +func (t *Storage) PutRootKey(newRootKey *BitArray) error { buffer := getBuffer() defer bufferPool.Put(buffer) - _, err := newRootKey.WriteTo(buffer) + _, err := newRootKey.Write(buffer) if err != nil { return err } diff --git a/core/trie/storage_test.go b/core/trie/storage_test.go index 809ded4791..21302f1308 100644 --- a/core/trie/storage_test.go +++ b/core/trie/storage_test.go @@ -15,7 +15,7 @@ import ( func TestStorage(t *testing.T) { testDB := pebble.NewMemTest(t) prefix := []byte{37, 44} - key := trie.NewKey(44, nil) + key := trie.NewBitArray(44, 0) value, err := new(felt.Felt).SetRandom() require.NoError(t, err) @@ -77,7 +77,7 @@ func TestStorage(t *testing.T) { }), db.ErrKeyNotFound.Error()) }) - rootKey := trie.NewKey(8, []byte{0x2}) + rootKey := trie.NewBitArray(8, 2) t.Run("put root key", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { @@ -91,7 +91,7 @@ func TestStorage(t *testing.T) { tTxn := trie.NewStorage(txn, prefix) gotRootKey, err := tTxn.RootKey() require.NoError(t, err) - assert.Equal(t, rootKey, *gotRootKey) + assert.Equal(t, &rootKey, gotRootKey) return nil })) }) diff --git a/core/trie/trie.go b/core/trie/trie.go index 5f8a51d9c0..bb2320d801 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -35,12 +35,12 @@ const globalTrieHeight = 251 // TODO(weiihann): this is declared in core also, s // [specification]: https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#merkle_patricia_trie type Trie struct { height uint8 - rootKey *Key + rootKey *BitArray maxKey *felt.Felt storage *Storage hash crypto.HashFn - dirtyNodes []*Key + dirtyNodes []*BitArray rootKeyIsDirty bool } @@ -94,32 +94,36 @@ func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error { return do(trie) } -// feltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], +// FeltToKey converts a key, given in felt, to a trie.Key which when followed on a [Trie], // leads to the corresponding [Node] -func (t *Trie) FeltToKey(k *felt.Felt) Key { - return FeltToKey(t.height, k) +func (t *Trie) FeltToKey(k *felt.Felt) BitArray { + var ba BitArray + ba.SetFelt(t.height, k) + return ba } // path returns the path as mentioned in the [specification] for commitment calculations. // path is suffix of key that diverges from parentKey. For example, // for a key 0b1011 and parentKey 0b10, this function would return the path object of 0b0. -func path(key, parentKey *Key) Key { - path := *key +func path(key, parentKey *BitArray) BitArray { // drop parent key, and one more MSB since left/right relation already encodes that information - if parentKey != nil { - path.Truncate(path.Len() - parentKey.Len() - 1) + if parentKey == nil { + return key.Copy() } - return path + + var pathKey BitArray + pathKey.LSBs(key, parentKey.Len()+1) + return pathKey } // storageNode is the on-disk representation of a [Node], // where key is the storage key and node is the value. type StorageNode struct { - key *Key + key *BitArray node *Node } -func (sn *StorageNode) Key() *Key { +func (sn *StorageNode) Key() *BitArray { return sn.key } @@ -133,7 +137,7 @@ func (sn *StorageNode) String() string { func (sn *StorageNode) Update(other *StorageNode) error { // First validate all fields for conflicts - if sn.key != nil && other.key != nil && !sn.key.Equal(NilKey) && !other.key.Equal(NilKey) { + if sn.key != nil && other.key != nil && !sn.key.Equal(emptyBitArray) && !other.key.Equal(emptyBitArray) { if !sn.key.Equal(other.key) { return fmt.Errorf("keys do not match: %s != %s", sn.key, other.key) } @@ -147,47 +151,47 @@ func (sn *StorageNode) Update(other *StorageNode) error { } // After validation, perform update - if other.key != nil && !other.key.Equal(NilKey) { + if other.key != nil && !other.key.Equal(emptyBitArray) { sn.key = other.key } return nil } -func NewStorageNode(key *Key, node *Node) *StorageNode { +func NewStorageNode(key *BitArray, node *Node) *StorageNode { return &StorageNode{key: key, node: node} } // NewPartialStorageNode creates a new StorageNode with a given key and value, // where the right and left children are nil. -func NewPartialStorageNode(key *Key, value *felt.Felt) *StorageNode { +func NewPartialStorageNode(key *BitArray, value *felt.Felt) *StorageNode { return &StorageNode{ key: key, node: &Node{ Value: value, - Left: NilKey, - Right: NilKey, + Left: emptyBitArray, + Right: emptyBitArray, }, } } // StorageNodeSet wraps OrderedSet to provide specific functionality for StorageNodes type StorageNodeSet struct { - set *utils.OrderedSet[Key, *StorageNode] + set *utils.OrderedSet[BitArray, *StorageNode] } func NewStorageNodeSet() *StorageNodeSet { return &StorageNodeSet{ - set: utils.NewOrderedSet[Key, *StorageNode](), + set: utils.NewOrderedSet[BitArray, *StorageNode](), } } -func (s *StorageNodeSet) Get(key Key) (*StorageNode, bool) { +func (s *StorageNodeSet) Get(key BitArray) (*StorageNode, bool) { return s.set.Get(key) } // Put adds a new StorageNode or updates an existing one. -func (s *StorageNodeSet) Put(key Key, node *StorageNode) error { +func (s *StorageNodeSet) Put(key BitArray, node *StorageNode) error { if node == nil { return errors.New("cannot put nil node") } @@ -217,7 +221,7 @@ func (s *StorageNodeSet) Size() int { // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. -func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { +func (t *Trie) nodesFromRoot(key *BitArray) ([]StorageNode, error) { var nodes []StorageNode cur := t.rootKey for cur != nil { @@ -236,12 +240,11 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { node: node, }) - subset := isSubset(key, cur) - if cur.Len() >= key.Len() || !subset { + if cur.Len() >= key.Len() || !key.EqualMSBs(cur) { return nodes, nil } - if key.IsBitSet(key.Len() - cur.Len() - 1) { + if key.IsBitSet(cur.Len()) { cur = node.Right } else { cur = node.Left @@ -267,12 +270,12 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { } // GetNodeFromKey returns the node for a given key. -func (t *Trie) GetNodeFromKey(key *Key) (*Node, error) { +func (t *Trie) GetNodeFromKey(key *BitArray) (*Node, error) { return t.storage.Get(key) } // check if we are updating an existing leaf, if yes avoid traversing the trie -func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { +func (t *Trie) updateLeaf(nodeKey BitArray, node *Node, value *felt.Felt) (*felt.Felt, error) { // Check if we are updating an existing leaf if !value.IsZero() { if existingLeaf, err := t.storage.Get(&nodeKey); err == nil { @@ -289,7 +292,7 @@ func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt return nil, nil } -func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { +func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey BitArray, node *Node, value *felt.Felt) (*felt.Felt, error) { if value.IsZero() { return nil, nil // no-op } @@ -301,7 +304,7 @@ func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *fe return &old, nil } -func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []StorageNode) (*felt.Felt, error) { +func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey BitArray, nodes []StorageNode) (*felt.Felt, error) { if nodeKey.Equal(sibling.key) { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -314,7 +317,7 @@ func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []Stora return nil, nil } -func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent StorageNode) { +func (t *Trie) replaceLinkWithNewParent(key *BitArray, commonKey BitArray, siblingParent StorageNode) { if siblingParent.node.Left.Equal(key) { *siblingParent.node.Left = commonKey } else { @@ -323,8 +326,15 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent S } // TODO(weiihann): not a good idea to couple proof verification logic with trie logic -func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { - commonKey, _ := findCommonKey(nodeKey, sibling.key) +func (t *Trie) insertOrUpdateValue( + nodeKey *BitArray, + node *Node, + nodes []StorageNode, + sibling StorageNode, + siblingIsParentProof bool, +) error { + var commonKey BitArray + commonKey.CommonMSBs(nodeKey, sibling.key) newParent := &Node{} var leftChild, rightChild *Node @@ -336,7 +346,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode if err != nil { return err } - if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSet(commonKey.Len()) { newParent.Right = nodeKey newParent.RightHash = node.Hash(nodeKey, t.hash) } else { @@ -348,7 +358,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode } t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSet(commonKey.Len()) { newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { @@ -497,19 +507,19 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, proof []*StorageNode) (*felt. } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutInner(key *Key, node *Node) error { +func (t *Trie) PutInner(key *BitArray, node *Node) error { if err := t.storage.Put(key, node); err != nil { return err } return nil } -func (t *Trie) setRootKey(newRootKey *Key) { +func (t *Trie) setRootKey(newRootKey *BitArray) { t.rootKey = newRootKey t.rootKeyIsDirty = true } -func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo +func (t *Trie) updateValueIfDirty(key *BitArray) (*Node, error) { //nolint:gocyclo node, err := t.storage.Get(key) if err != nil { return nil, err @@ -523,7 +533,7 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo shouldUpdate := false for _, dirtyNode := range t.dirtyNodes { if key.Len() < dirtyNode.Len() { - shouldUpdate = isSubset(dirtyNode, key) + shouldUpdate = key.EqualMSBs(dirtyNode) if shouldUpdate { break } @@ -531,9 +541,9 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo } // Update inner proof nodes - if node.Left.Equal(NilKey) && node.Right.Equal(NilKey) { // leaf + if node.Left.Equal(emptyBitArray) && node.Right.Equal(emptyBitArray) { // leaf shouldUpdate = false - } else if node.Left.Equal(NilKey) || node.Right.Equal(NilKey) { // inner + } else if node.Left.Equal(emptyBitArray) || node.Right.Equal(emptyBitArray) { // inner shouldUpdate = true } if !shouldUpdate { @@ -542,11 +552,11 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo var leftIsProof, rightIsProof bool var leftHash, rightHash *felt.Felt - if node.Left.Equal(NilKey) { // key could be nil but hash cannot be + if node.Left.Equal(emptyBitArray) { // key could be nil but hash cannot be leftIsProof = true leftHash = node.LeftHash } - if node.Right.Equal(NilKey) { + if node.Right.Equal(emptyBitArray) { rightIsProof = true rightHash = node.RightHash } @@ -643,7 +653,7 @@ func (t *Trie) deleteLast(nodes []StorageNode) error { return err } - var siblingKey Key + var siblingKey BitArray if parent.node.Left.Equal(last.key) { siblingKey = *parent.node.Right } else { @@ -710,7 +720,7 @@ func (t *Trie) Commit() error { } // RootKey returns db key of the [Trie] root node -func (t *Trie) RootKey() *Key { +func (t *Trie) RootKey() *BitArray { return t.rootKey } @@ -732,7 +742,7 @@ The following can be printed: The spacing to represent the levels of the trie can remain the same. */ -func (t *Trie) dump(level int, parentP *Key) { +func (t *Trie) dump(level int, parentP *BitArray) { if t.rootKey == nil { fmt.Printf("%sEMPTY\n", strings.Repeat("\t", level)) return diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 5426cbcafa..d9d13b1e4c 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -55,16 +55,17 @@ func TestTrieKeys(t *testing.T) { // Check parent and its left right children l := tempTrie.FeltToKey(leftKey) r := tempTrie.FeltToKey(rightKey) - commonKey, isSame := findCommonKey(&l, &r) - require.False(t, isSame) + var commonKey BitArray + commonKey.CommonMSBs(&l, &r) // Common key should be 0b100, length 251-2; - expectKey := NewKey(251-2, []byte{0x4}) + // expectKey := NewKey(251-2, []byte{0x4}) + expectKey := NewBitArray(249, 4) assert.Equal(t, expectKey, commonKey) // Current rootKey should be the common key - assert.Equal(t, expectKey, *tempTrie.rootKey) + assert.Equal(t, &expectKey, tempTrie.rootKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -98,12 +99,12 @@ func TestTrieKeys(t *testing.T) { // Check parent and its left right children l := tempTrie.FeltToKey(leftKey) r := tempTrie.FeltToKey(rightKey) - commonKey, isSame := findCommonKey(&l, &r) - require.False(t, isSame) + var commonKey BitArray + commonKey.CommonMSBs(&l, &r) - expectKey := NewKey(251-2, []byte{0x4}) + expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, commonKey) + assert.Equal(t, &expectKey, &commonKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -134,23 +135,21 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) newVal := new(felt.Felt).SetUint64(12) - //nolint: dupl t.Run("Add to left branch", func(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b101) _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(250, []byte{0x2}) + commonKey := NewBitArray(250, 2) parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) }) - //nolint: dupl t.Run("Add to right branch", func(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b110) _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(250, []byte{0x3}) + commonKey := NewBitArray(250, 3) parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) @@ -166,15 +165,15 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(248, []byte{}) + commonKey := NewBitArray(248, 0) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) - expectRightKey := NewKey(249, []byte{0x1}) + expectRightKey := NewBitArray(249, 1) - assert.Equal(t, expectRightKey, *parentNode.Right) + assert.Equal(t, &expectRightKey, parentNode.Right) }) }) } @@ -239,9 +238,9 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { _, err = tempTrie.Put(test.deleteKey, zeroVal) require.NoError(t, err) - newRootKey := NewKey(251-2, []byte{0x1}) + newRootKey := NewBitArray(249, 1) - assert.Equal(t, newRootKey, *tempTrie.rootKey) + assert.Equal(t, &newRootKey, tempTrie.rootKey) rootNode, err := tempTrie.storage.Get(&newRootKey) require.NoError(t, err) diff --git a/core/trie/trie_test.go b/core/trie/trie_test.go index 51d589ab63..7384bf558b 100644 --- a/core/trie/trie_test.go +++ b/core/trie/trie_test.go @@ -164,7 +164,81 @@ func TestPutZero(t *testing.T) { var keys []*felt.Felt // put random 64 keys and record roots - for range 64 { + for i := 0; i < 64; i++ { + key, value := new(felt.Felt), new(felt.Felt) + + _, err = key.SetRandom() + require.NoError(t, err) + + t.Logf("key: %s", key.String()) + + _, err = value.SetRandom() + require.NoError(t, err) + + t.Logf("value: %s", value.String()) + + _, err = tempTrie.Put(key, value) + require.NoError(t, err) + + keys = append(keys, key) + + var root *felt.Felt + root, err = tempTrie.Root() + require.NoError(t, err) + + roots = append(roots, root) + } + + t.Run("adding a zero value to a non-existent key should not change Trie", func(t *testing.T) { + var key, root *felt.Felt + key, err = new(felt.Felt).SetRandom() + require.NoError(t, err) + + _, err = tempTrie.Put(key, new(felt.Felt)) + require.NoError(t, err) + + root, err = tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, true, root.Equal(roots[len(roots)-1])) + }) + + t.Run("remove keys one by one, check roots", func(t *testing.T) { + var gotRoot *felt.Felt + // put zero in reverse order and check roots still match + for i := range 64 { + root := roots[len(roots)-1-i] + + gotRoot, err = tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, root, gotRoot) + + key := keys[len(keys)-1-i] + _, err = tempTrie.Put(key, new(felt.Felt)) + require.NoError(t, err) + } + }) + + t.Run("empty roots should match", func(t *testing.T) { + actualEmptyRoot, err := tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, true, actualEmptyRoot.Equal(emptyRoot)) + }) + return nil + })) +} + +func TestTrie(t *testing.T) { + require.NoError(t, trie.RunOnTempTriePedersen(251, func(tempTrie *trie.Trie) error { + emptyRoot, err := tempTrie.Root() + require.NoError(t, err) + var roots []*felt.Felt + var keys []*felt.Felt + + // put random 64 keys and record roots + for i := 0; i < 64; i++ { key, value := new(felt.Felt), new(felt.Felt) _, err = key.SetRandom() diff --git a/migration/migration.go b/migration/migration.go index 107bd40f10..97ce613f58 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -511,7 +511,7 @@ func calculateL1MsgHashes(txn db.Transaction, n *utils.Network) error { return processBlocks(txn, processBlockFunc) } -func bitset2Key(bs *bitset.BitSet) *trie.Key { +func bitset2BitArray(bs *bitset.BitSet) *trie.BitArray { bsWords := bs.Words() if len(bsWords) > felt.Limbs { panic("key too long to fit in Felt") @@ -524,9 +524,7 @@ func bitset2Key(bs *bitset.BitSet) *trie.Key { } f := new(felt.Felt).SetBytes(bsBytes[:]) - fBytes := f.Bytes() - k := trie.NewKey(uint8(bs.Len()), fBytes[:]) - return &k + return new(trie.BitArray).SetFelt(uint8(bs.Len()), f) } func migrateTrieRootKeysFromBitsetToTrieKeys(txn db.Transaction, key, value []byte, _ *utils.Network) error { @@ -535,8 +533,8 @@ func migrateTrieRootKeysFromBitsetToTrieKeys(txn db.Transaction, key, value []by if err := bs.UnmarshalBinary(value); err != nil { return err } - trieKey := bitset2Key(&bs) - _, err := trieKey.WriteTo(&tempBuf) + trieKey := bitset2BitArray(&bs) + _, err := trieKey.Write(&tempBuf) if err != nil { return err } @@ -574,8 +572,8 @@ func migrateTrieNodesFromBitsetToTrieKey(target db.Bucket) BucketMigratorDoFunc Value: n.Value, } if n.Left != nil { - trieNode.Left = bitset2Key(n.Left) - trieNode.Right = bitset2Key(n.Right) + trieNode.Left = bitset2BitArray(n.Left) + trieNode.Right = bitset2BitArray(n.Right) } if _, err := trieNode.WriteTo(&tempBuf); err != nil { @@ -594,7 +592,7 @@ func migrateTrieNodesFromBitsetToTrieKey(target db.Bucket) BucketMigratorDoFunc } var keyBuffer bytes.Buffer - if _, err := bitset2Key(&bs).WriteTo(&keyBuffer); err != nil { + if _, err := bitset2BitArray(&bs).Write(&keyBuffer); err != nil { return err } diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index e2d5613c48..688643386c 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -260,8 +260,11 @@ func TestMigrateTrieRootKeysFromBitsetToTrieKeys(t *testing.T) { require.NoError(t, migrateTrieRootKeysFromBitsetToTrieKeys(memTxn, key, bsBytes, &utils.Mainnet)) - var trieKey trie.Key - err = memTxn.Get(key, trieKey.UnmarshalBinary) + var trieKey trie.BitArray + err = memTxn.Get(key, func(data []byte) error { + trieKey.UnmarshalBinary(data) + return nil + }) require.NoError(t, err) require.Equal(t, bs.Len(), uint(trieKey.Len())) require.Equal(t, felt.Zero, trieKey.Felt()) @@ -357,7 +360,7 @@ func TestMigrateCairo1CompiledClass(t *testing.T) { } } -func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { +func TestMigrateTrieNodesFromBitsetToBitArray(t *testing.T) { migrator := migrateTrieNodesFromBitsetToTrieKey(db.ClassesTrie) memTxn := db.NewMemTransaction() @@ -388,9 +391,9 @@ func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { require.ErrorIs(t, err, db.ErrKeyNotFound) var nodeKeyBuf bytes.Buffer - newNodeKey := bitset2Key(bs) - wrote, err = newNodeKey.WriteTo(&nodeKeyBuf) - require.True(t, wrote > 0) + newNodeKey := bitset2BitArray(bs) + bWrite, err := newNodeKey.Write(&nodeKeyBuf) + require.True(t, bWrite > 0) require.NoError(t, err) var trieNode trie.Node