-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
A common issue with users is manually managing the states when writing the forward pass. We have a StatefulLuxLayer, but that isn't used that often. I think we can add the following and push that as the new API for users:
In LuxCore
states_type_remains_constant(::AbstractLuxLayer) = Val(true)
# Currently, this is not defined. With this dispatch, if users
# have this dispatch (current official API), it just uses that
# dispatch. So no breaking change whatsoever
function (m::AbstractLuxLayer)(x, ps, st)
smodel = Lux.StatefulLuxLayer{states_type_remains_constant(m)}(m, ps, st)
# This is not done via dispatch to avoid ambiguities
xs = x isa Tuple ? x : (x,)
res = LuxCore.apply(typeof(m), smodel, xs...)
if constatnt_state_type
return res, smodel.st
else
return res, smodel.st_any
end
endUser Code
Old
Will still continue to work
function (m::MyCustomLayer)(x, ps, st)
y, st2 = m.layer_1(x, ps.layer_1, st.layer_1)
return y, (; layer_1 = st2)
endNew
# Note that `model` is not MyCustomLayer. Instead, it is `StatefulLuxLayer`
LuxCore.apply(::Type{<:MyCustomLayer}, model, x) = model.layer_1(x)
# For multi input layers, we can now have
LuxCore.apply(::Type{<:MyCustomLayer}, model, x1, x2) = model.layer_1(x1)Updates to StatefulLuxLayer
- All fields apart from
:ps,:st,:st_anyare forwarded tomodel. - Introduce an API to access
.model, callinggetfield(..., :model). - If a layer from
modelis being accessed, the layer is returned as aStatefulLuxLayer. - Questions:
- How to support
st.my_state? We can enforce an access ordering -- first look intomodel, thenpsand thenst. This is error prone. Instead we can introduceget_stateandget_parameterfor accessing these fields.
- How to support
Metadata
Metadata
Assignees
Labels
No labels