Skip to content

Commit 161dcda

Browse files
committed
start adjusting at-layer for auto-functor
1 parent 32db5d4 commit 161dcda

File tree

3 files changed

+49
-34
lines changed

3 files changed

+49
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ CUDA = "4, 5"
4747
ChainRulesCore = "1.12"
4848
Compat = "4.10.0"
4949
Enzyme = "0.12, 0.13"
50-
Functors = "0.4"
50+
Functors = "0.5"
5151
MLDataDevices = "1.4.2"
5252
MLUtils = "0.4"
5353
MPI = "0.20.19"

src/deprecations.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,26 @@ end
158158
# where `loss_mxy` accepts the model as its first argument.
159159
# """
160160
# ))
161+
162+
"""
163+
@functor MyLayer
164+
165+
Flux used to require the use of `Functors.@functor` to mark any new layer-like struct.
166+
This allowed it to explore inside the struct, and update any trainable parameters within.
167+
168+
[email protected] removes this requirement. This is because [email protected] changed ist behaviour
169+
to be opt-out instead of opt-in. Arbitrary structs will now be explored without special marking.
170+
Hence calling `@functor` is no longer required.
171+
172+
Calling `Flux.@layer MyLayer` is, however, still recommended. This adds various convenience methods
173+
for your layer type, such as pretty printing.
174+
"""
175+
macro functor(ex)
176+
Base.depwarn("""The macro `@functor` is deprecated.
177+
Most likely, you should write `Flux.@layer MyLayer` which will add various convenience methods for your type,
178+
such as pretty-printing, and use with Adapt.jl.
179+
However, this is not required. Flux.jl v0.15 uses Functors.jl v0.5, which makes exploration of most nested `struct`s
180+
opt-out instead of opt-in... so Flux will automatically see inside any custom struct definitions.
181+
""", Symbol("@functor")
182+
_layer_macro(ex)
183+
end

src/layers/macro.jl

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
@layer :expand Chain
55
@layer BatchNorm trainable=(β,γ)
66
7+
NEEDS NEW DOCS!
8+
79
This macro replaces most uses of `@functor`. Its basic purpose is the same:
810
When you define a new layer, this tells Flux to explore inside it
911
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
@@ -48,11 +50,15 @@ Trio(
4850
4951
"""
5052
macro layer(exs...)
53+
_layer_macro(exs...)
54+
end
55+
56+
function _layer_macro(exs...)
5157
out = quote end
5258

5359
# These functions are defined in show.jl, and each return an expression overloading Base.show
5460
type, rest... = if exs[1] == QuoteNode(:expand)
55-
push!(out.args, _macro_big_show(esc(exs[2])))
61+
push!(out.args, _macro_big_show(esc(exs[2])))
5662
exs[2:end]
5763
elseif exs[1] == QuoteNode(:ignore)
5864
exs[2:end]
@@ -63,9 +69,6 @@ macro layer(exs...)
6369
exs
6470
end
6571

66-
# This function exists only for depwarns when you use @functor directly
67-
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing))
68-
6972
push!(out.args, _macro_functor(esc(type)))
7073

7174
for j in 1:length(rest)
@@ -85,23 +88,12 @@ macro layer(exs...)
8588
out
8689
end
8790

88-
# Temporary depwarn function, called within `params`, is also called by `show`.
89-
90-
function _check_new_macro(x::T) where T
91-
Functors.isleaf(x) && return
92-
Base.depwarn(LazyString("This type should probably now use `Flux.@layer` instead of `@functor`: ", T), Symbol("@functor"))
93-
end
94-
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
95-
_check_new_macro(::NamedTuple) = nothing
96-
_check_new_macro(::AbstractArray) = nothing
97-
_check_new_macro(::Ref) = nothing
98-
9991
# @layer's code for Functors & Adapt
10092
# Unlike @functor, _default_functor doesn't need to eval anything
10193

10294
function _macro_functor(type)
10395
quote
104-
Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x)
96+
# Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x)
10597
Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer)
10698
end
10799
end
@@ -110,28 +102,28 @@ function _macro_functor(type, fields)
110102
Meta.isexpr(fields, :tuple) || error("expected a tuple of field names")
111103
symbols = Tuple(map(_noquotenode, fields.args))
112104
quote
113-
Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols))
105+
# Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols))
114106
Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer)
115107
end
116108
end
117109
_macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma
118110

119-
function _default_functor(::Type{T}, x) where {T}
120-
if @generated
121-
F = fieldnames(T)
122-
args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
123-
C = Base.typename(T).wrapper # constructor
124-
# recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
125-
recon = :(Base.splat($C))
126-
:((NamedTuple{$F}(($(args...),)), $recon))
127-
else
128-
# Getting this parameterless type takes about 2μs, every time:
129-
# spl = VERSION > v"1.9-" ? Splat : Base.splat
130-
spl = Base.splat
131-
namedtuple(x), spl(Base.typename(T).wrapper)
132-
end
133-
end
134-
111+
# function _default_functor(::Type{T}, x) where {T}
112+
# if @generated
113+
# F = fieldnames(T)
114+
# args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
115+
# C = Base.typename(T).wrapper # constructor
116+
# # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
117+
# recon = :(Base.splat($C))
118+
# :((NamedTuple{$F}(($(args...),)), $recon))
119+
# else
120+
# # Getting this parameterless type takes about 2μs, every time:
121+
# # spl = VERSION > v"1.9-" ? Splat : Base.splat
122+
# spl = Base.splat
123+
# namedtuple(x), spl(Base.typename(T).wrapper)
124+
# end
125+
# end
126+
135127
function namedtuple(x::T) where T
136128
F = fieldnames(T)
137129
NamedTuple{F}(map(sy -> getfield(x, sy), F))

0 commit comments

Comments
 (0)