Skip to content

Commit ac708cd

Browse files
Merge pull request #2294 from AI-Hypercomputer:shuningjin-oss3
PiperOrigin-RevId: 805870003
2 parents 3b27bed + 597d102 commit ac708cd

15 files changed

+974
-23
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,9 @@ rope_factor: 40
624624
beta_fast: 32
625625
beta_slow: 1
626626
mscale: 1.0
627+
rope_interleave: True # RoPE with sin/cos interleaved vs concatenated
628+
rope_truncate: True # Floor lower bound and ceil upper bound for correction range
629+
rope_attention_scaling: False # Scale the rotary embedding output
627630

628631
# Ahead of time Compilation (aka AOT)
629632
# Only set these arguments if you are running train_compile or loading a compiled train step.

src/MaxText/configs/models/deepseek2-16b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ original_max_position_embeddings: 4096
4848
rope_factor: 40
4949
beta_fast: 32
5050
mscale: 0.707
51+
rope_interleave: True
52+
rope_truncate: True
53+
rope_attention_scaling: False

src/MaxText/configs/models/deepseek2-236b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,6 @@ original_max_position_embeddings: 4096
4949
rope_factor: 40
5050
beta_fast: 32
5151
mscale: 0.707
52+
rope_interleave: True
53+
rope_truncate: True
54+
rope_attention_scaling: False

src/MaxText/configs/models/deepseek3-671b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ max_position_embeddings: 163840
5151
original_max_position_embeddings: 4096
5252
rope_factor: 40
5353
beta_fast: 32
54+
rope_interleave: True
55+
rope_truncate: True
56+
rope_attention_scaling: False

src/MaxText/configs/models/gpt-oss-120b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ original_max_position_embeddings: 4096
3434
rope_factor: 32
3535
beta_fast: 32
3636
beta_slow: 1
37+
rope_interleave: False
38+
rope_truncate: False
39+
rope_attention_scaling: True
3740

3841
# MLP
3942
base_mlp_dim: 2880

src/MaxText/configs/models/gpt-oss-20b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ original_max_position_embeddings: 4096
3434
rope_factor: 32
3535
beta_fast: 32
3636
beta_slow: 1
37+
rope_interleave: False
38+
rope_truncate: False
39+
rope_attention_scaling: True
3740

3841
# MLP
3942
base_mlp_dim: 2880

src/MaxText/convert_gpt_oss_ckpt.py

Lines changed: 346 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Convert weights from a GPT-OSS style model to a MaxText one.
16+
17+
Example cmd:
18+
19+
python3 -m MaxText.convert_gpt_oss_unscanned_ckpt --base-model-path <path/to/hf/ckpt> \
20+
--maxtext-model-path <GCS/path/to/save/new/maxtext/ckpt> --model-size gpt-oss-20b
21+
"""
22+
23+
# pylint: disable=g-line-too-long
24+
import argparse
25+
import gc
26+
import logging
27+
import os
28+
import pathlib
29+
import re
30+
from dataclasses import dataclass
31+
32+
os.environ["JAX_PLATFORMS"] = "cpu"
33+
34+
import ml_dtypes
35+
import numpy as np
36+
import psutil
37+
from safetensors import safe_open
38+
import torch
39+
from tqdm import tqdm
40+
41+
from MaxText import max_logging
42+
from MaxText.inference_utils import str2bool
43+
from MaxText.llama_or_mistral_ckpt import save_weights_to_checkpoint
44+
45+
46+
# NOTE: numpy doesn't have native support for bfloat16, so
47+
# we'll use ml_dtypes instead (which is quasi native)
48+
# NOTE: it's incredibly silly but you can't directly cast from
49+
# a torch tensor of type bfloat16 to a numpy array of type bfloat16
50+
# so we have to cast to float32 first
51+
CAST_DTYPE = ml_dtypes.bfloat16
52+
53+
54+
def _pt_to_np(pt_weight, cast_dtype=None, transpose=False):
55+
if cast_dtype:
56+
np_weight = pt_weight.to(torch.float32).numpy().astype(cast_dtype)
57+
else:
58+
np_weight = pt_weight.to(torch.float32).numpy()
59+
if transpose:
60+
np_weight = np_weight.transpose()
61+
return np_weight
62+
63+
64+
MODEL_PARAMS_DICT = {
65+
"gpt-oss-20b": {
66+
"base_emb_dim": 2880,
67+
"base_num_query_heads": 64,
68+
"base_num_kv_heads": 8,
69+
"head_dim": 64,
70+
"base_num_decoder_layers": 24,
71+
"inhomogeneous_layer_cycle_interval": 2,
72+
},
73+
"gpt-oss-120b": {
74+
"base_emb_dim": 2880,
75+
"base_num_query_heads": 64,
76+
"base_num_kv_heads": 8,
77+
"head_dim": 64,
78+
"base_num_decoder_layers": 36,
79+
"inhomogeneous_layer_cycle_interval": 2,
80+
},
81+
}
82+
83+
84+
def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
85+
"""
86+
Returns a mapping from HuggingFace model weight names to MaxText model weight names.
87+
88+
Args:
89+
layer_idx (int): Layer index.
90+
91+
Returns:
92+
dict [str, str]: Mapping from HuggingFace model weight names to MaxText model weight names.
93+
"""
94+
# pylint: disable=line-too-long
95+
return {
96+
"model.embed_tokens.weight": "tok_embeddings.weight",
97+
"model.norm.weight": "norm.weight",
98+
"lm_head.weight": "output.weight",
99+
# layernorm
100+
f"model.layers.{layer_idx}.input_layernorm.weight": f"layers.{layer_idx}.attention_norm.weight",
101+
f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"layers.{layer_idx}.ffn_norm.weight",
102+
# attention
103+
f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"layers.{layer_idx}.attention.wq.weight",
104+
f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight",
105+
f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight",
106+
f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight",
107+
f"model.layers.{layer_idx}.self_attn.q_proj.bias": f"layers.{layer_idx}.attention.wq.bias",
108+
f"model.layers.{layer_idx}.self_attn.k_proj.bias": f"layers.{layer_idx}.attention.wk.bias",
109+
f"model.layers.{layer_idx}.self_attn.v_proj.bias": f"layers.{layer_idx}.attention.wv.bias",
110+
f"model.layers.{layer_idx}.self_attn.o_proj.bias": f"layers.{layer_idx}.attention.wo.bias",
111+
f"model.layers.{layer_idx}.self_attn.sinks": f"layers.{layer_idx}.attention.sinks",
112+
# MoE
113+
f"model.layers.{layer_idx}.mlp.router.weight": f"layers.{layer_idx}.feed_forward.gate.weight",
114+
f"model.layers.{layer_idx}.mlp.router.bias": f"layers.{layer_idx}.feed_forward.gate.bias",
115+
f"model.layers.{layer_idx}.mlp.experts.gate_up_proj": f"layers.{layer_idx}.feed_forward.experts.gate_up_proj",
116+
f"model.layers.{layer_idx}.mlp.experts.gate_up_proj_bias": f"layers.{layer_idx}.feed_forward.experts.gate_up_proj_bias",
117+
f"model.layers.{layer_idx}.mlp.experts.down_proj": f"layers.{layer_idx}.feed_forward.experts.down_proj",
118+
f"model.layers.{layer_idx}.mlp.experts.down_proj_bias": f"layers.{layer_idx}.feed_forward.experts.down_proj_bias",
119+
}
120+
121+
122+
def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, model_params: dict, mem_info: psutil.Process):
123+
"""Convert a Huggingface Checkpoint to a dictionary of Numpy arrays representing the weights.
124+
125+
Args:
126+
base_model_path (str): Path to the base model checkpoint.
127+
model_size (str): Size of the base model.
128+
model_params (dict): Dictionary containing model parameters.
129+
mem_info (psutil.Process): Process object to track memory usage.
130+
131+
Returns:
132+
jax_weights (dict): Dictionary containing the converted weights.
133+
"""
134+
# model params
135+
base_num_decoder_layers = model_params["base_num_decoder_layers"]
136+
base_emb_dim = model_params["base_emb_dim"]
137+
base_num_query_heads = model_params["base_num_query_heads"]
138+
base_num_kv_heads = model_params["base_num_kv_heads"]
139+
head_dim = model_params["head_dim"]
140+
141+
# load model
142+
max_logging.log(f"Loading the base model from {base_model_path}")
143+
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
144+
chkpt_vars = {}
145+
for i, ckpt_path in enumerate(ckpt_paths):
146+
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
147+
148+
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
149+
for key in f.keys():
150+
parts = key.split(".")
151+
layer = int(parts[2]) if "layers" in key else 0
152+
mapped_key = _hf_to_maxtext_mapping(layer)[key]
153+
chkpt_vars[mapped_key] = f.get_tensor(key)
154+
155+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
156+
157+
# initialize the data structure for storing jax_weights
158+
jax_weights = {
159+
"token_embedder": {"embedding": None},
160+
"decoder": {
161+
"decoder_norm": {"scale": None},
162+
"logits_dense": {"kernel": None},
163+
},
164+
}
165+
for layer_idx in range(base_num_decoder_layers):
166+
jax_weights["decoder"][f"layers_{layer_idx}"] = {
167+
"pre_self_attention_layer_norm": {"scale": None},
168+
"post_self_attention_layer_norm": {"scale": None},
169+
"GptOssAttention": {
170+
"query": {"kernel": None, "bias": None},
171+
"key": {"kernel": None, "bias": None},
172+
"value": {"kernel": None, "bias": None},
173+
"out": {"kernel": None, "bias": None},
174+
"sinks": None,
175+
},
176+
"GptOssMlp": {
177+
"gate": {"kernel": None, "bias": None},
178+
"wi_0": None,
179+
"wi_0_bias": None,
180+
"wi_1": None,
181+
"wi_1_bias": None,
182+
"wo": None,
183+
"wo_bias": None,
184+
},
185+
}
186+
187+
# decoder norm scale ###########################################
188+
max_logging.log("Processing decoder norm scale")
189+
jax_weights["decoder"]["decoder_norm"]["scale"] = _pt_to_np(chkpt_vars["norm.weight"], cast_dtype=CAST_DTYPE)
190+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
191+
192+
# logits dense #################################################
193+
max_logging.log("Processing logits dense")
194+
195+
logit_dense = _pt_to_np(chkpt_vars["output.weight"], cast_dtype=CAST_DTYPE)
196+
jax_weights["decoder"]["logits_dense"]["kernel"] = logit_dense.transpose() # [:, :vocab_size]
197+
198+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
199+
200+
# token embedding ##############################################
201+
max_logging.log("Processing token embeddings")
202+
203+
jax_weights["token_embedder"]["embedding"] = _pt_to_np(chkpt_vars["tok_embeddings.weight"], cast_dtype=CAST_DTYPE)
204+
205+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
206+
207+
# self attention ###############################################
208+
max_logging.log("Processing self attention")
209+
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
210+
self_attention = jax_weights["decoder"][f"layers_{layer_idx}"]["GptOssAttention"]
211+
212+
wq = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wq.weight"], cast_dtype=CAST_DTYPE)
213+
wk = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wk.weight"], cast_dtype=CAST_DTYPE)
214+
wv = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wv.weight"], cast_dtype=CAST_DTYPE)
215+
w_post = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wo.weight"], cast_dtype=CAST_DTYPE)
216+
217+
# NOTE: not scale the query weights in checkpoint, but apply query_pre_attn_scalar=1/np.sqrt(head_dim) for attention
218+
# (num_attention_heads * head_dim, hidden_size) -> (hidden_size, num_attention_heads * head_dim) -> (hidden_size, num_attention_heads, head_dim)
219+
# [embed, q, head_dim]
220+
self_attention["query"]["kernel"] = wq.transpose().reshape([base_emb_dim, base_num_query_heads, head_dim])
221+
# [embed, kv, head_dim]
222+
self_attention["key"]["kernel"] = wk.transpose().reshape([base_emb_dim, base_num_kv_heads, head_dim])
223+
# [embed, kv, head_dim]
224+
self_attention["value"]["kernel"] = wv.transpose().reshape([base_emb_dim, base_num_kv_heads, head_dim])
225+
# (hidden_size, num_attention_heads * head_dim) -> (num_attention_heads * head_dim, hidden_size) -> (num_attention_heads, head_dim, hidden_size)
226+
# [q, head_dim, embed]
227+
self_attention["out"]["kernel"] = w_post.transpose().reshape([base_num_query_heads, head_dim, base_emb_dim])
228+
229+
sinks = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.sinks"], cast_dtype=CAST_DTYPE)
230+
wq_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wq.bias"], cast_dtype=CAST_DTYPE)
231+
wk_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wk.bias"], cast_dtype=CAST_DTYPE)
232+
wv_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wv.bias"], cast_dtype=CAST_DTYPE)
233+
w_post_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention.wo.bias"], cast_dtype=CAST_DTYPE)
234+
235+
self_attention["sinks"] = sinks
236+
self_attention["query"]["bias"] = wq_bias.reshape([base_num_query_heads, head_dim])
237+
self_attention["key"]["bias"] = wk_bias.reshape([base_num_kv_heads, head_dim])
238+
self_attention["value"]["bias"] = wv_bias.reshape([base_num_kv_heads, head_dim])
239+
self_attention["out"]["bias"] = w_post_bias
240+
241+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
242+
243+
# layer weight pre and post self attention norm ################
244+
max_logging.log("Processing pre and post self attention norms")
245+
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
246+
layer_weight = jax_weights["decoder"][f"layers_{layer_idx}"]
247+
pre_self_attention_layernorm = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.attention_norm.weight"], cast_dtype=CAST_DTYPE)
248+
post_self_attention_layernorm = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.ffn_norm.weight"], cast_dtype=CAST_DTYPE)
249+
250+
layer_weight["pre_self_attention_layer_norm"]["scale"] = pre_self_attention_layernorm # pylint: disable=E1137
251+
layer_weight["post_self_attention_layer_norm"]["scale"] = post_self_attention_layernorm # pylint: disable=E1137
252+
253+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
254+
255+
# layer weights ################################################
256+
max_logging.log("Processing layer weights")
257+
258+
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
259+
mlp_weight = jax_weights["decoder"][f"layers_{layer_idx}"]["GptOssMlp"]
260+
261+
gate = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.gate.weight"], cast_dtype=CAST_DTYPE)
262+
gate_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.gate.bias"], cast_dtype=CAST_DTYPE)
263+
wi_0_1 = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.gate_up_proj"], cast_dtype=CAST_DTYPE)
264+
wi_0_1_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.gate_up_proj_bias"], cast_dtype=CAST_DTYPE)
265+
wo = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.down_proj"], cast_dtype=CAST_DTYPE)
266+
wo_bias = _pt_to_np(chkpt_vars[f"layers.{layer_idx}.feed_forward.experts.down_proj_bias"], cast_dtype=CAST_DTYPE)
267+
268+
# router
269+
mlp_weight["gate"]["kernel"] = gate.transpose()
270+
mlp_weight["gate"]["bias"] = gate_bias
271+
# experts.gate_up_proj: de-interleave last dim, even for gate, odd for up_proj
272+
wi_0 = wi_0_1[..., ::2]
273+
wi_1 = wi_0_1[..., 1::2]
274+
del wi_0_1
275+
wi_0_bias = wi_0_1_bias[..., ::2]
276+
wi_1_bias = wi_0_1_bias[..., 1::2]
277+
del wi_0_1_bias
278+
mlp_weight["wi_0"] = wi_0
279+
mlp_weight["wi_1"] = wi_1
280+
mlp_weight["wi_0_bias"] = wi_0_bias
281+
mlp_weight["wi_1_bias"] = wi_1_bias
282+
# experts.down_proj
283+
mlp_weight["wo"] = wo
284+
mlp_weight["wo_bias"] = wo_bias
285+
286+
gc.collect()
287+
288+
del chkpt_vars
289+
gc.collect()
290+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
291+
return jax_weights
292+
293+
294+
def convert_to_jax_weights(base_model_path: str, model_size: str):
295+
"""
296+
Function to convert the checkpoint at base_model_path into Orbax checkpoint
297+
for MaxText and output jax_weights ready for MaxText
298+
299+
Attributes:
300+
base_model_path: checkpoint path
301+
model_size: gpt-oss-20b, gpt-oss-120b
302+
"""
303+
model_params = MODEL_PARAMS_DICT[model_size]
304+
mem_info = psutil.Process()
305+
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
306+
max_logging.log(f"Loading the base model from {base_model_path}")
307+
return _convert_huggingface_to_jax_weights(base_model_path, model_size, model_params, mem_info)
308+
309+
310+
if __name__ == "__main__":
311+
parser = argparse.ArgumentParser()
312+
parser.add_argument("--base-model-path", type=str, required=True)
313+
parser.add_argument("--maxtext-model-path", type=str, required=True)
314+
parser.add_argument("--model-size", type=str, required=True)
315+
parser.add_argument("--simulated-cpu-devices-count", type=int, required=False, default=16)
316+
parser.add_argument("--use-ocdbt", type=str2bool, required=False, default=True)
317+
parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True)
318+
args = parser.parse_args()
319+
320+
if args.model_size not in MODEL_PARAMS_DICT:
321+
raise NotImplementedError(f"Model '{args.model_size}' is not supported.")
322+
323+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}"
324+
base_weights_path = args.maxtext_model_path
325+
326+
save_weights_to_checkpoint(
327+
args.maxtext_model_path,
328+
convert_to_jax_weights(args.base_model_path, args.model_size),
329+
args.simulated_cpu_devices_count,
330+
args.use_ocdbt,
331+
args.use_zarr3,
332+
)
333+
max_logging.log(f"Successfully saved base_weights to {base_weights_path}.")

src/MaxText/layers/attentions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,9 @@ def init_rotary_embedding(self):
703703
rope_factor=self.config.rope_factor,
704704
embedding_dims=rope_embedding_dims,
705705
fprop_dtype=self.dtype,
706+
interleave=self.config.rope_interleave,
707+
truncate=self.config.rope_truncate,
708+
attention_scaling=self.config.rope_attention_scaling,
706709
rngs=self.rngs,
707710
)
708711
else:

0 commit comments

Comments
 (0)