Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD
else:
# Convert each shard to a Dataset then collect its gradients
buf, shard_id = [], 0
total_tokens_processed = 0

def flush():
nonlocal buf, shard_id
Expand All @@ -180,6 +181,19 @@ def flush():
shard_id += 1

for ex in tqdm(ds, desc="Collecting gradients"):
# Check if adding this example would exceed max_tokens
if cfg.max_tokens is not None:
example_tokens = ex.get("length", 0)
if total_tokens_processed + example_tokens > cfg.max_tokens:
# Flush current buffer and stop processing
flush()
print(
f"Reached max_tokens limit ({cfg.max_tokens}). "
f"Processed {total_tokens_processed} tokens."
)
break
total_tokens_processed += example_tokens

buf.append(ex)
if len(buf) == cfg.stream_shard_size:
flush()
Expand Down Expand Up @@ -229,6 +243,33 @@ def build_gradient_dataset(cfg: IndexConfig):
new_fingerprint="advantage", # type: ignore
)

# Apply max_tokens filtering if specified
if cfg.max_tokens is not None:
if isinstance(ds, Dataset):
# For non-streaming datasets, filter based on cumulative token count
total_tokens = 0
indices_to_keep = []

for i, length in enumerate(ds["length"]):
if total_tokens + length <= cfg.max_tokens:
total_tokens += length
indices_to_keep.append(i)
else:
break

if indices_to_keep:
ds = ds.select(indices_to_keep)
print(
f"Filtered dataset to {len(indices_to_keep)} examples "
f"with {total_tokens} tokens (target: {cfg.max_tokens})"
)
else:
raise ValueError(f"No examples fit within max_tokens={cfg.max_tokens}")
else:
# For streaming datasets, max_tokens filtering is handled in the worker
# function during the streaming processing loop
pass

world_size = torch.cuda.device_count()
if world_size <= 1:
# Run the worker directly if no distributed training is needed. This is great
Expand Down
3 changes: 3 additions & 0 deletions bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class IndexConfig:
head_cfgs: dict[str, HeadConfig] = field(default_factory=dict)
"""Configuration for each attention module to be split into head matrices."""

max_tokens: int | None = None
"""Maximum number of tokens to process. If None, process all available tokens."""


def ceildiv(a: int, b: int) -> int:
"""Ceiling division of two integers."""
Expand Down
Loading