@@ -20,24 +20,25 @@ export callback
2020end
2121
2222@testitem " Smoke test for FFJORD" setup= [CNFTestSetup] tags= [:advancedneuralde ] begin
23- nn = Chain (Dense (1 , 1 , tanh))
24- tspan = (0.0f0 , 1.0f0 )
25- ffjord_mdl = FFJORD (nn, tspan, (1 ,), Tsit5 ())
26- ps, st = Lux. setup (Xoshiro (0 ), ffjord_mdl)
27- ps = ComponentArray (ps)
23+ nn= Chain (Dense (1 , 1 , tanh))
24+ tspan= (0.0f0 , 1.0f0 )
25+ ffjord_mdl= FFJORD (nn, tspan, (1 ,), Tsit5 ())
26+ ps, st= Lux. setup (Xoshiro (0 ), ffjord_mdl)
27+ ps= ComponentArray (ps)
2828
29- data_dist = Beta (2.0f0 , 2.0f0 )
30- train_data = Float32 .(rand (data_dist, 1 , 100 ))
29+ data_dist= Beta (2.0f0 , 2.0f0 )
30+ train_data= Float32 .(rand (data_dist, 1 , 100 ))
3131
3232 function loss (model, θ)
33- logpx, λ₁, λ₂ = model (train_data, θ)
33+ logpx, λ₁, λ₂= model (train_data, θ)
3434 return - mean (logpx)
3535 end
3636
3737 @testset " ADType: $(adtype) " for adtype in (Optimization. AutoForwardDiff (),
3838 Optimization. AutoReverseDiff (), Optimization. AutoTracker (),
3939 Optimization. AutoZygote (), Optimization. AutoFiniteDiff ())
40- @testset " regularize = $(regularize) & monte_carlo = $(monte_carlo) " for regularize in (
40+ @testset " regularize = $(regularize) & monte_carlo = $(monte_carlo) " for regularize in
41+ (
4142 true , false ),
4243 monte_carlo in (true , false )
4344
@@ -54,177 +55,177 @@ end
5455end
5556
5657@testitem " Smoke test for FFJORDDistribution (sampling & pdf)" setup= [CNFTestSetup] tags= [:advancedneuralde ] begin
57- nn = Chain (Dense (1 , 1 , tanh))
58- tspan = (0.0f0 , 1.0f0 )
59- ffjord_mdl = FFJORD (nn, tspan, (1 ,), Tsit5 ())
60- ps, st = Lux. setup (Xoshiro (0 ), ffjord_mdl)
61- ps = ComponentArray (ps)
58+ nn= Chain (Dense (1 , 1 , tanh))
59+ tspan= (0.0f0 , 1.0f0 )
60+ ffjord_mdl= FFJORD (nn, tspan, (1 ,), Tsit5 ())
61+ ps, st= Lux. setup (Xoshiro (0 ), ffjord_mdl)
62+ ps= ComponentArray (ps)
6263
63- regularize = false
64- monte_carlo = false
64+ regularize= false
65+ monte_carlo= false
6566
66- data_dist = Beta (2.0f0 , 2.0f0 )
67- train_data = Float32 .(rand (data_dist, 1 , 100 ))
67+ data_dist= Beta (2.0f0 , 2.0f0 )
68+ train_data= Float32 .(rand (data_dist, 1 , 100 ))
6869
6970 function loss (model, θ)
70- logpx, λ₁, λ₂ = model (train_data, θ)
71+ logpx, λ₁, λ₂= model (train_data, θ)
7172 return - mean (logpx)
7273 end
7374
74- adtype = Optimization. AutoZygote ()
75+ adtype= Optimization. AutoZygote ()
7576
76- st_ = (; st... , regularize, monte_carlo)
77- model = StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
77+ st_= (; st... , regularize, monte_carlo)
78+ model= StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
7879
79- optf = Optimization. OptimizationFunction ((θ, _) -> loss (model, θ), adtype)
80- optprob = Optimization. OptimizationProblem (optf, ps)
81- res = Optimization. solve (optprob, Adam (0.1 ); callback = callback (adtype), maxiters = 10 )
80+ optf= Optimization. OptimizationFunction ((θ, _)-> loss (model, θ), adtype)
81+ optprob= Optimization. OptimizationProblem (optf, ps)
82+ res= Optimization. solve (optprob, Adam (0.1 ); callback = callback (adtype), maxiters = 10 )
8283
83- ffjord_d = FFJORDDistribution (ffjord_mdl, res. u, st_)
84+ ffjord_d= FFJORDDistribution (ffjord_mdl, res. u, st_)
8485
8586 @test ! isnothing (pdf (ffjord_d, train_data))
8687 @test ! isnothing (rand (ffjord_d))
8788 @test ! isnothing (rand (ffjord_d, 10 ))
8889end
8990
9091@testitem " Test for default base distribution and deterministic trace FFJORD" setup= [CNFTestSetup] tags= [:advancedneuralde ] begin
91- nn = Chain (Dense (1 , 1 , tanh))
92- tspan = (0.0f0 , 1.0f0 )
93- ffjord_mdl = FFJORD (nn, tspan, (1 ,), Tsit5 ())
94- ps, st = Lux. setup (Xoshiro (0 ), ffjord_mdl)
95- ps = ComponentArray (ps)
92+ nn= Chain (Dense (1 , 1 , tanh))
93+ tspan= (0.0f0 , 1.0f0 )
94+ ffjord_mdl= FFJORD (nn, tspan, (1 ,), Tsit5 ())
95+ ps, st= Lux. setup (Xoshiro (0 ), ffjord_mdl)
96+ ps= ComponentArray (ps)
9697
97- regularize = false
98- monte_carlo = false
98+ regularize= false
99+ monte_carlo= false
99100
100- data_dist = Beta (7.0f0 , 7.0f0 )
101- train_data = Float32 .(rand (data_dist, 1 , 100 ))
102- test_data = Float32 .(rand (data_dist, 1 , 100 ))
101+ data_dist= Beta (7.0f0 , 7.0f0 )
102+ train_data= Float32 .(rand (data_dist, 1 , 100 ))
103+ test_data= Float32 .(rand (data_dist, 1 , 100 ))
103104
104105 function loss (model, θ)
105- logpx, λ₁, λ₂ = model (train_data, θ)
106+ logpx, λ₁, λ₂= model (train_data, θ)
106107 return - mean (logpx)
107108 end
108109
109- adtype = Optimization. AutoZygote ()
110- st_ = (; st... , regularize, monte_carlo)
111- model = StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
110+ adtype= Optimization. AutoZygote ()
111+ st_= (; st... , regularize, monte_carlo)
112+ model= StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
112113
113- optf = Optimization. OptimizationFunction ((θ, _) -> loss (model, θ), adtype)
114- optprob = Optimization. OptimizationProblem (optf, ps)
115- res = Optimization. solve (optprob, Adam (0.1 ); callback = callback (adtype), maxiters = 10 )
114+ optf= Optimization. OptimizationFunction ((θ, _)-> loss (model, θ), adtype)
115+ optprob= Optimization. OptimizationProblem (optf, ps)
116+ res= Optimization. solve (optprob, Adam (0.1 ); callback = callback (adtype), maxiters = 10 )
116117
117- actual_pdf = pdf .(data_dist, test_data)
118- learned_pdf = exp .(model (test_data, res. u)[1 ])
118+ actual_pdf= pdf .(data_dist, test_data)
119+ learned_pdf= exp .(model (test_data, res. u)[1 ])
119120
120121 @test ps != res. u
121122 @test totalvariation (learned_pdf, actual_pdf) / size (test_data, 2 ) < 0.9
122123end
123124
124125@testitem " Test for alternative base distribution and deterministic trace FFJORD" setup= [CNFTestSetup] tags= [:advancedneuralde ] begin
125- nn = Chain (Dense (1 , 3 , tanh), Dense (3 , 1 , tanh))
126- tspan = (0.0f0 , 1.0f0 )
127- ffjord_mdl = FFJORD (
126+ nn= Chain (Dense (1 , 3 , tanh), Dense (3 , 1 , tanh))
127+ tspan= (0.0f0 , 1.0f0 )
128+ ffjord_mdl= FFJORD (
128129 nn, tspan, (1 ,), Tsit5 (); basedist = MvNormal ([0.0f0 ], Diagonal ([4.0f0 ])))
129- ps, st = Lux. setup (Xoshiro (0 ), ffjord_mdl)
130- ps = ComponentArray (ps)
130+ ps, st= Lux. setup (Xoshiro (0 ), ffjord_mdl)
131+ ps= ComponentArray (ps)
131132
132- regularize = false
133- monte_carlo = false
133+ regularize= false
134+ monte_carlo= false
134135
135- data_dist = Normal (6.0f0 , 0.7f0 )
136- train_data = Float32 .(rand (data_dist, 1 , 100 ))
137- test_data = Float32 .(rand (data_dist, 1 , 100 ))
136+ data_dist= Normal (6.0f0 , 0.7f0 )
137+ train_data= Float32 .(rand (data_dist, 1 , 100 ))
138+ test_data= Float32 .(rand (data_dist, 1 , 100 ))
138139
139140 function loss (model, θ)
140- logpx, λ₁, λ₂ = model (train_data, θ)
141+ logpx, λ₁, λ₂= model (train_data, θ)
141142 return - mean (logpx)
142143 end
143144
144- adtype = Optimization. AutoZygote ()
145- st_ = (; st... , regularize, monte_carlo)
146- model = StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
145+ adtype= Optimization. AutoZygote ()
146+ st_= (; st... , regularize, monte_carlo)
147+ model= StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
147148
148- optf = Optimization. OptimizationFunction ((θ, _) -> loss (model, θ), adtype)
149- optprob = Optimization. OptimizationProblem (optf, ps)
150- res = Optimization. solve (optprob, Adam (0.1 ); callback = callback (adtype), maxiters = 30 )
149+ optf= Optimization. OptimizationFunction ((θ, _)-> loss (model, θ), adtype)
150+ optprob= Optimization. OptimizationProblem (optf, ps)
151+ res= Optimization. solve (optprob, Adam (0.1 ); callback = callback (adtype), maxiters = 30 )
151152
152- actual_pdf = pdf .(data_dist, test_data)
153- learned_pdf = exp .(model (test_data, res. u)[1 ])
153+ actual_pdf= pdf .(data_dist, test_data)
154+ learned_pdf= exp .(model (test_data, res. u)[1 ])
154155
155156 @test ps != res. u
156157 @test totalvariation (learned_pdf, actual_pdf) / size (test_data, 2 ) < 0.25
157158end
158159
159160@testitem " Test for multivariate distribution and deterministic trace FFJORD" setup= [CNFTestSetup] tags= [:advancedneuralde ] begin
160- nn = Chain (Dense (2 , 2 , tanh))
161- tspan = (0.0f0 , 1.0f0 )
162- ffjord_mdl = FFJORD (nn, tspan, (2 ,), Tsit5 ())
163- ps, st = Lux. setup (Xoshiro (0 ), ffjord_mdl)
164- ps = ComponentArray (ps)
161+ nn= Chain (Dense (2 , 2 , tanh))
162+ tspan= (0.0f0 , 1.0f0 )
163+ ffjord_mdl= FFJORD (nn, tspan, (2 ,), Tsit5 ())
164+ ps, st= Lux. setup (Xoshiro (0 ), ffjord_mdl)
165+ ps= ComponentArray (ps)
165166
166- regularize = false
167- monte_carlo = false
167+ regularize= false
168+ monte_carlo= false
168169
169- μ = ones (Float32, 2 )
170- Σ = Diagonal ([7.0f0 , 7.0f0 ])
171- data_dist = MvNormal (μ, Σ)
172- train_data = Float32 .(rand (data_dist, 100 ))
173- test_data = Float32 .(rand (data_dist, 100 ))
170+ μ= ones (Float32, 2 )
171+ Σ= Diagonal ([7.0f0 , 7.0f0 ])
172+ data_dist= MvNormal (μ, Σ)
173+ train_data= Float32 .(rand (data_dist, 100 ))
174+ test_data= Float32 .(rand (data_dist, 100 ))
174175
175176 function loss (model, θ)
176- logpx, λ₁, λ₂ = model (train_data, θ)
177+ logpx, λ₁, λ₂= model (train_data, θ)
177178 return - mean (logpx)
178179 end
179180
180- adtype = Optimization. AutoZygote ()
181- st_ = (; st... , regularize, monte_carlo)
182- model = StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
181+ adtype= Optimization. AutoZygote ()
182+ st_= (; st... , regularize, monte_carlo)
183+ model= StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
183184
184- optf = Optimization. OptimizationFunction ((θ, _) -> loss (model, θ), adtype)
185- optprob = Optimization. OptimizationProblem (optf, ps)
186- res = Optimization. solve (
185+ optf= Optimization. OptimizationFunction ((θ, _)-> loss (model, θ), adtype)
186+ optprob= Optimization. OptimizationProblem (optf, ps)
187+ res= Optimization. solve (
187188 optprob, Adam (0.01 ); callback = callback (adtype), maxiters = 30 )
188189
189- actual_pdf = pdf (data_dist, test_data)
190- learned_pdf = exp .(model (test_data, res. u)[1 ])
190+ actual_pdf= pdf (data_dist, test_data)
191+ learned_pdf= exp .(model (test_data, res. u)[1 ])
191192
192193 @test ps != res. u
193194 @test totalvariation (learned_pdf, actual_pdf) / size (test_data, 2 ) < 0.25
194195end
195196
196197@testitem " Test for multivariate distribution and FFJORD with regularizers" setup= [CNFTestSetup] tags= [:advancedneuralde ] begin
197- nn = Chain (Dense (2 , 2 , tanh))
198- tspan = (0.0f0 , 1.0f0 )
199- ffjord_mdl = FFJORD (nn, tspan, (2 ,), Tsit5 ())
200- ps, st = Lux. setup (Xoshiro (0 ), ffjord_mdl)
201- ps = ComponentArray (ps) .* 0.001f0
198+ nn= Chain (Dense (2 , 2 , tanh))
199+ tspan= (0.0f0 , 1.0f0 )
200+ ffjord_mdl= FFJORD (nn, tspan, (2 ,), Tsit5 ())
201+ ps, st= Lux. setup (Xoshiro (0 ), ffjord_mdl)
202+ ps= ComponentArray (ps) .* 0.001f0
202203
203- regularize = true
204- monte_carlo = true
204+ regularize= true
205+ monte_carlo= true
205206
206- μ = ones (Float32, 2 )
207- Σ = Diagonal ([7.0f0 , 7.0f0 ])
208- data_dist = MvNormal (μ, Σ)
209- train_data = Float32 .(rand (data_dist, 100 ))
210- test_data = Float32 .(rand (data_dist, 100 ))
207+ μ= ones (Float32, 2 )
208+ Σ= Diagonal ([7.0f0 , 7.0f0 ])
209+ data_dist= MvNormal (μ, Σ)
210+ train_data= Float32 .(rand (data_dist, 100 ))
211+ test_data= Float32 .(rand (data_dist, 100 ))
211212
212213 function loss (model, θ)
213- logpx, λ₁, λ₂ = model (train_data, θ)
214+ logpx, λ₁, λ₂= model (train_data, θ)
214215 return - mean (logpx)
215216 end
216217
217- adtype = Optimization. AutoZygote ()
218- st_ = (; st... , regularize, monte_carlo)
219- model = StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
218+ adtype= Optimization. AutoZygote ()
219+ st_= (; st... , regularize, monte_carlo)
220+ model= StatefulLuxLayer {true} (ffjord_mdl, nothing , st_)
220221
221- optf = Optimization. OptimizationFunction ((θ, _) -> loss (model, θ), adtype)
222- optprob = Optimization. OptimizationProblem (optf, ps)
223- res = Optimization. solve (
222+ optf= Optimization. OptimizationFunction ((θ, _)-> loss (model, θ), adtype)
223+ optprob= Optimization. OptimizationProblem (optf, ps)
224+ res= Optimization. solve (
224225 optprob, Adam (0.01 ); callback = callback (adtype), maxiters = 30 )
225226
226- actual_pdf = pdf (data_dist, test_data)
227- learned_pdf = exp .(model (test_data, res. u)[1 ])
227+ actual_pdf= pdf (data_dist, test_data)
228+ learned_pdf= exp .(model (test_data, res. u)[1 ])
228229
229230 @test ps != res. u
230231 @test totalvariation (learned_pdf, actual_pdf) / size (test_data, 2 ) < 0.25
0 commit comments