-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
using Lux, Reactant, Enzyme, Random
xdev = reactant_device()
function enzyme_grad(model, x, ps, st)
return Enzyme.gradient(
Enzyme.Reverse,
sum ∘ first ∘ Lux.apply,
Const(model),
x,
ps,
Const(st),
)[2]
end
model = Recurrence(RNNCell(4 => 4); ordering=BatchLastIndex(), checkpointing=true)
model_no_checkpoint = Recurrence(RNNCell(4 => 4); ordering=BatchLastIndex())
x_ra = randn(Float32, 4, 16, 12) |> xdev
ps_ra, st_ra = Lux.setup(Random.default_rng(), model) |> xdev
@jit enzyme_grad(model, x_ra, ps_ra, st_ra)
@jit enzyme_grad(model_no_checkpoint, x_ra, ps_ra, st_ra)xref LuxDL/Lux.jl#1561 (comment)
Will need Lux installed via #ap/rnn_testing
Metadata
Metadata
Assignees
Labels
No labels