diff --git a/dllm/core/__init__.py b/dllm/core/__init__.py index ce1effd5..27f24664 100644 --- a/dllm/core/__init__.py +++ b/dllm/core/__init__.py @@ -1 +1 @@ -from dllm.core import trainers, schedulers, samplers +from dllm.core import samplers, schedulers, trainers diff --git a/dllm/core/samplers/__init__.py b/dllm/core/samplers/__init__.py index b5c93b9b..efe54263 100644 --- a/dllm/core/samplers/__init__.py +++ b/dllm/core/samplers/__init__.py @@ -1,5 +1,5 @@ -from . import base, utils, mdlm, bd3lm -from .base import SamplerOutput, SamplerConfig, BaseSampler +from . import base, bd3lm, mdlm, utils +from .base import BaseSampler, SamplerConfig, SamplerOutput +from .bd3lm import BD3LMSampler, BD3LMSamplerConfig +from .mdlm import MDLMSampler, MDLMSamplerConfig from .utils import * -from .mdlm import MDLMSamplerConfig, MDLMSampler -from .bd3lm import BD3LMSamplerConfig, BD3LMSampler diff --git a/dllm/core/samplers/base.py b/dllm/core/samplers/base.py index c7991970..3a2f9d14 100644 --- a/dllm/core/samplers/base.py +++ b/dllm/core/samplers/base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import torch -from transformers import PreTrainedTokenizer, PreTrainedModel +from transformers import PreTrainedModel, PreTrainedTokenizer from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler diff --git a/dllm/core/samplers/bd3lm.py b/dllm/core/samplers/bd3lm.py index 7dec5a0a..6754e41f 100644 --- a/dllm/core/samplers/bd3lm.py +++ b/dllm/core/samplers/bd3lm.py @@ -9,12 +9,8 @@ import torch import torch.nn.functional as F -from dllm.core.samplers.base import ( - SamplerOutput, - SamplerConfig, - BaseSampler, -) -from dllm.core.samplers.utils import get_num_transfer_tokens, add_gumbel_noise +from dllm.core.samplers.base import BaseSampler, SamplerConfig, SamplerOutput +from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens def build_staircase_attention_mask( diff --git a/dllm/core/samplers/mdlm.py b/dllm/core/samplers/mdlm.py index 07757550..19661c8a 100644 --- a/dllm/core/samplers/mdlm.py +++ b/dllm/core/samplers/mdlm.py @@ -9,12 +9,8 @@ import torch import torch.nn.functional as F -from dllm.core.samplers.base import ( - SamplerOutput, - SamplerConfig, - BaseSampler, -) -from dllm.core.samplers.utils import get_num_transfer_tokens, add_gumbel_noise +from dllm.core.samplers.base import BaseSampler, SamplerConfig, SamplerOutput +from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens @dataclass diff --git a/dllm/core/schedulers/alpha.py b/dllm/core/schedulers/alpha.py index e88ed7c1..aef211ac 100644 --- a/dllm/core/schedulers/alpha.py +++ b/dllm/core/schedulers/alpha.py @@ -1,7 +1,8 @@ from __future__ import annotations + import dataclasses import math -from typing import ClassVar, Dict, Type, Any, Union +from typing import Any, ClassVar, Union import torch diff --git a/dllm/core/schedulers/kappa.py b/dllm/core/schedulers/kappa.py index db18cf54..19b8ad53 100644 --- a/dllm/core/schedulers/kappa.py +++ b/dllm/core/schedulers/kappa.py @@ -1,7 +1,8 @@ from __future__ import annotations + import dataclasses import math -from typing import ClassVar, Dict, Type, Any, Union +from typing import Any, ClassVar, Union import torch diff --git a/dllm/core/trainers/__init__.py b/dllm/core/trainers/__init__.py index 8ff909cf..4d8b18d3 100644 --- a/dllm/core/trainers/__init__.py +++ b/dllm/core/trainers/__init__.py @@ -1,3 +1,3 @@ from . import bd3lm, mdlm -from .mdlm import MDLMTrainer from .bd3lm import BD3LMTrainer +from .mdlm import MDLMTrainer diff --git a/dllm/core/trainers/bd3lm.py b/dllm/core/trainers/bd3lm.py index 9d0f751a..5d6f4add 100644 --- a/dllm/core/trainers/bd3lm.py +++ b/dllm/core/trainers/bd3lm.py @@ -5,18 +5,18 @@ https://arxiv.org/abs/2503.09573 """ -from functools import partial from dataclasses import dataclass +from functools import partial +from typing import Any import torch import torch.nn as nn import torch.nn.functional as F import transformers -from typing import Any -from .mdlm import MDLMTrainer from dllm.utils.collators import CollatorWrapper +from .mdlm import MDLMTrainer # @dataclass # class BD3LMSFTCollator(transformers.DataCollatorForSeq2Seq): diff --git a/dllm/core/trainers/mdlm.py b/dllm/core/trainers/mdlm.py index cfc5b6e7..cf9f0474 100644 --- a/dllm/core/trainers/mdlm.py +++ b/dllm/core/trainers/mdlm.py @@ -8,11 +8,12 @@ https://arxiv.org/abs/2502.09992 """ +from typing import Any + import torch import torch.nn as nn import torch.nn.functional as F import transformers -from typing import Any from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler from dllm.utils.data import prepend_bos diff --git a/dllm/data/__init__.py b/dllm/data/__init__.py index 3d994ba9..33313077 100644 --- a/dllm/data/__init__.py +++ b/dllm/data/__init__.py @@ -1 +1 @@ -from .utils import load_sft_dataset, load_pt_dataset +from .utils import load_pt_dataset, load_sft_dataset diff --git a/dllm/data/alpaca.py b/dllm/data/alpaca.py index 5f15b0e3..fb941277 100644 --- a/dllm/data/alpaca.py +++ b/dllm/data/alpaca.py @@ -1,5 +1,4 @@ -from typing import Optional -from datasets import load_dataset, DatasetDict +from datasets import DatasetDict, load_dataset def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str: diff --git a/dllm/data/opc.py b/dllm/data/opc.py index 88a739bd..d8b6e0dc 100644 --- a/dllm/data/opc.py +++ b/dllm/data/opc.py @@ -1,18 +1,16 @@ -from typing import Optional, Text, List, Dict from datasets import ( - load_dataset, - get_dataset_config_names, - concatenate_datasets, - DatasetDict, Dataset, - IterableDatasetDict, + DatasetDict, + concatenate_datasets, + get_dataset_config_names, + load_dataset, ) + from dllm.data.utils import ( - _merge_datasetdicts, - _merge_iterabledatasetdicts, _ensure_datasetdict, _ensure_iterabledatasetdict, - _ensure_datasetdict, + _merge_datasetdicts, + _merge_iterabledatasetdicts, ) diff --git a/dllm/data/ultrachat.py b/dllm/data/ultrachat.py index badc6edb..87adea67 100644 --- a/dllm/data/ultrachat.py +++ b/dllm/data/ultrachat.py @@ -1,5 +1,4 @@ -from typing import Optional, List, Dict -from datasets import load_dataset, DatasetDict +from datasets import DatasetDict, load_dataset def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None: diff --git a/dllm/data/utils.py b/dllm/data/utils.py index 6e0ff5b2..fb6d2a4c 100644 --- a/dllm/data/utils.py +++ b/dllm/data/utils.py @@ -1,15 +1,15 @@ import re + from datasets import ( Dataset, DatasetDict, - IterableDatasetDict, IterableDataset, + IterableDatasetDict, load_dataset, load_from_disk, ) -from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger - +from dllm.utils.utils import get_default_logger, parse_spec, resolve_with_base_env logger = get_default_logger(__name__) diff --git a/dllm/pipelines/__init__.py b/dllm/pipelines/__init__.py index 3f21545c..74d2121d 100644 --- a/dllm/pipelines/__init__.py +++ b/dllm/pipelines/__init__.py @@ -1 +1 @@ -from . import llada, llada2, dream, bert, editflow, a2d +from . import a2d, bert, dream, editflow, llada, llada2 diff --git a/dllm/pipelines/a2d/__init__.py b/dllm/pipelines/a2d/__init__.py index 4839dddf..9abe0366 100644 --- a/dllm/pipelines/a2d/__init__.py +++ b/dllm/pipelines/a2d/__init__.py @@ -2,21 +2,12 @@ # A2DGPT2Config, # A2DGPT2LMHeadModel, # ) -from .models.llama.modeling_llama import ( - A2DLlamaConfig, - A2DLlamaLMHeadModel, -) -from .models.qwen2.modeling_qwen2 import ( - A2DQwen2Config, - A2DQwen2LMHeadModel, -) -from .models.qwen3.modeling_qwen3 import ( - A2DQwen3Config, - A2DQwen3LMHeadModel, -) - import transformers +from .models.llama.modeling_llama import A2DLlamaConfig, A2DLlamaLMHeadModel +from .models.qwen2.modeling_qwen2 import A2DQwen2Config, A2DQwen2LMHeadModel +from .models.qwen3.modeling_qwen3 import A2DQwen3Config, A2DQwen3LMHeadModel + A2D_CONFIG_MAP = { # transformers.GPT2Config: A2DGPT2Config, transformers.LlamaConfig: A2DLlamaConfig, diff --git a/dllm/pipelines/a2d/convert.py b/dllm/pipelines/a2d/convert.py index 81d69923..5085a1ae 100644 --- a/dllm/pipelines/a2d/convert.py +++ b/dllm/pipelines/a2d/convert.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -import tyro import transformers +import tyro import dllm diff --git a/dllm/pipelines/a2d/eval.py b/dllm/pipelines/a2d/eval.py index f4bf932c..9abea55a 100644 --- a/dllm/pipelines/a2d/eval.py +++ b/dllm/pipelines/a2d/eval.py @@ -9,25 +9,23 @@ --model_args "pretrained=dllm-collection/Qwen3-0.6B-diffusion-mdlm-v0.1,max_new_tokens=256,steps=256,block_size=32,cfg=0.0" """ -from types import SimpleNamespace from dataclasses import dataclass +from types import SimpleNamespace import accelerate import torch import torch.nn.functional as F from datasets import Dataset -from tqdm import tqdm from lm_eval.__main__ import cli_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.models.utils import get_dtype +from tqdm import tqdm import dllm from dllm.core.samplers import BD3LMSampler, BD3LMSamplerConfig - - -from dllm.pipelines.llada.eval import LLaDAEvalHarness, LLaDAEvalConfig +from dllm.pipelines.llada.eval import LLaDAEvalConfig, LLaDAEvalHarness @dataclass diff --git a/dllm/pipelines/bert/eval.py b/dllm/pipelines/bert/eval.py index fc280aa5..74276575 100644 --- a/dllm/pipelines/bert/eval.py +++ b/dllm/pipelines/bert/eval.py @@ -9,19 +9,19 @@ --model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0.1,max_new_tokens=256,steps=256,block_size=32" """ -from types import SimpleNamespace from dataclasses import dataclass +from types import SimpleNamespace import accelerate import torch import torch.nn.functional as F from datasets import Dataset -from tqdm import tqdm from lm_eval.__main__ import cli_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.models.utils import get_dtype +from tqdm import tqdm import dllm from dllm.core.samplers import MDLMSampler, MDLMSamplerConfig diff --git a/dllm/pipelines/dream/__init__.py b/dllm/pipelines/dream/__init__.py index a31c3319..15a569a3 100644 --- a/dllm/pipelines/dream/__init__.py +++ b/dllm/pipelines/dream/__init__.py @@ -1,6 +1,6 @@ from . import models, sampler, trainer, utils -from .models.modeling_dream import DreamModel from .models.configuration_dream import DreamConfig +from .models.modeling_dream import DreamModel from .models.tokenization_dream import DreamTokenizer -from .sampler import DreamSamplerConfig, DreamSampler +from .sampler import DreamSampler, DreamSamplerConfig from .trainer import DreamTrainer diff --git a/dllm/pipelines/dream/eval.py b/dllm/pipelines/dream/eval.py index 9f5e3569..e1f79d5e 100644 --- a/dllm/pipelines/dream/eval.py +++ b/dllm/pipelines/dream/eval.py @@ -10,19 +10,19 @@ """ import logging -from types import SimpleNamespace from dataclasses import dataclass +from types import SimpleNamespace import accelerate import torch import torch.nn.functional as F from datasets import Dataset -from tqdm import tqdm from lm_eval.__main__ import cli_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.models.utils import get_dtype +from tqdm import tqdm import dllm from dllm.pipelines.dream import DreamSampler, DreamSamplerConfig diff --git a/dllm/pipelines/dream/sampler.py b/dllm/pipelines/dream/sampler.py index c955464a..3826d56e 100644 --- a/dllm/pipelines/dream/sampler.py +++ b/dllm/pipelines/dream/sampler.py @@ -5,16 +5,12 @@ from dataclasses import dataclass import torch -import torch.nn.functional as F import torch.distributions as dists +import torch.nn.functional as F -from dllm.pipelines.dream.models.generation_utils import top_p_logits, top_k_logits -from dllm.core.samplers.base import ( - SamplerOutput, - SamplerConfig, - BaseSampler, -) +from dllm.core.samplers.base import BaseSampler, SamplerConfig, SamplerOutput from dllm.core.samplers.utils import get_num_transfer_tokens +from dllm.pipelines.dream.models.generation_utils import top_k_logits, top_p_logits def sample_tokens( diff --git a/dllm/pipelines/dream/utils.py b/dllm/pipelines/dream/utils.py index 37ecb940..9ac3de3d 100644 --- a/dllm/pipelines/dream/utils.py +++ b/dllm/pipelines/dream/utils.py @@ -5,7 +5,6 @@ import torch.nn.functional as F import transformers - # def top_p_logits(logits, top_p=None): # sorted_logits, sorted_indices = torch.sort(logits, descending=True) # cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) diff --git a/dllm/pipelines/editflow/__init__.py b/dllm/pipelines/editflow/__init__.py index c7aa217f..5cec019d 100644 --- a/dllm/pipelines/editflow/__init__.py +++ b/dllm/pipelines/editflow/__init__.py @@ -1,14 +1,9 @@ +from dllm.pipelines.editflow.trainer import EditFlowTrainer + from . import trainer, utils -from .models.dream.modelling_dream import ( - EditFlowDreamConfig, - EditFlowDreamModel, -) -from .models.llada.modelling_llada import ( - EditFlowLLaDAConfig, - EditFlowLLaDAModel, -) from .models.bert.modelling_modernbert import ( EditFlowModernBertConfig, EditFlowModernBertModel, ) -from dllm.pipelines.editflow.trainer import EditFlowTrainer +from .models.dream.modelling_dream import EditFlowDreamConfig, EditFlowDreamModel +from .models.llada.modelling_llada import EditFlowLLaDAConfig, EditFlowLLaDAModel diff --git a/dllm/pipelines/editflow/trainer.py b/dllm/pipelines/editflow/trainer.py index 33c67c60..2635539f 100644 --- a/dllm/pipelines/editflow/trainer.py +++ b/dllm/pipelines/editflow/trainer.py @@ -1,16 +1,14 @@ -from typing import Any, Dict, Union, List, Tuple, Optional from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F - import transformers from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler from dllm.pipelines.editflow.utils import pad_1d - BLANK = -1 diff --git a/dllm/pipelines/editflow/utils.py b/dllm/pipelines/editflow/utils.py index f2bf8bf4..1e4ff140 100644 --- a/dllm/pipelines/editflow/utils.py +++ b/dllm/pipelines/editflow/utils.py @@ -1,9 +1,9 @@ import math import random -from dataclasses import dataclass from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Text from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Text, Tuple import torch import transformers diff --git a/dllm/pipelines/llada/__init__.py b/dllm/pipelines/llada/__init__.py index 0c7e26ef..6a211b21 100644 --- a/dllm/pipelines/llada/__init__.py +++ b/dllm/pipelines/llada/__init__.py @@ -1,4 +1,4 @@ -from .models.modeling_llada import LLaDAModelLM from .models.configuration_llada import LLaDAConfig -from .models.modeling_lladamoe import LLaDAMoEModelLM from .models.configuration_lladamoe import LLaDAMoEConfig +from .models.modeling_llada import LLaDAModelLM +from .models.modeling_lladamoe import LLaDAMoEModelLM diff --git a/dllm/pipelines/llada/eval.py b/dllm/pipelines/llada/eval.py index 7048f738..bedc75d8 100644 --- a/dllm/pipelines/llada/eval.py +++ b/dllm/pipelines/llada/eval.py @@ -9,19 +9,19 @@ --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,max_new_tokens=512,steps=512,block_size=512,cfg=0.0,logits_eos_inf=False,confidence_eos_eot_inf=True" """ -from types import SimpleNamespace from dataclasses import dataclass +from types import SimpleNamespace import accelerate import torch import torch.nn.functional as F from datasets import Dataset -from tqdm import tqdm from lm_eval.__main__ import cli_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.models.utils import get_dtype +from tqdm import tqdm import dllm from dllm.core.samplers import MDLMSampler, MDLMSamplerConfig diff --git a/dllm/pipelines/llada2/__init__.py b/dllm/pipelines/llada2/__init__.py index 6e456883..1ded4c19 100644 --- a/dllm/pipelines/llada2/__init__.py +++ b/dllm/pipelines/llada2/__init__.py @@ -1,3 +1,3 @@ -from .models.modeling_llada2_moe import LLaDA2MoeModelLM from .models.configuration_llada2_moe import LLaDA2MoeConfig +from .models.modeling_llada2_moe import LLaDA2MoeModelLM from .sampler import LLaDA2Sampler, LLaDA2SamplerConfig diff --git a/dllm/tools/preprocess_pt_dataset.py b/dllm/tools/preprocess_pt_dataset.py index 11454045..9146b5e9 100644 --- a/dllm/tools/preprocess_pt_dataset.py +++ b/dllm/tools/preprocess_pt_dataset.py @@ -3,13 +3,13 @@ """ import os -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from functools import partial +from pprint import pprint import datasets -import tyro import transformers -from pprint import pprint +import tyro import dllm diff --git a/dllm/tools/preprocess_sft_dataset.py b/dllm/tools/preprocess_sft_dataset.py index 40b80bcd..48bbfff4 100644 --- a/dllm/tools/preprocess_sft_dataset.py +++ b/dllm/tools/preprocess_sft_dataset.py @@ -9,8 +9,8 @@ --num_proc 64 """ -import os import importlib +import os from dataclasses import dataclass from functools import partial diff --git a/dllm/utils/__init__.py b/dllm/utils/__init__.py index 002c53c8..21c4be3d 100644 --- a/dllm/utils/__init__.py +++ b/dllm/utils/__init__.py @@ -3,7 +3,7 @@ from .collators import * from .configs import * from .data import * -from .sampling import * from .models import * +from .sampling import * from .utils import * from .visualizers import * diff --git a/dllm/utils/chat.py b/dllm/utils/chat.py index a50e8c47..a552978b 100644 --- a/dllm/utils/chat.py +++ b/dllm/utils/chat.py @@ -1,11 +1,9 @@ import shutil -from typing import List, Literal - import textwrap +from typing import List, Literal import dllm - # ============================================================ # Utility helpers # ============================================================ diff --git a/dllm/utils/collators.py b/dllm/utils/collators.py index 7f5a1473..a01292e6 100644 --- a/dllm/utils/collators.py +++ b/dllm/utils/collators.py @@ -1,10 +1,9 @@ from dataclasses import dataclass +from typing import Any import torch import transformers -from typing import Any - @dataclass class CollatorWrapper: diff --git a/dllm/utils/configs.py b/dllm/utils/configs.py index e1cb6f2f..42b33393 100644 --- a/dllm/utils/configs.py +++ b/dllm/utils/configs.py @@ -3,7 +3,7 @@ import transformers -from dllm.utils.utils import resolve_with_base_env, get_default_logger +from dllm.utils.utils import get_default_logger, resolve_with_base_env logger = get_default_logger(__name__) diff --git a/dllm/utils/data.py b/dllm/utils/data.py index 51094b78..bec9aa8e 100644 --- a/dllm/utils/data.py +++ b/dllm/utils/data.py @@ -4,12 +4,12 @@ from itertools import chain from typing import TYPE_CHECKING -import torch import datasets +import torch import transformers if TYPE_CHECKING: - from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments + from dllm.utils.configs import DataArguments, ModelArguments, TrainingArguments def tokenize_and_group( diff --git a/dllm/utils/models.py b/dllm/utils/models.py index e54ccf87..4878ad32 100644 --- a/dllm/utils/models.py +++ b/dllm/utils/models.py @@ -1,10 +1,10 @@ -import torch import accelerate +import torch import transformers from peft import prepare_model_for_kbit_training -from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft from dllm.utils.configs import ModelArguments, TrainingArguments +from dllm.utils.utils import disable_caching_allocator_warmup, load_peft, print_main def get_model( @@ -57,7 +57,7 @@ def get_model( model = transformers.AutoModelForMaskedLM.from_pretrained( model_name_or_path, **params ) - except: + except Exception: model = transformers.AutoModel.from_pretrained(model_name_or_path, **params) # --- if quantized, prepare for LoRA / QLoRA training --- @@ -83,20 +83,21 @@ def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer: transformers.PreTrainedTokenizer """ # Lazy imports to avoid circular dependencies - from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM - from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM - from dllm.pipelines.llada2.models.modeling_llada2_moe import LLaDA2MoeModelLM - from dllm.pipelines.dream.models.modeling_dream import DreamModel + from transformers import ( + BertPreTrainedModel, + ModernBertPreTrainedModel, + RobertaPreTrainedModel, + ) + from dllm.pipelines.a2d import ( A2DLlamaLMHeadModel, A2DQwen2LMHeadModel, A2DQwen3LMHeadModel, ) - from transformers import ( - BertPreTrainedModel, - RobertaPreTrainedModel, - ModernBertPreTrainedModel, - ) + from dllm.pipelines.dream.models.modeling_dream import DreamModel + from dllm.pipelines.llada2.models.modeling_llada2_moe import LLaDA2MoeModelLM + from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM + from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM model_name_or_path = getattr(model_args, "model_name_or_path") @@ -106,7 +107,7 @@ def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer: padding_side="right", ) - assert tokenizer.eos_token != None or tokenizer.pad_token != None + assert tokenizer.eos_token is not None or tokenizer.pad_token is not None if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token diff --git a/dllm/utils/utils.py b/dllm/utils/utils.py index ab7c3a55..4f11d5d9 100644 --- a/dllm/utils/utils.py +++ b/dllm/utils/utils.py @@ -1,18 +1,19 @@ +import logging import os import re import sys -import logging from contextlib import contextmanager -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments import pprint -import torch -import peft + import accelerate +import peft +import torch import transformers diff --git a/dllm/utils/visualizers.py b/dllm/utils/visualizers.py index 15b602a3..1fa26552 100644 --- a/dllm/utils/visualizers.py +++ b/dllm/utils/visualizers.py @@ -1,15 +1,15 @@ from __future__ import annotations + import os import re import sys import time -from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Sequence, Optional +from dataclasses import dataclass +from typing import Optional, Sequence import torch from tqdm import tqdm - from transformers import PreTrainedTokenizer @@ -127,18 +127,18 @@ def visualize_one_history( # --------- imports & env checks ---------- try: from rich.console import Console + from rich.layout import Layout from rich.live import Live - from rich.text import Text from rich.panel import Panel from rich.progress import ( - Progress, BarColumn, - TextColumn, - TimeRemainingColumn, MofNCompleteColumn, + Progress, SpinnerColumn, + TextColumn, + TimeRemainingColumn, ) - from rich.layout import Layout + from rich.text import Text _RICH_IMPORTED = True except Exception: @@ -176,8 +176,8 @@ def visualize_one_history( final_text = self._truncate(final_text, max_chars) # ------------------ new: estimate height from final_text ------------------ - import textwrap import shutil + import textwrap def strip_ansi(s: str) -> str: return self.ansi_escape.sub("", s) if s else "" diff --git a/examples/a2d/bd3lm/chat.py b/examples/a2d/bd3lm/chat.py index 46ed81cd..3c5880b1 100644 --- a/examples/a2d/bd3lm/chat.py +++ b/examples/a2d/bd3lm/chat.py @@ -9,6 +9,7 @@ import sys from dataclasses import dataclass + import transformers import dllm diff --git a/examples/a2d/bd3lm/pt.py b/examples/a2d/bd3lm/pt.py index 6292b243..5ad2ee72 100644 --- a/examples/a2d/bd3lm/pt.py +++ b/examples/a2d/bd3lm/pt.py @@ -12,7 +12,7 @@ examples/a2d/bd3lm/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (ZeRO-2): @@ -26,12 +26,12 @@ --script_path "examples/bd3lm/mdlm/pt.py" """ -import os import functools +import os from dataclasses import dataclass, field -import transformers import accelerate +import transformers import dllm diff --git a/examples/a2d/bd3lm/sft.py b/examples/a2d/bd3lm/sft.py index 8a96159c..b4d891a7 100644 --- a/examples/a2d/bd3lm/sft.py +++ b/examples/a2d/bd3lm/sft.py @@ -5,14 +5,14 @@ accelerate launch \ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/a2d/bd3lm/sft.py - + - 8 GPUs (ZeRO-2): accelerate launch \ --config_file scripts/accelerate_configs/zero2.yaml \ examples/a2d/bd3lm/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (ZeRO-2): @@ -30,8 +30,8 @@ from dataclasses import dataclass, field from functools import partial -import transformers import accelerate +import transformers import dllm diff --git a/examples/a2d/mdlm/chat.py b/examples/a2d/mdlm/chat.py index c8831c3f..78b9a070 100644 --- a/examples/a2d/mdlm/chat.py +++ b/examples/a2d/mdlm/chat.py @@ -9,6 +9,7 @@ import sys from dataclasses import dataclass + import transformers import dllm diff --git a/examples/a2d/mdlm/pt.py b/examples/a2d/mdlm/pt.py index 73dcfcfd..bca892c8 100644 --- a/examples/a2d/mdlm/pt.py +++ b/examples/a2d/mdlm/pt.py @@ -12,7 +12,7 @@ examples/a2d/mdlm/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (ZeRO-2): @@ -26,12 +26,12 @@ --script_path "examples/a2d/mdlm/pt.py" """ -import os import functools +import os from dataclasses import dataclass, field -import transformers import accelerate +import transformers import dllm diff --git a/examples/a2d/mdlm/sft.py b/examples/a2d/mdlm/sft.py index 93434819..9fcabfcc 100644 --- a/examples/a2d/mdlm/sft.py +++ b/examples/a2d/mdlm/sft.py @@ -5,14 +5,14 @@ accelerate launch \ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/a2d/mdlm/sft.py - + - 8 GPUs (ZeRO-2): accelerate launch \ --config_file scripts/accelerate_configs/zero2.yaml \ examples/a2d/mdlm/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (ZeRO-2): @@ -30,8 +30,8 @@ from dataclasses import dataclass, field from functools import partial -import transformers import accelerate +import transformers import dllm diff --git a/examples/bert/chat.py b/examples/bert/chat.py index a6353b48..cba761ba 100644 --- a/examples/bert/chat.py +++ b/examples/bert/chat.py @@ -9,6 +9,7 @@ import sys from dataclasses import dataclass + import transformers import dllm diff --git a/examples/bert/pt.py b/examples/bert/pt.py index d9e7b0f3..8058a312 100644 --- a/examples/bert/pt.py +++ b/examples/bert/pt.py @@ -12,7 +12,7 @@ examples/bert/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (ZeRO-2): @@ -26,12 +26,12 @@ --script_path "examples/bert/pt.py" """ -import os import functools +import os from dataclasses import dataclass, field -import transformers import accelerate +import transformers import dllm diff --git a/examples/bert/sft.py b/examples/bert/sft.py index 1c53ee21..57e9fc2c 100644 --- a/examples/bert/sft.py +++ b/examples/bert/sft.py @@ -5,14 +5,14 @@ accelerate launch \ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/bert/sft.py - + - 8 GPUs (ZeRO-2): accelerate launch \ --config_file scripts/accelerate_configs/zero2.yaml \ examples/bert/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (ZeRO-2): @@ -30,8 +30,8 @@ from dataclasses import dataclass, field from functools import partial -import transformers import accelerate +import transformers import dllm diff --git a/examples/dream/chat.py b/examples/dream/chat.py index 6000c5c2..b6a0c2d0 100644 --- a/examples/dream/chat.py +++ b/examples/dream/chat.py @@ -12,6 +12,7 @@ import sys from dataclasses import dataclass + import transformers import dllm diff --git a/examples/dream/pt.py b/examples/dream/pt.py index ef5f129b..623c273b 100644 --- a/examples/dream/pt.py +++ b/examples/dream/pt.py @@ -13,7 +13,7 @@ examples/dream/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 24 Nodes, 192 GPUs (FSDP): @@ -22,13 +22,13 @@ --script_path "examples/dream/pt.py" """ -import os import functools +import os from dataclasses import dataclass, field +import accelerate import torch import transformers -import accelerate import dllm from dllm.pipelines import dream diff --git a/examples/dream/sft.py b/examples/dream/sft.py index d1645691..5d88264d 100644 --- a/examples/dream/sft.py +++ b/examples/dream/sft.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/dream/sft.py \ --load_in_4bit True --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/dream/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): @@ -31,8 +31,8 @@ from dataclasses import dataclass, field from functools import partial -import transformers import accelerate +import transformers import dllm from dllm.pipelines import dream diff --git a/examples/editflow/dream/adapt.py b/examples/editflow/dream/adapt.py index 735dd437..08d9ff19 100644 --- a/examples/editflow/dream/adapt.py +++ b/examples/editflow/dream/adapt.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/editflow/dream/adapt.py \ --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/editflow/dream/adapt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): diff --git a/examples/editflow/dream/pt.py b/examples/editflow/dream/pt.py index b37d20c1..63ed03cd 100644 --- a/examples/editflow/dream/pt.py +++ b/examples/editflow/dream/pt.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/editflow/dream/pt.py \ --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/editflow/dream/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): diff --git a/examples/editflow/dream/sft.py b/examples/editflow/dream/sft.py index e8eb9d5a..82b2f59f 100644 --- a/examples/editflow/dream/sft.py +++ b/examples/editflow/dream/sft.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/editflow/dream/sft.py \ --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/zero2.yaml \ examples/editflow/dream/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): diff --git a/examples/editflow/llada/adapt.py b/examples/editflow/llada/adapt.py index 23fb199d..b0f1bf95 100644 --- a/examples/editflow/llada/adapt.py +++ b/examples/editflow/llada/adapt.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/editflow/llada/adapt.py \ --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/editflow/llada/adapt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): diff --git a/examples/editflow/llada/pt.py b/examples/editflow/llada/pt.py index da5f1c4e..b615d138 100644 --- a/examples/editflow/llada/pt.py +++ b/examples/editflow/llada/pt.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/editflow/llada/pt.py \ --lora True - + - 8 GPUs (DeepSpeed FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/editflow/llada/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): diff --git a/examples/editflow/llada/sft.py b/examples/editflow/llada/sft.py index 1990ded6..41f25914 100644 --- a/examples/editflow/llada/sft.py +++ b/examples/editflow/llada/sft.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/editflow/llada/sft.py \ --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/editflow/llada/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): diff --git a/examples/editflow/pt.py b/examples/editflow/pt.py index ff9edbda..12d96263 100644 --- a/examples/editflow/pt.py +++ b/examples/editflow/pt.py @@ -1,9 +1,9 @@ -import os import functools +import os from dataclasses import dataclass, field -import transformers import accelerate +import transformers import dllm from dllm.pipelines import editflow diff --git a/examples/editflow/sample.py b/examples/editflow/sample.py index 31d140d2..7a8d133f 100644 --- a/examples/editflow/sample.py +++ b/examples/editflow/sample.py @@ -18,13 +18,12 @@ from dataclasses import dataclass from typing import Annotated -import tyro import torch +import tyro from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer from dllm.core.schedulers import BaseKappaScheduler, LinearKappaScheduler - # ------------------------------- Small utilities -------------------------------- diff --git a/examples/editflow/sft.py b/examples/editflow/sft.py index a26fc2ea..2bb4e682 100644 --- a/examples/editflow/sft.py +++ b/examples/editflow/sft.py @@ -1,9 +1,9 @@ import os -from functools import partial from dataclasses import dataclass, field +from functools import partial -import transformers import accelerate +import transformers import dllm from dllm.pipelines import editflow diff --git a/examples/editflow/viz.py b/examples/editflow/viz.py index 5df7cca6..f3082c01 100644 --- a/examples/editflow/viz.py +++ b/examples/editflow/viz.py @@ -1,11 +1,11 @@ # ------------------------------ Visualization (NEW) ------------------------------ # Diffusion-style consecutive output: only show the CURRENT output per frame. # ------------------ Visualization (sanitized, masks stripped) ------------------ -from PIL import Image, ImageDraw, ImageFont - import re import unicodedata -from typing import Optional, List, Tuple, Annotated +from typing import Annotated, List, Optional, Tuple + +from PIL import Image, ImageDraw, ImageFont def render_consecutive_trace_gif( @@ -36,9 +36,10 @@ def render_consecutive_trace_gif( - Substitution mask→nonmask -> stays BLACK (no extra color). Adds a final clean frame (5s) with no events box. """ - from PIL import Image, ImageDraw, ImageFont import unicodedata + from PIL import Image, ImageDraw, ImageFont + # ---------- font ---------- try: font = ImageFont.truetype( diff --git a/examples/llada/chat.py b/examples/llada/chat.py index 512ef7d7..7568bc4d 100644 --- a/examples/llada/chat.py +++ b/examples/llada/chat.py @@ -12,6 +12,7 @@ import sys from dataclasses import dataclass + import transformers import dllm diff --git a/examples/llada/pt.py b/examples/llada/pt.py index 4097fa31..ad0c85e3 100644 --- a/examples/llada/pt.py +++ b/examples/llada/pt.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/llada/pt.py \ --load_in_4bit True --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/llada/pt.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 24 Nodes, 192 GPUs (FSDP): @@ -22,13 +22,13 @@ --script_path "examples/llada/pt.py" """ -import os import functools +import os from dataclasses import dataclass, field +import accelerate import torch import transformers -import accelerate import dllm diff --git a/examples/llada/sft.py b/examples/llada/sft.py index 885db6d1..9ad35f18 100644 --- a/examples/llada/sft.py +++ b/examples/llada/sft.py @@ -6,14 +6,14 @@ --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ examples/llada/sft.py \ --load_in_4bit True --lora True - + - 8 GPUs (FSDP): accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/llada/sft.py Slurm users -# Note: run `mkdir logs` before running sbatch; and adjust +# Note: run `mkdir logs` before running sbatch; and adjust # `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. ------------ - 1 Node, 8 GPUs (FSDP): @@ -31,8 +31,8 @@ from dataclasses import dataclass, field from functools import partial -import transformers import accelerate +import transformers import dllm diff --git a/scripts/tests/test_attention.py b/scripts/tests/test_attention.py index 67956ec8..ef485716 100644 --- a/scripts/tests/test_attention.py +++ b/scripts/tests/test_attention.py @@ -38,10 +38,11 @@ import gc from typing import Dict, List +import pytest import torch import transformers + import dllm -import pytest # Numerical tolerance ERROR_THRESHOLD = 1e-3