Skip to content

Commit 502488e

Browse files
Add SafeTensors Conversion Support for Qwen2/2.5
Models
1 parent 364b9d3 commit 502488e

File tree

3 files changed

+269
-37
lines changed

3 files changed

+269
-37
lines changed

keras_hub/src/utils/transformers/export/hf_exporter.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,43 @@
55

66
import keras
77

8+
# --- Gemma Utils ---
89
from keras_hub.src.utils.transformers.export.gemma import get_gemma_config
910
from keras_hub.src.utils.transformers.export.gemma import (
1011
get_gemma_tokenizer_config,
1112
)
1213
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map
14+
15+
# --- GPT-2 Utils ---
1316
from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_config
1417
from keras_hub.src.utils.transformers.export.gpt2 import (
1518
get_gpt2_tokenizer_config,
1619
)
1720
from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_weights_map
1821

22+
# --- Qwen Utils ---
23+
from keras_hub.src.utils.transformers.export.qwen import get_qwen_config
24+
from keras_hub.src.utils.transformers.export.qwen import (
25+
get_qwen_tokenizer_config,
26+
)
27+
from keras_hub.src.utils.transformers.export.qwen import get_qwen_weights_map
28+
1929
MODEL_CONFIGS = {
2030
"GemmaBackbone": get_gemma_config,
2131
"GPT2Backbone": get_gpt2_config,
22-
# Add for future models, e.g., "MistralBackbone": get_mistral_config
32+
"QwenBackbone": get_qwen_config,
2333
}
2434

2535
MODEL_EXPORTERS = {
2636
"GemmaBackbone": get_gemma_weights_map,
2737
"GPT2Backbone": get_gpt2_weights_map,
28-
# Add for future models, e.g., "MistralBackbone": get_mistral_weights_map
38+
"QwenBackbone": get_qwen_weights_map,
2939
}
3040

3141
MODEL_TOKENIZER_CONFIGS = {
3242
"GemmaTokenizer": get_gemma_tokenizer_config,
3343
"GPT2Tokenizer": get_gpt2_tokenizer_config,
34-
# Add for future models, e.g., "MistralTokenizer":
35-
# get_mistral_tokenizer_config
44+
"QwenTokenizer": get_qwen_tokenizer_config,
3645
}
3746

3847

@@ -62,50 +71,64 @@ def export_backbone(backbone, path, include_lm_head=False):
6271
weights_dict = get_weights_fn(backbone, include_lm_head=include_lm_head)
6372
if not weights_dict:
6473
raise ValueError("No weights to save.")
74+
6575
# Save config
6676
os.makedirs(path, exist_ok=True)
6777
config_path = os.path.join(path, "config.json")
78+
79+
# Handle Config Objects (GPT2/Qwen) vs Dicts (Gemma)
80+
config_to_save = hf_config
81+
if hasattr(hf_config, "to_dict"):
82+
config_to_save = hf_config.to_dict()
83+
6884
with open(config_path, "w") as f:
69-
json.dump(hf_config.to_dict(), f)
85+
json.dump(config_to_save, f, indent=2)
86+
7087
# Save weights based on backend
7188
weights_path = os.path.join(path, "model.safetensors")
7289
if backend == "torch":
90+
# Lazy import to prevent crash on TF-only environments
7391
import torch
7492
from safetensors.torch import save_file
7593

7694
weights_dict_torch = {}
77-
7895
for k, v in weights_dict.items():
7996
tensor = v.value if hasattr(v, "value") else v
8097

81-
# Torch tensor -> move to CPU
8298
if isinstance(tensor, torch.Tensor):
8399
t = tensor.detach().to("cpu")
84-
85-
# TensorFlow / JAX -> convert via numpy()
86100
elif hasattr(tensor, "numpy"):
87101
t = torch.tensor(tensor.numpy())
88-
89-
# numpy array
90102
elif hasattr(tensor, "__array__"):
91103
t = torch.tensor(tensor)
92-
93104
else:
94-
raise TypeError(f"Unsupported tensor type: {type(tensor)}")
105+
t = tensor
106+
107+
if hasattr(t, "contiguous"):
108+
t = t.contiguous()
95109

96-
weights_dict_torch[k] = t.contiguous()
110+
weights_dict_torch[k] = t
97111

98-
# ---- GPT-2 tied weights ----
112+
# Handle Tied Weights (GPT-2, Qwen)
113+
# Safetensors crashes if we try to save the same shared memory twice.
99114
if (
115+
"lm_head.weight" in weights_dict_torch
116+
and "model.embed_tokens.weight" in weights_dict_torch
117+
):
118+
# Qwen / Llama naming convention
119+
wte = weights_dict_torch["model.embed_tokens.weight"]
120+
lm = weights_dict_torch["lm_head.weight"]
121+
if wte.data_ptr() == lm.data_ptr():
122+
weights_dict_torch["lm_head.weight"] = lm.clone().contiguous()
123+
elif (
100124
"lm_head.weight" in weights_dict_torch
101125
and "transformer.wte.weight" in weights_dict_torch
102126
):
127+
# GPT-2 naming convention
103128
wte = weights_dict_torch["transformer.wte.weight"]
104129
lm = weights_dict_torch["lm_head.weight"]
105-
106-
if wte.data_ptr() == lm.data_ptr():
107-
weights_dict_torch["lm_head.weight"] = lm.clone().contiguous()
108-
# --------------------------------
130+
if wte.data_ptr() == lm.data_ptr():
131+
weights_dict_torch["lm_head.weight"] = lm.clone().contiguous()
109132

110133
save_file(weights_dict_torch, weights_path, metadata={"format": "pt"})
111134

@@ -129,46 +152,41 @@ def export_tokenizer(tokenizer, path):
129152
path: str. Path to save the exported tokenizer.
130153
"""
131154
os.makedirs(path, exist_ok=True)
155+
132156
# Save tokenizer assets
157+
# BytePairTokenizer (GPT2, Qwen) -> "vocabulary.json", "merges.txt"
158+
# SentencePieceTokenizer (Gemma) -> "vocabulary.spm"
133159
tokenizer.save_assets(path)
160+
134161
# Export tokenizer config
135162
tokenizer_type = tokenizer.__class__.__name__
136163
if tokenizer_type not in MODEL_TOKENIZER_CONFIGS:
137164
raise ValueError(
138-
"Export to Transformers format not implemented for {tokenizer_type}"
165+
f"Export to Transformer format not implemented for {tokenizer_type}"
139166
)
140167
get_tokenizer_config_fn = MODEL_TOKENIZER_CONFIGS[tokenizer_type]
141168
tokenizer_config = get_tokenizer_config_fn(tokenizer)
142169
tokenizer_config_path = os.path.join(path, "tokenizer_config.json")
143170
with open(tokenizer_config_path, "w") as f:
144171
json.dump(tokenizer_config, f, indent=4)
145172

173+
# Rename files to match Hugging Face expectations
146174
if tokenizer_type == "GemmaTokenizer":
147-
# Rename vocabulary file
148175
vocab_spm_path = os.path.join(path, "vocabulary.spm")
149176
tokenizer_model_path = os.path.join(path, "tokenizer.model")
150177
if os.path.exists(vocab_spm_path):
151178
shutil.move(vocab_spm_path, tokenizer_model_path)
152179
else:
153-
warnings.warn(
154-
f"{vocab_spm_path} not found. Tokenizer may not load "
155-
"correctly. Ensure that the tokenizer configuration "
156-
"is correct and that the vocabulary file is present "
157-
"in the original model."
158-
)
159-
elif tokenizer_type == "GPT2Tokenizer":
160-
# Rename vocabulary file
180+
warnings.warn(f"{vocab_spm_path} not found.")
181+
182+
elif tokenizer_type in ["GPT2Tokenizer", "QwenTokenizer"]:
183+
# Both GPT-2 and Qwen (BPE) use vocab.json in HF
161184
vocab_json_path = os.path.join(path, "vocabulary.json")
162-
renamed_vocab_json_path = os.path.join(path, "vocab.json")
185+
vocab_hf_path = os.path.join(path, "vocab.json")
163186
if os.path.exists(vocab_json_path):
164-
shutil.move(vocab_json_path, renamed_vocab_json_path)
187+
shutil.move(vocab_json_path, vocab_hf_path)
165188
else:
166-
warnings.warn(
167-
f"{vocab_json_path} not found. Tokenizer may not load "
168-
"correctly. Ensure that the tokenizer configuration "
169-
"is correct and that the vocabulary file is present "
170-
"in the original model."
171-
)
189+
warnings.warn(f"{vocab_json_path} not found.")
172190

173191

174192
def export_to_safetensors(keras_model, path):
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import keras.ops as ops
2+
import transformers
3+
4+
5+
def get_qwen_config(backbone):
6+
"""Convert Keras Qwen config to Hugging Face Qwen2Config."""
7+
# Qwen2Config handles the architecture specifics (RoPE, RMSNorm, SwiGLU)
8+
return transformers.Qwen2Config(
9+
vocab_size=backbone.vocabulary_size,
10+
hidden_size=backbone.hidden_dim,
11+
num_hidden_layers=backbone.num_layers,
12+
num_attention_heads=backbone.num_query_heads,
13+
num_key_value_heads=backbone.num_key_value_heads,
14+
intermediate_size=backbone.intermediate_dim,
15+
hidden_act="silu", # Qwen uses SwiGLU (SiLU activation)
16+
rms_norm_eps=backbone.layer_norm_epsilon,
17+
rope_theta=backbone.rope_max_wavelength,
18+
tie_word_embeddings=backbone.tie_word_embeddings,
19+
# Default initialization parameters
20+
initializer_range=0.02,
21+
use_cache=True,
22+
)
23+
24+
25+
def get_qwen_weights_map(backbone, include_lm_head=False):
26+
"""Create a weights map for a given Qwen model."""
27+
weights_map = {}
28+
29+
# 1. Embeddings
30+
# Keras: token_embedding.embeddings
31+
# HF: model.embed_tokens.weight
32+
weights_map["model.embed_tokens.weight"] = backbone.get_layer(
33+
"token_embedding"
34+
).embeddings
35+
36+
for i in range(backbone.num_layers):
37+
# Access the decoder layer
38+
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
39+
40+
# --- Normalization ---
41+
# Input Norm (Pre-Attention)
42+
# Keras uses 'scale' (gamma), HF uses 'weight'
43+
weights_map[f"model.layers.{i}.input_layernorm.weight"] = (
44+
decoder_layer._self_attention_layernorm.scale
45+
)
46+
47+
# Post Attention Norm (Pre-MLP)
48+
weights_map[f"model.layers.{i}.post_attention_layernorm.weight"] = (
49+
decoder_layer._feedforward_layernorm.scale
50+
)
51+
52+
# --- Attention ---
53+
# QwenAttention uses EinsumDense for Q/K/V/O
54+
# Keras Shape: (hidden_dim, num_heads, head_dim)
55+
# HF Shape: (num_heads * head_dim, hidden_dim) -> Transposed Linear
56+
57+
attn_layer = decoder_layer._self_attention_layer
58+
59+
# Query
60+
q_kernel = attn_layer._query_dense.kernel
61+
q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1))
62+
weights_map[f"model.layers.{i}.self_attn.q_proj.weight"] = (
63+
ops.transpose(q_kernel)
64+
)
65+
66+
# Keras: (num_heads, head_dim) -> HF: (hidden_dim,)
67+
weights_map[f"model.layers.{i}.self_attn.q_proj.bias"] = ops.reshape(
68+
attn_layer._query_dense.bias, (-1,)
69+
)
70+
71+
# Key
72+
k_kernel = attn_layer._key_dense.kernel
73+
k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1))
74+
weights_map[f"model.layers.{i}.self_attn.k_proj.weight"] = (
75+
ops.transpose(k_kernel)
76+
)
77+
78+
weights_map[f"model.layers.{i}.self_attn.k_proj.bias"] = ops.reshape(
79+
attn_layer._key_dense.bias, (-1,)
80+
)
81+
82+
# Value
83+
v_kernel = attn_layer._value_dense.kernel
84+
v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1))
85+
weights_map[f"model.layers.{i}.self_attn.v_proj.weight"] = (
86+
ops.transpose(v_kernel)
87+
)
88+
89+
weights_map[f"model.layers.{i}.self_attn.v_proj.bias"] = ops.reshape(
90+
attn_layer._value_dense.bias, (-1,)
91+
)
92+
93+
# Output
94+
o_kernel = attn_layer._output_dense.kernel
95+
o_kernel = ops.reshape(o_kernel, (-1, backbone.hidden_dim))
96+
weights_map[f"model.layers.{i}.self_attn.o_proj.weight"] = (
97+
ops.transpose(o_kernel)
98+
)
99+
100+
# --- MLP (SwiGLU) ---
101+
# Gate (feedforward_gate_dense)
102+
gate_kernel = decoder_layer._feedforward_gate_dense.kernel
103+
weights_map[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
104+
gate_kernel
105+
)
106+
107+
# Up (feedforward_intermediate_dense)
108+
up_kernel = decoder_layer._feedforward_intermediate_dense.kernel
109+
weights_map[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
110+
up_kernel
111+
)
112+
113+
# Down (feedforward_output_dense)
114+
down_kernel = decoder_layer._feedforward_output_dense.kernel
115+
weights_map[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
116+
down_kernel
117+
)
118+
119+
# Final Norm
120+
weights_map["model.norm.weight"] = backbone.get_layer(
121+
"sequence_output_layernorm"
122+
).scale
123+
124+
# LM Head
125+
if include_lm_head:
126+
if backbone.tie_word_embeddings:
127+
# If tied, point to embeddings (Exporter handles cloning)
128+
weights_map["lm_head.weight"] = weights_map[
129+
"model.embed_tokens.weight"
130+
]
131+
else:
132+
# If not tied, QwenBackbone uses ReversibleEmbedding.
133+
lm_head_w = backbone.get_layer("token_embedding").reverse_embeddings
134+
weights_map["lm_head.weight"] = ops.transpose(lm_head_w)
135+
136+
return weights_map
137+
138+
139+
def get_qwen_tokenizer_config(tokenizer):
140+
"""Convert Keras Qwen tokenizer config to Hugging Face."""
141+
# Qwen2 uses BPE. We specify the class and basic special tokens.
142+
# The actual vocab/merges files are handled by the exporter.
143+
return {
144+
"tokenizer_class": "Qwen2Tokenizer",
145+
"bos_token": None, # Qwen often doesn't use BOS
146+
"eos_token": "<|endoftext|>",
147+
"pad_token": "<|endoftext|>", # Often mapped to EOS or null
148+
"unk_token": None,
149+
"model_max_length": 32768, # Default window size
150+
}

0 commit comments

Comments
 (0)