diff --git a/protein_lm/modeling/models/mamba/lm.py b/protein_lm/modeling/models/mamba/lm.py index b9bc35d..ffc8d92 100644 --- a/protein_lm/modeling/models/mamba/lm.py +++ b/protein_lm/modeling/models/mamba/lm.py @@ -8,12 +8,13 @@ import torch import torch.nn as nn from dataclasses import dataclass, field -from mamba_ssm.modules.mamba_simple import Mamba, Block +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn + from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None @@ -66,6 +67,7 @@ def create_block( block = Block( d_model, mixer_cls, + nn.Identity(), norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, diff --git a/protein_lm/modeling/scripts/__initi__.py b/protein_lm/modeling/scripts/__initi__.py new file mode 100644 index 0000000..e69de29 diff --git a/protein_lm/modeling/scripts/train.py b/protein_lm/modeling/scripts/train.py index c2af0da..304a181 100644 --- a/protein_lm/modeling/scripts/train.py +++ b/protein_lm/modeling/scripts/train.py @@ -141,7 +141,7 @@ def main(config_dict: DictConfig): ) def load_ckpt(ckpt_path, tokenizer, device): - ckpt = torch.load(ckpt_path) + ckpt = torch.load(ckpt_path, weights_only=False) model_state_dict = ckpt["model"] model_config = ckpt["config"] model_config.vocab_size = tokenizer.get_vocab_size()