|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Convert weights from a GPT-OSS style model to a MaxText one. |
| 16 | +
|
| 17 | +Example cmd: |
| 18 | +
|
| 19 | +python3 -m MaxText.convert_gpt_oss_unscanned_ckpt --base-model-path <path/to/hf/ckpt> \ |
| 20 | + --maxtext-model-path <GCS/path/to/save/new/maxtext/ckpt> --model-size gpt-oss-20b |
| 21 | +""" |
| 22 | + |
| 23 | +# pylint: disable=g-line-too-long |
| 24 | +import argparse |
| 25 | +import gc |
| 26 | +import logging |
| 27 | +import os |
| 28 | +import pathlib |
| 29 | +import re |
| 30 | +from dataclasses import dataclass |
| 31 | + |
| 32 | +os.environ["JAX_PLATFORMS"] = "cpu" |
| 33 | + |
| 34 | +import ml_dtypes |
| 35 | +import numpy as np |
| 36 | +import psutil |
| 37 | +from safetensors import safe_open |
| 38 | +import torch |
| 39 | +from tqdm import tqdm |
| 40 | + |
| 41 | +from MaxText import max_logging |
| 42 | +from MaxText.inference_utils import str2bool |
| 43 | +from MaxText.llama_or_mistral_ckpt import save_weights_to_checkpoint |
| 44 | + |
| 45 | + |
| 46 | +# NOTE: numpy doesn't have native support for bfloat16, so |
| 47 | +# we'll use ml_dtypes instead (which is quasi native) |
| 48 | +# NOTE: it's incredibly silly but you can't directly cast from |
| 49 | +# a torch tensor of type bfloat16 to a numpy array of type bfloat16 |
| 50 | +# so we have to cast to float32 first |
| 51 | +CAST_DTYPE = ml_dtypes.bfloat16 |
| 52 | + |
| 53 | + |
| 54 | +def _pt_to_np(pt_weight, cast_dtype=None, transpose=False): |
| 55 | + if cast_dtype: |
| 56 | + np_weight = pt_weight.to(torch.float32).numpy().astype(cast_dtype) |
| 57 | + else: |
| 58 | + np_weight = pt_weight.to(torch.float32).numpy() |
| 59 | + if transpose: |
| 60 | + np_weight = np_weight.transpose() |
| 61 | + return np_weight |
| 62 | + |
| 63 | + |
| 64 | +MODEL_PARAMS_DICT = { |
| 65 | + "gpt-oss-20b": { |
| 66 | + "base_emb_dim": 2880, |
| 67 | + "base_num_query_heads": 64, |
| 68 | + "base_num_kv_heads": 8, |
| 69 | + "head_dim": 64, |
| 70 | + "base_num_decoder_layers": 24, |
| 71 | + "inhomogeneous_layer_cycle_interval": 2, |
| 72 | + }, |
| 73 | + "gpt-oss-120b": { |
| 74 | + "base_emb_dim": 2880, |
| 75 | + "base_num_query_heads": 64, |
| 76 | + "base_num_kv_heads": 8, |
| 77 | + "head_dim": 64, |
| 78 | + "base_num_decoder_layers": 36, |
| 79 | + "inhomogeneous_layer_cycle_interval": 2, |
| 80 | + }, |
| 81 | +} |
| 82 | + |
| 83 | + |
| 84 | +def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict: |
| 85 | + """ |
| 86 | + Returns a mapping from HuggingFace model weight names to MaxText model weight names. |
| 87 | +
|
| 88 | + Args: |
| 89 | + layer_idx (int): Layer index. |
| 90 | +
|
| 91 | + Returns: |
| 92 | + dict [str, str]: Mapping from HuggingFace model weight names to MaxText model weight names. |
| 93 | + """ |
| 94 | + # pylint: disable=line-too-long |
| 95 | + return { |
| 96 | + "model.embed_tokens.weight": "tok_embeddings.weight", |
| 97 | + "model.norm.weight": "norm.weight", |
| 98 | + "lm_head.weight": "output.weight", |
| 99 | + # layernorm |
| 100 | + f"model.layers.{layer_idx}.input_layernorm.weight": f"layers.{layer_idx}.attention_norm.weight", |
| 101 | + f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"layers.{layer_idx}.ffn_norm.weight", |
| 102 | + # attention |
| 103 | + f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"layers.{layer_idx}.attention.wq.weight", |
| 104 | + f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight", |
| 105 | + f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight", |
| 106 | + f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight", |
| 107 | + f"model.layers.{layer_idx}.self_attn.q_proj.bias": f"layers.{layer_idx}.attention.wq.bias", |
| 108 | + f"model.layers.{layer_idx}.self_attn.k_proj.bias": f"layers.{layer_idx}.attention.wk.bias", |
| 109 | + f"model.layers.{layer_idx}.self_attn.v_proj.bias": f"layers.{layer_idx}.attention.wv.bias", |
| 110 | + f"model.layers.{layer_idx}.self_attn.o_proj.bias": f"layers.{layer_idx}.attention.wo.bias", |
| 111 | + f"model.layers.{layer_idx}.self_attn.sinks": f"layers.{layer_idx}.attention.sinks", |
| 112 | + # MoE |
| 113 | + f"model.layers.{layer_idx}.mlp.router.weight": f"layers.{layer_idx}.feed_forward.gate.weight", |
| 114 | + f"model.layers.{layer_idx}.mlp.router.bias": f"layers.{layer_idx}.feed_forward.gate.bias", |
| 115 | + f"model.layers.{layer_idx}.mlp.experts.gate_up_proj": f"layers.{layer_idx}.feed_forward.experts.gate_up_proj", |
| 116 | + f"model.layers.{layer_idx}.mlp.experts.gate_up_proj_bias": f"layers.{layer_idx}.feed_forward.experts.gate_up_proj_bias", |
| 117 | + f"model.layers.{layer_idx}.mlp.experts.down_proj": f"layers.{layer_idx}.feed_forward.experts.down_proj", |
| 118 | + f"model.layers.{layer_idx}.mlp.experts.down_proj_bias": f"layers.{layer_idx}.feed_forward.experts.down_proj_bias", |
| 119 | + } |
| 120 | + |
| 121 | + |
| 122 | +def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, model_params: dict, mem_info: psutil.Process): |
| 123 | + """Convert a Huggingface Checkpoint to a dictionary of Numpy arrays representing the weights. |
| 124 | +
|
| 125 | + Args: |
| 126 | + base_model_path (str): Path to the base model checkpoint. |
| 127 | + model_size (str): Size of the base model. |
| 128 | + model_params (dict): Dictionary containing model parameters. |
| 129 | + mem_info (psutil.Process): Process object to track memory usage. |
| 130 | +
|
| 131 | + Returns: |
| 132 | + jax_weights (dict): Dictionary containing the converted weights. |
| 133 | + """ |
| 134 | + # model params |
| 135 | + base_num_decoder_layers = model_params["base_num_decoder_layers"] |
| 136 | + base_emb_dim = model_params["base_emb_dim"] |
| 137 | + base_num_query_heads = model_params["base_num_query_heads"] |
| 138 | + base_num_kv_heads = model_params["base_num_kv_heads"] |
| 139 | + head_dim = model_params["head_dim"] |
| 140 | + |
| 141 | + # load model |
| 142 | + max_logging.log(f"Loading the base model from {base_model_path}") |
| 143 | + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) |
| 144 | + chkpt_vars = {} |
| 145 | + for i, ckpt_path in enumerate(ckpt_paths): |
| 146 | + max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") |
| 147 | + |
| 148 | + with safe_open(ckpt_path, framework="pt", device="cpu") as f: |
| 149 | + for key in f.keys(): |
| 150 | + parts = key.split(".") |
| 151 | + layer = int(parts[2]) if "layers" in key else 0 |
| 152 | + mapped_key = _hf_to_maxtext_mapping(layer)[key] |
| 153 | + chkpt_vars[mapped_key] = f.get_tensor(key) |
| 154 | + |
| 155 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 156 | + |
| 157 | + # initialize the data structure for storing jax_weights |
| 158 | + jax_weights = { |
| 159 | + "token_embedder": {"embedding": None}, |
| 160 | + "decoder": { |
| 161 | + "decoder_norm": {"scale": None}, |
| 162 | + "logits_dense": {"kernel": None}, |
| 163 | + }, |
| 164 | + } |
| 165 | + for layer_idx in range(base_num_decoder_layers): |
| 166 | + jax_weights["decoder"][f"layers_{layer_idx}"] = { |
| 167 | + "pre_self_attention_layer_norm": {"scale": None}, |
| 168 | + "post_self_attention_layer_norm": {"scale": None}, |
| 169 | + "GptOssAttention": { |
| 170 | + "query": {"kernel": None, "bias": None}, |
| 171 | + "key": {"kernel": None, "bias": None}, |
| 172 | + "value": {"kernel": None, "bias": None}, |
| 173 | + "out": {"kernel": None, "bias": None}, |
| 174 | + "sinks": None, |
| 175 | + }, |
| 176 | + "GptOssMlp": { |
| 177 | + "gate": {"kernel": None, "bias": None}, |
| 178 | + "wi_0": None, |
| 179 | + "wi_0_bias": None, |
| 180 | + "wi_1": None, |
| 181 | + "wi_1_bias": None, |
| 182 | + "wo": None, |
| 183 | + "wo_bias": None, |
| 184 | + }, |
| 185 | + } |
| 186 | + |
| 187 | + # decoder norm scale ########################################### |
| 188 | + max_logging.log("Processing decoder norm scale") |
| 189 | + jax_weights["decoder"]["decoder_norm"]["scale"] = _pt_to_np(chkpt_vars["norm.weight"], cast_dtype=CAST_DTYPE) |
| 190 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 191 | + |
| 192 | + # logits dense ################################################# |
| 193 | + max_logging.log("Processing logits dense") |
| 194 | + |
| 195 | + logit_dense = _pt_to_np(chkpt_vars["output.weight"], cast_dtype=CAST_DTYPE) |
| 196 | + jax_weights["decoder"]["logits_dense"]["kernel"] = logit_dense.transpose() # [:, :vocab_size] |
| 197 | + |
| 198 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 199 | + |
| 200 | + # token embedding ############################################## |
| 201 | + max_logging.log("Processing token embeddings") |
| 202 | + |
| 203 | + jax_weights["token_embedder"]["embedding"] = _pt_to_np(chkpt_vars["tok_embeddings.weight"], cast_dtype=CAST_DTYPE) |
| 204 | + |
| 205 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 206 | + |
| 207 | + # self attention ############################################### |
| 208 | + max_logging.log("Processing self attention") |
| 209 | + for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): |
| 210 | + self_attention = jax_weights["decoder"][f"layers_{layer_idx}"]["GptOssAttention"] |
| 211 | + |
| 212 | + wq = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wq.weight"], cast_dtype=CAST_DTYPE) |
| 213 | + wk = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wk.weight"], cast_dtype=CAST_DTYPE) |
| 214 | + wv = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wv.weight"], cast_dtype=CAST_DTYPE) |
| 215 | + w_post = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wo.weight"], cast_dtype=CAST_DTYPE) |
| 216 | + |
| 217 | + # NOTE: not scale the query weights in checkpoint, but apply query_pre_attn_scalar=1/np.sqrt(head_dim) for attention |
| 218 | + # (num_attention_heads * head_dim, hidden_size) -> (hidden_size, num_attention_heads * head_dim) -> (hidden_size, num_attention_heads, head_dim) |
| 219 | + # [embed, q, head_dim] |
| 220 | + self_attention["query"]["kernel"] = wq.transpose().reshape([base_emb_dim, base_num_query_heads, head_dim]) |
| 221 | + # [embed, kv, head_dim] |
| 222 | + self_attention["key"]["kernel"] = wk.transpose().reshape([base_emb_dim, base_num_kv_heads, head_dim]) |
| 223 | + # [embed, kv, head_dim] |
| 224 | + self_attention["value"]["kernel"] = wv.transpose().reshape([base_emb_dim, base_num_kv_heads, head_dim]) |
| 225 | + # (hidden_size, num_attention_heads * head_dim) -> (num_attention_heads * head_dim, hidden_size) -> (num_attention_heads, head_dim, hidden_size) |
| 226 | + # [q, head_dim, embed] |
| 227 | + self_attention["out"]["kernel"] = w_post.transpose().reshape([base_num_query_heads, head_dim, base_emb_dim]) |
| 228 | + |
| 229 | + sinks = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.sinks"], cast_dtype=CAST_DTYPE) |
| 230 | + wq_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wq.bias"], cast_dtype=CAST_DTYPE) |
| 231 | + wk_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wk.bias"], cast_dtype=CAST_DTYPE) |
| 232 | + wv_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wv.bias"], cast_dtype=CAST_DTYPE) |
| 233 | + w_post_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wo.bias"], cast_dtype=CAST_DTYPE) |
| 234 | + |
| 235 | + self_attention["sinks"] = sinks |
| 236 | + self_attention["query"]["bias"] = wq_bias.reshape([base_num_query_heads, head_dim]) |
| 237 | + self_attention["key"]["bias"] = wk_bias.reshape([base_num_kv_heads, head_dim]) |
| 238 | + self_attention["value"]["bias"] = wv_bias.reshape([base_num_kv_heads, head_dim]) |
| 239 | + self_attention["out"]["bias"] = w_post_bias |
| 240 | + |
| 241 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 242 | + |
| 243 | + # layer weight pre and post self attention norm ################ |
| 244 | + max_logging.log("Processing pre and post self attention norms") |
| 245 | + for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): |
| 246 | + layer_weight = jax_weights["decoder"][f"layers_{layer_idx}"] |
| 247 | + pre_self_attention_layernorm = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention_norm.weight"], cast_dtype=CAST_DTYPE) |
| 248 | + post_self_attention_layernorm = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.ffn_norm.weight"], cast_dtype=CAST_DTYPE) |
| 249 | + |
| 250 | + layer_weight["pre_self_attention_layer_norm"]["scale"] = pre_self_attention_layernorm # pylint: disable=E1137 |
| 251 | + layer_weight["post_self_attention_layer_norm"]["scale"] = post_self_attention_layernorm # pylint: disable=E1137 |
| 252 | + |
| 253 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 254 | + |
| 255 | + # layer weights ################################################ |
| 256 | + max_logging.log("Processing layer weights") |
| 257 | + |
| 258 | + for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): |
| 259 | + mlp_weight = jax_weights["decoder"][f"layers_{layer_idx}"]["GptOssMlp"] |
| 260 | + |
| 261 | + gate = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.gate.weight"], cast_dtype=CAST_DTYPE) |
| 262 | + gate_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.gate.bias"], cast_dtype=CAST_DTYPE) |
| 263 | + wi_0_1 = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.gate_up_proj"], cast_dtype=CAST_DTYPE) |
| 264 | + wi_0_1_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.gate_up_proj_bias"], cast_dtype=CAST_DTYPE) |
| 265 | + wo = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.down_proj"], cast_dtype=CAST_DTYPE) |
| 266 | + wo_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.down_proj_bias"], cast_dtype=CAST_DTYPE) |
| 267 | + |
| 268 | + # router |
| 269 | + mlp_weight["gate"]["kernel"] = gate.transpose() |
| 270 | + mlp_weight["gate"]["bias"] = gate_bias |
| 271 | + # experts.gate_up_proj: de-interleave last dim, even for gate, odd for up_proj |
| 272 | + wi_0 = wi_0_1[..., ::2] |
| 273 | + wi_1 = wi_0_1[..., 1::2] |
| 274 | + del wi_0_1 |
| 275 | + wi_0_bias = wi_0_1_bias[..., ::2] |
| 276 | + wi_1_bias = wi_0_1_bias[..., 1::2] |
| 277 | + del wi_0_1_bias |
| 278 | + mlp_weight["wi_0"] = wi_0 |
| 279 | + mlp_weight["wi_1"] = wi_1 |
| 280 | + mlp_weight["wi_0_bias"] = wi_0_bias |
| 281 | + mlp_weight["wi_1_bias"] = wi_1_bias |
| 282 | + # experts.down_proj |
| 283 | + mlp_weight["wo"] = wo |
| 284 | + mlp_weight["wo_bias"] = wo_bias |
| 285 | + |
| 286 | + gc.collect() |
| 287 | + |
| 288 | + del chkpt_vars |
| 289 | + gc.collect() |
| 290 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 291 | + return jax_weights |
| 292 | + |
| 293 | + |
| 294 | +def convert_to_jax_weights(base_model_path: str, model_size: str): |
| 295 | + """ |
| 296 | + Function to convert the checkpoint at base_model_path into Orbax checkpoint |
| 297 | + for MaxText and output jax_weights ready for MaxText |
| 298 | +
|
| 299 | + Attributes: |
| 300 | + base_model_path: checkpoint path |
| 301 | + model_size: gpt-oss-20b, gpt-oss-120b |
| 302 | + """ |
| 303 | + model_params = MODEL_PARAMS_DICT[model_size] |
| 304 | + mem_info = psutil.Process() |
| 305 | + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
| 306 | + max_logging.log(f"Loading the base model from {base_model_path}") |
| 307 | + return _convert_huggingface_to_jax_weights(base_model_path, model_size, model_params, mem_info) |
| 308 | + |
| 309 | + |
| 310 | +if __name__ == "__main__": |
| 311 | + parser = argparse.ArgumentParser() |
| 312 | + parser.add_argument("--base-model-path", type=str, required=True) |
| 313 | + parser.add_argument("--maxtext-model-path", type=str, required=True) |
| 314 | + parser.add_argument("--model-size", type=str, required=True) |
| 315 | + parser.add_argument("--simulated-cpu-devices-count", type=int, required=False, default=16) |
| 316 | + parser.add_argument("--use-ocdbt", type=str2bool, required=False, default=True) |
| 317 | + parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True) |
| 318 | + args = parser.parse_args() |
| 319 | + |
| 320 | + if args.model_size not in MODEL_PARAMS_DICT: |
| 321 | + raise NotImplementedError(f"Model '{args.model_size}' is not supported.") |
| 322 | + |
| 323 | + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}" |
| 324 | + base_weights_path = args.maxtext_model_path |
| 325 | + |
| 326 | + save_weights_to_checkpoint( |
| 327 | + args.maxtext_model_path, |
| 328 | + convert_to_jax_weights(args.base_model_path, args.model_size), |
| 329 | + args.simulated_cpu_devices_count, |
| 330 | + args.use_ocdbt, |
| 331 | + args.use_zarr3, |
| 332 | + ) |
| 333 | + max_logging.log(f"Successfully saved base_weights to {base_weights_path}.") |
0 commit comments