Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dllm/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from dllm.core import trainers, schedulers, samplers
from dllm.core import samplers, schedulers, trainers
8 changes: 4 additions & 4 deletions dllm/core/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import base, utils, mdlm, bd3lm
from .base import SamplerOutput, SamplerConfig, BaseSampler
from . import base, bd3lm, mdlm, utils
from .base import BaseSampler, SamplerConfig, SamplerOutput
from .bd3lm import BD3LMSampler, BD3LMSamplerConfig
from .mdlm import MDLMSampler, MDLMSamplerConfig
from .utils import *
from .mdlm import MDLMSamplerConfig, MDLMSampler
from .bd3lm import BD3LMSamplerConfig, BD3LMSampler
2 changes: 1 addition & 1 deletion dllm/core/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

import torch
from transformers import PreTrainedTokenizer, PreTrainedModel
from transformers import PreTrainedModel, PreTrainedTokenizer

from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler

Expand Down
8 changes: 2 additions & 6 deletions dllm/core/samplers/bd3lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
import torch
import torch.nn.functional as F

from dllm.core.samplers.base import (
SamplerOutput,
SamplerConfig,
BaseSampler,
)
from dllm.core.samplers.utils import get_num_transfer_tokens, add_gumbel_noise
from dllm.core.samplers.base import BaseSampler, SamplerConfig, SamplerOutput
from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens


def build_staircase_attention_mask(
Expand Down
8 changes: 2 additions & 6 deletions dllm/core/samplers/mdlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
import torch
import torch.nn.functional as F

from dllm.core.samplers.base import (
SamplerOutput,
SamplerConfig,
BaseSampler,
)
from dllm.core.samplers.utils import get_num_transfer_tokens, add_gumbel_noise
from dllm.core.samplers.base import BaseSampler, SamplerConfig, SamplerOutput
from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion dllm/core/schedulers/alpha.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import dataclasses
import math
from typing import ClassVar, Dict, Type, Any, Union
from typing import Any, ClassVar, Union

import torch

Expand Down
3 changes: 2 additions & 1 deletion dllm/core/schedulers/kappa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import dataclasses
import math
from typing import ClassVar, Dict, Type, Any, Union
from typing import Any, ClassVar, Union

import torch

Expand Down
2 changes: 1 addition & 1 deletion dllm/core/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import bd3lm, mdlm
from .mdlm import MDLMTrainer
from .bd3lm import BD3LMTrainer
from .mdlm import MDLMTrainer
6 changes: 3 additions & 3 deletions dllm/core/trainers/bd3lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
https://arxiv.org/abs/2503.09573
"""

from functools import partial
from dataclasses import dataclass
from functools import partial
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from typing import Any

from .mdlm import MDLMTrainer
from dllm.utils.collators import CollatorWrapper

from .mdlm import MDLMTrainer

# @dataclass
# class BD3LMSFTCollator(transformers.DataCollatorForSeq2Seq):
Expand Down
3 changes: 2 additions & 1 deletion dllm/core/trainers/mdlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
https://arxiv.org/abs/2502.09992
"""

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from typing import Any

from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
from dllm.utils.data import prepend_bos
Expand Down
2 changes: 1 addition & 1 deletion dllm/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .utils import load_sft_dataset, load_pt_dataset
from .utils import load_pt_dataset, load_sft_dataset
3 changes: 1 addition & 2 deletions dllm/data/alpaca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional
from datasets import load_dataset, DatasetDict
from datasets import DatasetDict, load_dataset


def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str:
Expand Down
16 changes: 7 additions & 9 deletions dllm/data/opc.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Optional, Text, List, Dict
from datasets import (
load_dataset,
get_dataset_config_names,
concatenate_datasets,
DatasetDict,
Dataset,
IterableDatasetDict,
DatasetDict,
concatenate_datasets,
get_dataset_config_names,
load_dataset,
)

from dllm.data.utils import (
_merge_datasetdicts,
_merge_iterabledatasetdicts,
_ensure_datasetdict,
_ensure_iterabledatasetdict,
_ensure_datasetdict,
_merge_datasetdicts,
_merge_iterabledatasetdicts,
)


Expand Down
3 changes: 1 addition & 2 deletions dllm/data/ultrachat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, List, Dict
from datasets import load_dataset, DatasetDict
from datasets import DatasetDict, load_dataset


def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None:
Expand Down
6 changes: 3 additions & 3 deletions dllm/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import re

from datasets import (
Dataset,
DatasetDict,
IterableDatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
load_from_disk,
)

from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger

from dllm.utils.utils import get_default_logger, parse_spec, resolve_with_base_env

logger = get_default_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion dllm/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import llada, llada2, dream, bert, editflow, a2d
from . import a2d, bert, dream, editflow, llada, llada2
17 changes: 4 additions & 13 deletions dllm/pipelines/a2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@
# A2DGPT2Config,
# A2DGPT2LMHeadModel,
# )
from .models.llama.modeling_llama import (
A2DLlamaConfig,
A2DLlamaLMHeadModel,
)
from .models.qwen2.modeling_qwen2 import (
A2DQwen2Config,
A2DQwen2LMHeadModel,
)
from .models.qwen3.modeling_qwen3 import (
A2DQwen3Config,
A2DQwen3LMHeadModel,
)

import transformers

from .models.llama.modeling_llama import A2DLlamaConfig, A2DLlamaLMHeadModel
from .models.qwen2.modeling_qwen2 import A2DQwen2Config, A2DQwen2LMHeadModel
from .models.qwen3.modeling_qwen3 import A2DQwen3Config, A2DQwen3LMHeadModel

A2D_CONFIG_MAP = {
# transformers.GPT2Config: A2DGPT2Config,
transformers.LlamaConfig: A2DLlamaConfig,
Expand Down
2 changes: 1 addition & 1 deletion dllm/pipelines/a2d/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

import tyro
import transformers
import tyro

import dllm

Expand Down
8 changes: 3 additions & 5 deletions dllm/pipelines/a2d/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,23 @@
--model_args "pretrained=dllm-collection/Qwen3-0.6B-diffusion-mdlm-v0.1,max_new_tokens=256,steps=256,block_size=32,cfg=0.0"
"""

from types import SimpleNamespace
from dataclasses import dataclass
from types import SimpleNamespace

import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
from tqdm import tqdm

import dllm
from dllm.core.samplers import BD3LMSampler, BD3LMSamplerConfig


from dllm.pipelines.llada.eval import LLaDAEvalHarness, LLaDAEvalConfig
from dllm.pipelines.llada.eval import LLaDAEvalConfig, LLaDAEvalHarness


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions dllm/pipelines/bert/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
--model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0.1,max_new_tokens=256,steps=256,block_size=32"
"""

from types import SimpleNamespace
from dataclasses import dataclass
from types import SimpleNamespace

import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
from tqdm import tqdm

import dllm
from dllm.core.samplers import MDLMSampler, MDLMSamplerConfig
Expand Down
4 changes: 2 additions & 2 deletions dllm/pipelines/dream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import models, sampler, trainer, utils
from .models.modeling_dream import DreamModel
from .models.configuration_dream import DreamConfig
from .models.modeling_dream import DreamModel
from .models.tokenization_dream import DreamTokenizer
from .sampler import DreamSamplerConfig, DreamSampler
from .sampler import DreamSampler, DreamSamplerConfig
from .trainer import DreamTrainer
4 changes: 2 additions & 2 deletions dllm/pipelines/dream/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
"""

import logging
from types import SimpleNamespace
from dataclasses import dataclass
from types import SimpleNamespace

import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
from tqdm import tqdm

import dllm
from dllm.pipelines.dream import DreamSampler, DreamSamplerConfig
Expand Down
10 changes: 3 additions & 7 deletions dllm/pipelines/dream/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@
from dataclasses import dataclass

import torch
import torch.nn.functional as F
import torch.distributions as dists
import torch.nn.functional as F

from dllm.pipelines.dream.models.generation_utils import top_p_logits, top_k_logits
from dllm.core.samplers.base import (
SamplerOutput,
SamplerConfig,
BaseSampler,
)
from dllm.core.samplers.base import BaseSampler, SamplerConfig, SamplerOutput
from dllm.core.samplers.utils import get_num_transfer_tokens
from dllm.pipelines.dream.models.generation_utils import top_k_logits, top_p_logits


def sample_tokens(
Expand Down
1 change: 0 additions & 1 deletion dllm/pipelines/dream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn.functional as F
import transformers


# def top_p_logits(logits, top_p=None):
# sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
Expand Down
13 changes: 4 additions & 9 deletions dllm/pipelines/editflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from dllm.pipelines.editflow.trainer import EditFlowTrainer

from . import trainer, utils
from .models.dream.modelling_dream import (
EditFlowDreamConfig,
EditFlowDreamModel,
)
from .models.llada.modelling_llada import (
EditFlowLLaDAConfig,
EditFlowLLaDAModel,
)
from .models.bert.modelling_modernbert import (
EditFlowModernBertConfig,
EditFlowModernBertModel,
)
from dllm.pipelines.editflow.trainer import EditFlowTrainer
from .models.dream.modelling_dream import EditFlowDreamConfig, EditFlowDreamModel
from .models.llada.modelling_llada import EditFlowLLaDAConfig, EditFlowLLaDAModel
4 changes: 1 addition & 3 deletions dllm/pipelines/editflow/trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Any, Dict, Union, List, Tuple, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Tuple' is not used.
Import of 'Optional' is not used.
Import of 'Union' is not used.

Suggested change
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List

Copilot uses AI. Check for mistakes.

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers

from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler
from dllm.pipelines.editflow.utils import pad_1d


BLANK = -1


Expand Down
4 changes: 2 additions & 2 deletions dllm/pipelines/editflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
import random
from dataclasses import dataclass
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Text
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Text, Tuple
Copy link

Copilot AI Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Dict' is not used.
Import of 'List' is not used.
Import of 'Tuple' is not used.
Import of 'Optional' is not used.
Import of 'Text' is not used.

Suggested change
from typing import Any, Dict, List, Optional, Text, Tuple
from typing import Any

Copilot uses AI. Check for mistakes.

import torch
import transformers
Expand Down
4 changes: 2 additions & 2 deletions dllm/pipelines/llada/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .models.modeling_llada import LLaDAModelLM
from .models.configuration_llada import LLaDAConfig
from .models.modeling_lladamoe import LLaDAMoEModelLM
from .models.configuration_lladamoe import LLaDAMoEConfig
from .models.modeling_llada import LLaDAModelLM
from .models.modeling_lladamoe import LLaDAMoEModelLM
Loading