Skip to content

Commit

Permalink
gemma2 support (ffn_layernorm, embeddings_normalize)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Jan 23, 2025
1 parent 02a9dbe commit 80f87ad
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 19 deletions.
42 changes: 30 additions & 12 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
"LlamaForCausalLM": {}, # default
"MistralForCausalLM": {},
"Qwen2ForCausalLM": {},
"Gemma2ForCausalLM": {
".pre_feedforward_layernorm.weight": ".pre_feedforward_layernorm.weight",
".post_feedforward_layernorm.weight": ".post_feedforward_layernorm.weight",
},
"MixtralForCausalLM": {
".mlp.gate.weight": ".block_sparse_moe.gate.weight",
**{
Expand Down Expand Up @@ -130,6 +134,7 @@
"PhiForCausalLM": "standard",
"GPT2LMHeadModel": "standard",
"XLMRobertaXLForMaskedLM": "standard",
"Gemma2ForCausalLM": "gemma-rms",
},
)

Expand All @@ -140,6 +145,7 @@
"PhiForCausalLM": "gelu",
"GPT2LMHeadModel": "gelu",
"XLMRobertaXLForMaskedLM": "gelu",
"Gemma2ForCausalLM": "gated-gelu",
},
)

Expand Down Expand Up @@ -411,16 +417,6 @@ def build_config_dict(hf):
"generator_bias": False,
"rope_config": {
"rotary_theta": config.get("rope_theta"),
"rotary_dim": config.get(
"rotary_dim",
int(
config.get("partial_rotary_factor", 1)
* (
config.get("hidden_size", config.get("n_embd"))
// config.get("num_attention_heads", config.get("n_head"))
)
),
),
"rotary_interleave": False,
},
"embeddings": {}, # Populated later
Expand All @@ -430,6 +426,18 @@ def build_config_dict(hf):
if model_config["sliding_window"] is None:
model_config["sliding_window"] = 4096

# patch rotary dim
if "rotary_dim" in config.keys():
model_config["rope_config"]["rotary_dim"] = config["rotary_dim"]
elif "partial_rotary_factor" in config.keys():
model_config["rope_config"]["rotary_dim"] = int(
config["partial_rotary_factor"] * (model_config["hidden_size"] // model_config["heads"])
)
elif model_config.get("head_dim", None) is not None:
model_config["rope_config"]["rotary_dim"] = model_config["head_dim"]
else:
model_config["rope_config"]["rotary_dim"] = model_config["hidden_size"] // model_config["heads"]

# Validate required fields
required_fields = {
"layers": "Can't find the number of layers in the config.json file",
Expand Down Expand Up @@ -541,6 +549,13 @@ def build_config_dict(hf):
"add_qkvbias": True,
"add_final_linear_bias": False,
},
"Gemma2ForCausalLM": {
"share_decoder_embeddings": True,
"ffn_layernorm": True,
"embeddings": {
"normalize": True,
},
},
}

# Update model_config based on architecture
Expand Down Expand Up @@ -758,6 +773,8 @@ def build_first_shard(hf, eole_safetensor):
"input_layernorm",
"layer_norm_res",
"post_attention_layernorm",
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
"mlp.gate",
]:
module_p = f".{module}.{p}"
Expand Down Expand Up @@ -925,8 +942,6 @@ def run(cls, args):
# Deduce dtype from args or config, or default to fp16
compute_dtype = args.dtype or hf.config.get("torch_dtype") or "fp16"

build_shards(model_config, hf, args, params)

# Check tokenizer and vocab related configuration
(
add_bos_token,
Expand Down Expand Up @@ -1019,3 +1034,6 @@ def run(cls, args):

with open(os.path.join(args.output, "config.json"), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)

# Build shards last, as it's the most io intensive
build_shards(model_config, hf, args, params)
17 changes: 16 additions & 1 deletion eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class EmbeddingsConfig(Config):
default=0,
description="Positions IDS shift before making position embed " "dirty patch to cover for xlm-roberta-xl",
)
normalize: bool | None = Field(
default=False,
description="Enable embeddings scaling. "
"Not always necessary, but useful for some model compatibility, e.g. gemma. "
"https://datascience.stackexchange.com/a/87909",
)

@model_validator(mode="after")
def validate_embeddings(self):
Expand All @@ -50,6 +56,11 @@ def validate_embeddings(self):
"n_positions must be set if position_encoding_type "
f"is {PositionEncodingType.Learned} or {PositionEncodingType.Relative}"
)
elif self.position_encoding_type in [
PositionEncodingType.SinusoidalInterleaved,
PositionEncodingType.SinusoidalConcat,
]:
assert not (self.normalize), "embeddings normalization is already handled in PositionalEncoding"
return self


Expand Down Expand Up @@ -193,7 +204,7 @@ class TransformerConfig(Config):
default=ActivationFunction.relu,
description="The activation function to use in MLP layer.",
)
layer_norm: Literal["standard", "rms"] = Field(
layer_norm: Literal["standard", "rms", "gemma-rms"] = Field(
default="standard",
description="Type of layer normalization in transformer architecture.",
)
Expand All @@ -205,6 +216,10 @@ class TransformerConfig(Config):
"Note: must be True for Falcon 7B, False for Falcon 40B, "
"same for GPT-J and GPT-NeoX models.",
)
ffn_layernorm: bool = Field(
default=False,
description="Add pre/post_feedforward_layernorm around MLP forward. " "Note: introduced for gemma2 support.",
)
add_qkvbias: bool = Field(
default=False,
description="Add bias to nn.Linear of Query/Key/Value in MHA. "
Expand Down
8 changes: 6 additions & 2 deletions eole/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from enum import Enum
import torch
from eole.modules.rmsnorm import RMSNorm
from eole.modules.rmsnorm import RMSNorm, GemmaRMSNorm
import torch.nn.functional as F


Expand Down Expand Up @@ -79,7 +79,11 @@ class TransformType(str, Enum):
}


LayerNorm = {"standard": torch.nn.LayerNorm, "rms": RMSNorm}
LayerNorm = {
"standard": torch.nn.LayerNorm,
"rms": RMSNorm,
"gemma-rms": GemmaRMSNorm,
}


TORCH_DTYPES = {
Expand Down
25 changes: 25 additions & 0 deletions eole/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.dropout_p = getattr(running_config, "dropout", [0.0])[0]
self.full_context_alignment = model_config.full_context_alignment
self.alignment_heads = model_config.alignment_heads
self.ffn_layernorm = model_config.ffn_layernorm

# order of layers corresponds to forward flow of tensors
self.input_layernorm = LayerNorm[model_config.layer_norm](model_config.hidden_size, eps=model_config.norm_eps)
Expand All @@ -50,6 +51,15 @@ def __init__(
)
else:
self.context_attn = None

if self.ffn_layernorm:
self.pre_feedforward_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.post_feedforward_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

if model_config.parallel_residual and not model_config.shared_layer_norm:
self.residual_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
Expand All @@ -65,6 +75,15 @@ def __init__(
running_config=running_config,
)

def _mlp(self, hidden_states):
if self.ffn_layernorm:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states

def forward(self, layer_in, **kwargs):
"""
Args:
Expand Down Expand Up @@ -103,6 +122,12 @@ def forward(self, layer_in, **kwargs):
if self.dropout_p > 0:
self_attn = self.dropout(self_attn)

if self.ffn_layernorm:
# NOTE: this case was added for Gemma2 support and might be extended for further ctx_attn support
ff_in = layer_in + self.post_attention_layernorm(self_attn)
layer_out = ff_in + self._mlp(ff_in)
return layer_out, attns

if self.parallel_residual:
if self.context_attn:
ctx_attn, attns = self.context_attn(
Expand Down
2 changes: 2 additions & 0 deletions eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def build_src_emb(model_config, vocabs, running_config=None):
sparse=getattr(running_config, "optim", None) == "sparseadam",
freeze_word_vecs=model_config.embeddings.freeze_word_vecs_enc,
n_positions=model_config.embeddings.n_positions,
normalize=model_config.embeddings.normalize,
)
return src_emb

Expand All @@ -99,6 +100,7 @@ def build_tgt_emb(model_config, vocabs, running_config=None, share_embeddings=Fa
sparse=getattr(running_config, "optim", None) == "sparseadam",
freeze_word_vecs=model_config.embeddings.freeze_word_vecs_dec,
n_positions=model_config.embeddings.n_positions,
normalize=model_config.embeddings.normalize,
)

if share_embeddings:
Expand Down
6 changes: 6 additions & 0 deletions eole/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
sparse=False,
freeze_word_vecs=False,
n_positions=1024,
normalize=False,
):
super(Embeddings, self).__init__()
self._validate_args()
Expand All @@ -114,6 +115,7 @@ def __init__(

self.position_encoding_type = position_encoding_type
self.position_shift = position_shift
self.normalize = normalize

if self.position_encoding_type == PositionEncodingType.Learned:
self.pe = nn.Embedding(n_positions, word_vec_size)
Expand Down Expand Up @@ -185,6 +187,10 @@ def forward(self, source, step=None):
]:
emb = self.pe(emb, step)

if self.normalize:
normalizer = torch.tensor(self.word_vec_size**0.5, dtype=emb.dtype)
emb = emb * normalizer

if self.dropout_p > 0:
return self.dropout(emb)
else:
Expand Down
18 changes: 14 additions & 4 deletions eole/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))

@torch.compile(dynamic=True)
def compute_rms(self, hidden_states, dtype):
def compute_rms(self, hidden_states, dtype, residual=False):
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(dtype)
return hidden_states * self.weight
factor = 1.0 if residual else 0.0
hidden_states = hidden_states * (factor + self.weight.float())
return hidden_states.type_as(self.weight)

def forward(self, hidden_states):
def _forward(self, hidden_states, residual=False):
inp_dtype = hidden_states.dtype
if AWQ_EXT and not self.training:
# cuda kernel support only fp16 - need to cast
Expand All @@ -44,4 +46,12 @@ def forward(self, hidden_states):
output = output.unsqueeze(0)
return output.to(inp_dtype)
else:
return self.compute_rms(hidden_states, inp_dtype)
return self.compute_rms(hidden_states, inp_dtype, residual=residual)

def forward(self, hidden_states):
return self._forward(hidden_states)


class GemmaRMSNorm(RMSNorm):
def forward(self, hidden_states):
return self._forward(hidden_states, residual=True)
1 change: 1 addition & 0 deletions recipes/model-validator/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# Define the models table
models=(
# Validated
"google/gemma-2-2b"
"mistralai/Ministral-8B-Instruct-2410"
"mistralai/Mistral-7B-v0.3"
"mistralai/Mistral-7B-Instruct-v0.3"
Expand Down

0 comments on commit 80f87ad

Please sign in to comment.