-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Discussion : What is the functionality of pygrain.ShardOption
, like pygrain.ShardByJaxProcess
?
From what I understand, for each process, it loads a different batch. At the end of the process, we need to use jax.make_array_from_process_local_data
to combine the batches into a single global batch with global sharding (across multiple devices in multiple processes), especially if the model is also sharded across global_devices.
Correct me if I'm wrong, but with this approach, we still need to load the entire dataset for each process. This is where I’m confused— There is another way to handle this.
On the other hand, since we load all the data for each process, we can shard the array across global_devices as long as the batch is consistent across processes.
Even when I want to load different datasets for each process (for example, when I have multiple files like data_#number.json
), I think we don’t need to use pygrain.ShardOption
. Instead, we can load the batch and create a global batch using jax.make_array_from_process_local_data
.
I feel like pygrain.ShardOption
is inspired by torch.DistributedSampler
, but in the case of Torch, it makes sense because the API for sharding the array is not really a well known one.
is there any scenario where using pygrain.shardoption in dataloader is the only way?