Skip to content

Commit 821fcaa

Browse files
yashk2810jax authors
authored andcommitted
Make the pjit docs clear about who does local and global communication
PiperOrigin-RevId: 405421833
1 parent 0f47712 commit 821fcaa

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

docs/jax-101/08-pjit.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ mesh = maps.Mesh(devices, ('x', 'y'))
458458

459459
### Input data
460460

461-
In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.
461+
In a multi-host environment, all the devices connected to one host have to contain a subslice of a single continuous large slice of the data. In JAX SPMD, there are no direct communications between hosts, so hosts only talk to each other via collective communication between devices. As a result, users need to handle distributed data loading on hosts.
462462

463463
In this example, the input array of size (32,2) is manually split into quarters of size (8,2) along the `x` axis by user and assigned to each host.
464464

@@ -473,11 +473,18 @@ else:
473473
input_data = np.arange(48,64).reshape(8,2)
474474
```
475475

476+
Pjit always assumes that the input is the local data chunk of a global array. If the local chunk it to be sharded over multiple local devices and is not partitioned as expected, pjit will put the right slices on the right **local devices** for you. Once all of the local chunks are on the devices on all the
477+
hosts, then XLA will run the computation.
478+
479+
XLA operates on the global data so if `in_axis_resources` is different than `out_axis_resources` then XLA will do data redistribution cross-host. So global communication doesn't happen in preparation for the launch of the XLA executable that pjit represents, but only in the XLA executable itself.
480+
481+
One way to do data redistribution cross-host is to use an identity pjit with `in_axis_resources` different from `out_axis_resources`. XLA will do the global data reordering for you via pjit.
482+
476483
+++ {"id": "gTS_bgtkdch1"}
477484

478485
### in_axis_resources & out_axis_resources
479486

480-
- `in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. This lets Pjit know that the (32, 2) input data is already split evenly across hosts (done by user). Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
487+
- `in_axis_resources`: PartitionSpec(('x', 'y'),). This partitions the first dimension of input data over both `x` and `y` axes. Since input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, pjit sends the (8, 2) on each host to its devices based on `in_axis_resources`. Since each host has a logical mesh of size (4, 2) within the entire logical mesh, each device has a (1, 2) slice.
481488
- `out_axis_resources`: PartitionSpec('x', 'y'). It specifies that the two dimensions of output data are sharded over `x` and `y` respectively, so each device gets a (2,1) slice.
482489

483490
**Note**: in_axis_resources and out_axis_resources are different. Here, in_axis_resources shards input data's first dimension over both `x` and `y`, whereas out_axis_resources shards input data's first dimension only over `x`.

0 commit comments

Comments
 (0)