@@ -20,16 +20,16 @@ function _macro_big_show(ex)
2020end
2121
2222function _big_show (io:: IO , obj, indent:: Int = 0 , name= nothing )
23- pre, post = obj isa Chain{ <: AbstractVector } ? ( " ([ " , " ]) " ) : ( " ( " , " ) " )
23+ pre, post = _show_pre_post (obj )
2424 children = _show_children (obj)
2525 if all (_show_leaflike, children)
2626 # This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids,
2727 # but once all layers use @layer, they stop the recursion by defining a method for _big_show.
2828 _layer_show (io, obj, indent, name)
2929 else
30- println (io, " " ^ indent, isnothing (name) ? " " : " $name = " , nameof ( typeof (obj)), pre)
31- if obj isa Chain{<: NamedTuple } && children == getfield ( obj, :layers )
32- # then we insert names -- can this be done more generically?
30+ println (io, " " ^ indent, isnothing (name) ? " " : " $name = " , pre)
31+ if obj isa Chain{<: NamedTuple } || obj isa NamedTuple
32+ # then we insert names -- can this be done more generically?
3333 for k in Base. keys (obj)
3434 _big_show (io, obj[k], indent+ 2 , k)
3535 end
@@ -52,6 +52,20 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
5252 end
5353end
5454
55+ for Fix in (:Fix1 , :Fix2 )
56+ pre = string (Fix, " (" )
57+ @eval function _big_show (io:: IO , obj:: Base. $ Fix, indent:: Int = 0 , name= nothing )
58+ println (io, " " ^ indent, isnothing (name) ? " " : " $name = " , $ pre)
59+ _big_show (io, obj. f, indent+ 2 )
60+ _big_show (io, obj. x, indent+ 2 )
61+ println (io, " " ^ indent, " )" , " ," )
62+ end
63+ end
64+
65+ _show_pre_post (obj) = string (nameof (typeof (obj)), " (" ), " )"
66+ _show_pre_post (:: AbstractVector ) = " [" , " ]"
67+ _show_pre_post (:: NamedTuple ) = " (;" , " )"
68+
5569_show_leaflike (x) = isleaf (x) # mostly follow Functors, except for:
5670
5771# note the covariance of tuple, using <:T causes warning or error
88102
89103function _layer_show (io:: IO , layer, indent:: Int = 0 , name= nothing )
90104 _str = isnothing (name) ? " " : " $name = "
91- str = _str * sprint (show , layer, context = io )
105+ str = _str * _layer_string (io , layer)
92106 print (io, " " ^ indent, str, indent== 0 ? " " : " ," )
93107 if ! isempty (params (layer))
94108 print (io, " " ^ max (2 , (indent== 0 ? 20 : 39 ) - indent - length (str)))
@@ -103,6 +117,15 @@ color=:light_black)
103117 indent== 0 || println (io)
104118end
105119
120+ _layer_string (io:: IO , layer) = sprint (show, layer, context= io)
121+ # _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray
122+ function _layer_string (:: IO , a:: AbstractArray )
123+ full = string (typeof (a))
124+ comma = findfirst (' ,' , full)
125+ short = isnothing (comma) ? full : full[1 : comma] * " ...}"
126+ Base. dims2string (size (a)) * " " * short
127+ end
128+
106129function _big_finale (io:: IO , m)
107130 ps = params (m)
108131 if length (ps) > 2
@@ -150,3 +173,43 @@ _any(f, x::Number) = f(x)
150173# _any(f, x) = false
151174
152175_all (f, xs) = ! _any (! f, xs)
176+
177+ #=
178+
179+ julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
180+
181+ # Before, notice Array(), NamedTuple(), and values
182+
183+ julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
184+ Chain(
185+ Tmp2(
186+ Array(
187+ Dense(2 => 3), # 9 parameters
188+ [0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
189+ ),
190+ NamedTuple(
191+ 1:3, # 3 parameters
192+ Dense(3 => 4), # 16 parameters
193+ [0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
194+ ),
195+ ),
196+ ) # Total: 7 arrays, 43 parameters, 644 bytes.
197+
198+ # After, (; x=, y=, z=) and "3-element Array"
199+
200+ julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
201+ Chain(
202+ Tmp2(
203+ [
204+ Dense(2 => 3), # 9 parameters
205+ 4×3 Adjoint, # 12 parameters
206+ ],
207+ (;
208+ x = 3-element UnitRange, # 3 parameters
209+ y = Dense(3 => 4), # 16 parameters
210+ z = 3-element Array, # 3 parameters
211+ ),
212+ ),
213+ ) # Total: 7 arrays, 43 parameters, 644 bytes.
214+
215+ =#
0 commit comments