Skip to content

Commit 47851e6

Browse files
author
pathfinder-fp
committed
fix nnx.data & copy=true
1 parent bce2119 commit 47851e6

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

python/sgl_jax/srt/layers/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(
163163
)
164164
)
165165
else:
166-
self.bias = None
166+
self.bias = nnx.data(None)
167167

168168
def tie_weights(self, embed_tokens: Embed):
169169
"""Tie the weights with word embeddings."""

python/sgl_jax/srt/layers/layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
if use_scale:
4949
self.scale = nnx.Param(scale_init(jax.random.PRNGKey(0), feature_shape, param_dtype))
5050
else:
51-
self.scale = None
51+
self.scale = nnx.data(None)
5252

5353
self.num_features = num_features
5454
self.epsilon = epsilon

python/sgl_jax/srt/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
)
5656
)
5757
else:
58-
self.bias = None
58+
self.bias = nnx.data(None)
5959

6060
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
6161
"""Forward pass of the linear layer."""

python/sgl_jax/srt/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
)
3434
)
3535
else:
36-
self.bias = None
36+
self.bias = nnx.data(None)
3737

3838
def __call__(self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array | None]:
3939
logits = hidden_states.astype(self.weight_dtype) @ self.kernel

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def jitted_run_model(
171171
logits_metadata,
172172
):
173173
model_state = jax.tree_util.tree_unflatten(model_state_def, model_state_leaves)
174-
model = nnx.merge(model_def, model_state)
174+
model = nnx.merge(model_def, model_state, copy=True)
175175
return model(forward_batch, token_to_kv_pool, logits_metadata)
176176

177177
@partial(jax.jit, static_argnames=["sampler_state_def", "mesh", "use_sort_for_toppk_minp"])
@@ -184,7 +184,7 @@ def jitted_sampler(
184184
*args,
185185
):
186186
model_state = jax.tree_util.tree_unflatten(sampler_state_def, sampler_state_leaves)
187-
sampler = nnx.merge(sampler_def, model_state)
187+
sampler = nnx.merge(sampler_def, model_state, copy=True)
188188
return sampler(*args, mesh=mesh, use_sort_for_toppk_minp=use_sort_for_toppk_minp)
189189

190190
def run_model_wrapper(forward_batch, logits_metadata):

0 commit comments

Comments
 (0)