diff --git a/python/sgl_jax/srt/utils/weight_utils.py b/python/sgl_jax/srt/utils/weight_utils.py index f5a4d5a1c..fa36ca413 100644 --- a/python/sgl_jax/srt/utils/weight_utils.py +++ b/python/sgl_jax/srt/utils/weight_utils.py @@ -78,7 +78,6 @@ def __init__( ) self.head_dim = (self.head_dim_original + 127) // 128 * 128 - self.head_dim_pad = self.head_dim - self.head_dim_original if hasattr(self.mesh, "shape") and "tensor" in self.mesh.shape: self.sharding_size = self.mesh.shape["tensor"] @@ -281,9 +280,6 @@ def _handle_single_weight( if mapping.reshape is not None: processed_weight = jnp.reshape(processed_weight, mapping.reshape) - if mapping.head_dim_padding and self.head_dim_pad > 0: - processed_weight = self._apply_head_dim_padding(processed_weight, hf_key, mapping) - if mapping.kv_head_padding: processed_weight = self._apply_kv_head_padding(processed_weight, hf_key) @@ -313,32 +309,16 @@ def _split_qkv_weight( ): jax_paths = mapping.target_path - if hf_key.endswith(".bias"): - q_dim = self.num_heads * self.head_dim_original - kv_dim = self.num_kv_heads * self.head_dim_original + q_dim = self.num_heads * self.head_dim_original + kv_dim = self.num_kv_heads * self.head_dim_original + if hf_key.endswith(".bias"): q_bias = weight[:q_dim] k_bias = weight[q_dim : q_dim + kv_dim] v_bias = weight[q_dim + kv_dim : q_dim + 2 * kv_dim] - if mapping.head_dim_padding and self.head_dim_pad > 0: - q_bias = jnp.reshape(q_bias, (self.num_heads, self.head_dim_original)) - q_bias = jnp.pad(q_bias, ((0, 0), (0, self.head_dim_pad))) - q_bias = jnp.reshape(q_bias, (self.num_heads * self.head_dim,)) - - k_bias = jnp.reshape(k_bias, (self.num_kv_heads, self.head_dim_original)) - k_bias = jnp.pad(k_bias, ((0, 0), (0, self.head_dim_pad))) - k_bias = jnp.reshape(k_bias, (self.num_kv_heads * self.head_dim,)) - - v_bias = jnp.reshape(v_bias, (self.num_kv_heads, self.head_dim_original)) - v_bias = jnp.pad(v_bias, ((0, 0), (0, self.head_dim_pad))) - v_bias = jnp.reshape(v_bias, (self.num_kv_heads * self.head_dim,)) - splits = [q_bias, k_bias, v_bias] else: - q_dim = self.num_heads * self.head_dim_original - kv_dim = self.num_kv_heads * self.head_dim_original - if mapping.transpose: q_weight = weight[:, :q_dim] k_weight = weight[:, q_dim : q_dim + kv_dim] @@ -348,62 +328,6 @@ def _split_qkv_weight( k_weight = weight[q_dim : q_dim + kv_dim, :] v_weight = weight[q_dim + kv_dim : q_dim + 2 * kv_dim, :] - if mapping.head_dim_padding and self.head_dim_pad > 0: - if mapping.transpose: - q_weight = jnp.reshape( - q_weight, - (self.hidden_size, self.num_heads, self.head_dim_original), - ) - q_weight = jnp.pad(q_weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) - q_weight = jnp.reshape( - q_weight, (self.hidden_size, self.num_heads * self.head_dim) - ) - - k_weight = jnp.reshape( - k_weight, - (self.hidden_size, self.num_kv_heads, self.head_dim_original), - ) - k_weight = jnp.pad(k_weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) - k_weight = jnp.reshape( - k_weight, (self.hidden_size, self.num_kv_heads * self.head_dim) - ) - - v_weight = jnp.reshape( - v_weight, - (self.hidden_size, self.num_kv_heads, self.head_dim_original), - ) - v_weight = jnp.pad(v_weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) - v_weight = jnp.reshape( - v_weight, (self.hidden_size, self.num_kv_heads * self.head_dim) - ) - else: - q_weight = jnp.reshape( - q_weight, - (self.num_heads, self.head_dim_original, self.hidden_size), - ) - q_weight = jnp.pad(q_weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) - q_weight = jnp.reshape( - q_weight, (self.num_heads * self.head_dim, self.hidden_size) - ) - - k_weight = jnp.reshape( - k_weight, - (self.num_kv_heads, self.head_dim_original, self.hidden_size), - ) - k_weight = jnp.pad(k_weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) - k_weight = jnp.reshape( - k_weight, (self.num_kv_heads * self.head_dim, self.hidden_size) - ) - - v_weight = jnp.reshape( - v_weight, - (self.num_kv_heads, self.head_dim_original, self.hidden_size), - ) - v_weight = jnp.pad(v_weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) - v_weight = jnp.reshape( - v_weight, (self.num_kv_heads * self.head_dim, self.hidden_size) - ) - splits = [q_weight, k_weight, v_weight] for split_weight, jax_path in zip(splits, jax_paths): @@ -460,67 +384,6 @@ def _get_param(self, params: nnx.State, path: str) -> nnx.State: return current_level - def _apply_head_dim_padding( - self, weight: jax.Array, hf_key: str, mapping: WeightMapping - ) -> jax.Array: - if hf_key.endswith(".bias"): - if any(proj in hf_key for proj in ["q_proj", "k_proj", "v_proj"]): - if "q_proj" in hf_key: - reshaped = jnp.reshape(weight, (self.num_heads, self.head_dim_original)) - padded = jnp.pad(reshaped, ((0, 0), (0, self.head_dim_pad))) - return jnp.reshape(padded, (self.num_heads * self.head_dim,)) - else: # k_proj or v_proj - reshaped = jnp.reshape(weight, (self.num_kv_heads, self.head_dim_original)) - padded = jnp.pad(reshaped, ((0, 0), (0, self.head_dim_pad))) - return jnp.reshape(padded, (self.num_kv_heads * self.head_dim,)) - else: - if mapping.reshape is not None: - if "o_proj" in hf_key: - padded = jnp.pad(weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) - else: - padded = jnp.pad(weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) - return padded - else: - if mapping.transpose: - if "q_proj" in hf_key: - reshaped = jnp.reshape( - weight, - (self.hidden_size, self.num_heads, self.head_dim_original), - ) - padded = jnp.pad(reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad))) - return jnp.reshape( - padded, (self.hidden_size, self.num_heads * self.head_dim) - ) - elif any(proj in hf_key for proj in ["k_proj", "v_proj"]): - reshaped = jnp.reshape( - weight, - ( - self.hidden_size, - self.num_kv_heads, - self.head_dim_original, - ), - ) - padded = jnp.pad(reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad))) - return jnp.reshape( - padded, - (self.hidden_size, self.num_kv_heads * self.head_dim), - ) - elif "o_proj" in hf_key: - reshaped = jnp.reshape( - weight, - (self.num_heads * self.head_dim_original, self.hidden_size), - ) - padded_reshaped = jnp.reshape( - reshaped, - (self.num_heads, self.head_dim_original, self.hidden_size), - ) - padded = jnp.pad(padded_reshaped, ((0, 0), (0, self.head_dim_pad), (0, 0))) - return jnp.reshape( - padded, (self.num_heads * self.head_dim, self.hidden_size) - ) - - return weight - def _apply_kv_head_padding(self, weight: jax.Array, hf_key: str) -> jax.Array: """Apply KV head padding/replication when tp_size > total_kv_heads.""" if any(