Skip to content

Commit 29d0f03

Browse files
Get all passing
1 parent 8191206 commit 29d0f03

File tree

5 files changed

+318
-263
lines changed

5 files changed

+318
-263
lines changed

lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function EnzymeCore.EnzymeRules.inactive_noinl(
66
true
77
end
88
function EnzymeCore.EnzymeRules.inactive_noinl(
9-
::typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), args...)
9+
::typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), args...)
1010
true
1111
end
1212
function EnzymeCore.EnzymeRules.inactive_noinl(

lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
88
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
99
typeof(OrdinaryDiffEqCore.SciMLBase.check_error), Any}
1010
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
11-
typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), Any, Any}
11+
typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), Any, Any}
1212
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
1313
typeof(OrdinaryDiffEqCore.final_progress), Any}
1414

lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ function last_step_failed(integrator::ODEIntegrator)
7373
end
7474

7575
function modify_dt_for_tstops!(integrator)
76+
integrator.t, integrator.dt
77+
tdir_tstop = first_tstop(integrator)
78+
distance_to_tstop = abs(tdir_tstop - integrator.tdir * integrator.t)
79+
7680
if has_tstop(integrator)
7781
tdir_t = integrator.tdir * integrator.t
7882
tdir_tstop = first_tstop(integrator)
@@ -82,6 +86,7 @@ function modify_dt_for_tstops!(integrator)
8286
original_dt = abs(integrator.dt)
8387

8488
if integrator.opts.adaptive
89+
integrator.dtpropose = original_dt
8590
if original_dt < distance_to_tstop
8691
# Normal step, no tstop interference
8792
integrator.next_step_tstop = false
@@ -124,7 +129,7 @@ function handle_tstop_step!(integrator)
124129
perform_step!(integrator, integrator.cache)
125130
end
126131

127-
# Flag will be reset in fixed_t_for_floatingpoint_error! when t is updated
132+
# Flag will be reset in fixed_t_for_tstop_error! when t is updated
128133
end
129134

130135
# Want to extend savevalues! for DDEIntegrator
@@ -186,7 +191,6 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool}
186191
end
187192
if force_save || (integrator.opts.save_everystep &&
188193
(isempty(integrator.sol.t) ||
189-
(integrator.t !== integrator.sol.t[end]) &&
190194
(integrator.opts.save_end || integrator.t !== integrator.sol.prob.tspan[2])
191195
))
192196
integrator.saveiter += 1
@@ -311,12 +315,20 @@ function _loopfooter!(integrator)
311315
if integrator.accept_step # Accept
312316
increment_accept!(integrator.stats)
313317
integrator.last_stepfail = false
318+
integrator.tprev = integrator.t
319+
320+
if integrator.next_step_tstop
321+
# Step controller dt is overly pessimistic, since dt = time to tstop
322+
# For example, if super dense time, dt = eps(t), so the next step is tiny
323+
# Thus if snap to tstop, let the step controller assume dt was the pre-fixed version
324+
integrator.dt = integrator.dtpropose
325+
end
326+
integrator.t = fixed_t_for_tstop_error!(integrator, ttmp)
327+
314328
dtnew = DiffEqBase.value(step_accept_controller!(integrator,
315329
integrator.alg,
316330
q)) *
317331
oneunit(integrator.dt)
318-
integrator.tprev = integrator.t
319-
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
320332
calc_dt_propose!(integrator, dtnew)
321333
handle_callbacks!(integrator)
322334
else # Reject
@@ -325,7 +337,7 @@ function _loopfooter!(integrator)
325337
elseif !integrator.opts.adaptive #Not adaptive
326338
increment_accept!(integrator.stats)
327339
integrator.tprev = integrator.t
328-
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
340+
integrator.t = fixed_t_for_tstop_error!(integrator, ttmp)
329341
integrator.last_stepfail = false
330342
integrator.accept_step = true
331343
integrator.dtpropose = integrator.dt
@@ -364,7 +376,7 @@ function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, ts
364376
progress=(t-t1)/(t2-t1))
365377
end
366378

367-
function fixed_t_for_floatingpoint_error!(integrator, ttmp)
379+
function fixed_t_for_tstop_error!(integrator, ttmp)
368380
# If we're in tstop snap mode, use exact tstop target
369381
if integrator.next_step_tstop
370382
# Reset the flag now that we're snapping to tstop
@@ -501,10 +513,7 @@ function handle_tstop!(integrator)
501513
tdir_t = integrator.tdir * integrator.t
502514
tdir_tstop = first_tstop(integrator)
503515
if tdir_t == tdir_tstop
504-
while tdir_t == tdir_tstop #remove all redundant copies
505-
res = pop_tstop!(integrator)
506-
has_tstop(integrator) ? (tdir_tstop = first_tstop(integrator)) : break
507-
end
516+
res = pop_tstop!(integrator)
508517
integrator.just_hit_tstop = true
509518
elseif tdir_t > tdir_tstop
510519
if !integrator.dtchangeable

lib/OrdinaryDiffEqCore/src/solve.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ function SciMLBase.__init(
506506
next_step_tstop = false
507507
tstop_target = zero(t)
508508
isout = false
509-
accept_step = false
509+
accept_step = true
510510
force_stepfail = false
511511
last_stepfail = false
512512
do_error_check = true
@@ -605,7 +605,8 @@ end
605605

606606
function SciMLBase.solve!(integrator::ODEIntegrator)
607607
@inbounds while !isempty(integrator.opts.tstops)
608-
while integrator.tdir * integrator.t < first(integrator.opts.tstops)
608+
first_tstop = first(integrator.opts.tstops)
609+
while integrator.tdir * integrator.t <= first_tstop
609610
loopheader!(integrator)
610611
if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success
611612
return integrator.sol
@@ -617,9 +618,11 @@ function SciMLBase.solve!(integrator::ODEIntegrator)
617618
else
618619
perform_step!(integrator, integrator.cache)
619620
end
620-
621+
622+
should_exit = integrator.next_step_tstop
623+
621624
loopfooter!(integrator)
622-
if isempty(integrator.opts.tstops)
625+
if isempty(integrator.opts.tstops) || should_exit
623626
break
624627
end
625628
end
@@ -671,11 +674,11 @@ end
671674

672675
for t in tstops
673676
tdir_t = tdir * t
674-
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
677+
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
675678
end
676679
for t in d_discontinuities
677680
tdir_t = tdir * t
678-
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
681+
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
679682
end
680683
push!(tstops_internal, tdir_tf)
681684

0 commit comments

Comments
 (0)