|
55 | 55 | "import jax\n", |
56 | 56 | "import jax.numpy as jnp\n", |
57 | 57 | "\n", |
58 | | - "from jax.sharding import Mesh, PartitionSpec as P" |
| 58 | + "from jax.sharding import Mesh, PartitionSpec as P\n", |
| 59 | + "Explicit = jax.sharding.AxisType.Explicit\n", |
| 60 | + "Auto = jax.sharding.AxisType.Auto" |
59 | 61 | ] |
60 | 62 | }, |
61 | 63 | { |
|
65 | 67 | "metadata": {}, |
66 | 68 | "outputs": [], |
67 | 69 | "source": [ |
68 | | - "mesh = jax.make_mesh((4, 2), ('x', 'y'))\n", |
| 70 | + "mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(Explicit,) * 2)\n", |
| 71 | + "jax.set_mesh(mesh)\n", |
69 | 72 | "\n", |
70 | 73 | "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", |
71 | 74 | "b = jnp.arange(16 * 4.).reshape(16, 4)\n", |
72 | 75 | "\n", |
73 | | - "@jax.shard_map(mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", |
| 76 | + "@jax.shard_map(in_specs=(P('x', 'y'), P('y', None)),\n", |
74 | 77 | " out_specs=P('x', None))\n", |
75 | 78 | "def matmul_basic(a_block, b_block):\n", |
76 | 79 | " # a_block: f32[2, 8]\n", |
|
148 | 151 | "source": [ |
149 | 152 | "from jax.sharding import NamedSharding\n", |
150 | 153 | "\n", |
151 | | - "a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n", |
152 | | - "b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n", |
| 154 | + "a = jax.device_put(a, P('x', 'y'))\n", |
| 155 | + "b = jax.device_put(b, P('y', None))\n", |
153 | 156 | "\n", |
154 | 157 | "@jax.jit\n", |
155 | 158 | "def matmul_reference(a, b):\n", |
156 | | - " c = jnp.dot(a, b)\n", |
157 | | - " return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n", |
| 159 | + " return jnp.dot(a, b, out_sharding=P('x', None))\n", |
158 | 160 | "\n", |
159 | 161 | "c_ref = matmul_reference(a, b)\n", |
160 | | - "allclose(c_ref, jnp.dot(a, b))" |
| 162 | + "allclose(c_ref, jnp.dot(a, b, out_sharding=P('x', None)))" |
161 | 163 | ] |
162 | 164 | }, |
163 | 165 | { |
|
245 | 247 | "source": [ |
246 | 248 | "import numpy as np\n", |
247 | 249 | "devices = np.array(jax.devices()[:4])\n", |
248 | | - "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n", |
| 250 | + "mesh = Mesh(devices, ('i',), axis_types=(Explicit,)) # mesh.shape['i'] = 4\n", |
| 251 | + "jax.set_mesh(mesh)\n", |
249 | 252 | "\n", |
250 | 253 | "def check_shmap(f, y):\n", |
251 | 254 | " ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", |
|
293 | 296 | "metadata": {}, |
294 | 297 | "outputs": [], |
295 | 298 | "source": [ |
296 | | - "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", |
| 299 | + "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Auto,) * 2)\n", |
| 300 | + "jax.set_mesh(mesh)\n", |
297 | 301 | "\n", |
298 | 302 | "@jax.shard_map(mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", |
299 | 303 | "def f1(x_block):\n", |
|
494 | 498 | "metadata": {}, |
495 | 499 | "outputs": [], |
496 | 500 | "source": [ |
497 | | - "mesh = jax.make_mesh((2,), ('i',))\n", |
| 501 | + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", |
| 502 | + "jax.set_mesh(mesh)\n", |
498 | 503 | "\n", |
499 | 504 | "@jax.shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", |
500 | 505 | "def f(x):\n", |
|
607 | 612 | "metadata": {}, |
608 | 613 | "outputs": [], |
609 | 614 | "source": [ |
610 | | - "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", |
| 615 | + "mesh = jax.make_mesh((4, 2), ('i', 'j'), axis_types=(Explicit,) * 2)\n", |
| 616 | + "jax.set_mesh(mesh)\n", |
611 | 617 | "\n", |
612 | 618 | "@jax.shard_map(mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", |
613 | 619 | "def f(x):\n", |
|
645 | 651 | "metadata": {}, |
646 | 652 | "outputs": [], |
647 | 653 | "source": [ |
648 | | - "mesh = jax.make_mesh((2,), ('i',))\n", |
| 654 | + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", |
| 655 | + "jax.set_mesh(mesh)\n", |
649 | 656 | "\n", |
650 | 657 | "x = jnp.arange(6.)\n", |
651 | 658 | "try:\n", |
|
741 | 748 | "metadata": {}, |
742 | 749 | "outputs": [], |
743 | 750 | "source": [ |
744 | | - "mesh = jax.make_mesh((2,), ('i',))\n", |
| 751 | + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", |
| 752 | + "jax.set_mesh(mesh)\n", |
745 | 753 | "\n", |
746 | 754 | "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", |
747 | 755 | "def f(x, y):\n", |
|
776 | 784 | "metadata": {}, |
777 | 785 | "outputs": [], |
778 | 786 | "source": [ |
779 | | - "mesh = jax.make_mesh((2,), ('i',))\n", |
| 787 | + "mesh = jax.make_mesh((2,), ('i',), axis_types=(Explicit,))\n", |
| 788 | + "jax.set_mesh(mesh)\n", |
780 | 789 | "\n", |
781 | 790 | "@jax.shard_map(mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", |
782 | 791 | "def f(x, y):\n", |
|
946 | 955 | "outputs": [], |
947 | 956 | "source": [ |
948 | 957 | "mesh1d = Mesh(jax.devices()[:4], ('i',))\n", |
| 958 | + "jax.set_mesh(mesh1d)\n", |
949 | 959 | "\n", |
950 | 960 | "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", |
951 | 961 | "def f1(x_block):\n", |
|
1002 | 1012 | "outputs": [], |
1003 | 1013 | "source": [ |
1004 | 1014 | "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n", |
| 1015 | + "jax.set_mesh(mesh2d)\n", |
1005 | 1016 | "\n", |
1006 | 1017 | "@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", |
1007 | 1018 | "def f2(x_block):\n", |
|
1031 | 1042 | "cell_type": "code", |
1032 | 1043 | "execution_count": null, |
1033 | 1044 | "id": "2919056c", |
1034 | | - "metadata": {}, |
| 1045 | + "metadata": { |
| 1046 | + "lines_to_end_of_cell_marker": 2 |
| 1047 | + }, |
1035 | 1048 | "outputs": [], |
1036 | 1049 | "source": [ |
1037 | 1050 | "@jax.shard_map(mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", |
|
1071 | 1084 | "metadata": {}, |
1072 | 1085 | "outputs": [], |
1073 | 1086 | "source": [ |
| 1087 | + "jax.set_mesh(mesh1d)\n", |
| 1088 | + "\n", |
1074 | 1089 | "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", |
1075 | 1090 | "def f4(x_block):\n", |
1076 | 1091 | " print('BEFORE:\\n', x_block)\n", |
|
1153 | 1168 | "metadata": {}, |
1154 | 1169 | "outputs": [], |
1155 | 1170 | "source": [ |
1156 | | - "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", |
| 1171 | + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", |
1157 | 1172 | "def f6(x_block):\n", |
1158 | 1173 | " print('BEFORE:\\n', x_block)\n", |
1159 | 1174 | " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", |
|
1229 | 1244 | "metadata": {}, |
1230 | 1245 | "outputs": [], |
1231 | 1246 | "source": [ |
1232 | | - "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", |
| 1247 | + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", |
1233 | 1248 | "def f7(x_block):\n", |
1234 | 1249 | " sz = jax.lax.axis_size('i')\n", |
1235 | 1250 | " print('BEFORE:\\n', x_block)\n", |
|
1307 | 1322 | "metadata": {}, |
1308 | 1323 | "outputs": [], |
1309 | 1324 | "source": [ |
1310 | | - "@jax.shard_map(mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", |
| 1325 | + "@jax.shard_map(in_specs=P('i'), out_specs=P('i'))\n", |
1311 | 1326 | "def f8(x_block):\n", |
1312 | 1327 | " print('BEFORE:\\n', x_block)\n", |
1313 | 1328 | " y_block = psum_scatter(x_block, 'i', tiled=True)\n", |
|
1438 | 1453 | "outputs": [], |
1439 | 1454 | "source": [ |
1440 | 1455 | "mesh = Mesh(jax.devices()[:4], ('i',))\n", |
| 1456 | + "jax.set_mesh(mesh)\n", |
1441 | 1457 | "\n", |
1442 | 1458 | "def device_put(x, pspec):\n", |
1443 | 1459 | " return jax.device_put(x, NamedSharding(mesh, pspec))" |
|
1884 | 1900 | "source": [ |
1885 | 1901 | "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", |
1886 | 1902 | "\n", |
1887 | | - "mesh = jax.make_mesh((8,), ('batch',))\n", |
| 1903 | + "mesh = jax.make_mesh((8,), ('batch',), axis_types=(Auto,))\n", |
| 1904 | + "jax.set_mesh(mesh)\n", |
1888 | 1905 | "\n", |
1889 | 1906 | "# replicate initial params on all devices, shard data batch over devices\n", |
1890 | 1907 | "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", |
|
1982 | 1999 | "source": [ |
1983 | 2000 | "# shard data batch *and params* over devices\n", |
1984 | 2001 | "mesh = Mesh(devices, ('batch',))\n", |
| 2002 | + "jax.set_mesh(mesh)\n", |
1985 | 2003 | "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", |
1986 | 2004 | "params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n", |
1987 | 2005 | "\n", |
|
2055 | 2073 | "metadata": {}, |
2056 | 2074 | "outputs": [], |
2057 | 2075 | "source": [ |
2058 | | - "mesh = jax.make_mesh((8,), ('feats',))\n", |
| 2076 | + "mesh = jax.make_mesh((8,), ('feats',), axis_types=(Auto,))\n", |
| 2077 | + "jax.set_mesh(mesh)\n", |
2059 | 2078 | "\n", |
2060 | 2079 | "batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n", |
2061 | 2080 | "params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n", |
|
2096 | 2115 | "metadata": {}, |
2097 | 2116 | "outputs": [], |
2098 | 2117 | "source": [ |
2099 | | - "mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n", |
| 2118 | + "mesh = jax.make_mesh((4, 2), ('batch', 'feats'), axis_types=(Auto,) * 2)\n", |
| 2119 | + "jax.set_mesh(mesh)\n", |
2100 | 2120 | "\n", |
2101 | 2121 | "batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n", |
2102 | 2122 | "params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n", |
|
2285 | 2305 | "metadata": {}, |
2286 | 2306 | "outputs": [], |
2287 | 2307 | "source": [ |
2288 | | - "print(jax.jit(loss)(params, batch))\n", |
| 2308 | + "print(jax.jit(loss)(params, batch))" |
| 2309 | + ] |
| 2310 | + }, |
| 2311 | + { |
| 2312 | + "cell_type": "code", |
| 2313 | + "execution_count": null, |
| 2314 | + "id": "9ff83661", |
| 2315 | + "metadata": {}, |
| 2316 | + "outputs": [], |
| 2317 | + "source": [ |
| 2318 | + "jax.set_mesh(mesh)\n", |
2289 | 2319 | "print(jax.jit(loss_pp)(params_, batch_))" |
2290 | 2320 | ] |
2291 | 2321 | }, |
|
2303 | 2333 | "metadata": { |
2304 | 2334 | "jupytext": { |
2305 | 2335 | "cell_metadata_filter": "-all", |
2306 | | - "formats": "ipynb,md:myst", |
2307 | | - "main_language": "python" |
| 2336 | + "formats": "ipynb,md:myst,py" |
2308 | 2337 | }, |
2309 | 2338 | "kernelspec": { |
2310 | 2339 | "display_name": "Python 3", |
|
0 commit comments