-
Notifications
You must be signed in to change notification settings - Fork 54
Description
Trying to reproduce bug 2472 on Enzyme, I kept reducing my examples until I found that the most trivial use of GCNConv
results in Enzyme throwing an error about "Constant memory is stored (or returned) to a differentiable variable", and sometimes even in segfaults. Here's the example:
using Lux, GNNLux, Random, Enzyme, Optimisers
Lux.@concrete struct GraphPolicyLux <: GNNContainerLayer{(:conv1, :conv2, :dense)}
conv1
conv2
dense
end
function GraphPolicyLux(nin::Int, d::Int, n_nodes::Int)
conv1 = GCNConv(nin => d)
conv2 = GCNConv(d => d)
dense = Dense(d => n_nodes, sigmoid)
return GraphPolicyLux(conv1, conv2, dense)
end
function (model::GraphPolicyLux)(g::GNNGraph, x, ps, st)
dense = StatefulLuxLayer{true}(model.dense, ps.dense, GNNLux._getstate(st, :dense))
x, st_c1 = model.conv1(g, x, ps.conv1, st.conv1)
x = tanh.(x)
x, st_c2 = model.conv2(g, x, ps.conv2, st.conv2)
x = relu.(x)
new_weights = dense(x)
return new_weights, (conv1 = st_c1, conv2 = st_c2)
end
# Dummy graph construction
n = 100
D = 2
A = sprand(Float32, n, n, 0.2)
X = rand(Float32, D, n)
g = GNNGraph(A; ndata=(features = X))
# Setup model and parameters
model = GraphPolicyLux(2, 50, g.num_nodes)
ps, st = Lux.setup(Random.default_rng(), model)
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
function compute_loss(model, ps, st, (g, X))
prop_A, st = model(g, X, ps, st)
Reward = 0.0f0
# This is where a simulation with A prop_A would go
for t in 1:100
Reward -= t * 0.01
# Reward -= rand(Float32)
end
return Reward, st, 0
end
function train_trigger_bug!(model, ps, st, g, nepochs::Int=100)
tstate = Training.TrainState(model, ps, st, Adam(0.01f0))
data = (g, g.ndata.x)
for epoch in 1:nepochs
_, loss, _, tstate = Training.single_train_step!(AutoEnzyme(), compute_loss, (g, g.ndata.x), tstate)
@show epoch
end
end
To make sure I wasn't doing something wrong, I copy-pasted the (only) example with Lux.jl. This problem is not limited to Lux.jl, though, and the same issue occurs when using Flux.jl. By changing a single line from
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.x, g.y), train_state)
to
_, loss, _, train_state = Lux.Training.single_train_step!(AutoEnzyme(), custom_loss,(g, g.x, g.y), train_state)
I started seeing the cryptic errors again:
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for: store {} addrspace(10)* %157, {} addrspace(10)** %sret_return.repack.repack, align 8, !dbg !234, !noalias !244 const val: %157 = call fastcc nonnull {} addrspace(10)* @julia_vcat_89903({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %15, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %59) #77, !dbg !228
Type tree: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}
llvalue= %157 = call fastcc nonnull {} addrspace(10)* @julia_vcat_89903({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %15, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %59) #77, !dbg !228
Stacktrace:
[1] add_self_loops
@ /scratch/htc/amartine/julia/packages/GNNGraphs/Ldvz4/src/transform.jl:24
Stacktrace:
[1] add_self_loops
@ /scratch/htc/amartine/julia/packages/GNNGraphs/Ldvz4/src/transform.jl:24
[2] #_#10
@ /scratch/htc/amartine/julia/packages/GNNLux/AHHiN/src/layers/conv.jl:137 [inlined]
[3] GCNConv
@ /scratch/htc/amartine/julia/packages/GNNLux/AHHiN/src/layers/conv.jl:131 [inlined]
[4] GCNConv
@ /scratch/htc/amartine/julia/packages/GNNLux/AHHiN/src/layers/conv.jl:128 [inlined]
[5] GCN
@ ./REPL[21]:3
[6] custom_loss
@ ./REPL[29]:4
[7] #4
@ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:257 [inlined]
[8] augmented_julia__4_89363_inner_33wrap
@ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:0
[9] macro expansion
@ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5610 [inlined]
[10] enzyme_call
@ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5144 [inlined]
[11] AugmentedForwardThunk
@ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5083 [inlined]
[12] autodiff
@ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/Enzyme.jl:408 [inlined]
[13] compute_gradients_impl
@ /scratch/htc/amartine/julia/packages/Lux/ptjU6/ext/LuxEnzymeExt/training.jl:10 [inlined]
[14] compute_gradients
@ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:200 [inlined]
[15] single_train_step_impl!
@ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:327 [inlined]
[16] #single_train_step!#6
@ /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:292 [inlined]
[17] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(custom_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training /scratch/htc/amartine/julia/packages/Lux/ptjU6/src/helpers/training.jl:285
[18] train_model!(gcn::GCN{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, g::GNNGraph{…})
@ Main ./REPL[35]:4
[19] top-level scope
@ REPL[36]:1
Some type information was truncated. Use `show(err)` to see complete types.
Now that I have come across #389 I realize why. I think it would be a good idea to include a warning on the documentation that Zygote.jl is not supported. Perhaps, better yet, catch it and let the user know before they get the scary-looking exception with bits of IR and so that they are less likely to spend a bunch of time wondering why it's just not working.