Skip to content

Commit c5cfe4c

Browse files
committed
Use generic trees for recovery
1 parent c307f09 commit c5cfe4c

File tree

1 file changed

+102
-56
lines changed

1 file changed

+102
-56
lines changed

packages/wallet/primitives/src/extensions/recovery.ts

+102-56
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import { Address, Bytes, Hash, Hex } from 'ox'
22
import * as Payload from '../payload'
33
import { getSignPayload } from 'ox/TypedData'
4+
import * as GenericTree from '../generic-tree'
45

56
export const FLAG_RECOVERY_LEAF = 1
67
export const FLAG_NODE = 3
78
export const FLAG_BRANCH = 4
89

10+
const RECOVERY_LEAF_PREFIX = Bytes.fromString('Sequence recovery leaf:\n')
11+
912
/**
1013
* A leaf in the Recovery tree, storing:
1114
* - signer who can queue a payload
@@ -19,23 +22,18 @@ export type RecoveryLeaf = {
1922
minTimestamp: bigint
2023
}
2124

22-
/**
23-
* A node is just a 32-byte hash
24-
*/
25-
export type NodeLeaf = Hex.Hex
26-
2725
/**
2826
* A branch is a list of subtrees (≥2 in length).
2927
*/
30-
export type Node = [Node, Node]
28+
export type Branch = [Tree, Tree]
3129

3230
/**
3331
* The topology of a recovery tree can be either:
3432
* - A node (pair of subtrees)
3533
* - A node leaf (32-byte hash)
3634
* - A recovery leaf (signer with timing constraints)
3735
*/
38-
export type Topology = Node | NodeLeaf | RecoveryLeaf
36+
export type Tree = Branch | GenericTree.Node | RecoveryLeaf
3937

4038
/**
4139
* Type guard to check if a value is a RecoveryLeaf
@@ -44,25 +42,18 @@ export function isRecoveryLeaf(cand: any): cand is RecoveryLeaf {
4442
return typeof cand === 'object' && cand !== null && cand.type === 'leaf'
4543
}
4644

47-
/**
48-
* Type guard to check if a value is a NodeLeaf (32-byte hash)
49-
*/
50-
export function isNodeLeaf(cand: any): cand is NodeLeaf {
51-
return typeof cand === 'string' && cand.length === 66 && cand.startsWith('0x')
52-
}
53-
5445
/**
5546
* Type guard to check if a value is a Node (pair of subtrees)
5647
*/
57-
export function isNode(cand: any): cand is Node {
58-
return Array.isArray(cand) && cand.length === 2 && isTopology(cand[0]) && isTopology(cand[1])
48+
export function isBranch(cand: any): cand is Branch {
49+
return Array.isArray(cand) && cand.length === 2 && isTree(cand[0]) && isTree(cand[1])
5950
}
6051

6152
/**
6253
* Type guard to check if a value is a Topology
6354
*/
64-
export function isTopology(cand: any): cand is Topology {
65-
return isRecoveryLeaf(cand) || isNodeLeaf(cand) || isNode(cand)
55+
export function isTree(cand: any): cand is Tree {
56+
return isRecoveryLeaf(cand) || GenericTree.isNode(cand) || isBranch(cand)
6657
}
6758

6859
/**
@@ -79,24 +70,8 @@ export const DOMAIN_VERSION = '1'
7970
* For node leaves, it returns the hash directly.
8071
* For nodes, it hashes the concatenation of the hashes of both subtrees.
8172
*/
82-
export function hashConfiguration(topology: Topology): Hex.Hex {
83-
if (isRecoveryLeaf(topology)) {
84-
return Hash.keccak256(
85-
Bytes.concat(
86-
Bytes.fromString('Sequence recovery leaf:\n'),
87-
Bytes.fromHex(topology.signer, { size: 20 }),
88-
Bytes.padLeft(Bytes.fromNumber(topology.requiredDeltaTime), 32),
89-
Bytes.padLeft(Bytes.fromNumber(topology.minTimestamp), 32),
90-
),
91-
{ as: 'Hex' },
92-
)
93-
} else if (isNodeLeaf(topology)) {
94-
return topology
95-
} else if (isNode(topology)) {
96-
return Hash.keccak256(Hex.concat(hashConfiguration(topology[0]), hashConfiguration(topology[1])), { as: 'Hex' })
97-
} else {
98-
throw new Error('Invalid topology')
99-
}
73+
export function hashConfiguration(topology: Tree): Hex.Hex {
74+
return GenericTree.hash(toGenericTree(topology))
10075
}
10176

10277
/**
@@ -107,13 +82,13 @@ export function hashConfiguration(topology: Topology): Hex.Hex {
10782
* - leaves: Array of RecoveryLeaf nodes
10883
* - isComplete: boolean indicating if all leaves are present (no node references)
10984
*/
110-
export function getRecoveryLeaves(topology: Topology): { leaves: RecoveryLeaf[]; isComplete: boolean } {
85+
export function getRecoveryLeaves(topology: Tree): { leaves: RecoveryLeaf[]; isComplete: boolean } {
11186
const isComplete = true
11287
if (isRecoveryLeaf(topology)) {
11388
return { leaves: [topology], isComplete }
114-
} else if (isNodeLeaf(topology)) {
89+
} else if (GenericTree.isNode(topology)) {
11590
return { leaves: [], isComplete: false }
116-
} else if (isNode(topology)) {
91+
} else if (isBranch(topology)) {
11792
const left = getRecoveryLeaves(topology[0])
11893
const right = getRecoveryLeaves(topology[1])
11994
return { leaves: [...left.leaves, ...right.leaves], isComplete: left.isComplete && right.isComplete }
@@ -129,7 +104,7 @@ export function getRecoveryLeaves(topology: Topology): { leaves: RecoveryLeaf[];
129104
* @returns The decoded Topology object
130105
* @throws Error if the encoding is invalid
131106
*/
132-
export function decodeTopology(encoded: Bytes.Bytes): Topology {
107+
export function decodeTopology(encoded: Bytes.Bytes): Tree {
133108
const { nodes, leftover } = parseBranch(encoded)
134109
if (leftover.length > 0) {
135110
throw new Error('Leftover bytes in branch')
@@ -146,12 +121,12 @@ export function decodeTopology(encoded: Bytes.Bytes): Topology {
146121
* - leftover: Any remaining unparsed bytes
147122
* @throws Error if the encoding is invalid
148123
*/
149-
export function parseBranch(encoded: Bytes.Bytes): { nodes: Topology[]; leftover: Bytes.Bytes } {
124+
export function parseBranch(encoded: Bytes.Bytes): { nodes: Tree[]; leftover: Bytes.Bytes } {
150125
if (encoded.length === 0) {
151126
throw new Error('Empty branch')
152127
}
153128

154-
const nodes: Topology[] = []
129+
const nodes: Tree[] = []
155130
let index = 0
156131

157132
while (index < encoded.length) {
@@ -208,7 +183,7 @@ export function parseBranch(encoded: Bytes.Bytes): { nodes: Topology[]; leftover
208183
* @param signer - The signer address to keep
209184
* @returns The trimmed topology
210185
*/
211-
export function trimTopology(topology: Topology, signer: Address.Address): Topology {
186+
export function trimTopology(topology: Tree, signer: Address.Address): Tree {
212187
if (isRecoveryLeaf(topology)) {
213188
if (topology.signer === signer) {
214189
return topology
@@ -217,20 +192,20 @@ export function trimTopology(topology: Topology, signer: Address.Address): Topol
217192
}
218193
}
219194

220-
if (isNodeLeaf(topology)) {
195+
if (GenericTree.isNode(topology)) {
221196
return topology
222197
}
223198

224-
if (isNode(topology)) {
199+
if (isBranch(topology)) {
225200
const left = trimTopology(topology[0], signer)
226201
const right = trimTopology(topology[1], signer)
227202

228203
// If both are hashes, we can just return the hash of the node
229-
if (isNodeLeaf(left) && isNodeLeaf(right)) {
204+
if (GenericTree.isNode(left) && GenericTree.isNode(right)) {
230205
return hashConfiguration(topology)
231206
}
232207

233-
return [left, right] as Node
208+
return [left, right] as Branch
234209
}
235210

236211
throw new Error('Invalid topology')
@@ -243,11 +218,11 @@ export function trimTopology(topology: Topology, signer: Address.Address): Topol
243218
* @returns The binary encoded topology
244219
* @throws Error if the topology is invalid
245220
*/
246-
export function encodeTopology(topology: Topology): Bytes.Bytes {
247-
if (isNode(topology)) {
221+
export function encodeTopology(topology: Tree): Bytes.Bytes {
222+
if (isBranch(topology)) {
248223
const encoded0 = encodeTopology(topology[0]!)
249224
const encoded1 = encodeTopology(topology[1]!)
250-
const isBranching = isNode(topology[1]!)
225+
const isBranching = isBranch(topology[1]!)
251226

252227
if (isBranching) {
253228
// max 3 bytes for the size
@@ -263,7 +238,7 @@ export function encodeTopology(topology: Topology): Bytes.Bytes {
263238
}
264239
}
265240

266-
if (isNodeLeaf(topology)) {
241+
if (GenericTree.isNode(topology)) {
267242
const flag = Bytes.fromNumber(FLAG_NODE)
268243
const nodeHash = Bytes.fromHex(topology, { size: 32 })
269244
return Bytes.concat(flag, nodeHash)
@@ -296,7 +271,7 @@ export function encodeTopology(topology: Topology): Bytes.Bytes {
296271
* @returns A binary tree structure
297272
* @throws Error if the nodes array is empty
298273
*/
299-
function foldNodes(nodes: Topology[]): Topology {
274+
function foldNodes(nodes: Tree[]): Tree {
300275
if (nodes.length === 0) {
301276
throw new Error('Empty signature tree')
302277
}
@@ -305,9 +280,9 @@ function foldNodes(nodes: Topology[]): Topology {
305280
return nodes[0]!
306281
}
307282

308-
let tree: Topology = nodes[0]!
283+
let tree: Tree = nodes[0]!
309284
for (let i = 1; i < nodes.length; i++) {
310-
tree = [tree, nodes[i]!] as Topology
285+
tree = [tree, nodes[i]!] as Tree
311286
}
312287
return tree
313288
}
@@ -321,7 +296,7 @@ function foldNodes(nodes: Topology[]): Topology {
321296
* @returns A topology tree structure
322297
* @throws Error if the leaves array is empty
323298
*/
324-
export function fromRecoveryLeaves(leaves: RecoveryLeaf[]): Topology {
299+
export function fromRecoveryLeaves(leaves: RecoveryLeaf[]): Tree {
325300
if (leaves.length === 0) {
326301
throw new Error('Cannot build a tree with zero leaves')
327302
}
@@ -333,7 +308,7 @@ export function fromRecoveryLeaves(leaves: RecoveryLeaf[]): Topology {
333308
const mid = Math.floor(leaves.length / 2)
334309
const left = fromRecoveryLeaves(leaves.slice(0, mid))
335310
const right = fromRecoveryLeaves(leaves.slice(mid))
336-
return [left, right] as Node
311+
return [left, right] as Branch
337312
}
338313

339314
/**
@@ -389,3 +364,74 @@ export function hashRecoveryPayload(
389364
const structHash = Bytes.fromHex(getSignPayload(Payload.toTyped(wallet, noChainId ? 0n : chainId, payload)))
390365
return Hash.keccak256(Bytes.concat(Bytes.fromString('\x19\x01'), Hex.toBytes(ds), structHash), { as: 'Hex' })
391366
}
367+
368+
/**
369+
* Convert a RecoveryTree topology to a generic tree format
370+
*
371+
* @param topology - The recovery tree topology to convert
372+
* @returns A generic tree that produces the same root hash
373+
*/
374+
export function toGenericTree(topology: Tree): GenericTree.Tree {
375+
if (isRecoveryLeaf(topology)) {
376+
// Convert recovery leaf to generic leaf
377+
return {
378+
type: 'leaf',
379+
value: Bytes.concat(
380+
RECOVERY_LEAF_PREFIX,
381+
Bytes.fromHex(topology.signer, { size: 20 }),
382+
Bytes.padLeft(Bytes.fromNumber(topology.requiredDeltaTime), 32),
383+
Bytes.padLeft(Bytes.fromNumber(topology.minTimestamp), 32),
384+
),
385+
}
386+
} else if (GenericTree.isNode(topology)) {
387+
// Node leaves are already in the correct format
388+
return topology
389+
} else if (isBranch(topology)) {
390+
// Convert node to branch
391+
return [toGenericTree(topology[0]), toGenericTree(topology[1])]
392+
} else {
393+
throw new Error('Invalid topology')
394+
}
395+
}
396+
397+
/**
398+
* Convert a generic tree back to a RecoveryTree topology
399+
*
400+
* @param tree - The generic tree to convert
401+
* @returns A recovery tree topology that produces the same root hash
402+
*/
403+
export function fromGenericTree(tree: GenericTree.Tree): Tree {
404+
if (GenericTree.isLeaf(tree)) {
405+
// Convert generic leaf back to recovery leaf
406+
const bytes = tree.value
407+
if (
408+
bytes.length !== RECOVERY_LEAF_PREFIX.length + 84 ||
409+
!Bytes.isEqual(bytes.slice(0, RECOVERY_LEAF_PREFIX.length), RECOVERY_LEAF_PREFIX)
410+
) {
411+
throw new Error('Invalid recovery leaf format')
412+
}
413+
414+
const offset = RECOVERY_LEAF_PREFIX.length
415+
const signer = Address.from(Hex.fromBytes(bytes.slice(offset, offset + 20)))
416+
const requiredDeltaTime = Bytes.toBigInt(bytes.slice(offset + 20, offset + 52))
417+
const minTimestamp = Bytes.toBigInt(bytes.slice(offset + 52, offset + 84))
418+
419+
return {
420+
type: 'leaf',
421+
signer,
422+
requiredDeltaTime,
423+
minTimestamp,
424+
}
425+
} else if (GenericTree.isNode(tree)) {
426+
// Nodes are already in the correct format
427+
return tree
428+
} else if (GenericTree.isBranch(tree)) {
429+
// Convert branch back to node
430+
if (tree.length !== 2) {
431+
throw new Error('Recovery tree only supports binary branches')
432+
}
433+
return [fromGenericTree(tree[0]), fromGenericTree(tree[1])] as Branch
434+
} else {
435+
throw new Error('Invalid tree format')
436+
}
437+
}

0 commit comments

Comments
 (0)