Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions docs/notebooks/shard_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from jax.sharding import Mesh, PartitionSpec as P"
"from jax.sharding import Mesh, PartitionSpec as P\n",
"Explicit = jax.sharding.AxisType.Explicit\n",
"Auto = jax.sharding.AxisType.Auto"
]
},
{
Expand All @@ -65,12 +67,13 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((4, 2), ('x', 'y'))\n",
"mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n",
"jax.set_mesh(mesh)\n",
"\n",
"a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
"b = jnp.arange(16 * 4.).reshape(16, 4)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n",
"@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n",
" out_specs=P('x', None))\n",
"def matmul_basic(a_block, b_block):\n",
" # a_block: f32[2, 8]\n",
Expand Down Expand Up @@ -148,16 +151,15 @@
"source": [
"from jax.sharding import NamedSharding\n",
"\n",
"a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n",
"b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n",
"a = jax.device_put(a, P('x', 'y'))\n",
"b = jax.device_put(b, P('y', None))\n",
"\n",
"@jax.jit\n",
"def matmul_reference(a, b):\n",
" c = jnp.dot(a, b)\n",
" return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n",
" return jnp.dot(a, b, out_sharding=P('x', None))\n",
"\n",
"c_ref = matmul_reference(a, b)\n",
"allclose(c_ref, jnp.dot(a, b))"
"allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))"
]
},
{
Expand Down Expand Up @@ -245,7 +247,8 @@
"source": [
"import numpy as np\n",
"devices = np.array(jax.devices()[:4])\n",
"mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n",
"mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n",
"jax.set_mesh(mesh)\n",
"\n",
"def check_shmap(f, y):\n",
" ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n",
Expand Down Expand Up @@ -293,7 +296,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
"mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n",
"jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
"def f1(x_block):\n",
Expand Down Expand Up @@ -494,7 +498,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((2,), ('i',))\n",
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
"jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n",
"def f(x):\n",
Expand Down Expand Up @@ -607,7 +612,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
"mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n",
"jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n",
"def f(x):\n",
Expand Down Expand Up @@ -645,7 +651,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((2,), ('i',))\n",
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
"jax.set_mesh(mesh)\n",
"\n",
"x = jnp.arange(6.)\n",
"try:\n",
Expand Down Expand Up @@ -741,7 +748,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((2,), ('i',))\n",
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
"jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
"def f(x, y):\n",
Expand Down Expand Up @@ -776,7 +784,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((2,), ('i',))\n",
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
"jax.set_mesh(mesh)\n",
"\n",
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
"def f(x, y):\n",
Expand Down Expand Up @@ -946,6 +955,7 @@
"outputs": [],
"source": [
"mesh1d = Mesh(jax.devices()[:4], ('i',))\n",
"jax.set_mesh(mesh1d)\n",
"\n",
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n",
"def f1(x_block):\n",
Expand Down Expand Up @@ -1002,6 +1012,7 @@
"outputs": [],
"source": [
"mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n",
"jax.set_mesh(mesh2d)\n",
"\n",
"@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n",
"def f2(x_block):\n",
Expand Down Expand Up @@ -1071,6 +1082,8 @@
"metadata": {},
"outputs": [],
"source": [
"jax.set_mesh(mesh1d)\n",
"\n",
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
"def f4(x_block):\n",
" print('BEFORE:\\n', x_block)\n",
Expand Down Expand Up @@ -1153,7 +1166,7 @@
"metadata": {},
"outputs": [],
"source": [
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
"def f6(x_block):\n",
" print('BEFORE:\\n', x_block)\n",
" y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n",
Expand Down Expand Up @@ -1229,7 +1242,7 @@
"metadata": {},
"outputs": [],
"source": [
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
"def f7(x_block):\n",
" sz = jax.lax.axis_size('i')\n",
" print('BEFORE:\\n', x_block)\n",
Expand Down Expand Up @@ -1307,7 +1320,7 @@
"metadata": {},
"outputs": [],
"source": [
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
"def f8(x_block):\n",
" print('BEFORE:\\n', x_block)\n",
" y_block = psum_scatter(x_block, 'i', tiled=True)\n",
Expand Down Expand Up @@ -1438,6 +1451,7 @@
"outputs": [],
"source": [
"mesh = Mesh(jax.devices()[:4], ('i',))\n",
"jax.set_mesh(mesh)\n",
"\n",
"def device_put(x, pspec):\n",
" return jax.device_put(x, NamedSharding(mesh, pspec))"
Expand Down Expand Up @@ -1884,7 +1898,8 @@
"source": [
"from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
"\n",
"mesh = jax.make_mesh((8,), ('batch',))\n",
"mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n",
"jax.set_mesh(mesh)\n",
"\n",
"# replicate initial params on all devices, shard data batch over devices\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
Expand Down Expand Up @@ -1982,6 +1997,7 @@
"source": [
"# shard data batch *and params* over devices\n",
"mesh = Mesh(devices, ('batch',))\n",
"jax.set_mesh(mesh)\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n",
"\n",
Expand Down Expand Up @@ -2055,7 +2071,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((8,), ('feats',))\n",
"mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n",
"jax.set_mesh(mesh)\n",
"\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
Expand Down Expand Up @@ -2096,7 +2113,8 @@
"metadata": {},
"outputs": [],
"source": [
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n",
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n",
"jax.set_mesh(mesh)\n",
"\n",
"batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
"params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
Expand Down Expand Up @@ -2285,7 +2303,17 @@
"metadata": {},
"outputs": [],
"source": [
"print(jax.jit(loss)(params, batch))\n",
"print(jax.jit(loss)(params, batch))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ff83661",
"metadata": {},
"outputs": [],
"source": [
"jax.set_mesh(mesh)\n",
"print(jax.jit(loss_pp)(params_, batch_))"
]
},
Expand All @@ -2303,8 +2331,7 @@
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"formats": "ipynb,md:myst",
"main_language": "python"
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
Expand Down
63 changes: 42 additions & 21 deletions docs/notebooks/shard_map.md

Large diffs are not rendered by default.

Loading