Skip to content

all_reduce does not apply scale when xr.world_size == 1 #9670

@afzalxo

Description

@afzalxo

❓ Questions and Help

Hi, I have noticed that when world_size == 1, all_reduce is a no-op and does not apply scale:

In torch_xla.core.xla_model in def all_reduce:

# No-op if there is only one device
  if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
                                                    bool, False):
    if isinstance(inputs, torch.Tensor):
      return inputs.clone()
    else:
      return inputs

Is this intended behavior? If it is indeed intended, it makes the use of all_reduce inconsistent when using world_size == 1 vs world_size > 1. The issue manifests, for example, when you are logging running average loss value:

epoch_loss = xm.all_reduce(xm.REDUCE_SUM, loss_accum, scale=1.0 / ((idx + 1) * world_size))

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions