Skip to content

Commit 841dc02

Browse files
committed
Updated bridge guide
1 parent 676b1e7 commit 841dc02

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

docs_nnx/guides/bridge_guide.ipynb

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"from flax.nnx import bridge\n",
3737
"import jax\n",
3838
"from jax import numpy as jnp\n",
39-
"from jax.experimental import mesh_utils\n",
39+
"from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n",
4040
"from typing import *"
4141
]
4242
},
@@ -686,15 +686,14 @@
686686
"\n",
687687
"\n",
688688
"print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n",
689-
"mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n",
690-
" axis_names=('in', 'out'))\n",
689+
"mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto))\n",
691690
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
692-
"with mesh:\n",
691+
"with jax.set_mesh(mesh):\n",
693692
" model = create_sharded_nnx_module(x)\n",
694693
"\n",
695-
"print(type(model.w)) # `nnx.Param`\n",
696-
"print(model.w.sharding) # The partition annotation attached with `w`\n",
697-
"print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh"
694+
"print(type(model.w)) # `nnx.Param`\n",
695+
"print(model.w.sharding) # The partition annotation attached with `w`\n",
696+
"print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh"
698697
]
699698
},
700699
{
@@ -703,9 +702,9 @@
703702
"metadata": {},
704703
"source": [
705704
" We have 8 fake JAX devices now to partition this model...\n",
706-
" <class 'flax.nnx.variables.Param'>\n",
707-
" ('in', 'out')\n",
708-
" GSPMDSharding({devices=[2,4]<=[8]})"
705+
" <class 'flax.nnx.variablelib.Param'>\n",
706+
" NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)\n",
707+
" NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)"
709708
]
710709
},
711710
{
@@ -737,8 +736,9 @@
737736
"source": [
738737
"class NNXDotWithParititioning(nnx.Module):\n",
739738
" def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n",
740-
" init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))\n",
741-
" self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))\n",
739+
" init_fn = nnx.initializers.lecun_normal()\n",
740+
" self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)),\n",
741+
" sharding_names=('in', 'out'))\n",
742742
" def __call__(self, x: jax.Array):\n",
743743
" return x @ self.w\n",
744744
"\n",
@@ -751,7 +751,7 @@
751751
" # A `NNXMeta` wrapper of the underlying `nnx.Param`\n",
752752
" assert type(variables['params']['w']) == bridge.NNXMeta\n",
753753
" # The annotation coming from the `nnx.Param` => (in, out)\n",
754-
" assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n",
754+
" assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out')\n",
755755
"\n",
756756
" unboxed_variables = nn.unbox(variables)\n",
757757
" variable_pspecs = nn.get_partition_spec(variables)\n",
@@ -763,7 +763,7 @@
763763
" nn.get_partition_spec(variables))\n",
764764
" return sharded_vars\n",
765765
"\n",
766-
"with mesh:\n",
766+
"with jax.set_mesh(mesh):\n",
767767
" variables = create_sharded_variables(jax.random.key(0), x)\n",
768768
"\n",
769769
"# The underlying JAX array is sharded across the 2x4 mesh\n",
@@ -774,7 +774,7 @@
774774
"cell_type": "markdown",
775775
"metadata": {},
776776
"source": [
777-
" GSPMDSharding({devices=[2,4]<=[8]})"
777+
" NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)"
778778
]
779779
},
780780
{

docs_nnx/guides/bridge_guide.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ from flax import linen as nn
3636
from flax.nnx import bridge
3737
import jax
3838
from jax import numpy as jnp
39-
from jax.experimental import mesh_utils
39+
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
4040
from typing import *
4141
```
4242

@@ -367,21 +367,20 @@ def create_sharded_nnx_module(x):
367367
368368
369369
print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')
370-
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),
371-
axis_names=('in', 'out'))
370+
mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto))
372371
x = jax.random.normal(jax.random.key(42), (4, 32))
373-
with mesh:
372+
with jax.set_mesh(mesh):
374373
model = create_sharded_nnx_module(x)
375374
376-
print(type(model.w)) # `nnx.Param`
377-
print(model.w.sharding) # The partition annotation attached with `w`
378-
print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh
375+
print(type(model.w)) # `nnx.Param`
376+
print(model.w.sharding) # The partition annotation attached with `w`
377+
print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh
379378
```
380379

381380
We have 8 fake JAX devices now to partition this model...
382-
<class 'flax.nnx.variables.Param'>
383-
('in', 'out')
384-
GSPMDSharding({devices=[2,4]<=[8]})
381+
<class 'flax.nnx.variablelib.Param'>
382+
NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)
383+
NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)
385384

386385
+++
387386

@@ -396,8 +395,9 @@ Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the ra
396395
```{code-cell} ipython3
397396
class NNXDotWithParititioning(nnx.Module):
398397
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
399-
init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
400-
self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))
398+
init_fn = nnx.initializers.lecun_normal()
399+
self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)),
400+
sharding_names=('in', 'out'))
401401
def __call__(self, x: jax.Array):
402402
return x @ self.w
403403
@@ -410,7 +410,7 @@ def create_sharded_variables(key, x):
410410
# A `NNXMeta` wrapper of the underlying `nnx.Param`
411411
assert type(variables['params']['w']) == bridge.NNXMeta
412412
# The annotation coming from the `nnx.Param` => (in, out)
413-
assert variables['params']['w'].metadata['sharding'] == ('in', 'out')
413+
assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out')
414414
415415
unboxed_variables = nn.unbox(variables)
416416
variable_pspecs = nn.get_partition_spec(variables)
@@ -422,14 +422,14 @@ def create_sharded_variables(key, x):
422422
nn.get_partition_spec(variables))
423423
return sharded_vars
424424
425-
with mesh:
425+
with jax.set_mesh(mesh):
426426
variables = create_sharded_variables(jax.random.key(0), x)
427427
428428
# The underlying JAX array is sharded across the 2x4 mesh
429429
print(variables['params']['w'].sharding)
430430
```
431431

432-
GSPMDSharding({devices=[2,4]<=[8]})
432+
NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)
433433

434434
+++
435435

0 commit comments

Comments
 (0)