Skip to content

feat(eval): add loglikelihood support for BD3LM eval harness#95

Open
zamal-db wants to merge 2 commits intoZHZisZZ:mainfrom
zamal-db:feat/bd3lm-loglikelihood
Open

feat(eval): add loglikelihood support for BD3LM eval harness#95
zamal-db wants to merge 2 commits intoZHZisZZ:mainfrom
zamal-db:feat/bd3lm-loglikelihood

Conversation

@zamal-db
Copy link

@zamal-db zamal-db commented Feb 26, 2026

Summary

Implements Monte Carlo ELBO loglikelihood() for BD3LMEvalHarness, enabling BD3LM models to be evaluated on likelihood-based benchmarks (ARC-Challenge, ARC-Easy, HellaSwag, WinoGrande, PIQA, MMLU, etc.).

Addresses #93 — previously, all BD3LM evaluations were generation-based only, and loglikelihood() raised NotImplementedError.

Motivation

As noted by @lingjiechen2 in #93:

The core issue is that we have not yet implemented the loglikelihood() method for the bd3lm models. All of our current evaluations for bd3lm are generation-based.

This PR fills that gap by implementing the proper BD3LM loglikelihood with the correct [x_t ⊕ x_0] block-diffusion attention, matching the training procedure exactly.

Mathematical Framework

BD3LM factorizes the likelihood over B blocks (Arriola et al., 2025):

$$\log p_\theta(x) \geq \mathcal{L}_{BD}(x,\theta) := \sum_{b=1}^{B} \mathcal{L}_{\text{diffusion}}(x^b, x^{<b}, \theta)$$

Each per-block ELBO uses the absorbing-state diffusion parameterization. The MC estimator is:

$$\hat{\mathcal{L}} = -\frac{1}{N}\sum_{n=1}^{N} \sum_{i \in M_n} \frac{\text{CE}(f_\theta(x_t, x^{<b})_i, x_i)}{p_{\text{mask},i}}$$

where p_mask = k / L' is the importance weight and k is the number of masked tokens.

The forward process (token-level masking) is identical to MDLM — the block structure is enforced only through the attention mask, not the masking pattern.

Implementation

1 file changed: dllm/core/eval/bd3lm.py (235 insertions, 3 deletions)

The only architecturally novel method is _get_logits(), which differs from MDLM by:

Component MDLM BD3LM (this PR)
Model input x_t (length L) [x_t ⊕ x_0] (length 2L)
Attention Full bidirectional Block-diffusion (M_BD | M_OBC | M_BC)
Position IDs Default sequential [0..L-1, 0..L-1] (duplicated)
Logits used All L First L only (logits[:, :l])

All other methods (_forward_process, _get_loglikelihood, _suffix_greedy_prediction, loglikelihood) follow the same MC ELBO framework as MDLMEvalHarness.

Key design decisions:

  • Reuses _create_bd3lm_attention_mask from dllm.core.trainers.bd3lm — same function used in training
  • Supports both SDPA and flex_attention — mirrors the trainer's dual-path implementation
  • Handles self.accelerator = None for single-GPU eval (BaseEvalHarness sets this to None when num_processes == 1)
  • Inherits BaseEvalHarness (not MDLM) — independent implementation, consistent with Dream's eval pattern
  • No changes to __init__.pyBD3LMEvalConfig and BD3LMEvalHarness are already exported
  • A2DBD3LMEvalHarness automatically inherits loglikelihood() via class hierarchy

Usage

accelerate launch \
    --num_processes 4 \
    dllm/pipelines/a2d/eval.py \
    --tasks arc_challenge \
    --model a2d_bd3lm \
    --num_fewshot 0 \
    --model_args "pretrained=dllm-collection/Qwen3-0.6B-diffusion-bd3lm-v0.1,mc_num=128,batch_size=32,block_size=32"

Verification

  • Black formatting passes
  • Import chain verified (BD3LMEvalHarness, A2DBD3LMEvalHarness)
  • Attention mask verified against paper's 3-component decomposition (M_BD, M_OBC, M_BC) with concrete traces
  • Edge cases tested: partial final blocks, single block, block_size=1 (AR limit)
  • x_0 tokens verified to NEVER attend to x_t tokens
  • Config fields (mc_num, batch_size, block_size, is_check_greedy) follow MDLM conventions

Implement Monte Carlo ELBO loglikelihood estimation for BD3LM models,
enabling evaluation on likelihood-based benchmarks (ARC, HellaSwag, etc.).

The key architectural difference from MDLM is _get_logits, which constructs
the [x_t  x_0] input with block-diffusion attention (M_BD | M_OBC | M_BC)
and duplicated position IDs, matching the BD3LM training procedure exactly.

- Add _get_logits with [x_t  x_0] construction and block-diffusion mask
- Add _create_attention_mask supporting both SDPA and flex_attention
- Add _forward_process, _get_loglikelihood, _suffix_greedy_prediction
- Add loglikelihood public API (lm-eval interface)
- Add mc_num, batch_size, is_check_greedy to BD3LMEvalConfig
- Reuse _create_bd3lm_attention_mask from dllm.core.trainers.bd3lm

Closes ZHZisZZ#93
@gemini-code-assist
Copy link

Summary of Changes

Hello @zamal-db, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the evaluation capabilities for BD3LM models by introducing a loglikelihood method. This implementation, based on Monte Carlo ELBO with block-diffusion attention, allows these models to be assessed on a wider range of likelihood-based benchmarks, filling a critical gap where only generation-based evaluations were previously supported.

Highlights

  • Loglikelihood Support: Implemented the Monte Carlo ELBO loglikelihood() method for BD3LMEvalHarness, enabling BD3LM models to be evaluated on likelihood-based benchmarks.
  • Block-Diffusion Attention: Introduced the correct [x_t ⊕ x_0] block-diffusion attention mechanism for loglikelihood calculation, precisely matching the training procedure.
  • Addressing Previous Limitation: Resolved the NotImplementedError that previously occurred when attempting to use loglikelihood() with BD3LM models, which were limited to generation-based evaluations.
  • Flexible Attention Implementation: The attention mask creation supports both SDPA and flex_attention, mirroring the dual-path implementation in the trainer.
Changelog
  • dllm/core/eval/bd3lm.py
    • Updated the module docstring to reflect the new loglikelihood support via Monte Carlo ELBO.
    • Added necessary imports for torch, tqdm, Instance, partial, and _create_bd3lm_attention_mask.
    • Extended BD3LMEvalConfig with batch_size, mc_num, and is_check_greedy parameters.
    • Revised the BD3LMEvalHarness docstring to indicate loglikelihood support and the use of block-diffusion attention.
    • Initialized new instance attributes (mask_id, max_length, mc_num, is_check_greedy, block_size) in the __init__ method.
    • Implemented _encode_pair to handle context and continuation encoding, including trailing space management.
    • Added _create_attention_mask to generate the BD3LM block-diffusion attention mask, supporting both SDPA and flex_attention.
    • Introduced _get_logits for the BD3LM forward pass, processing [x_t ⊕ x_0] input with block-diffusion attention and returning x_t predictions.
    • Implemented _forward_process to apply the forward diffusion process by masking a random subset of target tokens.
    • Developed _get_loglikelihood to compute the Monte Carlo estimate of log-likelihood using the forward process and logits.
    • Added _suffix_greedy_prediction for an optional greedy unmasking check.
    • Implemented the public loglikelihood method, which orchestrates the encoding, log-likelihood estimation, and greedy prediction for evaluation requests.
Activity
  • Addressed issue Evaluation of A2D models(mdllm/bd3lm) on base tasks #93, which highlighted the lack of loglikelihood() implementation for BD3LM models.
  • Verified Black formatting, import chain correctness, and attention mask behavior against the paper's decomposition.
  • Tested edge cases including partial final blocks, single blocks, and block_size=1 (autoregressive limit).
  • Confirmed that x_0 tokens are prevented from attending to x_t tokens.
  • Ensured config fields (mc_num, batch_size, block_size, is_check_greedy) adhere to MDLM conventions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements loglikelihood() for BD3LMEvalHarness, enabling evaluation of BD3LM models on likelihood-based benchmarks using a Monte Carlo ELBO estimate. A critical security concern was identified: an empty continuation string can lead to a Denial of Service vulnerability by causing a crash due to division by zero or an invalid range in the forward diffusion process; a graceful fix is recommended. Furthermore, to enhance maintainability, consider refactoring a helper function to remove statefulness by explicitly passing data as arguments, which will improve robustness and readability.

continuation_enc, device=self.device, dtype=torch.long
)

logprob = self._get_loglikelihood(context, continuation)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _forward_process method crashes when the continuation (target) is empty because target_len becomes 0. This leads to an invalid range for torch.randint(1, target_len + 1) on line 164, and division/modulo by zero on lines 174 and 197. This can be triggered by malformed requests or edge cases in benchmark datasets, leading to a Denial of Service (DoS) of the evaluation process.

Suggested change
logprob = self._get_loglikelihood(context, continuation)
logprob = self._get_loglikelihood(context, continuation) if continuation.shape[0] > 0 else 0.0

Comment on lines +134 to +141
def _get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
"""BD3LM forward: [x_t ⊕ x_0] with block-diffusion attention, return x_t logits."""
b, l = batch.shape

# [x_t ⊕ x_0]: noised first half, clean second half
concat_input_ids = torch.cat([batch, self._x0], dim=1) # [b, 2l]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of self._x0 to pass the clean sequence to this method makes it stateful and reliant on its callers to set this property correctly. This can be fragile and harder to maintain.

The suggested change refactors this method to accept x0 as an explicit argument, making its dependencies clear.

Please update the calls to _get_logits in _get_loglikelihood and _suffix_greedy_prediction accordingly:

  • In _get_loglikelihood:

    # ...
    x0 = seq.clone()
    # ...
    logits = self._get_logits(perturbed_seq, x0, prompt_index)
    # ...
  • In _suffix_greedy_prediction:

    # ...
    x0 = torch.cat([prefix, target]).unsqueeze(0)
    # ...
    logits = self._get_logits(seq, x0, prompt_index)[mask_index]
    # ...
Suggested change
def _get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
"""BD3LM forward: [x_t ⊕ x_0] with block-diffusion attention, return x_t logits."""
b, l = batch.shape
# [x_t ⊕ x_0]: noised first half, clean second half
concat_input_ids = torch.cat([batch, self._x0], dim=1) # [b, 2l]
def _get_logits(
self, xt: torch.Tensor, x0: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
"""BD3LM forward: [x_t ⊕ x_0] with block-diffusion attention, return x_t logits."""
b, l = xt.shape
# [x_t ⊕ x_0]: noised first half, clean second half
concat_input_ids = torch.cat([xt, x0], dim=1) # [b, 2l]

- Add guard for empty continuation (returns 0.0, False) to prevent
  DoS via division-by-zero in _forward_process (target_len=0)
- Refactor _get_logits to accept x0 as explicit parameter instead
  of reading from self._x0, improving maintainability
- Rename x0 -> x0_clean in _suffix_greedy_prediction to avoid
  shadowing by the greedy argmax variable

Co-authored-by: gemini-code-assist[bot] <176aborting-id@users.noreply.github.com>
@sglucas
Copy link

sglucas commented Feb 27, 2026

Hi @zamal-db ,Thank you very much for your great work! I have verified your code about evaluation on arc-easy, arc-challenge, hellaswag, piqa and it can work well on these tasks. But I found there is atill an error when I try to evaluate on mmlu and I am not sure whether my running command is corrent:

accelerate launch \
    --num_processes 8 \
    --main_process_port 39501 \
    dllm/pipelines/a2d/eval.py \
    --tasks mmlu \
    --model a2d_bd3lm \
    --num_fewshot 5 \
   --model_args "pretrained=${model_name_or_path},mc_num=1,block_size=32,max_length=4096"

And the error is:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/opt/tiger/code/dllm/dllm/pipelines/a2d/eval.py", line 38, in <module>
[rank3]:     cli_evaluate()
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/__main__.py", line 459, in cli_evaluate
[rank3]:     results = evaluator.simple_evaluate(
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/utils.py", line 458, in _wrapper
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/evaluator.py", line 357, in simple_evaluate
[rank3]:     results = evaluate(
[rank3]:               ^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/utils.py", line 458, in _wrapper
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/evaluator.py", line 600, in evaluate
[rank3]:     task.apply_filters()
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/api/task.py", line 1246, in apply_filters
[rank3]:     f.apply(self._instances)
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/api/filter.py", line 51, in apply
[rank3]:     resps = f().apply(resps, docs)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/filters/extraction.py", line 56, in apply
[rank3]:     filtered_resps = list(map(lambda x: filter_set(x), resps))
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/filters/extraction.py", line 56, in <lambda>
[rank3]:     filtered_resps = list(map(lambda x: filter_set(x), resps))
[rank3]:                                         ^^^^^^^^^^^^^
[rank3]:   File "/opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/filters/extraction.py", line 41, in filter_set
[rank3]:     match = self.regex.findall(resp)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: TypeError: expected string or bytes-like object, got 'tuple'

@zamal-db
Copy link
Author

Hey @sglucas, thanks for testing this, and great to hear arc-easy, arc-challenge, hellaswag, piqa all work!

I dug deep into the MMLU regex error and I'm pretty confident this isn't coming from our code or from --tasks mmlu. Here's why:

--tasks mmlu resolves to the mmlu/default variant. I checked _default_template_yaml directly and it's output_type: multiple_choice with no filter_list at all. When there's no filter_list, lm-eval falls back to just take_first (task.py line 887), which simply picks r[0] from each response. No regex involved.

The RegexFilter at extraction.py:41 (self.regex.findall(resp)) only runs when a YAML config explicitly declares function: "regex" in its filter_list. The only MMLU variants that do this are mmlu_generative and mmlu_generative_dream, both of which use output_type: generate_until (strings), not multiple_choice (tuples). So there's no config in the bundled harness where a regex filter would ever see a (float, bool) tuple.

And the 4 tasks you confirmed working (arc-easy, arc-challenge, hellaswag, piqa) all go through the exact same pipeline: output_type: multiple_choice, no filter_list. Identical to mmlu default.

So where is this regex coming from? My best guess is there's a separate lm-eval installed in the environment (pip/conda) that's getting loaded instead of the bundled submodule. Could you check:

python -c "import lm_eval; print(lm_eval.__file__)"

It should point to /opt/tiger/code/dllm/lm-evaluation-harness/lm_eval/__init__.py. If it's pointing somewhere else, that's the issue. The official eval scripts set export PYTHONPATH=.:$PYTHONPATH for exactly this reason (see examples/a2d/bd3lm/eval.sh).

Also worth checking: git submodule status, expected hash is dca091db.


For MMLU specifically, the official eval scripts actually use --tasks mmlu_generative_dream (generation-based, not likelihood-based):

accelerate launch \
    --num_processes 8 \
    --main_process_port 39501 \
    dllm/pipelines/a2d/eval.py \
    --tasks mmlu_generative_dream \
    --model a2d_bd3lm \
    --num_fewshot 0 \
    --apply_chat_template \
    --model_args "pretrained=${model_name_or_path},max_new_tokens=3,steps=3,block_size=32,cfg_scale=0.0"

For likelihood-based mmlu (using our loglikelihood()):

accelerate launch \
    --num_processes 8 \
    --main_process_port 39501 \
    dllm/pipelines/a2d/eval.py \
    --tasks mmlu \
    --model a2d_bd3lm \
    --num_fewshot 5 \
    --model_args "pretrained=${model_name_or_path},mc_num=128,batch_size=32,block_size=32,max_length=4096"

One small note: mc_num=1 technically works (the assert passes since CLI default batch_size=1), but it's a single MC sample so the estimate will be very noisy. mc_num=128 is the default and what I'd recommend.

Let me know what lm_eval.__file__ gives you, that should tell us exactly what's going on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants