diff --git a/docs/byte_ensembling_example.py b/docs/byte_ensembling_example.py new file mode 100644 index 0000000..60ba03a --- /dev/null +++ b/docs/byte_ensembling_example.py @@ -0,0 +1,241 @@ +import asyncio +from typing import Callable, Literal, List, Tuple, Union +from collections import defaultdict + +import numpy as np +from cachetools import LRUCache +from arsenal.maths import logsumexp + +from genlm.backend import load_model_by_name +from genlm.bytes import ByteBeamState, BeamParams +from genlm.bytes.util import split_with_atomic_tokens +from genlm.control import Potential +from genlm.control.sampler.token import TokenSampler +from genlm.control.util import fast_sample_logprobs +from genlm.control.constant import EOS + + +def convert_to_logop(op: Literal["sum", "prod", "min"]) -> Callable: + """Convert a string operation to its log-space equivalent.""" + if op == "sum": + return lambda x, y: logsumexp([x, y], axis=0) + elif op == "prod": + return lambda x, y: x + y + elif op == "min": + return lambda x, y: np.minimum(x, y) + else: + raise ValueError(f"Invalid operation: {op}. Choose from 'sum', 'prod', 'min'.") + + +class ByteEnsemble(Potential): + """ + An ensemble potential combining two language models using a specified log-space operation. + + Attributes: + p1, p2: The base LM potentials. + op: A function to combine log-probabilities. + data_dict_1, data_dict_2: Beam state caches keyed by context (bytes). + vocabulary: Byte-level vocabulary. + """ + + def __init__(self, p1, p2, op: Callable, data_dict_1, data_dict_2, vocab): + self.p1 = p1 + self.p2 = p2 + self.op = op + self.data_dict_1 = data_dict_1 + self.data_dict_2 = data_dict_2 + self.eos_tokens = ( + [self.p1.byte_vocab[self.p1.tokenizer.eos_token_id]] + + [self.p2.byte_vocab[self.p2.tokenizer.eos_token_id]] + ) + super().__init__(vocabulary=vocab) + + @classmethod + async def create(cls, llm1, llm2, op: str, prompt1: bytes, prompt2: bytes): + """Factory method to initialize beam states from prompts and return a ByteEnsemble instance.""" + beam_params = BeamParams(K=5, prune_threshold=0.1, verbose=True) + data_dict_1 = defaultdict() + data_dict_2 = defaultdict() + + async def setup(): + beam1, beam2 = await asyncio.gather( + ByteBeamState.initial(llm1, beam_params), + ByteBeamState.initial(llm2, beam_params), + ) + return await asyncio.gather(beam1.prefill(prompt1), beam2.prefill(prompt2)) + + beam_state_1, beam_state_2 = await setup() + data_dict_1[b""] = beam_state_1 + data_dict_2[b""] = beam_state_2 + return cls( + llm1, + llm2, + convert_to_logop(op), + data_dict_1, + data_dict_2, + vocab=list(range(256)), + ) + + async def _cleanup_cache(self): + """Remove old entries to avoid cache bloat.""" + max_len = max( + (len(split_with_atomic_tokens(k, self.eos_tokens)) for k in self.data_dict_1), + default=0, + ) + min_len = max_len - 2 + for d in [self.data_dict_1, self.data_dict_2]: + for k in list(d.keys()): + if len(k) < min_len: + del d[k] + + async def get_beam_states(self, context: List[int]): + """Fetch beam states for the current context.""" + ctx_bytes = bytes(context) + await self._cleanup_cache() + return self.data_dict_1[ctx_bytes], self.data_dict_2[ctx_bytes] + + async def prefix(self, context: List[int]): + """Stub for abstract method.""" + return None # or raise NotImplementedError if you're sure it's never needed + + async def complete(self, context: List[int]): + """Stub for abstract method.""" + return None + + +class ByteEnsembleTokenSampler(TokenSampler): + """ + Token sampler that draws from an ensemble of potentials using a log-space proposal strategy. + + Args: + potential: The target ensemble potential. + proposal: How to combine log-probabilities ('linear', 'abs', etc.). + n_particles: Number of particles for SMC sampling. + eos_tokens: List of end-of-sequence tokens. + max_tokens: Maximum number of tokens to generate. + models_equal: Flag whether the two potentials have the same base LM. + """ + + def __init__( + self, + potential: ByteEnsemble, + proposal: Literal["linear", "abs", "square", "soft n"] = "linear", + n_particles: int = 10, + eos_tokens: List[int] = [], + max_tokens: int = None, + models_equal: bool = False, + ): + super().__init__(target=potential) + self.potential = potential + self.proposal = proposal + self.n_particles = n_particles + self.eos_tokens = eos_tokens + self.max_tokens = max_tokens + self.models_equal = models_equal + + self.prefix_cache_1 = LRUCache(maxsize=3 * n_particles) + self.prefix_cache_2 = LRUCache(maxsize=3 * n_particles) + self.particle_prefix_log_prob_1 = defaultdict() + self.particle_prefix_log_prob_2 = defaultdict() + + self.prefix_cache_1[()] = 0.0 + self.prefix_cache_2[()] = 0.0 + + async def start_weight(self) -> float: + return 0.0 + + async def sample(self, context: List[int]) -> Tuple[int, float, float]: + """Sample one token from the ensemble distribution.""" + beam1, beam2 = await self.potential.get_beam_states(context) + logp_1, logp_2 = await beam1.logp_next(), await beam2.logp_next() + + ctx_tuple = tuple(context) + log_context_weight_1 = self.prefix_cache_1[ctx_tuple] + log_context_weight_2 = self.prefix_cache_2[ctx_tuple] + + logws1 = log_context_weight_1 + logp_1.ps + logws2 = log_context_weight_2 + logp_2.ps + + log_shaping_weight_prev = ( + 0 + if not context + else self.potential.op(log_context_weight_1, log_context_weight_2) + ) + + proposal_weights = self.potential.op(logws1, logws2) - log_shaping_weight_prev + logps = proposal_weights - logsumexp(proposal_weights) + token_idx = fast_sample_logprobs(logps)[0] + + token = beam1.states[0].trie.trie.trie_decode[token_idx] + assert token == beam2.states[0].trie.trie.trie_decode[token_idx] + + next_context = ( + bytes(context + [token]) + if isinstance(token, int) + else bytes(context) + token + ) + self.potential.data_dict_1[next_context] = await (beam1.prune() << token) + self.potential.data_dict_2[next_context] = await (beam2.prune() << token) + + new_ctx_tuple = ctx_tuple + (token,) + self.prefix_cache_1[new_ctx_tuple] = logws1[token_idx] + self.prefix_cache_2[new_ctx_tuple] = logws2[token_idx] + + if token in self.eos_tokens: + token = EOS + + if token == EOS or (self.max_tokens and len(ctx_tuple) + 1 == self.max_tokens): + self.particle_prefix_log_prob_1[ctx_tuple + (token,)] = logws1[token_idx] + self.particle_prefix_log_prob_2[ctx_tuple + (token,)] = logws2[token_idx] + + return token, proposal_weights[token_idx] - logps[token_idx], logps[token_idx] + + async def smc( + self, + n_particles: int, + ess_threshold: float, + max_tokens: int, + critic=None, + **kwargs, + ): + """Run Sequential Monte Carlo inference.""" + from genlm.control.sampler.sequence import EnsembleSMC + + return await EnsembleSMC(self, critic)( + n_particles=n_particles, + ess_threshold=ess_threshold, + max_tokens=max_tokens, + **kwargs, + ) + + +async def main(): + llm1 = load_model_by_name("meta-llama/Llama-3.2-1B-Instruct") + llm2 = load_model_by_name("meta-llama/Llama-3.2-1B-Instruct") + + prompt1 = ( + b"London is good." + llm1.byte_vocab[llm1.tokenizer.eos_token_id] + b"Paris is " + ) + prompt2 = ( + b"London is good." + llm2.byte_vocab[llm1.tokenizer.eos_token_id] + b"Paris is " + ) + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt1, prompt2=prompt2 + ) + max_tokens = 25 + eos_tokens = [ + llm1.byte_vocab[llm1.tokenizer.eos_token_id], + llm2.byte_vocab[llm2.tokenizer.eos_token_id], + ] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=max_tokens, eos_tokens=eos_tokens, n_particles=5 + ) + + result = await sampler.smc(n_particles=10, ess_threshold=0.5, max_tokens=max_tokens) + print(result.posterior) + print(result.decoded_posterior) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/genlm/bytes/byte_lm/beam.py b/genlm/bytes/byte_lm/beam.py index f53d739..dfd584a 100644 --- a/genlm/bytes/byte_lm/beam.py +++ b/genlm/bytes/byte_lm/beam.py @@ -53,7 +53,7 @@ def __init__(self, states, params): self.params = params @classmethod - async def initial(cls, llm, params): + async def initial(cls, llm, params, eos_tokens=None): """Creates initial beam state. Args: @@ -63,8 +63,17 @@ async def initial(cls, llm, params): Returns: (ByteBeamState): Initial beam state. """ + eos_tokens = ( + [llm.tokenizer.eos_token.encode("utf-8")] + if eos_tokens is None + else eos_tokens + ) state = LazyTrieState.initial( - llm, AsyncTokenByteTrie.from_vocab(get_byte_vocab(llm.tokenizer)) + llm, + AsyncTokenByteTrie.from_vocab( + get_byte_vocab(llm.tokenizer), + eos_tokens=eos_tokens, + ), ) return cls([await state.materialize()], params) @@ -121,16 +130,19 @@ async def logp_next(self): for state in await self.extend(self.logZ): logqs.append(state.logp_next.ps + state.weight) - logqs = np.stack(logqs, axis=0) # shape: (num_states, 257) + logqs = np.stack(logqs, axis=0) logqs[: len(self), -1] = -np.inf # mask EOT positions of non-extended logps = scipy_logsumexp(logqs, axis=0) - return LazyByteProbs(logps - logsumexp(logps)) + # byte-encode and decode are the same across states + encode = self.states[0].trie.trie.trie_encode + decode = self.states[0].trie.trie.trie_decode + return LazyByteProbs(logps - logsumexp(logps), encode=encode, decode=decode) async def extend(self, logZ): """Attempts to advance each candidate in the beam by a token (EOT). - For each candididate with EOT available, this ends the current token and + For each candididate with EOT available, this ends the current token and starts a new one in preparation for the next byte. Args: diff --git a/genlm/bytes/byte_lm/lm_state.py b/genlm/bytes/byte_lm/lm_state.py index ee238b7..f1fc838 100644 --- a/genlm/bytes/byte_lm/lm_state.py +++ b/genlm/bytes/byte_lm/lm_state.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from arsenal.maths import sample_dict -from ..util import escape +from ..util import escape, split_with_atomic_tokens class StatefulTokenizedLM: @@ -117,7 +117,9 @@ async def prefill(self, bs): (StatefulByteLM): New state with all bytes added """ state = self - for b in bs: + trie = state.states[0].trie.trie + atomic_tokens = (trie.atomic_tokens or []) + (trie.eos_tokens or []) + for b in split_with_atomic_tokens(bs, atomic_tokens): state = await (state.prune() << b) return state @@ -133,12 +135,18 @@ async def greedy(self, context, steps): """ context = list(context) state = await self.prefill(context) + eos_tokens = self.states[0].trie.trie.eos_tokens for _ in range(steps): Q = (await state.logp_next()).materialize() b = Q.argmax() state = await (state.prune() << b) context.append(b) - return bytes(context) + if b in eos_tokens: + break + return b''.join( + bytes([item]) if isinstance(item, int) else item + for item in context + ) async def sample(self, context, steps, draw=sample_dict): """Samples from the model for given number of steps. @@ -153,12 +161,19 @@ async def sample(self, context, steps, draw=sample_dict): """ context = list(context) state = await self.prefill(context) + eos_tokens = self.states[0].trie.trie.eos_tokens for _ in range(steps): Q = (await state.logp_next()).materialize() b = draw(Q.map_values(exp)) state = await (state.prune() << b) context.append(b) - return bytes(context) + if b in eos_tokens: + break + return b''.join( + bytes([item]) if isinstance(item, int) else item + for item in context + ) + async def cleanup(self): """Performs any necessary cleanup of the model state.""" diff --git a/genlm/bytes/byte_lm/trie_state.py b/genlm/bytes/byte_lm/trie_state.py index a1a40bc..f6ed4ff 100644 --- a/genlm/bytes/byte_lm/trie_state.py +++ b/genlm/bytes/byte_lm/trie_state.py @@ -3,7 +3,7 @@ from functools import cached_property from arsenal import colors from .lm_state import StatefulTokenizedLM -from ..util import escape, LazyByteProbs +from ..util import escape, LazyByteProbs, logsumexp class LazyTrieState: @@ -87,12 +87,25 @@ def __lshift__(self, b): """ if node := self.children[self.node].get(b): mass = self.mass + if b in self.trie.trie.eos_tokens: + cum_eos_logprob = logsumexp( + self.mass[ + [ + result + for eos_token in self.trie.trie.eos_tokens + if (result := self.children[self.node].get(eos_token)) is not None + ] + ] + ) + node_mass = cum_eos_logprob + else: + node_mass = mass[node] return LazyTrieState( lm_state=self.lm_state, trie=self.trie, mass=mass, node=node, - weight=self.weight + mass[node] - mass[self.node], + weight=self.weight + node_mass - mass[self.node], ) def extend(self): @@ -120,12 +133,14 @@ def logp_next(self): Returns: (LazyByteProbs): Lazy log probability distribution over possible next bytes """ - logps = np.full(257, -np.inf) # 257 for EOT + logps = np.full(len(self.trie.trie.trie_encode), -np.inf) mass = self.mass logZ = mass[self.node] for byte, node in self.actions().items(): - logps[byte if byte is not None else 256] = mass[node] - logZ - return LazyByteProbs(logps) + logps[self.trie.trie.trie_encode[byte]] = mass[node] - logZ + return LazyByteProbs( + logps, encode=self.trie.trie.trie_encode, decode=self.trie.trie.trie_decode + ) async def materialize(self): """Materializes the masses for each node in the trie for the current state. diff --git a/genlm/bytes/trie.py b/genlm/bytes/trie.py index 22211c5..2ad56b4 100644 --- a/genlm/bytes/trie.py +++ b/genlm/bytes/trie.py @@ -11,7 +11,9 @@ class TokenByteTrie: """A trie data structure for efficient token-to-byte mapping.""" - def __init__(self, decode, device=None, atomic_tokens=None, eot_token=None): + def __init__( + self, decode, device=None, atomic_tokens=None, eot_token=None, eos_tokens=None + ): """Initialize a `TokenByteTrie`. Args: @@ -27,6 +29,8 @@ def __init__(self, decode, device=None, atomic_tokens=None, eot_token=None): raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None") self.eot_token = eot_token + self.eos_tokens = eos_tokens + self.atomic_tokens = atomic_tokens self._build_trie(atomic_tokens or []) self._renumber() self._build_node2prefix() @@ -44,7 +48,18 @@ def _build_trie(self, atomic_tokens): for token in atomic_tokens: if token not in self.decode: raise ValueError(f"Atomic token {token} not in vocabulary") + for token in self.eos_tokens: + if token not in self.decode: + raise ValueError(f"EOS token {token} not in vocabulary") + # construct mappings from byte vocab to indices in weight array + self.trie_decode = ( + list(range(256)) + atomic_tokens + self.eos_tokens + [self.eot_token] + ) + self.trie_encode = { + k: v for k, v in zip(self.trie_decode, list(range(len(self.trie_decode)))) + } + self.weight_encode = {} self.word2leaf = {} self.children = [{}] # First node is root self.root = 0 @@ -57,7 +72,8 @@ def _build_trie(self, atomic_tokens): self.lookup[word] = token_id curr = self.root - letters = [word] if word in atomic_tokens else word + + letters = [word] if word in atomic_tokens + self.eos_tokens else word for letter in letters: if letter not in self.children[curr]: self.children[curr][letter] = len(self.children) @@ -211,7 +227,7 @@ def _preprocess_ws(self, batch_ws): """Preprocess weight sums for batch processing. Args: - batch_ws (list|np.ndarray|torch.Tensor): List of weight sum tensors or lists of weight sums. + batch_ws (list|np.ndarray|torch.Tensor): List of weight sum tensors or lists of 8 sums. Returns: (torch.Tensor): Stacked weight sum tensor. diff --git a/genlm/bytes/util.py b/genlm/bytes/util.py index f50c4ba..b1fcf28 100644 --- a/genlm/bytes/util.py +++ b/genlm/bytes/util.py @@ -2,6 +2,8 @@ import numpy as np import pandas as pd from IPython.display import HTML, SVG +from typing import Union +import warnings from arsenal import colors @@ -9,16 +11,17 @@ class LazyByteProbs: """Represents a lazy (log) probability distribution over bytes. - Handles probability distributions over 256 possible bytes plus an EOT (End of Token) symbol. + Handles probability distributions over bytes plus an EOT (End of Token) symbol. Args: - ps (list): List of 257 probabilities (256 bytes + 1 EOT) + ps (list): List of probabilities log_space (bool, optional): Whether probabilities are in log space. Defaults to True """ - def __init__(self, ps, log_space=True): - assert len(ps) == 257 # 256 bytes + 1 EOT + def __init__(self, ps, encode, decode, log_space=True): self.ps = ps + self.encode = encode + self.decode = decode self.log_space = log_space def __getitem__(self, b): @@ -30,9 +33,7 @@ def __getitem__(self, b): Returns: (float): Probability (or log probability) for the byte/EOT """ - if b is None: - return self.ps[-1] - return self.ps[b] + return self.ps[self.encode[b]] def materialize(self): """Materializes the probability distribution into a Chart. @@ -41,9 +42,8 @@ def materialize(self): (Chart): Chart with probabilities for each byte/EOT """ Q = Chart(-np.inf if self.log_space else 0) - for b, p in enumerate(self.ps[:-1]): + for b, p in zip(self.decode, self.ps): Q[b] = p - Q[None] = self.ps[-1] return Q def pretty(self): @@ -53,7 +53,7 @@ def pretty(self): (str): Pretty string representation of the probability distribution """ return self.materialize().map_keys( - lambda x: bytes([x]) if x is not None else "EOT" + lambda x: bytes([x]) if x in range(256) else ("EOT" if x is None else x) ) @@ -272,3 +272,60 @@ def escape(x): else: y = repr(x)[1:-1] return y.replace(" ", "␣") + + +def split_with_atomic_tokens(data: bytes, atomic_tokens: list[bytes]) -> list[Union[int, bytes]]: + """ + Splits a bytestring into a list of either individual bytes (as integers) or atomic tokens (as bytes), + depending on whether the current position matches an atomic token. + + Args: + data (bytes): The input byte string to split. + atomic_tokens (list[bytes]): A list of byte substrings that are treated as indivisible atomic tokens. + + Returns: + list[Union[int, bytes]]: A list where each element is either: + - an atomic token (as bytes) if a match is found at that position, + - or a single byte (as an int) if no atomic token matches. + + Notes: + - Matching is greedy but only left-to-right: at each position, the function checks for atomic token matches + starting from length 1 up to the maximum token length. + - Only the first match (shortest prefix match) is used; longer overlapping tokens may be missed if a shorter + prefix matches first. + - If atomic tokens overlap (e.g., b"A" and b"AB"), a warning is raised and only the shortest prefix match + will be used. + + Example: + >>> split_with_atomic_tokens(b"ABC", [b"A", b"AB"]) + [b'A', 66, 67] # b"AB" is not matched because b"A" matched first + """ + # Detect overlapping atomic tokens + for i, token1 in enumerate(atomic_tokens): + for j, token2 in enumerate(atomic_tokens): + if i != j and (token1.startswith(token2) or token2.startswith(token1)): + warnings.warn( + f"Overlapping atomic tokens detected: {token1!r} and {token2!r}. " + "Only the shortest matching prefix will be used." + ) + break # One warning is enough + + result = [] + i = 0 + token_set = set(atomic_tokens) + max_len = max(len(t) for t in atomic_tokens) if atomic_tokens else 0 + + while i < len(data): + matched = False + for length in range(1, max_len + 1): + fragment = data[i:i+length] + if fragment in token_set: + result.append(fragment) + i += length + matched = True + break + if not matched: + result.append(data[i]) + i += 1 + + return result \ No newline at end of file