55
66import keras
77
8+ # --- Gemma Utils ---
89from keras_hub .src .utils .transformers .export .gemma import get_gemma_config
910from keras_hub .src .utils .transformers .export .gemma import (
1011 get_gemma_tokenizer_config ,
1112)
1213from keras_hub .src .utils .transformers .export .gemma import get_gemma_weights_map
14+
15+ # --- GPT-2 Utils ---
1316from keras_hub .src .utils .transformers .export .gpt2 import get_gpt2_config
1417from keras_hub .src .utils .transformers .export .gpt2 import (
1518 get_gpt2_tokenizer_config ,
1619)
1720from keras_hub .src .utils .transformers .export .gpt2 import get_gpt2_weights_map
1821
22+ # --- Qwen Utils ---
23+ from keras_hub .src .utils .transformers .export .qwen import get_qwen_config
24+ from keras_hub .src .utils .transformers .export .qwen import (
25+ get_qwen_tokenizer_config ,
26+ )
27+ from keras_hub .src .utils .transformers .export .qwen import get_qwen_weights_map
28+
1929MODEL_CONFIGS = {
2030 "GemmaBackbone" : get_gemma_config ,
2131 "GPT2Backbone" : get_gpt2_config ,
22- # Add for future models, e.g., "MistralBackbone ": get_mistral_config
32+ "QwenBackbone " : get_qwen_config ,
2333}
2434
2535MODEL_EXPORTERS = {
2636 "GemmaBackbone" : get_gemma_weights_map ,
2737 "GPT2Backbone" : get_gpt2_weights_map ,
28- # Add for future models, e.g., "MistralBackbone ": get_mistral_weights_map
38+ "QwenBackbone " : get_qwen_weights_map ,
2939}
3040
3141MODEL_TOKENIZER_CONFIGS = {
3242 "GemmaTokenizer" : get_gemma_tokenizer_config ,
3343 "GPT2Tokenizer" : get_gpt2_tokenizer_config ,
34- # Add for future models, e.g., "MistralTokenizer":
35- # get_mistral_tokenizer_config
44+ "QwenTokenizer" : get_qwen_tokenizer_config ,
3645}
3746
3847
@@ -62,50 +71,64 @@ def export_backbone(backbone, path, include_lm_head=False):
6271 weights_dict = get_weights_fn (backbone , include_lm_head = include_lm_head )
6372 if not weights_dict :
6473 raise ValueError ("No weights to save." )
74+
6575 # Save config
6676 os .makedirs (path , exist_ok = True )
6777 config_path = os .path .join (path , "config.json" )
78+
79+ # Handle Config Objects (GPT2/Qwen) vs Dicts (Gemma)
80+ config_to_save = hf_config
81+ if hasattr (hf_config , "to_dict" ):
82+ config_to_save = hf_config .to_dict ()
83+
6884 with open (config_path , "w" ) as f :
69- json .dump (hf_config .to_dict (), f )
85+ json .dump (config_to_save , f , indent = 2 )
86+
7087 # Save weights based on backend
7188 weights_path = os .path .join (path , "model.safetensors" )
7289 if backend == "torch" :
90+ # Lazy import to prevent crash on TF-only environments
7391 import torch
7492 from safetensors .torch import save_file
7593
7694 weights_dict_torch = {}
77-
7895 for k , v in weights_dict .items ():
7996 tensor = v .value if hasattr (v , "value" ) else v
8097
81- # Torch tensor -> move to CPU
8298 if isinstance (tensor , torch .Tensor ):
8399 t = tensor .detach ().to ("cpu" )
84-
85- # TensorFlow / JAX -> convert via numpy()
86100 elif hasattr (tensor , "numpy" ):
87101 t = torch .tensor (tensor .numpy ())
88-
89- # numpy array
90102 elif hasattr (tensor , "__array__" ):
91103 t = torch .tensor (tensor )
92-
93104 else :
94- raise TypeError (f"Unsupported tensor type: { type (tensor )} " )
105+ t = tensor
106+
107+ if hasattr (t , "contiguous" ):
108+ t = t .contiguous ()
95109
96- weights_dict_torch [k ] = t . contiguous ()
110+ weights_dict_torch [k ] = t
97111
98- # ---- GPT-2 tied weights ----
112+ # Handle Tied Weights (GPT-2, Qwen)
113+ # Safetensors crashes if we try to save the same shared memory twice.
99114 if (
115+ "lm_head.weight" in weights_dict_torch
116+ and "model.embed_tokens.weight" in weights_dict_torch
117+ ):
118+ # Qwen / Llama naming convention
119+ wte = weights_dict_torch ["model.embed_tokens.weight" ]
120+ lm = weights_dict_torch ["lm_head.weight" ]
121+ if wte .data_ptr () == lm .data_ptr ():
122+ weights_dict_torch ["lm_head.weight" ] = lm .clone ().contiguous ()
123+ elif (
100124 "lm_head.weight" in weights_dict_torch
101125 and "transformer.wte.weight" in weights_dict_torch
102126 ):
127+ # GPT-2 naming convention
103128 wte = weights_dict_torch ["transformer.wte.weight" ]
104129 lm = weights_dict_torch ["lm_head.weight" ]
105-
106- if wte .data_ptr () == lm .data_ptr ():
107- weights_dict_torch ["lm_head.weight" ] = lm .clone ().contiguous ()
108- # --------------------------------
130+ if wte .data_ptr () == lm .data_ptr ():
131+ weights_dict_torch ["lm_head.weight" ] = lm .clone ().contiguous ()
109132
110133 save_file (weights_dict_torch , weights_path , metadata = {"format" : "pt" })
111134
@@ -129,46 +152,41 @@ def export_tokenizer(tokenizer, path):
129152 path: str. Path to save the exported tokenizer.
130153 """
131154 os .makedirs (path , exist_ok = True )
155+
132156 # Save tokenizer assets
157+ # BytePairTokenizer (GPT2, Qwen) -> "vocabulary.json", "merges.txt"
158+ # SentencePieceTokenizer (Gemma) -> "vocabulary.spm"
133159 tokenizer .save_assets (path )
160+
134161 # Export tokenizer config
135162 tokenizer_type = tokenizer .__class__ .__name__
136163 if tokenizer_type not in MODEL_TOKENIZER_CONFIGS :
137164 raise ValueError (
138- "Export to Transformers format not implemented for {tokenizer_type}"
165+ f "Export to Transformer format not implemented for { tokenizer_type } "
139166 )
140167 get_tokenizer_config_fn = MODEL_TOKENIZER_CONFIGS [tokenizer_type ]
141168 tokenizer_config = get_tokenizer_config_fn (tokenizer )
142169 tokenizer_config_path = os .path .join (path , "tokenizer_config.json" )
143170 with open (tokenizer_config_path , "w" ) as f :
144171 json .dump (tokenizer_config , f , indent = 4 )
145172
173+ # Rename files to match Hugging Face expectations
146174 if tokenizer_type == "GemmaTokenizer" :
147- # Rename vocabulary file
148175 vocab_spm_path = os .path .join (path , "vocabulary.spm" )
149176 tokenizer_model_path = os .path .join (path , "tokenizer.model" )
150177 if os .path .exists (vocab_spm_path ):
151178 shutil .move (vocab_spm_path , tokenizer_model_path )
152179 else :
153- warnings .warn (
154- f"{ vocab_spm_path } not found. Tokenizer may not load "
155- "correctly. Ensure that the tokenizer configuration "
156- "is correct and that the vocabulary file is present "
157- "in the original model."
158- )
159- elif tokenizer_type == "GPT2Tokenizer" :
160- # Rename vocabulary file
180+ warnings .warn (f"{ vocab_spm_path } not found." )
181+
182+ elif tokenizer_type in ["GPT2Tokenizer" , "QwenTokenizer" ]:
183+ # Both GPT-2 and Qwen (BPE) use vocab.json in HF
161184 vocab_json_path = os .path .join (path , "vocabulary.json" )
162- renamed_vocab_json_path = os .path .join (path , "vocab.json" )
185+ vocab_hf_path = os .path .join (path , "vocab.json" )
163186 if os .path .exists (vocab_json_path ):
164- shutil .move (vocab_json_path , renamed_vocab_json_path )
187+ shutil .move (vocab_json_path , vocab_hf_path )
165188 else :
166- warnings .warn (
167- f"{ vocab_json_path } not found. Tokenizer may not load "
168- "correctly. Ensure that the tokenizer configuration "
169- "is correct and that the vocabulary file is present "
170- "in the original model."
171- )
189+ warnings .warn (f"{ vocab_json_path } not found." )
172190
173191
174192def export_to_safetensors (keras_model , path ):
0 commit comments