Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ See [Features](#features) for specific training recipes.
# Preprocess SFT data
python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \
--sft_map_fn_path "dllm.utils.default_sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
--num_proc 64
Expand All @@ -238,7 +238,7 @@ See [Features](#features) for specific training recipes.
# Preprocess SFT data
+ python dllm/tools/preprocess_sft_dataset.py \
+ --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
+ --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \
+ --sft_map_fn_path "dllm.utils.default_sft_map_fn" \
+ --dataset_args "allenai/tulu-3-sft-mixture" \
+ --output_dir "data/sft/llada/tulu-3-sft-mixture" \
+ --num_proc 64
Expand Down
52 changes: 34 additions & 18 deletions dllm/core/samplers/bd3lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens


def build_staircase_attention_mask(
def _prepare_for_sampling(
x: torch.Tensor,
block_size: int,
pad_token_id: int,
Expand Down Expand Up @@ -81,7 +81,7 @@ def build_staircase_attention_mask(
return attn_mask, position_ids


def diffusion_step_block(
def _diffusion_step_block(
logits: torch.Tensor, # [B, L, V]
x_block: torch.Tensor, # [B, L]
mask_block: torch.Tensor, # [B, L] bool
Expand Down Expand Up @@ -201,9 +201,10 @@ def sample(
mask_id = self.tokenizer.mask_token_id
bos_id = self.tokenizer.bos_token_id
pad_id = self.tokenizer.pad_token_id # used as padding here
eos_id = self.tokenizer.eos_token_id

# ---- normalize inputs to tensors ----
# If right_shift_logits is true and a sequence has length 0, replace that sequence with [eos].
# If right_shift_logits is true and a sequence has length 0, replace that sequence with [bos].
if right_shift_logits:
inputs = [
[bos_id] if isinstance(p, list) and len(p) == 0 else p for p in inputs
Expand All @@ -228,16 +229,21 @@ def sample(

# ==========================================================
# 1) Initialize with prompt only (left padded with pad_id)
# pad prefix length to a multiple of block_size
# ==========================================================
padded_prompt_len = (
(max_prompt_len + block_size - 1) // block_size
) * block_size

x = torch.full(
(B, max_prompt_len),
(B, padded_prompt_len),
pad_id,
dtype=torch.long,
device=self.model.device,
)
for b, p in enumerate(inputs):
L = prompt_lens[b]
offset = max_prompt_len - L # left padding
offset = padded_prompt_len - L # left padding
x[b, offset : offset + L] = p

# Tokens considered "given" for unconditional branch in CFG.
Expand All @@ -248,6 +254,9 @@ def sample(
)
unmasked_index = unmasked_index & (~keep_mask)

# track done per sequence (EOS)
done = torch.zeros((B,), dtype=torch.bool, device=self.model.device)

# ---- block scheduling ----
num_blocks = math.ceil(max_new_tokens / block_size)
if steps_per_block is None:
Expand All @@ -260,15 +269,13 @@ def sample(
# 2) Block-by-block generation loop
# ==========================================================
for b_idx in range(num_blocks):
# Align sampling block to physical block boundaries
if done.all():
break

T_prefix = x.shape[1] # current total length before appending this block
offset = T_prefix % block_size
if offset == 0:
block_room = block_size
else:
block_room = block_size - offset

cur_block_len = min(block_room, max_new_tokens - generated)
# With padded_prompt_len aligned, we always append whole blocks (except possibly final)
cur_block_len = min(block_size, max_new_tokens - generated)
if cur_block_len <= 0:
break

Expand All @@ -278,7 +285,7 @@ def sample(
x_prefix = x # [B, T_prefix]
B_cur, T_prefix = x_prefix.shape

prefix_attn, prefix_pos = build_staircase_attention_mask(
prefix_attn, prefix_pos = _prepare_for_sampling(
x=x_prefix,
block_size=block_size,
pad_token_id=pad_id,
Expand Down Expand Up @@ -317,6 +324,13 @@ def sample(
new_block = torch.full(
(B, cur_block_len), mask_id, dtype=torch.long, device=self.model.device
)
# if done.any():
# new_block = torch.where(
# done.unsqueeze(1),
# torch.full_like(new_block, pad_id),
# new_block,
# )
Comment on lines +327 to +332

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out block is a valuable optimization. By filling new blocks with pad_id for sequences that are already done, you can avoid wasteful computation during the diffusion steps. The current implementation processes these completed sequences unnecessarily. I recommend re-enabling this logic to improve efficiency in batch generation.

Suggested change
# if done.any():
# new_block = torch.where(
# done.unsqueeze(1),
# torch.full_like(new_block, pad_id),
# new_block,
# )
if done.any():
new_block = torch.where(
done.unsqueeze(1),
torch.full_like(new_block, pad_id),
new_block,
)


x = torch.cat([x, new_block], dim=1) # [B, T_prefix + cur_block_len]

unmasked_index = torch.cat(
Expand All @@ -342,7 +356,7 @@ def sample(
effective_steps = num_transfer_tokens.size(1)

# Full staircase attention mask + pos for prefix + current block
full_attention_mask, full_position_ids = build_staircase_attention_mask(
full_attention_mask, full_position_ids = _prepare_for_sampling(
x=x,
block_size=block_size,
pad_token_id=pad_id,
Expand Down Expand Up @@ -406,7 +420,7 @@ def sample(
logits_block = shifted

# ---- One diffusion step over this block ----
x_block_updated = diffusion_step_block(
x_block_updated = _diffusion_step_block(
logits=logits_block,
x_block=x_block,
mask_block=mask_block,
Expand All @@ -421,8 +435,10 @@ def sample(
if histories is not None:
histories.append(x.clone())

if self.tokenizer.eos_token_id in x[:, T_prefix:T_total]:
break
# per-sequence EOS stopping (after finishing denoising the block)
if eos_id is not None:
eos_in_block = (x[:, T_prefix:T_total] == eos_id).any(dim=1)
done = done | eos_in_block

generated += cur_block_len

Expand All @@ -437,7 +453,7 @@ def sample(
@torch.no_grad()
def infill(
self,
inputs: list[torch.Tensor, list],
inputs: list[torch.Tensor | list],
config: SamplerConfig | None = None,
**kwargs,
) -> SamplerOutput:
Expand Down
91 changes: 48 additions & 43 deletions dllm/core/trainers/bd3lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dllm.utils.collators import CollatorWrapper

from .mdlm import MDLMTrainer
from .utils import EpochPPLMeter


@dataclass
Expand All @@ -40,7 +39,7 @@ def before(self, features):
return features


def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
def _bd3lm_attention_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
"""
Constructs the specialized block diffusion attention mask for training
composed of three masks:
Expand Down Expand Up @@ -85,17 +84,18 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):

class BD3LMTrainer(MDLMTrainer):

@dataclass
class BD3LMConfig(MDLMTrainer.MDLMConfig):
block_size: int = 32

def __init__(
self,
block_size: int = 32,
*args,
args: BD3LMConfig,
*pargs,
**kwargs,
):
super().__init__(*args, **kwargs)
self.block_size = block_size

self.epoch_meter = EpochPPLMeter(self, train_prefix="train", eval_prefix="eval")
self.add_callback(self.epoch_meter)
super().__init__(args=args, *pargs, **kwargs)
self.block_size = args.block_size

def compute_loss(
self,
Expand Down Expand Up @@ -126,7 +126,7 @@ def compute_loss(
inputs.get("attention_mask", None),
)
b, l = input_ids.shape
token_cnt_per_seq = torch.sum(labels != -100, dim=1, keepdim=True) # [b, 1]
maskable_mask = labels != -100 # [b, l]

# === 1. Sample diffusion timesteps ===
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
Expand All @@ -139,12 +139,12 @@ def compute_loss(
# === 2. Apply stochastic masking ===
# Tokens are masked independently according to p_mask(t).
# Positions with label = -100 are excluded (ignored in loss).
masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
labels != -100
)
masked_mask = (
torch.rand((b, l), device=input_ids.device) < p_mask
) & maskable_mask
# Replace masked tokens with the special [MASK] token.
noised_input_ids = torch.where(
masked_indices, self.processing_class.mask_token_id, input_ids
masked_mask, self.processing_class.mask_token_id, input_ids
)

# === 3. Forward pass through the model (block-diffusion) ===
Expand All @@ -155,7 +155,7 @@ def compute_loss(

# [TODO]: others like flash attention 2
if self.accelerator.unwrap_model(model).config._attn_implementation == "sdpa":
attention_mask = block_diff_mask(
attention_mask = _bd3lm_attention_mask(
b=None,
h=None,
q_idx=torch.arange(l * 2)[:, None],
Expand All @@ -174,7 +174,7 @@ def compute_loss(
from torch.nn.attention.flex_attention import create_block_mask

attention_mask = create_block_mask(
partial(block_diff_mask, block_size=self.block_size, n=l),
partial(_bd3lm_attention_mask, block_size=self.block_size, n=l),
B=None,
H=None,
Q_LEN=l * 2,
Expand All @@ -201,46 +201,51 @@ def compute_loss(
# === 4. Handle degenerate cases (no tokens masked) ===
# If no positions were masked, return a zero loss to keep gradients valid.
# This step is necessary for Deepspeed Zero-{2,3}
if not masked_indices.any():
self.epoch_meter.update(
split="train" if model.training else "eval",
nll_sum=logits.sum() * 0.0,
token_cnt=token_cnt_per_seq.sum(),
)
return (
(logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
if not masked_mask.any():
zero = logits.sum() * 0.0 # scalar, grad-safe
self.meter.update(
split="train" if model.training else "eval",
value=torch.zeros_like(maskable_mask, dtype=logits.dtype),
weight=maskable_mask.to(dtype=logits.dtype).detach(),
)
return (zero, outputs) if return_outputs else zero

# === 5. Compute per-token loss weights ===
# Depending on the configuration, weights may depend on timestep t
# (e.g., scheduler-based) or be uniform (ones).
loss_weights = self._compute_loss_weights(
t=t, inputs=inputs, masked_indices=masked_indices
t=t, inputs=inputs, masked_mask=masked_mask
)

# === 6. Compute weighted cross-entropy ===
# Only masked tokens contribute to the loss.
assert (input_ids[masked_indices] == labels[masked_indices]).all()
token_loss = F.cross_entropy(
logits[masked_indices], input_ids[masked_indices], reduction="none"
# Sanity check: ensure input_ids and labels match at valid positions
assert (
input_ids[maskable_mask] == labels[maskable_mask]
).all(), "Mismatch between input_ids and labels at valid positions"

token_nll = F.cross_entropy(
logits.transpose(1, 2), # [b, V, l]
input_ids, # [b, l]
reduction="none", # [b, l]
)
token_nll = token_nll * loss_weights * masked_mask.to(token_nll.dtype) # [b, l]

self.meter.update(
split="train" if model.training else "eval",
value=token_nll.detach(),
weight=maskable_mask.to(dtype=logits.dtype).detach(),
)
token_loss = token_loss * loss_weights[masked_indices]

# === 7. Normalize loss ===
if self.loss_normalization_type == "batch":
token_loss_normalized = token_loss / b
elif self.loss_normalization_type == "sequence":
token_loss_normalized = token_loss / token_cnt_per_seq.expand(-1, l)[masked_indices] / b
elif self.loss_normalization_type == "token":
token_loss_normalized = token_loss / token_cnt_per_seq.sum()
if self.loss_norm_type == "batch":
token_nll /= b
elif self.loss_norm_type == "sequence":
token_nll /= maskable_mask.sum(-1, keepdim=True).clamp_min(1) * b
elif self.loss_norm_type == "token":
token_nll /= maskable_mask.sum().clamp_min(1)
else:
raise ValueError("Invalid loss_normalization_type.")
loss = token_loss_normalized.sum()
raise ValueError("Invalid loss_norm_type.")
loss = token_nll.sum()

# === 8. Return final loss (and optionally model outputs) ===
self.epoch_meter.update(
split="train" if model.training else "eval",
nll_sum=token_loss.sum(),
token_cnt=token_cnt_per_seq.sum(),
) # `nll_sum / token_cnt` is equivalent to `loss` when `self.loss_normalization_type == "token"``
return (loss, outputs) if return_outputs else loss
Loading