Skip to content

ROCM-Aware MPI requires AMDGPU.synchronize() #2591

@Alexander-Barth

Description

@Alexander-Barth

When using Distributed Data Parallel (DDP) with two AMD GPUs communicating via ROCM-aware MPI, AMDGPU.synchronize() is necessary at different steps otherwise the state of the optimizer is inconsistent or the averaged gradients are wrong.
This is a follow-up from this dicussion:

https://discourse.julialang.org/t/distributed-data-parallel-training-with-2-gpus-fails-with-flux-jl-on-amd-gpus/125993/6

The serial code (using SERIAL=true) works as expected:

serial = get(ENV,"SERIAL","false") == "true"

import AMDGPU
using Flux
using Optimisers
using Zygote
using Statistics
using Random
if !serial
    import MPI
end

Random.seed!(42)

function pprintln(backend,args...)
    MPI.Barrier(backend.comm)
    print("rank ",DistributedUtils.local_rank(backend),": ")
    println(args...)
end
pprintln(::Nothing,args...) = println(args...)

AMDGPU.allowscalar(false)

@show Flux.MPI_ROCM_AWARE

if !serial
    const backend_type = MPIBackend
    DistributedUtils.initialize(backend_type)
    backend = DistributedUtils.get_distributed_backend(backend_type)
else
    backend = nothing
end

T = Float32
device = gpu

x = randn(T,256,256,32,16*2) |> device

channels = 2 .^ vcat(5:7,6:-1:5)

model = Chain(
    [Conv((3,3),channels[i] => channels[i+1],pad=SamePad(),selu) for i in 1:length(channels)-1]...
)

losses = T[]
model = model |> device

loss(x,y) = mean((x-y).^2)

opt_s = Optimisers.Adam(1f-4)
#opt_s = Optimisers.Descent(0.01f0) # ok

if !serial
    data = DistributedUtils.DistributedDataContainer(backend, x)
    model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
    opt = DistributedUtils.DistributedOptimizer(backend, opt_s)
else
    data = x
    opt = opt_s
end

opt_state = Optimisers.setup(opt, model)

AMDGPU.synchronize() # necessary

if !serial
    opt_state = DistributedUtils.synchronize!!(backend, opt_state; root=0)
end

dl = Flux.DataLoader(data,batchsize=16)


for i = 1:1000
    global model, opt_state
    for (j,x_batch) in enumerate(dl)
        val, grads = Flux.withgradient(model) do m
            loss(x_batch,m(x_batch))
        end

        AMDGPU.synchronize() # necessary

        push!(losses, val)
        opt_state, model = Optimisers.update(opt_state, model, grads[1])
#        pprintln(backend,"update ",i," ",model.layers[1].weight[1:1])
    end
end

pprintln(backend,"losses ",losses)

The output is of this program without AMDGPU.synchronize() is:

Flux.MPI_ROCM_AWARE = true
Flux.MPI_ROCM_AWARE = true
rank 1: update 1 Float32[0.005779413]
rank 0: update 1 Float32[0.005779413]
rank 1: update 2 Float32[NaN]
rank 0: update 2 Float32[0.0056868508]
rank 1: update 3 Float32[NaN]
rank 0: update 3 Float32[0.005596291]
rank 1: update 4 Float32[NaN]
rank 0: update 4 Float32[0.0055066617]
rank 1: lossesFloat32[2.0405662, NaN, NaN, NaN]
rank 0: lossesFloat32[2.040605, 2.0056882, 1.9737886, 1.9429569]

My environment:

julia 1.11.2

⌃ [21141c5a] AMDGPU v1.2.2
  [0a1fb500] BlockDiagonals v0.1.42
  [052768ef] CUDA v5.6.1
⌃ [13f3f980] CairoMakie v0.12.18
⌃ [b0b7db55] ComponentArrays v0.15.22
  [efc8151c] DIVAnd v2.7.12
  [cf87cc76] DataAssim v0.4.1
  [8bb1440f] DelimitedFiles v1.9.1
  [4e2335b7] FlowMatching v0.1.0 `..`
  [587475ba] Flux v0.16.3 `~/.julia/dev/Flux`
  [db073c08] GeoMakie v0.7.10
  [033835bb] JLD2 v0.5.11
  [f1d291b0] MLUtils v0.4.7
  [da04e1cc] MPI v0.20.22
  [3da0fdf6] MPIPreferences v0.1.11
  [85f8d34a] NCDatasets v0.14.6
  [3bd65402] Optimisers v0.4.4
  [21216c6a] Preferences v1.4.3
  [10745b16] Statistics v1.11.1
⌃ [e88e6eb3] Zygote v0.7.3 or Zygote v0.7.4
  [02a925ec] cuDNN v1.4.1
  [ade2ca70] Dates v1.11.0
  [de0858da] Printf v1.11.0
  [8dfed614] Test v1.11.0

Just using MPI and AMDGPU, we can see that without AMDGPU.synchronize(), the send message is wrong in this example:

using MPI
using AMDGPU
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
# select device
comm_l = MPI.Comm_split_type(comm, MPI.COMM_TYPE_SHARED, rank)
rank_l = MPI.Comm_rank(comm_l)
device = AMDGPU.device_id!(rank_l+1)
gpu_id = AMDGPU.device_id(AMDGPU.device())
# select device
size = MPI.Comm_size(comm)
dst  = mod(rank+1, size)
src  = mod(rank-1, size)
println("rank=$rank rank_loc=$rank_l (gpu_id=$gpu_id - $device), size=$size, dst=$dst, src=$src")
N = 4
send_mesg = ROCArray{Float64}(undef, N)
recv_mesg = ROCArray{Float64}(undef, N)
fill!(send_mesg, Float64(rank))
send_mesg .+= 1

AMDGPU.synchronize() # necessary

MPI.Sendrecv!(send_mesg, dst, 0, recv_mesg, src, 0, comm)

if rank == 0
    println("got ",Array(recv_mesg))
    println("correct: ",all(Array(recv_mesg) .== (src+1)))
end

Rank zero gets the correct message only 2 out of 20 tries. With AMDGPU.synchronize() all received messages are correct.

Thanks to @pxl-th for suggesting that this is a synchronization issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions