Skip to content

Commit ab11905

Browse files
Merge pull request #33191 from jax-ml:shmap
PiperOrigin-RevId: 829631170
2 parents a6a85c8 + ace7797 commit ab11905

File tree

2 files changed

+93
-45
lines changed

2 files changed

+93
-45
lines changed

docs/notebooks/shard_map.ipynb

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@
5555
"import jax\n",
5656
"import jax.numpy as jnp\n",
5757
"\n",
58-
"from jax.sharding import Mesh, PartitionSpec as P"
58+
"from jax.sharding import Mesh, PartitionSpec as P\n",
59+
"Explicit = jax.sharding.AxisType.Explicit\n",
60+
"Auto = jax.sharding.AxisType.Auto"
5961
]
6062
},
6163
{
@@ -65,12 +67,13 @@
6567
"metadata": {},
6668
"outputs": [],
6769
"source": [
68-
"mesh = jax.make_mesh((4, 2), ('x', 'y'))\n",
70+
"mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n",
71+
"jax.set_mesh(mesh)\n",
6972
"\n",
7073
"a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
7174
"b = jnp.arange(16 * 4.).reshape(16, 4)\n",
7275
"\n",
73-
"@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n",
76+
"@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n",
7477
" out_specs=P('x', None))\n",
7578
"def matmul_basic(a_block, b_block):\n",
7679
" # a_block: f32[2, 8]\n",
@@ -148,16 +151,15 @@
148151
"source": [
149152
"from jax.sharding import NamedSharding\n",
150153
"\n",
151-
"a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n",
152-
"b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n",
154+
"a = jax.device_put(a, P('x', 'y'))\n",
155+
"b = jax.device_put(b, P('y', None))\n",
153156
"\n",
154157
"@jax.jit\n",
155158
"def matmul_reference(a, b):\n",
156-
" c = jnp.dot(a, b)\n",
157-
" return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n",
159+
" return jnp.dot(a, b, out_sharding=P('x', None))\n",
158160
"\n",
159161
"c_ref = matmul_reference(a, b)\n",
160-
"allclose(c_ref, jnp.dot(a, b))"
162+
"allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))"
161163
]
162164
},
163165
{
@@ -245,7 +247,8 @@
245247
"source": [
246248
"import numpy as np\n",
247249
"devices = np.array(jax.devices()[:4])\n",
248-
"mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n",
250+
"mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n",
251+
"jax.set_mesh(mesh)\n",
249252
"\n",
250253
"def check_shmap(f, y):\n",
251254
" ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n",
@@ -293,7 +296,8 @@
293296
"metadata": {},
294297
"outputs": [],
295298
"source": [
296-
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
299+
"mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n",
300+
"jax.set_mesh(mesh)\n",
297301
"\n",
298302
"@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
299303
"def f1(x_block):\n",
@@ -494,7 +498,8 @@
494498
"metadata": {},
495499
"outputs": [],
496500
"source": [
497-
"mesh = jax.make_mesh((2,), ('i',))\n",
501+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
502+
"jax.set_mesh(mesh)\n",
498503
"\n",
499504
"@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n",
500505
"def f(x):\n",
@@ -607,7 +612,8 @@
607612
"metadata": {},
608613
"outputs": [],
609614
"source": [
610-
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
615+
"mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n",
616+
"jax.set_mesh(mesh)\n",
611617
"\n",
612618
"@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n",
613619
"def f(x):\n",
@@ -645,7 +651,8 @@
645651
"metadata": {},
646652
"outputs": [],
647653
"source": [
648-
"mesh = jax.make_mesh((2,), ('i',))\n",
654+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
655+
"jax.set_mesh(mesh)\n",
649656
"\n",
650657
"x = jnp.arange(6.)\n",
651658
"try:\n",
@@ -741,7 +748,8 @@
741748
"metadata": {},
742749
"outputs": [],
743750
"source": [
744-
"mesh = jax.make_mesh((2,), ('i',))\n",
751+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
752+
"jax.set_mesh(mesh)\n",
745753
"\n",
746754
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
747755
"def f(x, y):\n",
@@ -776,7 +784,8 @@
776784
"metadata": {},
777785
"outputs": [],
778786
"source": [
779-
"mesh = jax.make_mesh((2,), ('i',))\n",
787+
"mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n",
788+
"jax.set_mesh(mesh)\n",
780789
"\n",
781790
"@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n",
782791
"def f(x, y):\n",
@@ -946,6 +955,7 @@
946955
"outputs": [],
947956
"source": [
948957
"mesh1d = Mesh(jax.devices()[:4], ('i',))\n",
958+
"jax.set_mesh(mesh1d)\n",
949959
"\n",
950960
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n",
951961
"def f1(x_block):\n",
@@ -1002,6 +1012,7 @@
10021012
"outputs": [],
10031013
"source": [
10041014
"mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n",
1015+
"jax.set_mesh(mesh2d)\n",
10051016
"\n",
10061017
"@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n",
10071018
"def f2(x_block):\n",
@@ -1071,6 +1082,8 @@
10711082
"metadata": {},
10721083
"outputs": [],
10731084
"source": [
1085+
"jax.set_mesh(mesh1d)\n",
1086+
"\n",
10741087
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
10751088
"def f4(x_block):\n",
10761089
" print('BEFORE:\\n', x_block)\n",
@@ -1153,7 +1166,7 @@
11531166
"metadata": {},
11541167
"outputs": [],
11551168
"source": [
1156-
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
1169+
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
11571170
"def f6(x_block):\n",
11581171
" print('BEFORE:\\n', x_block)\n",
11591172
" y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n",
@@ -1229,7 +1242,7 @@
12291242
"metadata": {},
12301243
"outputs": [],
12311244
"source": [
1232-
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
1245+
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
12331246
"def f7(x_block):\n",
12341247
" sz = jax.lax.axis_size('i')\n",
12351248
" print('BEFORE:\\n', x_block)\n",
@@ -1307,7 +1320,7 @@
13071320
"metadata": {},
13081321
"outputs": [],
13091322
"source": [
1310-
"@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
1323+
"@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n",
13111324
"def f8(x_block):\n",
13121325
" print('BEFORE:\\n', x_block)\n",
13131326
" y_block = psum_scatter(x_block, 'i', tiled=True)\n",
@@ -1438,6 +1451,7 @@
14381451
"outputs": [],
14391452
"source": [
14401453
"mesh = Mesh(jax.devices()[:4], ('i',))\n",
1454+
"jax.set_mesh(mesh)\n",
14411455
"\n",
14421456
"def device_put(x, pspec):\n",
14431457
" return jax.device_put(x, NamedSharding(mesh, pspec))"
@@ -1884,7 +1898,8 @@
18841898
"source": [
18851899
"from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
18861900
"\n",
1887-
"mesh = jax.make_mesh((8,), ('batch',))\n",
1901+
"mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n",
1902+
"jax.set_mesh(mesh)\n",
18881903
"\n",
18891904
"# replicate initial params on all devices, shard data batch over devices\n",
18901905
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
@@ -1982,6 +1997,7 @@
19821997
"source": [
19831998
"# shard data batch *and params* over devices\n",
19841999
"mesh = Mesh(devices, ('batch',))\n",
2000+
"jax.set_mesh(mesh)\n",
19852001
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
19862002
"params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n",
19872003
"\n",
@@ -2055,7 +2071,8 @@
20552071
"metadata": {},
20562072
"outputs": [],
20572073
"source": [
2058-
"mesh = jax.make_mesh((8,), ('feats',))\n",
2074+
"mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n",
2075+
"jax.set_mesh(mesh)\n",
20592076
"\n",
20602077
"batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
20612078
"params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
@@ -2096,7 +2113,8 @@
20962113
"metadata": {},
20972114
"outputs": [],
20982115
"source": [
2099-
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n",
2116+
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n",
2117+
"jax.set_mesh(mesh)\n",
21002118
"\n",
21012119
"batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
21022120
"params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
@@ -2285,7 +2303,17 @@
22852303
"metadata": {},
22862304
"outputs": [],
22872305
"source": [
2288-
"print(jax.jit(loss)(params, batch))\n",
2306+
"print(jax.jit(loss)(params, batch))"
2307+
]
2308+
},
2309+
{
2310+
"cell_type": "code",
2311+
"execution_count": null,
2312+
"id": "9ff83661",
2313+
"metadata": {},
2314+
"outputs": [],
2315+
"source": [
2316+
"jax.set_mesh(mesh)\n",
22892317
"print(jax.jit(loss_pp)(params_, batch_))"
22902318
]
22912319
},
@@ -2303,8 +2331,7 @@
23032331
"metadata": {
23042332
"jupytext": {
23052333
"cell_metadata_filter": "-all",
2306-
"formats": "ipynb,md:myst",
2307-
"main_language": "python"
2334+
"formats": "ipynb,md:myst"
23082335
},
23092336
"kernelspec": {
23102337
"display_name": "Python 3",

docs/notebooks/shard_map.md

Lines changed: 42 additions & 21 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)