-
-
Notifications
You must be signed in to change notification settings - Fork 614
Open
Description
When using Distributed Data Parallel (DDP) with two AMD GPUs communicating via ROCM-aware MPI, AMDGPU.synchronize()
is necessary at different steps otherwise the state of the optimizer is inconsistent or the averaged gradients are wrong.
This is a follow-up from this dicussion:
The serial code (using SERIAL=true
) works as expected:
serial = get(ENV,"SERIAL","false") == "true"
import AMDGPU
using Flux
using Optimisers
using Zygote
using Statistics
using Random
if !serial
import MPI
end
Random.seed!(42)
function pprintln(backend,args...)
MPI.Barrier(backend.comm)
print("rank ",DistributedUtils.local_rank(backend),": ")
println(args...)
end
pprintln(::Nothing,args...) = println(args...)
AMDGPU.allowscalar(false)
@show Flux.MPI_ROCM_AWARE
if !serial
const backend_type = MPIBackend
DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)
else
backend = nothing
end
T = Float32
device = gpu
x = randn(T,256,256,32,16*2) |> device
channels = 2 .^ vcat(5:7,6:-1:5)
model = Chain(
[Conv((3,3),channels[i] => channels[i+1],pad=SamePad(),selu) for i in 1:length(channels)-1]...
)
losses = T[]
model = model |> device
loss(x,y) = mean((x-y).^2)
opt_s = Optimisers.Adam(1f-4)
#opt_s = Optimisers.Descent(0.01f0) # ok
if !serial
data = DistributedUtils.DistributedDataContainer(backend, x)
model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
opt = DistributedUtils.DistributedOptimizer(backend, opt_s)
else
data = x
opt = opt_s
end
opt_state = Optimisers.setup(opt, model)
AMDGPU.synchronize() # necessary
if !serial
opt_state = DistributedUtils.synchronize!!(backend, opt_state; root=0)
end
dl = Flux.DataLoader(data,batchsize=16)
for i = 1:1000
global model, opt_state
for (j,x_batch) in enumerate(dl)
val, grads = Flux.withgradient(model) do m
loss(x_batch,m(x_batch))
end
AMDGPU.synchronize() # necessary
push!(losses, val)
opt_state, model = Optimisers.update(opt_state, model, grads[1])
# pprintln(backend,"update ",i," ",model.layers[1].weight[1:1])
end
end
pprintln(backend,"losses ",losses)
The output is of this program without AMDGPU.synchronize()
is:
Flux.MPI_ROCM_AWARE = true
Flux.MPI_ROCM_AWARE = true
rank 1: update 1 Float32[0.005779413]
rank 0: update 1 Float32[0.005779413]
rank 1: update 2 Float32[NaN]
rank 0: update 2 Float32[0.0056868508]
rank 1: update 3 Float32[NaN]
rank 0: update 3 Float32[0.005596291]
rank 1: update 4 Float32[NaN]
rank 0: update 4 Float32[0.0055066617]
rank 1: lossesFloat32[2.0405662, NaN, NaN, NaN]
rank 0: lossesFloat32[2.040605, 2.0056882, 1.9737886, 1.9429569]
My environment:
julia 1.11.2
⌃ [21141c5a] AMDGPU v1.2.2
[0a1fb500] BlockDiagonals v0.1.42
[052768ef] CUDA v5.6.1
⌃ [13f3f980] CairoMakie v0.12.18
⌃ [b0b7db55] ComponentArrays v0.15.22
[efc8151c] DIVAnd v2.7.12
[cf87cc76] DataAssim v0.4.1
[8bb1440f] DelimitedFiles v1.9.1
[4e2335b7] FlowMatching v0.1.0 `..`
[587475ba] Flux v0.16.3 `~/.julia/dev/Flux`
[db073c08] GeoMakie v0.7.10
[033835bb] JLD2 v0.5.11
[f1d291b0] MLUtils v0.4.7
[da04e1cc] MPI v0.20.22
[3da0fdf6] MPIPreferences v0.1.11
[85f8d34a] NCDatasets v0.14.6
[3bd65402] Optimisers v0.4.4
[21216c6a] Preferences v1.4.3
[10745b16] Statistics v1.11.1
⌃ [e88e6eb3] Zygote v0.7.3 or Zygote v0.7.4
[02a925ec] cuDNN v1.4.1
[ade2ca70] Dates v1.11.0
[de0858da] Printf v1.11.0
[8dfed614] Test v1.11.0
Just using MPI and AMDGPU, we can see that without AMDGPU.synchronize()
, the send message is wrong in this example:
using MPI
using AMDGPU
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
# select device
comm_l = MPI.Comm_split_type(comm, MPI.COMM_TYPE_SHARED, rank)
rank_l = MPI.Comm_rank(comm_l)
device = AMDGPU.device_id!(rank_l+1)
gpu_id = AMDGPU.device_id(AMDGPU.device())
# select device
size = MPI.Comm_size(comm)
dst = mod(rank+1, size)
src = mod(rank-1, size)
println("rank=$rank rank_loc=$rank_l (gpu_id=$gpu_id - $device), size=$size, dst=$dst, src=$src")
N = 4
send_mesg = ROCArray{Float64}(undef, N)
recv_mesg = ROCArray{Float64}(undef, N)
fill!(send_mesg, Float64(rank))
send_mesg .+= 1
AMDGPU.synchronize() # necessary
MPI.Sendrecv!(send_mesg, dst, 0, recv_mesg, src, 0, comm)
if rank == 0
println("got ",Array(recv_mesg))
println("correct: ",all(Array(recv_mesg) .== (src+1)))
end
Rank zero gets the correct message only 2 out of 20 tries. With AMDGPU.synchronize()
all received messages are correct.
Thanks to @pxl-th for suggesting that this is a synchronization issue.
Metadata
Metadata
Assignees
Labels
No labels