Skip to content

Commit

Permalink
Basic pixtral support, paving the way for vision models 🖼️ (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Jan 29, 2025
1 parent 7e2051b commit f6f9a95
Show file tree
Hide file tree
Showing 35 changed files with 1,026 additions and 108 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
pip install sacrebleu
pip install flake8
pip install rich
python -m pip install black flake8 pyproject-flake8
python -m pip install "black==24.10.0" "flake8<7.1" pyproject-flake8
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Check code with Black
run: |
Expand Down
151 changes: 115 additions & 36 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from eole.config.models import (
TransformerEncoderModelConfig,
TransformerLMModelConfig,
VisionTransformerLMModelConfig,
)
from eole.config.run import TrainConfig
from eole.config.training import TrainingConfig
Expand Down Expand Up @@ -122,6 +123,33 @@
"encoder.layer_norm.weight": "roberta.encoder.LayerNorm.weight",
"encoder.layer_norm.bias": "roberta.encoder.LayerNorm.bias",
},
"LlavaForConditionalGeneration": {
"decoder_layer_prefix": "language_model.model.layers.",
"tgt_emb.embeddings.weight": "language_model.model.embed_tokens.weight",
"decoder.layer_norm.weight": "language_model.model.norm.weight",
"generator.weight": "language_model.lm_head.weight",
"encoder.patch_conv.weight": "vision_tower.patch_conv.weight",
"encoder.ln_pre.weight": "vision_tower.ln_pre.weight",
# vision_tower
"encoder_layer_prefix": "vision_tower.transformer.layers.",
"encoder": {
"layers": 24,
".self_attn.linear_query.": ".attention.q_proj.",
".self_attn.linear_keys.": ".attention.k_proj.",
".self_attn.linear_values.": ".attention.v_proj.",
".self_attn.final_linear.": ".attention.o_proj.",
".mlp.gate_up_proj.": ".feed_forward.gate_proj.",
".mlp.down_proj.": ".feed_forward.down_proj.",
".mlp.up_proj.": ".feed_forward.up_proj.",
".input_layernorm.weight": ".attention_norm.weight", # not sure about this one
".post_attention_layernorm.weight": ".ffn_norm.weight",
},
# vision_adapter
"adapter.w_in.weight": "multi_modal_projector.linear_1.weight",
"adapter.w_in.bias": "multi_modal_projector.linear_1.bias",
"adapter.w_out.weight": "multi_modal_projector.linear_2.weight",
"adapter.w_out.bias": "multi_modal_projector.linear_2.bias",
},
}

# Combine base mappings with overrides
Expand Down Expand Up @@ -152,7 +180,10 @@
# Eole config class
ARCH_TABLE = defaultdict(
lambda: TransformerLMModelConfig,
{"XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig},
{
"XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig,
"LlavaForConditionalGeneration": VisionTransformerLMModelConfig,
},
)

# Default tokenization transform
Expand Down Expand Up @@ -284,6 +315,11 @@ def __getattr__(self, name):
def arch(self):
return self.config["architectures"][0]

@property
def vocab_size(self):
config = self.config.get("text_config", self.config)
return config["vocab_size"]

@property
def encoder_layer_prefix(self):
return KEY_MAPS[self.arch].get("encoder_layer_prefix", None)
Expand Down Expand Up @@ -381,14 +417,19 @@ def build_config_dict(hf):
config = hf.config
arch = hf.arch

vision_config = config.get("vision_config", None)
config = config.get("text_config", config)

model_config = {}
training_config = {}

# Initialize model_config with defaults and fallbacks
model_config = {
"layers": config.get("num_hidden_layers", config.get("n_layer")),
"hidden_size": config.get("hidden_size", config.get("n_embd")),
"heads": config.get("num_attention_heads", config.get("n_head")),
"layers": config.get("num_hidden_layers", config.get("n_layer", config.get("n_layers"))),
"hidden_size": config.get("hidden_size", config.get("n_embd", config.get("hidden_dim"))),
"heads": config.get(
"num_attention_heads", config.get("n_head", config.get("n_heads", 32))
), # default 32 patch for mistral-community/pixtral-12b
"transformer_ff": config.get("intermediate_size", config.get("hidden_size", config.get("n_embd")) * 4),
"mlp_activation_fn": ACT_TABLE[arch],
"layer_norm": LN_TABLE[arch],
Expand Down Expand Up @@ -561,6 +602,29 @@ def build_config_dict(hf):
},
}

# Vision encoder
if arch == "LlavaForConditionalGeneration":
# TODO: extend to other Llava models (with CLIP vision encoder)
model_config["encoder"] = {
"mlp_activation_fn": model_config["mlp_activation_fn"],
"layer_norm": model_config["layer_norm"],
"norm_eps": model_config["norm_eps"],
"hidden_size": vision_config["image_size"],
"transformer_ff": vision_config["image_size"] * 4, # hard-coded for mistral-community/pixtral-12b
"num_channels": 3,
"image_size": vision_config["image_size"],
"patch_size": vision_config["patch_size"],
"rope_config": {
"rotary_theta": vision_config["rope_theta"],
"rotary_interleave": False,
},
"layers": 24, # hard-coded for mistral-community/pixtral-12b
"heads": vision_config["image_size"] / vision_config["head_dim"],
"heads_kv": vision_config["image_size"] / vision_config["head_dim"],
"head_dim": vision_config["head_dim"],
"image_token_id": 10,
}

# Update model_config based on architecture
if arch in arch_configs:
for key, value in arch_configs[arch].items():
Expand Down Expand Up @@ -673,7 +737,7 @@ def get_shards_map(model_config, hf, nshards):

# Check if a layer key belongs to the current shard
def is_layer_in_range(key, prefix, layer_range):
return prefix and key.startswith(prefix) and int(key.split(".")[2]) in layer_range
return prefix and key.startswith(prefix) and int(key.split(".")[len(prefix.split(".")) - 1]) in layer_range

# Loop over the weightmap and distribute checkpoints to the appropriate shards
for key, ckpt in weightmap.items():
Expand Down Expand Up @@ -716,6 +780,12 @@ def build_shards(model_config, hf, args, params):
"encoder.layer_norm.weight",
"encoder.layer_norm.bias",
"generator.weight",
"encoder.patch_conv.weight",
"encoder.ln_pre.weight",
"adapter.w_in.weight",
"adapter.w_in.bias",
"adapter.w_out.weight",
"adapter.w_out.bias",
]

def build_first_shard(hf, eole_safetensor):
Expand Down Expand Up @@ -754,26 +824,30 @@ def build_first_shard(hf, eole_safetensor):
if hf_prefix is None:
continue
for param in params:
for target, source in KEY_MAPS[hf.arch].items():
if target in first_shard_targets:
continue
srckey, srcmap = source if isinstance(source, tuple) else (source, None)
w = get_weight(
checkpoint,
hf_prefix + str(i) + srckey + param,
)

if w is not None:
if srcmap is not None:
w = eval(
"w" + srcmap,
{
"w": w,
"hidden_size": model_config["hidden_size"],
"transformer_ff": model_config["transformer_ff"],
},
).contiguous()
eole_safetensor[eole_prefix + str(i) + target + param] = w
# TODO: factorize this better
for key_map in [KEY_MAPS[hf.arch], KEY_MAPS[hf.arch].get("encoder", {})]:
for target, source in key_map.items():
if not isinstance(source, str):
continue
if target in first_shard_targets:
continue
srckey, srcmap = source if isinstance(source, tuple) else (source, None)
w = get_weight(
checkpoint,
hf_prefix + str(i) + srckey + param,
)

if w is not None:
if srcmap is not None:
w = eval(
"w" + srcmap,
{
"w": w,
"hidden_size": model_config["hidden_size"],
"transformer_ff": model_config["transformer_ff"],
},
).contiguous()
eole_safetensor[eole_prefix + str(i) + target + param] = w

if model_config["shared_layer_norm"]:
idx = 0
Expand All @@ -789,26 +863,30 @@ def build_first_shard(hf, eole_safetensor):
"post_feedforward_layernorm",
"mlp.gate",
]:
if hf_prefix == hf.encoder_layer_prefix:
source_map = KEY_MAPS[hf.arch]["encoder"]
else:
source_map = KEY_MAPS[hf.arch]
module_p = f".{module}.{p}"
if module_p in KEY_MAPS[hf.arch].keys():
if isinstance(KEY_MAPS[hf.arch][module_p], tuple):
if module_p in source_map.keys():
if isinstance(source_map[module_p], tuple):
w = get_weight(
checkpoint,
hf_prefix + str(i) + KEY_MAPS[hf.arch][module_p][idx],
hf_prefix + str(i) + source_map[module_p][idx],
)
else:
w = get_weight(
checkpoint,
hf_prefix + str(i) + KEY_MAPS[hf.arch][module_p],
hf_prefix + str(i) + source_map[module_p],
)
if w is not None:
eole_safetensor[eole_prefix + str(i) + module_p] = w

for j in range(model_config["num_experts"]):
if f".mlp.experts.{j}.layer_norm." + p in KEY_MAPS[hf.arch].keys():
if f".mlp.experts.{j}.layer_norm." + p in source_map.keys():
w = get_weight(
checkpoint,
hf_prefix + str(i) + KEY_MAPS[hf.arch][f".mlp.experts.{j}.layer_norm." + p],
hf_prefix + str(i) + source_map[f".mlp.experts.{j}.layer_norm." + p],
)
if w is not None:
eole_safetensor[eole_prefix + str(i) + f".mlp.experts.{j}.layer_norm." + p] = w
Expand Down Expand Up @@ -840,7 +918,8 @@ def check_sentencepiece_tokenizer(hf):


def check_bpe_tokenizer(hf, vocabs, directory_path):
vocab_size = hf.config["vocab_size"]
config = hf.config.get("text_config", hf.config)
vocab_size = hf.vocab_size
# gpt2_pretok
pretokenizers = hf.tokenizer.get("pre_tokenizer", {}).get("pretokenizers", [{}])
pre_tokenizer = hf.tokenizer.get("pre_tokenizer", None)
Expand All @@ -866,8 +945,8 @@ def check_bpe_tokenizer(hf, vocabs, directory_path):
src_vocab = pyonmttok.build_vocab_from_tokens(vocab)
# TODO: not sure for which model(s) this is needed
for token_name in ["bos_token", "unk_token", "eos_token", "pad_token"]:
if f"{token_name}_id" in hf.config.keys():
token = hf.config[f"{token_name}_id"]
if f"{token_name}_id" in config.keys():
token = config[f"{token_name}_id"]
if isinstance(token, list):
vocabs["specials"][token_name] = vocab[token[0]]
elif isinstance(token, str):
Expand Down Expand Up @@ -1009,8 +1088,8 @@ def run(cls, args):
src_vocab=None,
tgt_vocab=None,
share_vocab=True,
src_vocab_size=hf.config["vocab_size"],
tgt_vocab_size=hf.config["vocab_size"],
src_vocab_size=hf.vocab_size,
tgt_vocab_size=hf.vocab_size,
vocab_size_multiple=8,
decoder_start_token=vocabs["decoder_start_token"],
**vocabs["specials"],
Expand Down
6 changes: 3 additions & 3 deletions eole/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def _validate_data(self):
logger.info(f"Missing transforms field for {cname} data, " f"set to default: {default_transforms}.")
corpus.transforms = default_transforms
# Check path
if corpus.path_src is None:
if corpus.path_src is None and corpus.path_txt is None:
raise ValueError(
f"Corpus {cname} src path is required."
f"Corpus {cname} `path_src` or `path_tgt` is required."
"tgt path is also required for non language"
" modeling tasks."
)
else:
elif corpus.path_src is not None:
self.__class__._validate_file(corpus.path_src, info=f"{cname}/path_src")
if corpus.path_tgt is None:
logger.debug("path_tgt is None, it should be set unless the task" " is language modeling")
Expand Down
Loading

0 comments on commit f6f9a95

Please sign in to comment.