Skip to content

Commit 3b359ba

Browse files
Merge pull request #28027 from jax-ml:explicit_axes
PiperOrigin-RevId: 747889283
2 parents 4d692d1 + 09edc49 commit 3b359ba

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

docs/notebooks/explicit-sharding.ipynb

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"import numpy as np\n",
6060
"import jax.numpy as jnp\n",
6161
"from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n",
62-
"from jax.experimental.shard import reshard, auto_axes\n",
62+
"from jax.experimental.shard import reshard, auto_axes, explicit_axes\n",
6363
"\n",
6464
"jax.config.update('jax_num_cpu_devices', 8)"
6565
]
@@ -652,7 +652,51 @@
652652
"id": "_3sfJjRq8w9f"
653653
},
654654
"source": [
655-
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
655+
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n",
656+
"\n",
657+
"\n",
658+
"You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes."
659+
]
660+
},
661+
{
662+
"cell_type": "code",
663+
"execution_count": null,
664+
"id": "a102e9c7",
665+
"metadata": {},
666+
"outputs": [],
667+
"source": [
668+
"auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n",
669+
" axis_types=(AxisType.Auto, AxisType.Auto))\n",
670+
"\n",
671+
"@functools.partial(explicit_axes, axes=('X', 'Y'))\n",
672+
"def explicit_g(y):\n",
673+
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
674+
" print(f'y.sharding inside g: {jax.typeof(y) = }')\n",
675+
" z = y * 2\n",
676+
" print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n",
677+
" return z\n",
678+
"\n",
679+
"@jax.jit\n",
680+
"def f(arr1):\n",
681+
" print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n",
682+
" x = jnp.sin(arr1)\n",
683+
"\n",
684+
" z = explicit_g(x, in_shardings=P(\"X\", \"Y\"))\n",
685+
"\n",
686+
" return z + 1\n",
687+
"\n",
688+
"with jax.sharding.use_mesh(auto_mesh):\n",
689+
" some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
690+
" f(some_x)"
691+
]
692+
},
693+
{
694+
"cell_type": "markdown",
695+
"id": "e64d40de",
696+
"metadata": {},
697+
"source": [
698+
"As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n",
699+
"Because of that, sharding is visible on the type of arrays inside `g`."
656700
]
657701
},
658702
{

docs/notebooks/explicit-sharding.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ import jax
5656
import numpy as np
5757
import jax.numpy as jnp
5858
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
59-
from jax.experimental.shard import reshard, auto_axes
59+
from jax.experimental.shard import reshard, auto_axes, explicit_axes
6060
6161
jax.config.update('jax_num_cpu_devices', 8)
6262
```
@@ -403,6 +403,38 @@ f(some_x)
403403

404404
As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.
405405

406+
407+
You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes.
408+
409+
```{code-cell} ipython3
410+
auto_mesh = jax.make_mesh((2, 4), ("X", "Y"),
411+
axis_types=(AxisType.Auto, AxisType.Auto))
412+
413+
@functools.partial(explicit_axes, axes=('X', 'Y'))
414+
def explicit_g(y):
415+
print(f'mesh inside g: {get_abstract_mesh()}')
416+
print(f'y.sharding inside g: {jax.typeof(y) = }')
417+
z = y * 2
418+
print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
419+
return z
420+
421+
@jax.jit
422+
def f(arr1):
423+
print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n')
424+
x = jnp.sin(arr1)
425+
426+
z = explicit_g(x, in_shardings=P("X", "Y"))
427+
428+
return z + 1
429+
430+
with jax.sharding.use_mesh(auto_mesh):
431+
some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y"))
432+
f(some_x)
433+
```
434+
435+
As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.
436+
Because of that, sharding is visible on the type of arrays inside `g`.
437+
406438
+++ {"id": "sJcWbfAh7UcO"}
407439

408440
## Concrete array shardings can mention `Auto` mesh axis

0 commit comments

Comments
 (0)