Skip to content
8 changes: 8 additions & 0 deletions keras_hub/src/models/mistral/mistral_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,12 @@
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.3_instruct_7b_en/1",
},
"devstral_small_1_1": {
"metadata": {
"description": "Devstral Small 1.1 finetuned from Mistral-Small-3.1 24B base model",
"params": 23572403200,
"path": "devstral_small_1_1",
},
# "kaggle_handle": "kaggle://keras/mistral/keras/devstral_small_1_1/1",
},
}
2 changes: 2 additions & 0 deletions keras_hub/src/utils/transformers/convert_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,6 @@ def convert_weights(backbone, loader, transformers_config):


def convert_tokenizer(cls, preset, **kwargs):
if "devstral" in preset.lower():
preset = "mistralai/Mistral-Small-24B-Base-2501"
return cls(get_file(preset, "tokenizer.model"), **kwargs)
11 changes: 9 additions & 2 deletions tools/checkpoint_conversion/convert_mistral_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1",
"mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2",
"mistral_0.3_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.3",
"devstral_small_1_1": "mistralai/Devstral-Small-2507",
}

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -220,7 +221,10 @@ def main(_):
try:
# === Load the Huggingface model ===
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
if "devstral" in hf_preset.lower():
hf_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Small-24B-Base-2501")
else:
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print("\n-> Huggingface model and tokenizer loaded")

Expand All @@ -239,7 +243,10 @@ def main(_):
)
keras_hub_backbone = MistralBackbone(**backbone_kwargs)

keras_hub_tokenizer = MistralTokenizer.from_preset(f"hf://{hf_preset}")
if "devstral" in hf_preset.lower():
keras_hub_tokenizer = MistralTokenizer.from_preset("hf://mistralai/Mistral-Small-24B-Base-2501")
else:
keras_hub_tokenizer = MistralTokenizer.from_preset(f"hf://{hf_preset}")
print("\n-> Keras 3 model and tokenizer loaded.")

# === Port the weights ===
Expand Down
Loading