Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions protein_lm/modeling/models/mamba/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.block import Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

Expand Down Expand Up @@ -66,6 +67,7 @@ def create_block(
block = Block(
d_model,
mixer_cls,
nn.Identity(),
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion protein_lm/modeling/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def main(config_dict: DictConfig):
)

def load_ckpt(ckpt_path, tokenizer, device):
ckpt = torch.load(ckpt_path)
ckpt = torch.load(ckpt_path, weights_only=False)
model_state_dict = ckpt["model"]
model_config = ckpt["config"]
model_config.vocab_size = tokenizer.get_vocab_size()
Expand Down