diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb index 98182d3c9..49bda3dbf 100644 --- a/docs_nnx/guides/bridge_guide.ipynb +++ b/docs_nnx/guides/bridge_guide.ipynb @@ -36,7 +36,7 @@ "from flax.nnx import bridge\n", "import jax\n", "from jax import numpy as jnp\n", - "from jax.experimental import mesh_utils\n", + "from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n", "from typing import *" ] }, @@ -638,7 +638,7 @@ "\n", "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n", "\n", - "The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX).\n", + "The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen).\n", "\n", "### Linen to NNX\n", "\n", @@ -686,15 +686,14 @@ "\n", "\n", "print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n", - "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n", - " axis_names=('in', 'out'))\n", + "mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto))\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "with mesh:\n", + "with jax.set_mesh(mesh):\n", " model = create_sharded_nnx_module(x)\n", "\n", - "print(type(model.w)) # `nnx.Param`\n", - "print(model.w.sharding) # The partition annotation attached with `w`\n", - "print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" + "print(type(model.w)) # `nnx.Param`\n", + "print(model.w.sharding) # The partition annotation attached with `w`\n", + "print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh" ] }, { @@ -703,9 +702,9 @@ "metadata": {}, "source": [ " We have 8 fake JAX devices now to partition this model...\n", - " \n", - " ('in', 'out')\n", - " GSPMDSharding({devices=[2,4]<=[8]})" + " \n", + " NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)\n", + " NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)" ] }, { @@ -737,8 +736,9 @@ "source": [ "class NNXDotWithParititioning(nnx.Module):\n", " def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n", - " init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))\n", - " self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))\n", + " init_fn = nnx.initializers.lecun_normal()\n", + " self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)),\n", + " sharding_names=('in', 'out'))\n", " def __call__(self, x: jax.Array):\n", " return x @ self.w\n", "\n", @@ -751,7 +751,7 @@ " # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", " assert type(variables['params']['w']) == bridge.NNXMeta\n", " # The annotation coming from the `nnx.Param` => (in, out)\n", - " assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n", + " assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out')\n", "\n", " unboxed_variables = nn.unbox(variables)\n", " variable_pspecs = nn.get_partition_spec(variables)\n", @@ -763,7 +763,7 @@ " nn.get_partition_spec(variables))\n", " return sharded_vars\n", "\n", - "with mesh:\n", + "with jax.set_mesh(mesh):\n", " variables = create_sharded_variables(jax.random.key(0), x)\n", "\n", "# The underlying JAX array is sharded across the 2x4 mesh\n", @@ -774,7 +774,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - " GSPMDSharding({devices=[2,4]<=[8]})" + " NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)" ] }, { diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md index e0f347504..e1fc74b57 100644 --- a/docs_nnx/guides/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -36,7 +36,7 @@ from flax import linen as nn from flax.nnx import bridge import jax from jax import numpy as jnp -from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec as P, NamedSharding, AxisType from typing import * ``` @@ -336,7 +336,7 @@ Flax uses a metadata wrapper box over the raw JAX array to annotate how a variab In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. -The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX). +The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen). ### Linen to NNX @@ -367,21 +367,20 @@ def create_sharded_nnx_module(x): print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...') -mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), - axis_names=('in', 'out')) +mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto)) x = jax.random.normal(jax.random.key(42), (4, 32)) -with mesh: +with jax.set_mesh(mesh): model = create_sharded_nnx_module(x) -print(type(model.w)) # `nnx.Param` -print(model.w.sharding) # The partition annotation attached with `w` -print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh +print(type(model.w)) # `nnx.Param` +print(model.w.sharding) # The partition annotation attached with `w` +print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh ``` We have 8 fake JAX devices now to partition this model... - - ('in', 'out') - GSPMDSharding({devices=[2,4]<=[8]}) + + NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device) + NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device) +++ @@ -396,8 +395,9 @@ Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the ra ```{code-cell} ipython3 class NNXDotWithParititioning(nnx.Module): def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs): - init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out')) - self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim))) + init_fn = nnx.initializers.lecun_normal() + self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)), + sharding_names=('in', 'out')) def __call__(self, x: jax.Array): return x @ self.w @@ -410,7 +410,7 @@ def create_sharded_variables(key, x): # A `NNXMeta` wrapper of the underlying `nnx.Param` assert type(variables['params']['w']) == bridge.NNXMeta # The annotation coming from the `nnx.Param` => (in, out) - assert variables['params']['w'].metadata['sharding'] == ('in', 'out') + assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out') unboxed_variables = nn.unbox(variables) variable_pspecs = nn.get_partition_spec(variables) @@ -422,14 +422,14 @@ def create_sharded_variables(key, x): nn.get_partition_spec(variables)) return sharded_vars -with mesh: +with jax.set_mesh(mesh): variables = create_sharded_variables(jax.random.key(0), x) # The underlying JAX array is sharded across the 2x4 mesh print(variables['params']['w'].sharding) ``` - GSPMDSharding({devices=[2,4]<=[8]}) + NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device) +++ diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index e833df534..d4d1348a1 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -69,7 +69,11 @@ "outputs": [], "source": [ "# Create an auto-mode mesh of two dimensions and annotate each axis with a name.\n", - "auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))" + "auto_mesh = jax.make_mesh(\n", + " (2, 4),\n", + " ('data', 'model'),\n", + " axis_types=(AxisType.Auto, AxisType.Auto),\n", + ")" ] }, { @@ -203,7 +207,7 @@ "source": [ "### Initialize with style\n", "\n", - "When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n", + "When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n", "\n", "Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out." ] @@ -216,10 +220,9 @@ "source": [ "@jax.jit\n", "def init_sharded_linear(key):\n", - " init_fn = nnx.nn.linear.default_kernel_init\n", " # Shard your parameter along `model` dimension, as in model/tensor parallelism\n", " return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),\n", - " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))\n", + " kernel_metadata={'sharding_names': (None, 'model')})\n", "\n", "with jax.set_mesh(auto_mesh):\n", " key= rngs()\n", @@ -328,12 +331,12 @@ " init_fn = nnx.initializers.lecun_normal()\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", - " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n", - " use_bias=False, # or use `bias_init` to give it annotation too\n", + " kernel_metadata={'sharding_names': (None, 'model')},\n", + " use_bias=False, # or use `bias_metadata` to give it annotation too\n", " rngs=rngs)\n", " self.w2 = nnx.Param(\n", " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", - " sharding=('model', None),\n", + " sharding=('model', None), # same as sharding_names=('model', None)\n", " )\n", "\n", " def __call__(self, x: jax.Array):\n", @@ -512,8 +515,8 @@ " init_fn = nnx.initializers.lecun_normal()\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", - " kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),\n", - " use_bias=False, # or use `bias_init` to give it annotation too\n", + " kernel_metadata={'sharding_names': ('embed', 'hidden')},\n", + " use_bias=False, # or use `bias_metadata` to give it annotation too\n", " rngs=rngs)\n", " self.w2 = nnx.Param(\n", " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index f5b8a0661..528858303 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -44,7 +44,11 @@ In this guide we use a standard FSDP layout and shard our devices on two axes - ```{code-cell} ipython3 # Create an auto-mode mesh of two dimensions and annotate each axis with a name. -auto_mesh = jax.make_mesh((2, 4), ('data', 'model')) +auto_mesh = jax.make_mesh( + (2, 4), + ('data', 'model'), + axis_types=(AxisType.Auto, AxisType.Auto), +) ``` > Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function. @@ -89,17 +93,16 @@ with jax.set_mesh(auto_mesh): ### Initialize with style -When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight. +When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight. Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out. ```{code-cell} ipython3 @jax.jit def init_sharded_linear(key): - init_fn = nnx.nn.linear.default_kernel_init # Shard your parameter along `model` dimension, as in model/tensor parallelism return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key), - kernel_init=nnx.with_partitioning(init_fn, (None, 'model'))) + kernel_metadata={'sharding_names': (None, 'model')}) with jax.set_mesh(auto_mesh): key= rngs() @@ -144,12 +147,12 @@ class DotReluDot(nnx.Module): init_fn = nnx.initializers.lecun_normal() self.dot1 = nnx.Linear( depth, depth, - kernel_init=nnx.with_partitioning(init_fn, (None, 'model')), - use_bias=False, # or use `bias_init` to give it annotation too + kernel_metadata={'sharding_names': (None, 'model')}, + use_bias=False, # or use `bias_metadata` to give it annotation too rngs=rngs) self.w2 = nnx.Param( init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation - sharding=('model', None), + sharding=('model', None), # same as sharding_names=('model', None) ) def __call__(self, x: jax.Array): @@ -258,8 +261,8 @@ class LogicalDotReluDot(nnx.Module): init_fn = nnx.initializers.lecun_normal() self.dot1 = nnx.Linear( depth, depth, - kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')), - use_bias=False, # or use `bias_init` to give it annotation too + kernel_metadata={'sharding_names': ('embed', 'hidden')}, + use_bias=False, # or use `bias_metadata` to give it annotation too rngs=rngs) self.w2 = nnx.Param( init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation