Skip to content

Is Enzyme.jl compatible at all? #609

@alonsoC1s

Description

@alonsoC1s

Trying to reproduce bug 2472 on Enzyme, I kept reducing my examples until I found that the most trivial use of GCNConv results in Enzyme throwing an error about "Constant memory is stored (or returned) to a differentiable variable", and sometimes even in segfaults. Here's the example:

using Lux, GNNLux, Random, Enzyme, Optimisers

Lux.@concrete struct GraphPolicyLux <: GNNContainerLayer{(:conv1, :conv2, :dense)}
    conv1
    conv2
    dense
end

function GraphPolicyLux(nin::Int, d::Int, n_nodes::Int)
    conv1 = GCNConv(nin => d)
    conv2 = GCNConv(d => d)
    dense = Dense(d => n_nodes, sigmoid)
    return GraphPolicyLux(conv1, conv2, dense)
end

function (model::GraphPolicyLux)(g::GNNGraph, x, ps, st)
    dense = StatefulLuxLayer{true}(model.dense, ps.dense, GNNLux._getstate(st, :dense))
    x, st_c1 = model.conv1(g, x, ps.conv1, st.conv1)
    x = tanh.(x)
    x, st_c2 = model.conv2(g, x, ps.conv2, st.conv2)
    x = relu.(x)
    new_weights = dense(x)
    return new_weights, (conv1 = st_c1, conv2 = st_c2)
end


# Dummy graph construction
n = 100
D = 2
A = sprand(Float32, n, n, 0.2)
X = rand(Float32, D, n)
g = GNNGraph(A; ndata=(features = X))
# Setup model and parameters
model = GraphPolicyLux(2, 50, g.num_nodes)
ps, st = Lux.setup(Random.default_rng(), model)
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

function compute_loss(model, ps, st, (g, X))
    prop_A, st = model(g, X, ps, st)
    Reward = 0.0f0
    # This is where a simulation with A prop_A would go
    for t in 1:100
        Reward -= t * 0.01
        # Reward -= rand(Float32)
    end
    return Reward, st, 0
end

function train_trigger_bug!(model, ps, st, g, nepochs::Int=100)
    tstate = Training.TrainState(model, ps, st, Adam(0.01f0))
    data = (g, g.ndata.x)
    for epoch in 1:nepochs
        _, loss, _, tstate = Training.single_train_step!(AutoEnzyme(), compute_loss, (g, g.ndata.x), tstate)
        @show epoch
    end
end

To make sure I wasn't doing something wrong, I copy-pasted the (only) example with Lux.jl. This problem is not limited to Lux.jl, though, and the same issue occurs when using Flux.jl. By changing a single line from

 _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.x, g.y), train_state)

to

 _, loss, _, train_state = Lux.Training.single_train_step!(AutoEnzyme(), custom_loss,(g, g.x, g.y), train_state)

I started seeing the cryptic errors again:

ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   store {} addrspace(10)* %157, {} addrspace(10)** %sret_return.repack.repack, align 8, !dbg !234, !noalias !244 const val:   %157 = call fastcc nonnull {} addrspace(10)* @julia_vcat_89903({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %15, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %59) #77, !dbg !228
Type tree: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}
 llvalue=  %157 = call fastcc nonnull {} addrspace(10)* @julia_vcat_89903({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %15, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %59) #77, !dbg !228

Stacktrace:
 [1] add_self_loops
   @ /scratch/htc/amartine/julia/packages/GNNGraphs/Ldvz4/src/transform.jl:24

Stacktrace:
  [1] add_self_loops
    @ /scratch/htc/amartine/julia/packages/GNNGraphs/Ldvz4/src/transform.jl:24
  [2] #_#10
    @ /scratch/htc/amartine/julia/packages/GNNLux/AHHiN/src/layers/conv.jl:137 [inlined]
  [3] GCNConv
    @ /scratch/htc/amartine/julia/packages/GNNLux/AHHiN/src/layers/conv.jl:131 [inlined]
  [4] GCNConv
    @ /scratch/htc/amartine/julia/packages/GNNLux/AHHiN/src/layers/conv.jl:128 [inlined]
  [5] GCN
    @ ./REPL[21]:3
  [6] custom_loss
    @ ./REPL[29]:4
  [7] #4
    @ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:257 [inlined]
  [8] augmented_julia__4_89363_inner_33wrap
    @ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:0
  [9] macro expansion
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5610 [inlined]
 [10] enzyme_call
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5144 [inlined]
 [11] AugmentedForwardThunk
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5083 [inlined]
 [12] autodiff
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/Enzyme.jl:408 [inlined]
 [13] compute_gradients_impl
    @ /scratch/htc/amartine/julia/packages/Lux/ptjU6/ext/LuxEnzymeExt/training.jl:10 [inlined]
 [14] compute_gradients
    @ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:200 [inlined]
 [15] single_train_step_impl!
    @ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:327 [inlined]
 [16] #single_train_step!#6
    @ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:292 [inlined]
 [17] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(custom_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:285
 [18] train_model!(gcn::GCN{…}, ps::@NamedTuple{}, st::@NamedTuple{}, g::GNNGraph{…})
    @ Main ./REPL[35]:4
 [19] top-level scope
    @ REPL[36]:1
Some type information was truncated. Use `show(err)` to see complete types.

Now that I have come across #389 I realize why. I think it would be a good idea to include a warning on the documentation that Zygote.jl is not supported. Perhaps, better yet, catch it and let the user know before they get the scary-looking exception with bits of IR and so that they are less likely to spend a bunch of time wondering why it's just not working.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions