Skip to content

Documentation's example using Enzyme does not work on GPU #2621

@camilodlt

Description

@camilodlt

Hello,

I'm having trouble running the Enzyme example in the documentation. Compared to the docs, in this case I send everything to the gpu (model and inputs).

using CUDA
using Flux
using Enzyme

model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);
dup_model = Duplicated(model |> gpu);

x1 = randn32(28 * 28, 1) |> gpu;
y1 = [i == 3 for i in 0:9] |> gpu;
grads_f = Flux.gradient((m, x, y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))

The last function takes a lot of time and eventually throws and error:

ERROR: "Error cannot store inactive but differentiable variable Float32[0.42446053; -0.43670258; -0.22128746; -0.09921402; -1.5923102; -0.50225735; 0.7328375; -1.5166384; -0.22234721; 0.96836793; 1.4810076; -0.28374726; 2.0655832; 0.22402526; -2.1271694; 0.96447814; -0.8850093; -1.1225328; 
[...]
-0.84058493;;] into active tuple"
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:15 [inlined]
  [2] create_shadow_ret
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:3 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:85 [inlined]
  [4] runtime_newstruct_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Ptr{…}, ::Nothing, ::Char, ::Nothing, ::Char, ::Nothing, ::Int64, ::Nothing, ::Int64, ::Nothing, ::Int64, ::Nothing, ::CUDA.CuRefValue{…}, ::CUDA.CuRefValue{…}, ::CuArray{…}, ::CuArray{…}, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CuArray{…}, ::Nothing, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CUDA.CuRefValue{…}, ::CUDA.CuRefValue{…}, ::CuArray{…}, ::CuArray{…}, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CUDA.CUBLAS.cublasComputeType_t, ::CUDA.CUBLAS.cublasComputeType_t, ::CUDA.CUBLAS.cublasGemmAlgo_t, ::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:357
  [5] cublasGemmEx
    @ ~/.julia/packages/GPUToolbox/XaIIx/src/ccalls.jl:33 [inlined]
  [6] #gemmEx!#1222
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:1251 [inlined]
  [7] augmented_julia__gemmEx__1222_386193wrap
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:0
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5713 [inlined]
  [9] enzyme_call
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5247 [inlined]
 [10] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5186 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/jitrules.jl:447 [inlined]
 [12] runtime_generic_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::CUDA.CUBLAS.var"##gemmEx!#1222", ::Nothing, ::CUDA.CUBLAS.cublasGemmAlgo_t, ::Nothing, ::typeof(CUDA.CUBLAS.gemmEx!), ::Nothing, ::Char, ::Nothing, ::Char, ::Nothing, ::Bool, ::Nothing, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::Nothing, ::Bool, ::Nothing, ::CuArray{…}, ::CuArray{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ez9it/src/rules/jitrules.jl:574
 [13] gemmEx!
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:1230 [inlined]
 [14] generic_matmatmul!
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/linalg.jl:251
 [15] generic_matmatmul!
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/linalg.jl:226 [inlined]
 [16] _mul!
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
 [17] mul!
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [18] mul!
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [19] *
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:124 [inlined]
 [20] Dense
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:199
 [21] macro expansion
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:68 [inlined]
 [22] _applychain
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:68 [inlined]
 [23] Chain
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:65 [inlined]
 [24] #3
    @ ./REPL[8]:1 [inlined]
 [25] diffejulia__3_32568_inner_652wrap
    @ ./REPL[8]:0
 [26] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5713 [inlined]
 [27] enzyme_call
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5247 [inlined]
 [28] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5122 [inlined]
 [29] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/ez9it/src/Enzyme.jl:517
 [30] _enzyme_gradient(::Function, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
    @ FluxEnzymeExt ~/.julia/packages/Flux/uRn8o/ext/FluxEnzymeExt/FluxEnzymeExt.jl:50
 [31] gradient(::Function, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
    @ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:122
 [32] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

Neither dup_model = Duplicated(model |> gpu) or dup_model = Duplicated(model) |> gpu work.

Julia info :

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × Intel(R) Xeon(R) W-11955M CPU @ 2.60GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, tigerlake)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  LD_GOLD = /home/camilo/miniconda3/envs/pytorch3/bin/x86_64-conda-linux-gnu-ld.gold
  JULIA_CONDAPKG_BACKEND = Current
  JULIA_CONDAPKG_OFFLINE = true

Cuda info :

CUDA toolchain:
- runtime 13.0, artifact installation
- driver 550.163.1 for 13.0
- compiler 13.0

CUDA libraries:
- CUBLAS: 13.0.2
- CURAND: 10.4.0
- CUFFT: 12.0.0
- CUSOLVER: 12.0.4
- CUSPARSE: 12.6.3
- CUPTI: 2025.3.1 (API 130001.0.0)
- NVML: 12.0.0+550.163.1

Julia packages:
- CUDA: 5.8.3
- CUDA_Driver_jll: 13.0.1+0
- CUDA_Compiler_jll: 0.2.1+0
- CUDA_Runtime_jll: 0.19.1+0

Toolchain:
- Julia: 1.11.6
- LLVM: 16.0.6

1 device:
  0: NVIDIA RTX A5000 Laptop GPU (sm_86, 12.816 GiB / 16.000 GiB available)

Flux info : Flux v0.16.5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions