|
36 | 36 | "from flax.nnx import bridge\n", |
37 | 37 | "import jax\n", |
38 | 38 | "from jax import numpy as jnp\n", |
39 | | - "from jax.experimental import mesh_utils\n", |
| 39 | + "from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n", |
40 | 40 | "from typing import *" |
41 | 41 | ] |
42 | 42 | }, |
|
686 | 686 | "\n", |
687 | 687 | "\n", |
688 | 688 | "print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n", |
689 | | - "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n", |
690 | | - " axis_names=('in', 'out'))\n", |
| 689 | + "mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto))\n", |
691 | 690 | "x = jax.random.normal(jax.random.key(42), (4, 32))\n", |
692 | | - "with mesh:\n", |
| 691 | + "with jax.set_mesh(mesh):\n", |
693 | 692 | " model = create_sharded_nnx_module(x)\n", |
694 | 693 | "\n", |
695 | | - "print(type(model.w)) # `nnx.Param`\n", |
696 | | - "print(model.w.sharding) # The partition annotation attached with `w`\n", |
697 | | - "print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" |
| 694 | + "print(type(model.w)) # `nnx.Param`\n", |
| 695 | + "print(model.w.sharding) # The partition annotation attached with `w`\n", |
| 696 | + "print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh" |
698 | 697 | ] |
699 | 698 | }, |
700 | 699 | { |
|
703 | 702 | "metadata": {}, |
704 | 703 | "source": [ |
705 | 704 | " We have 8 fake JAX devices now to partition this model...\n", |
706 | | - " <class 'flax.nnx.variables.Param'>\n", |
707 | | - " ('in', 'out')\n", |
708 | | - " GSPMDSharding({devices=[2,4]<=[8]})" |
| 705 | + " <class 'flax.nnx.variablelib.Param'>\n", |
| 706 | + " NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)\n", |
| 707 | + " NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)" |
709 | 708 | ] |
710 | 709 | }, |
711 | 710 | { |
|
737 | 736 | "source": [ |
738 | 737 | "class NNXDotWithParititioning(nnx.Module):\n", |
739 | 738 | " def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n", |
740 | | - " init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))\n", |
741 | | - " self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))\n", |
| 739 | + " init_fn = nnx.initializers.lecun_normal()\n", |
| 740 | + " self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)),\n", |
| 741 | + " sharding_names=('in', 'out'))\n", |
742 | 742 | " def __call__(self, x: jax.Array):\n", |
743 | 743 | " return x @ self.w\n", |
744 | 744 | "\n", |
|
751 | 751 | " # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", |
752 | 752 | " assert type(variables['params']['w']) == bridge.NNXMeta\n", |
753 | 753 | " # The annotation coming from the `nnx.Param` => (in, out)\n", |
754 | | - " assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n", |
| 754 | + " assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out')\n", |
755 | 755 | "\n", |
756 | 756 | " unboxed_variables = nn.unbox(variables)\n", |
757 | 757 | " variable_pspecs = nn.get_partition_spec(variables)\n", |
|
763 | 763 | " nn.get_partition_spec(variables))\n", |
764 | 764 | " return sharded_vars\n", |
765 | 765 | "\n", |
766 | | - "with mesh:\n", |
| 766 | + "with jax.set_mesh(mesh):\n", |
767 | 767 | " variables = create_sharded_variables(jax.random.key(0), x)\n", |
768 | 768 | "\n", |
769 | 769 | "# The underlying JAX array is sharded across the 2x4 mesh\n", |
|
774 | 774 | "cell_type": "markdown", |
775 | 775 | "metadata": {}, |
776 | 776 | "source": [ |
777 | | - " GSPMDSharding({devices=[2,4]<=[8]})" |
| 777 | + " NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)" |
778 | 778 | ] |
779 | 779 | }, |
780 | 780 | { |
|
0 commit comments