diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb
index 821bb476141a..994df964ebe5 100644
--- a/docs/notebooks/shard_map.ipynb
+++ b/docs/notebooks/shard_map.ipynb
@@ -55,7 +55,9 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
- "from jax.sharding import Mesh, PartitionSpec as P"
+ "from jax.sharding import Mesh, PartitionSpec as P\n",
+ "Explicit = jax.sharding.AxisType.Explicit\n",
+ "Auto = jax.sharding.AxisType.Auto"
]
},
{
@@ -65,12 +67,13 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((4, 2), ('x', 'y'))\n",
+ "mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
"b = jnp.arange(16 * 4.).reshape(16, 4)\n",
"\n",
- "@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n",
+ "@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n",
" out_specs=P('x', None))\n",
"def matmul_basic(a_block, b_block):\n",
" # a_block: f32[2, 8]\n",
@@ -148,16 +151,15 @@
"source": [
"from jax.sharding import NamedSharding\n",
"\n",
- "a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n",
- "b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n",
+ "a = jax.device_put(a, P('x', 'y'))\n",
+ "b = jax.device_put(b, P('y', None))\n",
"\n",
"@jax.jit\n",
"def matmul_reference(a, b):\n",
- " c = jnp.dot(a, b)\n",
- " return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n",
+ " return jnp.dot(a, b, out_sharding=P('x', None))\n",
"\n",
"c_ref = matmul_reference(a, b)\n",
- "allclose(c_ref, jnp.dot(a, b))"
+ "allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))"
]
},
{
@@ -245,7 +247,8 @@
"source": [
"import numpy as np\n",
"devices = np.array(jax.devices()[:4])\n",
- "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n",
+ "mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"def check_shmap(f, y):\n",
" ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n",
@@ -293,7 +296,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
+ "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
"def f1(x_block):\n",
@@ -494,7 +498,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((2,), ('i',))\n",
+ "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n",
"def f(x):\n",
@@ -607,7 +612,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
+ "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n",
"def f(x):\n",
@@ -645,7 +651,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((2,), ('i',))\n",
+ "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"x = jnp.arange(6.)\n",
"try:\n",
@@ -741,7 +748,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((2,), ('i',))\n",
+ "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
"def f(x, y):\n",
@@ -776,7 +784,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((2,), ('i',))\n",
+ "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
"def f(x, y):\n",
@@ -946,6 +955,7 @@
"outputs": [],
"source": [
"mesh1d = Mesh(jax.devices()[:4], ('i',))\n",
+ "jax.set_mesh(mesh1d)\n",
"\n",
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n",
"def f1(x_block):\n",
@@ -1002,6 +1012,7 @@
"outputs": [],
"source": [
"mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n",
+ "jax.set_mesh(mesh2d)\n",
"\n",
"@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n",
"def f2(x_block):\n",
@@ -1071,6 +1082,8 @@
"metadata": {},
"outputs": [],
"source": [
+ "jax.set_mesh(mesh1d)\n",
+ "\n",
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
"def f4(x_block):\n",
" print('BEFORE:\\n', x_block)\n",
@@ -1153,7 +1166,7 @@
"metadata": {},
"outputs": [],
"source": [
- "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
"def f6(x_block):\n",
" print('BEFORE:\\n', x_block)\n",
" y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n",
@@ -1229,7 +1242,7 @@
"metadata": {},
"outputs": [],
"source": [
- "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
"def f7(x_block):\n",
" sz = jax.lax.axis_size('i')\n",
" print('BEFORE:\\n', x_block)\n",
@@ -1307,7 +1320,7 @@
"metadata": {},
"outputs": [],
"source": [
- "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
"def f8(x_block):\n",
" print('BEFORE:\\n', x_block)\n",
" y_block = psum_scatter(x_block, 'i', tiled=True)\n",
@@ -1438,6 +1451,7 @@
"outputs": [],
"source": [
"mesh = Mesh(jax.devices()[:4], ('i',))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"def device_put(x, pspec):\n",
" return jax.device_put(x, NamedSharding(mesh, pspec))"
@@ -1884,7 +1898,8 @@
"source": [
"from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
"\n",
- "mesh = jax.make_mesh((8,), ('batch',))\n",
+ "mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"# replicate initial params on all devices, shard data batch over devices\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
@@ -1982,6 +1997,7 @@
"source": [
"# shard data batch *and params* over devices\n",
"mesh = Mesh(devices, ('batch',))\n",
+ "jax.set_mesh(mesh)\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n",
"\n",
@@ -2055,7 +2071,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((8,), ('feats',))\n",
+ "mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
@@ -2096,7 +2113,8 @@
"metadata": {},
"outputs": [],
"source": [
- "mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n",
+ "mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n",
+ "jax.set_mesh(mesh)\n",
"\n",
"batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
"params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
@@ -2285,7 +2303,17 @@
"metadata": {},
"outputs": [],
"source": [
- "print(jax.jit(loss)(params, batch))\n",
+ "print(jax.jit(loss)(params, batch))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ff83661",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "jax.set_mesh(mesh)\n",
"print(jax.jit(loss_pp)(params_, batch_))"
]
},
@@ -2303,8 +2331,7 @@
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
- "formats": "ipynb,md:myst",
- "main_language": "python"
+ "formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md
index e14cd930b7cb..0a09798f5bdd 100644
--- a/docs/notebooks/shard_map.md
+++ b/docs/notebooks/shard_map.md
@@ -2,7 +2,6 @@
jupytext:
cell_metadata_filter: -all
formats: ipynb,md:myst
- main_language: python
text_representation:
extension: .md
format_name: myst
@@ -46,15 +45,18 @@ import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
+Explicit = jax.sharding.AxisType.Explicit
+Auto = jax.sharding.AxisType.Auto
```
```{code-cell}
-mesh = jax.make_mesh((4, 2), ('x', 'y'))
+mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)
+jax.set_mesh(mesh)
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
-@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
+@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
@@ -96,16 +98,15 @@ The above code is performing the same computation as this `jax.jit` automatic pa
```{code-cell}
from jax.sharding import NamedSharding
-a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
-b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
+a = jax.device_put(a, P('x', 'y'))
+b = jax.device_put(b, P('y', None))
@jax.jit
def matmul_reference(a, b):
- c = jnp.dot(a, b)
- return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
+ return jnp.dot(a, b, out_sharding=P('x', None))
c_ref = matmul_reference(a, b)
-allclose(c_ref, jnp.dot(a, b))
+allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))
```
We can think of `shard_map` as performing a `device_put` or
@@ -157,7 +158,8 @@ when collectives aren't involved):
```{code-cell}
import numpy as np
devices = np.array(jax.devices()[:4])
-mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
+mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4
+jax.set_mesh(mesh)
def check_shmap(f, y):
ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)
@@ -193,7 +195,8 @@ input array axis size.) If an input's pspec does not mention a mesh axis name,
then there's no splitting over that mesh axis. For example:
```{code-cell}
-mesh = jax.make_mesh((4, 2), ('i', 'j'))
+mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)
+jax.set_mesh(mesh)
@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
@@ -334,7 +337,8 @@ the same. For example, when we use `in_specs` to split an argument over a mesh
axis, each function instance along that mesh axis gets a different value:
```{code-cell}
-mesh = jax.make_mesh((2,), ('i',))
+mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))
+jax.set_mesh(mesh)
@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(x):
@@ -399,7 +403,8 @@ In general, the VMA type of a value can include any subset of the manual mesh
axes over which the `shard_map` is acting:
```{code-cell}
-mesh = jax.make_mesh((4, 2), ('i', 'j'))
+mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)
+jax.set_mesh(mesh)
@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))
def f(x):
@@ -425,7 +430,8 @@ For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught
without it:
```{code-cell}
-mesh = jax.make_mesh((2,), ('i',))
+mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))
+jax.set_mesh(mesh)
x = jnp.arange(6.)
try:
@@ -485,7 +491,8 @@ In some cases, like with `jax.lax.scan`, you might need to apply
this code raises an error:
```{code-cell}
-mesh = jax.make_mesh((2,), ('i',))
+mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))
+jax.set_mesh(mesh)
@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
@@ -508,7 +515,8 @@ To make the types match, we need to apply `jax.lax.pvary` to some arguments to
the `scan`:
```{code-cell}
-mesh = jax.make_mesh((2,), ('i',))
+mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))
+jax.set_mesh(mesh)
@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
@@ -660,6 +668,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
```{code-cell}
mesh1d = Mesh(jax.devices()[:4], ('i',))
+jax.set_mesh(mesh1d)
@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
@@ -698,6 +707,7 @@ each one separately, or over multiple axes at once:
```{code-cell}
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
+jax.set_mesh(mesh2d)
@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
@@ -743,6 +753,8 @@ each function application has a full copy of the data along that axis:
```{code-cell}
+jax.set_mesh(mesh1d)
+
@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
@@ -801,7 +813,7 @@ The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like
```{code-cell}
-@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+@jax.shard_map(in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
@@ -865,7 +877,7 @@ that mesh axis, `ppermute` sends its argument value from each source function
instance to each destination:
```{code-cell}
-@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+@jax.shard_map(in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.axis_size('i')
print('BEFORE:\n', x_block)
@@ -925,7 +937,7 @@ def psum_scatter(x, axis_name, *, tiled=False):
```
```{code-cell}
-@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+@jax.shard_map(in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
@@ -1026,6 +1038,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
```{code-cell}
mesh = Mesh(jax.devices()[:4], ('i',))
+jax.set_mesh(mesh)
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
@@ -1298,7 +1311,8 @@ all-reduce-sums of parameter gradients in the backward pass.)
```{code-cell}
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
-mesh = jax.make_mesh((8,), ('batch',))
+mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))
+jax.set_mesh(mesh)
# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
@@ -1366,6 +1380,7 @@ to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and
```{code-cell}
# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
+jax.set_mesh(mesh)
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
@@ -1415,7 +1430,8 @@ multiplications followed by a `psum_scatter` to sum the local results and
efficiently scatter the result's shards.
```{code-cell}
-mesh = jax.make_mesh((8,), ('feats',))
+mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))
+jax.set_mesh(mesh)
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
@@ -1444,7 +1460,8 @@ def loss_tp(params, batch):
We can compose these strategies together, using multiple axes of parallelism.
```{code-cell}
-mesh = jax.make_mesh((4, 2), ('batch', 'feats'))
+mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)
+jax.set_mesh(mesh)
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
@@ -1586,6 +1603,10 @@ batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
```{code-cell}
print(jax.jit(loss)(params, batch))
+```
+
+```{code-cell}
+jax.set_mesh(mesh)
print(jax.jit(loss_pp)(params_, batch_))
```