-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
CUDAAll things GPUAll things GPU
Description
Bug description
I've experienced the following inconsistency between GPU and CPU gradient computation for sum(abs, _)
.
julia> using Zygote, CUDA
julia> rl, cplx = [0.0f0], [0.0f0 + 0.0f0im]
(Float32[0.0], ComplexF32[0.0f0 + 0.0f0im])
julia> l1(x) = sum(abs, x)
l1 (generic function with 1 method)
julia> Zygote.gradient(l1, rl)
(Float32[0.0],)
julia> Zygote.gradient(l1, cplx)
(ComplexF32[0.0f0 + 0.0f0im],)
julia> Zygote.gradient(l1, cu(rl))
(Float32[1.0],)
julia> Zygote.gradient(l1, cu(cplx))
(ComplexF32[NaN32 + NaN32*im],)
The last one is particularly problematic, as it leads to NaN
values in the gradient that may be hard to understand in a more complex model.
Slack discussion
On Slack, @mcabbott explained to me the most likely cause for this:
- on GPU, in the backward pass Zygote converts
sum(abs, x)
tosum(abs.(x))
and the broadcasting part is differentiated via ForwardDiff - ForwardDiff is responsible for the different values of the gradient
julia> abs(ForwardDiff.Dual(0,1))
Dual{Nothing}(0,1)
julia> abs(ForwardDiff.Dual(0,1) + 0im)
Dual{Nothing}(0.0,NaN)
- even though DiffRules has a rule for
abs
(used for real inputs), for complex inputs the computation passes throughhypot
and the DiffRule method for the derivative ofhypot
in(0, 0)
givesNaN
Not sure what the best fix is here. If DiffRules is open to it, maybe the easiest is to fix their hypot
derivative rule?
Version info
I'm on julia 1.10.5, on a fresh environment with
(jl_FHvUua) pkg> st
Status `/tmp/jl_FHvUua/Project.toml`
[052768ef] CUDA v5.5.2
[e88e6eb3] Zygote v0.6.71
Metadata
Metadata
Assignees
Labels
CUDAAll things GPUAll things GPU