Skip to content

Commit d3d6a7a

Browse files
committed
Move shard_map doc to use Explicit mode and Auto mode. We should transition the doc to full Explicit mode in a follow up.
1 parent 08fcdb7 commit d3d6a7a

File tree

3 files changed

+1725
-47
lines changed

3 files changed

+1725
-47
lines changed

docs/notebooks/shard_map.ipynb

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@
5555
"import jax\n",
5656
"import jax.numpy as jnp\n",
5757
"\n",
58-
"from jax.sharding import Mesh, PartitionSpec as P"
58+
"from jax.sharding import Mesh, PartitionSpec as P\n",
59+
"Explicit = jax.sharding.AxisType.Explicit\n",
60+
"Auto = jax.sharding.AxisType.Auto"
5961
]
6062
},
6163
{
@@ -65,12 +67,13 @@
6567
"metadata": {},
6668
"outputs": [],
6769
"source": [
68-
"mesh = jax.make_mesh((4, 2), ('x', 'y'))\n",
70+
"mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n",
71+
"jax.set_mesh(mesh)\n",
6972
"\n",
7073
"a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
7174
"b = jnp.arange(16 * 4.).reshape(16, 4)\n",
7275
"\n",
73-
"@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n",
76+
"@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n",
7477
" out_specs=P('x', None))\n",
7578
"def matmul_basic(a_block, b_block):\n",
7679
" # a_block: f32[2, 8]\n",
@@ -148,16 +151,15 @@
148151
"source": [
149152
"from jax.sharding import NamedSharding\n",
150153
"\n",
151-
"a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n",
152-
"b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n",
154+
"a = jax.device_put(a, P('x', 'y'))\n",
155+
"b = jax.device_put(b, P('y', None))\n",
153156
"\n",
154157
"@jax.jit\n",
155158
"def matmul_reference(a, b):\n",
156-
" c = jnp.dot(a, b)\n",
157-
" return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n",
159+
" return jnp.dot(a, b, out_sharding=P('x', None))\n",
158160
"\n",
159161
"c_ref = matmul_reference(a, b)\n",
160-
"allclose(c_ref, jnp.dot(a, b))"
162+
"allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))"
161163
]
162164
},
163165
{
@@ -245,7 +247,8 @@
245247
"source": [
246248
"import numpy as np\n",
247249
"devices = np.array(jax.devices()[:4])\n",
248-
"mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n",
250+
"mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n",
251+
"jax.set_mesh(mesh)\n",
249252
"\n",
250253
"def check_shmap(f, y):\n",
251254
" ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n",
@@ -293,7 +296,8 @@
293296
"metadata": {},
294297
"outputs": [],
295298
"source": [
296-
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
299+
"mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n",
300+
"jax.set_mesh(mesh)\n",
297301
"\n",
298302
"@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
299303
"def f1(x_block):\n",
@@ -494,7 +498,8 @@
494498
"metadata": {},
495499
"outputs": [],
496500
"source": [
497-
"mesh = jax.make_mesh((2,), ('i',))\n",
501+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
502+
"jax.set_mesh(mesh)\n",
498503
"\n",
499504
"@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n",
500505
"def f(x):\n",
@@ -607,7 +612,8 @@
607612
"metadata": {},
608613
"outputs": [],
609614
"source": [
610-
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
615+
"mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n",
616+
"jax.set_mesh(mesh)\n",
611617
"\n",
612618
"@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n",
613619
"def f(x):\n",
@@ -645,7 +651,8 @@
645651
"metadata": {},
646652
"outputs": [],
647653
"source": [
648-
"mesh = jax.make_mesh((2,), ('i',))\n",
654+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
655+
"jax.set_mesh(mesh)\n",
649656
"\n",
650657
"x = jnp.arange(6.)\n",
651658
"try:\n",
@@ -741,7 +748,8 @@
741748
"metadata": {},
742749
"outputs": [],
743750
"source": [
744-
"mesh = jax.make_mesh((2,), ('i',))\n",
751+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
752+
"jax.set_mesh(mesh)\n",
745753
"\n",
746754
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
747755
"def f(x, y):\n",
@@ -776,7 +784,8 @@
776784
"metadata": {},
777785
"outputs": [],
778786
"source": [
779-
"mesh = jax.make_mesh((2,), ('i',))\n",
787+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
788+
"jax.set_mesh(mesh)\n",
780789
"\n",
781790
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
782791
"def f(x, y):\n",
@@ -946,6 +955,7 @@
946955
"outputs": [],
947956
"source": [
948957
"mesh1d = Mesh(jax.devices()[:4], ('i',))\n",
958+
"jax.set_mesh(mesh1d)\n",
949959
"\n",
950960
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n",
951961
"def f1(x_block):\n",
@@ -1002,6 +1012,7 @@
10021012
"outputs": [],
10031013
"source": [
10041014
"mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n",
1015+
"jax.set_mesh(mesh2d)\n",
10051016
"\n",
10061017
"@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n",
10071018
"def f2(x_block):\n",
@@ -1031,7 +1042,9 @@
10311042
"cell_type": "code",
10321043
"execution_count": null,
10331044
"id": "2919056c",
1034-
"metadata": {},
1045+
"metadata": {
1046+
"lines_to_end_of_cell_marker": 2
1047+
},
10351048
"outputs": [],
10361049
"source": [
10371050
"@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n",
@@ -1071,6 +1084,8 @@
10711084
"metadata": {},
10721085
"outputs": [],
10731086
"source": [
1087+
"jax.set_mesh(mesh1d)\n",
1088+
"\n",
10741089
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
10751090
"def f4(x_block):\n",
10761091
" print('BEFORE:\\n', x_block)\n",
@@ -1153,7 +1168,7 @@
11531168
"metadata": {},
11541169
"outputs": [],
11551170
"source": [
1156-
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
1171+
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
11571172
"def f6(x_block):\n",
11581173
" print('BEFORE:\\n', x_block)\n",
11591174
" y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n",
@@ -1229,7 +1244,7 @@
12291244
"metadata": {},
12301245
"outputs": [],
12311246
"source": [
1232-
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
1247+
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
12331248
"def f7(x_block):\n",
12341249
" sz = jax.lax.axis_size('i')\n",
12351250
" print('BEFORE:\\n', x_block)\n",
@@ -1307,7 +1322,7 @@
13071322
"metadata": {},
13081323
"outputs": [],
13091324
"source": [
1310-
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
1325+
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
13111326
"def f8(x_block):\n",
13121327
" print('BEFORE:\\n', x_block)\n",
13131328
" y_block = psum_scatter(x_block, 'i', tiled=True)\n",
@@ -1438,6 +1453,7 @@
14381453
"outputs": [],
14391454
"source": [
14401455
"mesh = Mesh(jax.devices()[:4], ('i',))\n",
1456+
"jax.set_mesh(mesh)\n",
14411457
"\n",
14421458
"def device_put(x, pspec):\n",
14431459
" return jax.device_put(x, NamedSharding(mesh, pspec))"
@@ -1884,7 +1900,8 @@
18841900
"source": [
18851901
"from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
18861902
"\n",
1887-
"mesh = jax.make_mesh((8,), ('batch',))\n",
1903+
"mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n",
1904+
"jax.set_mesh(mesh)\n",
18881905
"\n",
18891906
"# replicate initial params on all devices, shard data batch over devices\n",
18901907
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
@@ -1982,6 +1999,7 @@
19821999
"source": [
19832000
"# shard data batch *and params* over devices\n",
19842001
"mesh = Mesh(devices, ('batch',))\n",
2002+
"jax.set_mesh(mesh)\n",
19852003
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
19862004
"params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n",
19872005
"\n",
@@ -2055,7 +2073,8 @@
20552073
"metadata": {},
20562074
"outputs": [],
20572075
"source": [
2058-
"mesh = jax.make_mesh((8,), ('feats',))\n",
2076+
"mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n",
2077+
"jax.set_mesh(mesh)\n",
20592078
"\n",
20602079
"batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
20612080
"params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
@@ -2096,7 +2115,8 @@
20962115
"metadata": {},
20972116
"outputs": [],
20982117
"source": [
2099-
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n",
2118+
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n",
2119+
"jax.set_mesh(mesh)\n",
21002120
"\n",
21012121
"batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
21022122
"params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
@@ -2285,7 +2305,17 @@
22852305
"metadata": {},
22862306
"outputs": [],
22872307
"source": [
2288-
"print(jax.jit(loss)(params, batch))\n",
2308+
"print(jax.jit(loss)(params, batch))"
2309+
]
2310+
},
2311+
{
2312+
"cell_type": "code",
2313+
"execution_count": null,
2314+
"id": "9ff83661",
2315+
"metadata": {},
2316+
"outputs": [],
2317+
"source": [
2318+
"jax.set_mesh(mesh)\n",
22892319
"print(jax.jit(loss_pp)(params_, batch_))"
22902320
]
22912321
},
@@ -2303,8 +2333,7 @@
23032333
"metadata": {
23042334
"jupytext": {
23052335
"cell_metadata_filter": "-all",
2306-
"formats": "ipynb,md:myst",
2307-
"main_language": "python"
2336+
"formats": "ipynb,md:myst,py"
23082337
},
23092338
"kernelspec": {
23102339
"display_name": "Python 3",

docs/notebooks/shard_map.md

Lines changed: 43 additions & 22 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)