Skip to content

Rethinking custom layer dispatches #1324

@avik-pal

Description

@avik-pal

A common issue with users is manually managing the states when writing the forward pass. We have a StatefulLuxLayer, but that isn't used that often. I think we can add the following and push that as the new API for users:

In LuxCore

states_type_remains_constant(::AbstractLuxLayer) = Val(true)

# Currently, this is not defined. With this dispatch, if users
# have this dispatch (current official API), it just uses that
# dispatch. So no breaking change whatsoever
function (m::AbstractLuxLayer)(x, ps, st)
    smodel = Lux.StatefulLuxLayer{states_type_remains_constant(m)}(m, ps, st)
    # This is not done via dispatch to avoid ambiguities
    xs = x isa Tuple ? x : (x,)
    res = LuxCore.apply(typeof(m), smodel, xs...)
    if constatnt_state_type
        return res, smodel.st
    else
        return res, smodel.st_any
    end
end

User Code

Old

Will still continue to work

function (m::MyCustomLayer)(x, ps, st)
    y, st2 = m.layer_1(x, ps.layer_1, st.layer_1)
    return y, (; layer_1 = st2)
end

New

# Note that `model` is not MyCustomLayer. Instead, it is `StatefulLuxLayer`
LuxCore.apply(::Type{<:MyCustomLayer}, model, x) = model.layer_1(x)

# For multi input layers, we can now have
LuxCore.apply(::Type{<:MyCustomLayer}, model, x1, x2) = model.layer_1(x1)

Updates to StatefulLuxLayer

  • All fields apart from :ps, :st, :st_any are forwarded to model.
  • Introduce an API to access .model, calling getfield(..., :model).
  • If a layer from model is being accessed, the layer is returned as a StatefulLuxLayer.
  • Questions:
    • How to support st.my_state? We can enforce an access ordering -- first look into model, then ps and then st. This is error prone. Instead we can introduce get_state and get_parameter for accessing these fields.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions