|
59 | 59 | "import numpy as np\n",
|
60 | 60 | "import jax.numpy as jnp\n",
|
61 | 61 | "from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n",
|
62 |
| - "from jax.experimental.shard import reshard, auto_axes\n", |
| 62 | + "from jax.experimental.shard import reshard, auto_axes, explicit_axes\n", |
63 | 63 | "\n",
|
64 | 64 | "jax.config.update('jax_num_cpu_devices', 8)"
|
65 | 65 | ]
|
|
652 | 652 | "id": "_3sfJjRq8w9f"
|
653 | 653 | },
|
654 | 654 | "source": [
|
655 |
| - "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." |
| 655 | + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n", |
| 656 | + "\n", |
| 657 | + "\n", |
| 658 | + "You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes." |
| 659 | + ] |
| 660 | + }, |
| 661 | + { |
| 662 | + "cell_type": "code", |
| 663 | + "execution_count": null, |
| 664 | + "id": "a102e9c7", |
| 665 | + "metadata": {}, |
| 666 | + "outputs": [], |
| 667 | + "source": [ |
| 668 | + "auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", |
| 669 | + " axis_types=(AxisType.Auto, AxisType.Auto))\n", |
| 670 | + "\n", |
| 671 | + "@functools.partial(explicit_axes, axes=('X', 'Y'))\n", |
| 672 | + "def explicit_g(y):\n", |
| 673 | + " print(f'mesh inside g: {get_abstract_mesh()}')\n", |
| 674 | + " print(f'y.sharding inside g: {jax.typeof(y) = }')\n", |
| 675 | + " z = y * 2\n", |
| 676 | + " print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n", |
| 677 | + " return z\n", |
| 678 | + "\n", |
| 679 | + "@jax.jit\n", |
| 680 | + "def f(arr1):\n", |
| 681 | + " print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n", |
| 682 | + " x = jnp.sin(arr1)\n", |
| 683 | + "\n", |
| 684 | + " z = explicit_g(x, in_shardings=P(\"X\", \"Y\"))\n", |
| 685 | + "\n", |
| 686 | + " return z + 1\n", |
| 687 | + "\n", |
| 688 | + "with jax.sharding.use_mesh(auto_mesh):\n", |
| 689 | + " some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", |
| 690 | + " f(some_x)" |
| 691 | + ] |
| 692 | + }, |
| 693 | + { |
| 694 | + "cell_type": "markdown", |
| 695 | + "id": "e64d40de", |
| 696 | + "metadata": {}, |
| 697 | + "source": [ |
| 698 | + "As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n", |
| 699 | + "Because of that, sharding is visible on the type of arrays inside `g`." |
656 | 700 | ]
|
657 | 701 | },
|
658 | 702 | {
|
|
0 commit comments