diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 821bb476141a..994df964ebe5 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -55,7 +55,9 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from jax.sharding import Mesh, PartitionSpec as P" + "from jax.sharding import Mesh, PartitionSpec as P\n", + "Explicit = jax.sharding.AxisType.Explicit\n", + "Auto = jax.sharding.AxisType.Auto" ] }, { @@ -65,12 +67,13 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('x', 'y'))\n", + "mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", "b = jnp.arange(16 * 4.).reshape(16, 4)\n", "\n", - "@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", + "@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n", " out_specs=P('x', None))\n", "def matmul_basic(a_block, b_block):\n", " # a_block: f32[2, 8]\n", @@ -148,16 +151,15 @@ "source": [ "from jax.sharding import NamedSharding\n", "\n", - "a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n", - "b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n", + "a = jax.device_put(a, P('x', 'y'))\n", + "b = jax.device_put(b, P('y', None))\n", "\n", "@jax.jit\n", "def matmul_reference(a, b):\n", - " c = jnp.dot(a, b)\n", - " return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n", + " return jnp.dot(a, b, out_sharding=P('x', None))\n", "\n", "c_ref = matmul_reference(a, b)\n", - "allclose(c_ref, jnp.dot(a, b))" + "allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))" ] }, { @@ -245,7 +247,8 @@ "source": [ "import numpy as np\n", "devices = np.array(jax.devices()[:4])\n", - "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n", + "mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n", + "jax.set_mesh(mesh)\n", "\n", "def check_shmap(f, y):\n", " ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", @@ -293,7 +296,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", + "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", "@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", "def f1(x_block):\n", @@ -494,7 +498,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((2,), ('i',))\n", + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", "\n", "@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", "def f(x):\n", @@ -607,7 +612,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", + "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", "def f(x):\n", @@ -645,7 +651,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((2,), ('i',))\n", + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", "\n", "x = jnp.arange(6.)\n", "try:\n", @@ -741,7 +748,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((2,), ('i',))\n", + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", "\n", "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", "def f(x, y):\n", @@ -776,7 +784,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((2,), ('i',))\n", + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", + "jax.set_mesh(mesh)\n", "\n", "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", "def f(x, y):\n", @@ -946,6 +955,7 @@ "outputs": [], "source": [ "mesh1d = Mesh(jax.devices()[:4], ('i',))\n", + "jax.set_mesh(mesh1d)\n", "\n", "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", "def f1(x_block):\n", @@ -1002,6 +1012,7 @@ "outputs": [], "source": [ "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n", + "jax.set_mesh(mesh2d)\n", "\n", "@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f2(x_block):\n", @@ -1071,6 +1082,8 @@ "metadata": {}, "outputs": [], "source": [ + "jax.set_mesh(mesh1d)\n", + "\n", "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f4(x_block):\n", " print('BEFORE:\\n', x_block)\n", @@ -1153,7 +1166,7 @@ "metadata": {}, "outputs": [], "source": [ - "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", "def f6(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", @@ -1229,7 +1242,7 @@ "metadata": {}, "outputs": [], "source": [ - "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", "def f7(x_block):\n", " sz = jax.lax.axis_size('i')\n", " print('BEFORE:\\n', x_block)\n", @@ -1307,7 +1320,7 @@ "metadata": {}, "outputs": [], "source": [ - "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", "def f8(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = psum_scatter(x_block, 'i', tiled=True)\n", @@ -1438,6 +1451,7 @@ "outputs": [], "source": [ "mesh = Mesh(jax.devices()[:4], ('i',))\n", + "jax.set_mesh(mesh)\n", "\n", "def device_put(x, pspec):\n", " return jax.device_put(x, NamedSharding(mesh, pspec))" @@ -1884,7 +1898,8 @@ "source": [ "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", "\n", - "mesh = jax.make_mesh((8,), ('batch',))\n", + "mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n", + "jax.set_mesh(mesh)\n", "\n", "# replicate initial params on all devices, shard data batch over devices\n", "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", @@ -1982,6 +1997,7 @@ "source": [ "# shard data batch *and params* over devices\n", "mesh = Mesh(devices, ('batch',))\n", + "jax.set_mesh(mesh)\n", "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", "params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n", "\n", @@ -2055,7 +2071,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((8,), ('feats',))\n", + "mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n", + "jax.set_mesh(mesh)\n", "\n", "batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n", "params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n", @@ -2096,7 +2113,8 @@ "metadata": {}, "outputs": [], "source": [ - "mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n", + "mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n", + "jax.set_mesh(mesh)\n", "\n", "batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n", "params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n", @@ -2285,7 +2303,17 @@ "metadata": {}, "outputs": [], "source": [ - "print(jax.jit(loss)(params, batch))\n", + "print(jax.jit(loss)(params, batch))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ff83661", + "metadata": {}, + "outputs": [], + "source": [ + "jax.set_mesh(mesh)\n", "print(jax.jit(loss_pp)(params_, batch_))" ] }, @@ -2303,8 +2331,7 @@ "metadata": { "jupytext": { "cell_metadata_filter": "-all", - "formats": "ipynb,md:myst", - "main_language": "python" + "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index e14cd930b7cb..0a09798f5bdd 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -2,7 +2,6 @@ jupytext: cell_metadata_filter: -all formats: ipynb,md:myst - main_language: python text_representation: extension: .md format_name: myst @@ -46,15 +45,18 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P +Explicit = jax.sharding.AxisType.Explicit +Auto = jax.sharding.AxisType.Auto ``` ```{code-cell} -mesh = jax.make_mesh((4, 2), ('x', 'y')) +mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2) +jax.set_mesh(mesh) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 4.).reshape(16, 4) -@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), +@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None)) def matmul_basic(a_block, b_block): # a_block: f32[2, 8] @@ -96,16 +98,15 @@ The above code is performing the same computation as this `jax.jit` automatic pa ```{code-cell} from jax.sharding import NamedSharding -a = jax.device_put(a, NamedSharding(mesh, P('x', 'y'))) -b = jax.device_put(b, NamedSharding(mesh, P('y', None))) +a = jax.device_put(a, P('x', 'y')) +b = jax.device_put(b, P('y', None)) @jax.jit def matmul_reference(a, b): - c = jnp.dot(a, b) - return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None))) + return jnp.dot(a, b, out_sharding=P('x', None)) c_ref = matmul_reference(a, b) -allclose(c_ref, jnp.dot(a, b)) +allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None))) ``` We can think of `shard_map` as performing a `device_put` or @@ -157,7 +158,8 @@ when collectives aren't involved): ```{code-cell} import numpy as np devices = np.array(jax.devices()[:4]) -mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 +mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4 +jax.set_mesh(mesh) def check_shmap(f, y): ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y) @@ -193,7 +195,8 @@ input array axis size.) If an input's pspec does not mention a mesh axis name, then there's no splitting over that mesh axis. For example: ```{code-cell} -mesh = jax.make_mesh((4, 2), ('i', 'j')) +mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2) +jax.set_mesh(mesh) @jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) def f1(x_block): @@ -334,7 +337,8 @@ the same. For example, when we use `in_specs` to split an argument over a mesh axis, each function instance along that mesh axis gets a different value: ```{code-cell} -mesh = jax.make_mesh((2,), ('i',)) +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) @jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -399,7 +403,8 @@ In general, the VMA type of a value can include any subset of the manual mesh axes over which the `shard_map` is acting: ```{code-cell} -mesh = jax.make_mesh((4, 2), ('i', 'j')) +mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2) +jax.set_mesh(mesh) @jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i')) def f(x): @@ -425,7 +430,8 @@ For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught without it: ```{code-cell} -mesh = jax.make_mesh((2,), ('i',)) +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) x = jnp.arange(6.) try: @@ -485,7 +491,8 @@ In some cases, like with `jax.lax.scan`, you might need to apply this code raises an error: ```{code-cell} -mesh = jax.make_mesh((2,), ('i',)) +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) @jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) def f(x, y): @@ -508,7 +515,8 @@ To make the types match, we need to apply `jax.lax.pvary` to some arguments to the `scan`: ```{code-cell} -mesh = jax.make_mesh((2,), ('i',)) +mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,)) +jax.set_mesh(mesh) @jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) def f(x, y): @@ -660,6 +668,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P ```{code-cell} mesh1d = Mesh(jax.devices()[:4], ('i',)) +jax.set_mesh(mesh1d) @jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) def f1(x_block): @@ -698,6 +707,7 @@ each one separately, or over multiple axes at once: ```{code-cell} mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j')) +jax.set_mesh(mesh2d) @jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f2(x_block): @@ -743,6 +753,8 @@ each function application has a full copy of the data along that axis: Illustration of an all_gather computation. ```{code-cell} +jax.set_mesh(mesh1d) + @jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f4(x_block): print('BEFORE:\n', x_block) @@ -801,7 +813,7 @@ The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like Illustration of a psum_scatter computation. ```{code-cell} -@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(in_specs=P('i'), out_specs=P('i')) def f6(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True) @@ -865,7 +877,7 @@ that mesh axis, `ppermute` sends its argument value from each source function instance to each destination: ```{code-cell} -@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(in_specs=P('i'), out_specs=P('i')) def f7(x_block): sz = jax.lax.axis_size('i') print('BEFORE:\n', x_block) @@ -925,7 +937,7 @@ def psum_scatter(x, axis_name, *, tiled=False): ``` ```{code-cell} -@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@jax.shard_map(in_specs=P('i'), out_specs=P('i')) def f8(x_block): print('BEFORE:\n', x_block) y_block = psum_scatter(x_block, 'i', tiled=True) @@ -1026,6 +1038,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P ```{code-cell} mesh = Mesh(jax.devices()[:4], ('i',)) +jax.set_mesh(mesh) def device_put(x, pspec): return jax.device_put(x, NamedSharding(mesh, pspec)) @@ -1298,7 +1311,8 @@ all-reduce-sums of parameter gradients in the backward pass.) ```{code-cell} from jax.sharding import NamedSharding, Mesh, PartitionSpec as P -mesh = jax.make_mesh((8,), ('batch',)) +mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,)) +jax.set_mesh(mesh) # replicate initial params on all devices, shard data batch over devices batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) @@ -1366,6 +1380,7 @@ to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and ```{code-cell} # shard data batch *and params* over devices mesh = Mesh(devices, ('batch',)) +jax.set_mesh(mesh) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P('batch'))) @@ -1415,7 +1430,8 @@ multiplications followed by a `psum_scatter` to sum the local results and efficiently scatter the result's shards. ```{code-cell} -mesh = jax.make_mesh((8,), ('feats',)) +mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,)) +jax.set_mesh(mesh) batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats'))) params = jax.device_put(params, NamedSharding(mesh, P('feats'))) @@ -1444,7 +1460,8 @@ def loss_tp(params, batch): We can compose these strategies together, using multiple axes of parallelism. ```{code-cell} -mesh = jax.make_mesh((4, 2), ('batch', 'feats')) +mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2) +jax.set_mesh(mesh) batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats'))) params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats')))) @@ -1586,6 +1603,10 @@ batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages'))) ```{code-cell} print(jax.jit(loss)(params, batch)) +``` + +```{code-cell} +jax.set_mesh(mesh) print(jax.jit(loss_pp)(params_, batch_)) ```