-
-
Notifications
You must be signed in to change notification settings - Fork 614
Open
Description
Hello,
I'm having trouble running the Enzyme example in the documentation. Compared to the docs, in this case I send everything to the gpu (model and inputs).
using CUDA
using Flux
using Enzyme
model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);
dup_model = Duplicated(model |> gpu);
x1 = randn32(28 * 28, 1) |> gpu;
y1 = [i == 3 for i in 0:9] |> gpu;
grads_f = Flux.gradient((m, x, y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))
The last function takes a lot of time and eventually throws and error:
ERROR: "Error cannot store inactive but differentiable variable Float32[0.42446053; -0.43670258; -0.22128746; -0.09921402; -1.5923102; -0.50225735; 0.7328375; -1.5166384; -0.22234721; 0.96836793; 1.4810076; -0.28374726; 2.0655832; 0.22402526; -2.1271694; 0.96447814; -0.8850093; -1.1225328;
[...]
-0.84058493;;] into active tuple"
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:15 [inlined]
[2] create_shadow_ret
@ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:3 [inlined]
[3] macro expansion
@ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:85 [inlined]
[4] runtime_newstruct_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Ptr{…}, ::Nothing, ::Char, ::Nothing, ::Char, ::Nothing, ::Int64, ::Nothing, ::Int64, ::Nothing, ::Int64, ::Nothing, ::CUDA.CuRefValue{…}, ::CUDA.CuRefValue{…}, ::CuArray{…}, ::CuArray{…}, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CuArray{…}, ::Nothing, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CUDA.CuRefValue{…}, ::CUDA.CuRefValue{…}, ::CuArray{…}, ::CuArray{…}, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CUDA.CUBLAS.cublasComputeType_t, ::CUDA.CUBLAS.cublasComputeType_t, ::CUDA.CUBLAS.cublasGemmAlgo_t, ::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:357
[5] cublasGemmEx
@ ~/.julia/packages/GPUToolbox/XaIIx/src/ccalls.jl:33 [inlined]
[6] #gemmEx!#1222
@ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:1251 [inlined]
[7] augmented_julia__gemmEx__1222_386193wrap
@ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:0
[8] macro expansion
@ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5713 [inlined]
[9] enzyme_call
@ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5247 [inlined]
[10] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5186 [inlined]
[11] macro expansion
@ ~/.julia/packages/Enzyme/ez9it/src/rules/jitrules.jl:447 [inlined]
[12] runtime_generic_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::CUDA.CUBLAS.var"##gemmEx!#1222", ::Nothing, ::CUDA.CUBLAS.cublasGemmAlgo_t, ::Nothing, ::typeof(CUDA.CUBLAS.gemmEx!), ::Nothing, ::Char, ::Nothing, ::Char, ::Nothing, ::Bool, ::Nothing, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::Nothing, ::Bool, ::Nothing, ::CuArray{…}, ::CuArray{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/ez9it/src/rules/jitrules.jl:574
[13] gemmEx!
@ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:1230 [inlined]
[14] generic_matmatmul!
@ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/linalg.jl:251
[15] generic_matmatmul!
@ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/linalg.jl:226 [inlined]
[16] _mul!
@ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
[17] mul!
@ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
[18] mul!
@ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
[19] *
@ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:124 [inlined]
[20] Dense
@ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:199
[21] macro expansion
@ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:68 [inlined]
[22] _applychain
@ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:68 [inlined]
[23] Chain
@ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:65 [inlined]
[24] #3
@ ./REPL[8]:1 [inlined]
[25] diffejulia__3_32568_inner_652wrap
@ ./REPL[8]:0
[26] macro expansion
@ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5713 [inlined]
[27] enzyme_call
@ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5247 [inlined]
[28] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5122 [inlined]
[29] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/ez9it/src/Enzyme.jl:517
[30] _enzyme_gradient(::Function, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
@ FluxEnzymeExt ~/.julia/packages/Flux/uRn8o/ext/FluxEnzymeExt/FluxEnzymeExt.jl:50
[31] gradient(::Function, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
@ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:122
[32] top-level scope
@ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.
Neither dup_model = Duplicated(model |> gpu)
or dup_model = Duplicated(model) |> gpu
work.
Julia info :
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 16 × Intel(R) Xeon(R) W-11955M CPU @ 2.60GHz
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, tigerlake)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
LD_GOLD = /home/camilo/miniconda3/envs/pytorch3/bin/x86_64-conda-linux-gnu-ld.gold
JULIA_CONDAPKG_BACKEND = Current
JULIA_CONDAPKG_OFFLINE = true
Cuda info :
CUDA toolchain:
- runtime 13.0, artifact installation
- driver 550.163.1 for 13.0
- compiler 13.0
CUDA libraries:
- CUBLAS: 13.0.2
- CURAND: 10.4.0
- CUFFT: 12.0.0
- CUSOLVER: 12.0.4
- CUSPARSE: 12.6.3
- CUPTI: 2025.3.1 (API 130001.0.0)
- NVML: 12.0.0+550.163.1
Julia packages:
- CUDA: 5.8.3
- CUDA_Driver_jll: 13.0.1+0
- CUDA_Compiler_jll: 0.2.1+0
- CUDA_Runtime_jll: 0.19.1+0
Toolchain:
- Julia: 1.11.6
- LLVM: 16.0.6
1 device:
0: NVIDIA RTX A5000 Laptop GPU (sm_86, 12.816 GiB / 16.000 GiB available)
Flux info : Flux v0.16.5
Metadata
Metadata
Assignees
Labels
No labels