You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.
461
+
In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.
462
462
463
463
In this example, the input array of size (32,2) is manually split into quarters of size (8,2) along the `x` axis by user and assigned to each host.
464
464
@@ -473,11 +473,18 @@ else:
473
473
input_data = np.arange(48,64).reshape(8,2)
474
474
```
475
475
476
+
Pjit always assumes that the input is the local data chunk of a global array. If the local chunk it to be sharded over multiple local devices and is not partitioned as expected, pjit will put the right slices on the right **local devices** for you. Once all of the local chunks are on the devices on all the
477
+
hosts, then XLA will run the computation.
478
+
479
+
XLA operates on the global data so if `in_axis_resources` is different than `out_axis_resources` then XLA will do data redistribution cross-host. So global communication doesn't happen in preparation for the launch of the XLA executable that pjit represents, but only in the XLA executable itself.
480
+
481
+
One way to do data redistribution cross-host is to use an identity pjit with `in_axis_resources` different from `out_axis_resources`. XLA will do the global data reordering for you via pjit.
482
+
476
483
+++ {"id": "gTS_bgtkdch1"}
477
484
478
485
### in_axis_resources & out_axis_resources
479
486
480
-
-`in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. This lets Pjit know that the (32, 2) input data is already split evenly across hosts (done by user). Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
487
+
-`in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
481
488
-`out_axis_resources`: PartitionSpec('x', 'y'). It specifies that the two dimensions of output data are sharded over `x` and `y` respectively, so each device gets a (2,1) slice.
482
489
483
490
**Note**: in_axis_resources and out_axis_resources are different. Here, in_axis_resources shards input data's first dimension over both `x` and `y`, whereas out_axis_resources shards input data's first dimension only over `x`.
0 commit comments