@@ -86,7 +86,7 @@ group_size = 3
8686continuity_term = 200
8787
8888l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
89- Tsit5(), group_size; continuity_term)
89+ Tsit5(), group_size; continuity_term)
9090
9191function loss_function(data, pred)
9292 return sum(abs2, data - pred)
@@ -124,14 +124,16 @@ pd, pax = getdata(ps), getaxes(ps)
124124
125125function loss_single_shooting(p)
126126 ps = ComponentArray(p, pax)
127- return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
127+ loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
128128 Tsit5(), group_size; continuity_term)
129+ global preds = currpred
130+ return loss
129131end
130132
131133adtype = Optimization.AutoZygote()
132134optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype)
133135optprob = Optimization.OptimizationProblem(optf, pd)
134- res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)
136+ res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000 )
135137gif(anim, "single_shooting.gif"; fps = 15)
136138```
137139
0 commit comments