44 @layer :expand Chain
55 @layer BatchNorm trainable=(β,γ)
66
7+ NEEDS NEW DOCS!
8+
79This macro replaces most uses of `@functor`. Its basic purpose is the same:
810When you define a new layer, this tells Flux to explore inside it
911to see the parameters it trains, and also to move them to the GPU, change precision, etc.
@@ -48,11 +50,15 @@ Trio(
4850
4951"""
5052macro layer (exs... )
53+ _layer_macro (exs... )
54+ end
55+
56+ function _layer_macro (exs... )
5157 out = quote end
5258
5359 # These functions are defined in show.jl, and each return an expression overloading Base.show
5460 type, rest... = if exs[1 ] == QuoteNode (:expand )
55- push! (out. args, _macro_big_show (esc (exs[2 ])))
61+ push! (out. args, _macro_big_show (esc (exs[2 ])))
5662 exs[2 : end ]
5763 elseif exs[1 ] == QuoteNode (:ignore )
5864 exs[2 : end ]
@@ -63,9 +69,6 @@ macro layer(exs...)
6369 exs
6470 end
6571
66- # This function exists only for depwarns when you use @functor directly
67- push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing ))
68-
6972 push! (out. args, _macro_functor (esc (type)))
7073
7174 for j in 1 : length (rest)
@@ -85,23 +88,12 @@ macro layer(exs...)
8588 out
8689end
8790
88- # Temporary depwarn function, called within `params`, is also called by `show`.
89-
90- function _check_new_macro (x:: T ) where T
91- Functors. isleaf (x) && return
92- Base. depwarn (LazyString (" This type should probably now use `Flux.@layer` instead of `@functor`: " , T), Symbol (" @functor" ))
93- end
94- _check_new_macro (:: Tuple ) = nothing # defined by Functors.jl, not by users
95- _check_new_macro (:: NamedTuple ) = nothing
96- _check_new_macro (:: AbstractArray ) = nothing
97- _check_new_macro (:: Ref ) = nothing
98-
9991# @layer's code for Functors & Adapt
10092# Unlike @functor, _default_functor doesn't need to eval anything
10193
10294function _macro_functor (type)
10395 quote
104- Functors. functor (:: Type{T} , x) where {T<: $type } = $ _default_functor (T, x)
96+ # Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x)
10597 Adapt. adapt_structure (to, layer:: $type ) = $ fmap ($ adapt (to), layer)
10698 end
10799end
@@ -110,28 +102,28 @@ function _macro_functor(type, fields)
110102 Meta. isexpr (fields, :tuple ) || error (" expected a tuple of field names" )
111103 symbols = Tuple (map (_noquotenode, fields. args))
112104 quote
113- Functors. functor (:: Type{T} , x) where {T<: $type } = $ _custom_functor (T, x, Val ($ symbols))
105+ # Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols))
114106 Adapt. adapt_structure (to, layer:: $type ) = $ fmap ($ adapt (to), layer)
115107 end
116108end
117109_macro_functor (type, field:: Union{Symbol,QuoteNode} ) = _macro_functor (type, :(($ field,))) # lets you forget a comma
118110
119- function _default_functor (:: Type{T} , x) where {T}
120- if @generated
121- F = fieldnames (T)
122- args = map (sy -> :(getfield (x, $ (QuoteNode (sy)))), F)
123- C = Base. typename (T). wrapper # constructor
124- # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
125- recon = :(Base. splat ($ C))
126- :((NamedTuple {$F} (($ (args... ),)), $ recon))
127- else
128- # Getting this parameterless type takes about 2μs, every time:
129- # spl = VERSION > v"1.9-" ? Splat : Base.splat
130- spl = Base. splat
131- namedtuple (x), spl (Base. typename (T). wrapper)
132- end
133- end
134-
111+ # function _default_functor(::Type{T}, x) where {T}
112+ # if @generated
113+ # F = fieldnames(T)
114+ # args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
115+ # C = Base.typename(T).wrapper # constructor
116+ # # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
117+ # recon = :(Base.splat($C))
118+ # :((NamedTuple{$F}(($(args...),)), $recon))
119+ # else
120+ # # Getting this parameterless type takes about 2μs, every time:
121+ # # spl = VERSION > v"1.9-" ? Splat : Base.splat
122+ # spl = Base.splat
123+ # namedtuple(x), spl(Base.typename(T).wrapper)
124+ # end
125+ # end
126+
135127function namedtuple (x:: T ) where T
136128 F = fieldnames (T)
137129 NamedTuple {F} (map (sy -> getfield (x, sy), F))
0 commit comments