Skip to content

Commit 7d507d5

Browse files
committed
Implement IMnemonic.collect to support collecting a mnemonic word
1 parent 0f0325f commit 7d507d5

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

hdwallet/mnemonics/imnemonic.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
abc
1212
)
1313
from typing import (
14-
Any, Callable, Dict, Generator, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Union
14+
Any, Callable, Collection, Dict, Generator, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Union
1515
)
1616

1717
import os
@@ -163,6 +163,25 @@ def scan(
163163
for suffix, found in self.scan( current=child, depth=max(0, depth-1), predicate=predicate ):
164164
yield prefix + char + suffix, found
165165

166+
def options(
167+
self,
168+
prefix: str = '',
169+
current: Optional[TrieNode] = None,
170+
) -> Generator[Tuple[bool, Set[str]], str, None]:
171+
"""With each symbol provided, yields the next available symbol options.
172+
173+
Doesn't advance unless a truthy symbol is provided via <generator>send(symbol).
174+
175+
Completes when the provided symbol doesn't match one of the available options.
176+
"""
177+
last: str = ''
178+
*_, (terminal, _, current) = self.find(prefix, current=current)
179+
while current is not None:
180+
terminal = current.value is not current.EMPTY
181+
symbol: str = yield (terminal, set(current.children))
182+
if symbol:
183+
current = current.children.get(symbol)
184+
166185
def dump_lines(
167186
self,
168187
current: Optional[TrieNode] = None,
@@ -343,6 +362,9 @@ def unique( current ):
343362
# Only abbreviations (not terminal words) that led to a unique terminal word
344363
yield abbrev
345364

365+
def options(self, *args, **kwargs):
366+
return self._trie.options(*args, **kwargs)
367+
346368
def __str__(self):
347369
return str(self._trie)
348370

@@ -663,6 +685,46 @@ def find_language(
663685

664686
return language_indices[candidate], candidate
665687

688+
@classmethod
689+
def collect(
690+
cls,
691+
languages: Optional[Collection[str]] = None,
692+
wordlist_path: Optional[Dict[str, Union[str, List[str]]]] = None,
693+
) -> Generator[Tuple[Set[str], bool, Set[str]], str, None]:
694+
"""A generator taking input symbols, and producing a sequence of sets of possible next
695+
characters in all remaining languages.
696+
697+
With each symbol provided, yields the remaining candidate languages, whether the symbol
698+
indicated a terminal word in some language, and the available next symbols in all remaining
699+
languages.
700+
701+
"""
702+
candidates: Dict[str, WordIndices] = dict(
703+
(candidate, words_indices)
704+
for candidate, _, words_indices in cls.wordlist_indices( wordlist_path=wordlist_path )
705+
if languages is None or candidate in languages
706+
)
707+
708+
word: str = ''
709+
updaters = {
710+
candidate: words_indices.options()
711+
for candidate, words_indices in candidates.items()
712+
}
713+
714+
symbol = None
715+
complete = set()
716+
while complete < set(updaters):
717+
terminal = False
718+
possible = set()
719+
for candidate, updater in updaters.items():
720+
try:
721+
done, available = updater.send(symbol)
722+
except StopIteration:
723+
complete.add( candidate )
724+
terminal |= done
725+
possible |= available
726+
symbol = yield (set(updaters) - complete, terminal, possible)
727+
666728
@classmethod
667729
def is_valid(cls, mnemonic: Union[str, List[str]], language: Optional[str] = None, **kwargs) -> bool:
668730
"""Checks if the given mnemonic is valid.

tests/hdwallet/mnemonics/test_mnemonics_slip39.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import pytest
33

44
from hdwallet.exceptions import MnemonicError
5+
from hdwallet.mnemonics.imnemonic import (
6+
Trie, WordIndices,
7+
)
58
from hdwallet.mnemonics.slip39.mnemonic import (
6-
SLIP39Mnemonic, language_parser, group_parser
9+
SLIP39Mnemonic, language_parser, group_parser,
710
)
811

912
import shamir_mnemonic
@@ -70,7 +73,6 @@ def test_slip39_language():
7073
},
7174
}
7275

73-
7476
def test_slip39_mnemonics():
7577

7678
# Ensure our prefix and whitespace handling works correctly

tests/test_bip39_cross_language.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def dual_language_N_word_mnemonics(self, words=12, expected_rate=1/16, total_att
135135
except ChecksumError as exc:
136136
# Skip invalid mnemonics (e.g., checksum failures)
137137
continue
138-
138+
139139
success_rate = len(successful_both_languages) / total_attempts
140140

141141
print(f"{words}-word mnemonics: {len(successful_both_languages)}/{total_attempts} successful ({success_rate:.6f})")
@@ -438,7 +438,22 @@ class CustomTrieNode(TrieNode):
438438
o u n t == 16
439439
u s e == 17
440440
h i e v e == 18"""
441-
441+
442+
443+
444+
options = trie.options()
445+
446+
assert next( options ) == (False, set( 'a' ))
447+
assert options.send( 'a' ) == (False, set( 'bcd' ))
448+
assert options.send( 'd' ) == (False, set( 'dj' ))
449+
assert options.send( 'd' ) == (True, set( 'ir' ))
450+
assert options.send( 'i' ) == (False, set( 'c' ))
451+
assert options.send( 'c' ) == (False, set( 't' ))
452+
assert next(options) == (False, set( 't' ))
453+
assert options.send('') == (False, set( 't' ))
454+
assert options.send( 't' ) == (True, set())
455+
456+
442457
def test_ambiguous_languages(self):
443458
"""Test that find_language correctly detects and raises errors for ambiguous mnemonics.
444459
@@ -522,6 +537,18 @@ def test_ambiguous_languages(self):
522537
raise # Re-raise unexpected errors
523538

524539

540+
def test_bip39_collection():
541+
542+
languages = {'english', 'french', 'spanish', 'russian'}
543+
544+
collect = BIP39Mnemonic.collect(languages=languages)
545+
assert collect.send(None) == (languages, False, set('abcedefghijklmnopqrstuvwxyzáéíóúабвгдежзиклмнопрстуфхцчшщэюя'))
546+
assert collect.send('a') == ({'english', 'french', 'spanish'}, False, set('bcedefghijlmnpqrstuvwxyzéñ'))
547+
assert collect.send('d') == ({'english', 'french', 'spanish'}, False, set('adehijmoruvé'))
548+
assert collect.send('d') == ({'english'}, True , set('ir'))
549+
550+
551+
525552
def test_bip39_korean():
526553
# Confirm that UTF-8 Mark handling works in other languages (particularly Korean)
527554
(_, korean_nfc, korean_indices), = BIP39Mnemonic.wordlist_indices(

0 commit comments

Comments
 (0)