Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 3 additions & 140 deletions python/sgl_jax/srt/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
)

self.head_dim = (self.head_dim_original + 127) // 128 * 128
self.head_dim_pad = self.head_dim - self.head_dim_original

if hasattr(self.mesh, "shape") and "tensor" in self.mesh.shape:
self.sharding_size = self.mesh.shape["tensor"]
Expand Down Expand Up @@ -281,9 +280,6 @@ def _handle_single_weight(
if mapping.reshape is not None:
processed_weight = jnp.reshape(processed_weight, mapping.reshape)

if mapping.head_dim_padding and self.head_dim_pad > 0:
processed_weight = self._apply_head_dim_padding(processed_weight, hf_key, mapping)

if mapping.kv_head_padding:
processed_weight = self._apply_kv_head_padding(processed_weight, hf_key)

Expand Down Expand Up @@ -313,32 +309,16 @@ def _split_qkv_weight(
):
jax_paths = mapping.target_path

if hf_key.endswith(".bias"):
q_dim = self.num_heads * self.head_dim_original
kv_dim = self.num_kv_heads * self.head_dim_original
q_dim = self.num_heads * self.head_dim_original
kv_dim = self.num_kv_heads * self.head_dim_original

if hf_key.endswith(".bias"):
q_bias = weight[:q_dim]
k_bias = weight[q_dim : q_dim + kv_dim]
v_bias = weight[q_dim + kv_dim : q_dim + 2 * kv_dim]

if mapping.head_dim_padding and self.head_dim_pad > 0:
q_bias = jnp.reshape(q_bias, (self.num_heads, self.head_dim_original))
q_bias = jnp.pad(q_bias, ((0, 0), (0, self.head_dim_pad)))
q_bias = jnp.reshape(q_bias, (self.num_heads * self.head_dim,))

k_bias = jnp.reshape(k_bias, (self.num_kv_heads, self.head_dim_original))
k_bias = jnp.pad(k_bias, ((0, 0), (0, self.head_dim_pad)))
k_bias = jnp.reshape(k_bias, (self.num_kv_heads * self.head_dim,))

v_bias = jnp.reshape(v_bias, (self.num_kv_heads, self.head_dim_original))
v_bias = jnp.pad(v_bias, ((0, 0), (0, self.head_dim_pad)))
v_bias = jnp.reshape(v_bias, (self.num_kv_heads * self.head_dim,))

splits = [q_bias, k_bias, v_bias]
else:
q_dim = self.num_heads * self.head_dim_original
kv_dim = self.num_kv_heads * self.head_dim_original

if mapping.transpose:
q_weight = weight[:, :q_dim]
k_weight = weight[:, q_dim : q_dim + kv_dim]
Expand All @@ -348,62 +328,6 @@ def _split_qkv_weight(
k_weight = weight[q_dim : q_dim + kv_dim, :]
v_weight = weight[q_dim + kv_dim : q_dim + 2 * kv_dim, :]

if mapping.head_dim_padding and self.head_dim_pad > 0:
if mapping.transpose:
q_weight = jnp.reshape(
q_weight,
(self.hidden_size, self.num_heads, self.head_dim_original),
)
q_weight = jnp.pad(q_weight, ((0, 0), (0, 0), (0, self.head_dim_pad)))
q_weight = jnp.reshape(
q_weight, (self.hidden_size, self.num_heads * self.head_dim)
)

k_weight = jnp.reshape(
k_weight,
(self.hidden_size, self.num_kv_heads, self.head_dim_original),
)
k_weight = jnp.pad(k_weight, ((0, 0), (0, 0), (0, self.head_dim_pad)))
k_weight = jnp.reshape(
k_weight, (self.hidden_size, self.num_kv_heads * self.head_dim)
)

v_weight = jnp.reshape(
v_weight,
(self.hidden_size, self.num_kv_heads, self.head_dim_original),
)
v_weight = jnp.pad(v_weight, ((0, 0), (0, 0), (0, self.head_dim_pad)))
v_weight = jnp.reshape(
v_weight, (self.hidden_size, self.num_kv_heads * self.head_dim)
)
else:
q_weight = jnp.reshape(
q_weight,
(self.num_heads, self.head_dim_original, self.hidden_size),
)
q_weight = jnp.pad(q_weight, ((0, 0), (0, self.head_dim_pad), (0, 0)))
q_weight = jnp.reshape(
q_weight, (self.num_heads * self.head_dim, self.hidden_size)
)

k_weight = jnp.reshape(
k_weight,
(self.num_kv_heads, self.head_dim_original, self.hidden_size),
)
k_weight = jnp.pad(k_weight, ((0, 0), (0, self.head_dim_pad), (0, 0)))
k_weight = jnp.reshape(
k_weight, (self.num_kv_heads * self.head_dim, self.hidden_size)
)

v_weight = jnp.reshape(
v_weight,
(self.num_kv_heads, self.head_dim_original, self.hidden_size),
)
v_weight = jnp.pad(v_weight, ((0, 0), (0, self.head_dim_pad), (0, 0)))
v_weight = jnp.reshape(
v_weight, (self.num_kv_heads * self.head_dim, self.hidden_size)
)

splits = [q_weight, k_weight, v_weight]

for split_weight, jax_path in zip(splits, jax_paths):
Expand Down Expand Up @@ -460,67 +384,6 @@ def _get_param(self, params: nnx.State, path: str) -> nnx.State:

return current_level

def _apply_head_dim_padding(
self, weight: jax.Array, hf_key: str, mapping: WeightMapping
) -> jax.Array:
if hf_key.endswith(".bias"):
if any(proj in hf_key for proj in ["q_proj", "k_proj", "v_proj"]):
if "q_proj" in hf_key:
reshaped = jnp.reshape(weight, (self.num_heads, self.head_dim_original))
padded = jnp.pad(reshaped, ((0, 0), (0, self.head_dim_pad)))
return jnp.reshape(padded, (self.num_heads * self.head_dim,))
else: # k_proj or v_proj
reshaped = jnp.reshape(weight, (self.num_kv_heads, self.head_dim_original))
padded = jnp.pad(reshaped, ((0, 0), (0, self.head_dim_pad)))
return jnp.reshape(padded, (self.num_kv_heads * self.head_dim,))
else:
if mapping.reshape is not None:
if "o_proj" in hf_key:
padded = jnp.pad(weight, ((0, 0), (0, 0), (0, self.head_dim_pad)))
else:
padded = jnp.pad(weight, ((0, 0), (0, self.head_dim_pad), (0, 0)))
return padded
else:
if mapping.transpose:
if "q_proj" in hf_key:
reshaped = jnp.reshape(
weight,
(self.hidden_size, self.num_heads, self.head_dim_original),
)
padded = jnp.pad(reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad)))
return jnp.reshape(
padded, (self.hidden_size, self.num_heads * self.head_dim)
)
elif any(proj in hf_key for proj in ["k_proj", "v_proj"]):
reshaped = jnp.reshape(
weight,
(
self.hidden_size,
self.num_kv_heads,
self.head_dim_original,
),
)
padded = jnp.pad(reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad)))
return jnp.reshape(
padded,
(self.hidden_size, self.num_kv_heads * self.head_dim),
)
elif "o_proj" in hf_key:
reshaped = jnp.reshape(
weight,
(self.num_heads * self.head_dim_original, self.hidden_size),
)
padded_reshaped = jnp.reshape(
reshaped,
(self.num_heads, self.head_dim_original, self.hidden_size),
)
padded = jnp.pad(padded_reshaped, ((0, 0), (0, self.head_dim_pad), (0, 0)))
return jnp.reshape(
padded, (self.num_heads * self.head_dim, self.hidden_size)
)

return weight

def _apply_kv_head_padding(self, weight: jax.Array, hf_key: str) -> jax.Array:
"""Apply KV head padding/replication when tp_size > total_kv_heads."""
if any(
Expand Down