Skip to content

Commit 14068fe

Browse files
committed
Make grad_J_a act not-in-place
This is much easier for a user to deal with. It really shouldn't make any difference for performance, and in fact, may enhance it (the old implementation was using `copyto!`). In situations where it would make a difference, like for insanely large pulses, there's always the option of using a non-allocating functor.
1 parent c85edec commit 14068fe

File tree

4 files changed

+32
-33
lines changed

4 files changed

+32
-33
lines changed

ext/QuantumControlFiniteDifferencesExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ end
7272

7373

7474
function make_automatic_grad_J_a(J_a, tlist, ::Val{:FiniteDifferences})
75-
function automatic_grad_J_a!(∇J_a, pulsevals, tlist)
75+
function automatic_grad_J_a(pulsevals, tlist)
7676
func = pulsevals -> J_a(pulsevals, tlist)
7777
fdm = FiniteDifferences.central_fdm(5, 1)
7878
∇J_a_fdm = FiniteDifferences.grad(fdm, func, pulsevals)[1]
79-
copyto!(∇J_a, ∇J_a_fdm)
79+
return ∇J_a_fdm
8080
end
81-
return automatic_grad_J_a!
81+
return automatic_grad_J_a
8282
end
8383

8484
function make_gate_chi(J_T_U, trajectories, ::Val{:FiniteDifferences}; kwargs...)

ext/QuantumControlZygoteExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ end
7373

7474

7575
function make_automatic_grad_J_a(J_a, tlist, ::Val{:Zygote})
76-
function automatic_grad_J_a!(∇J_a, pulsevals, tlist)
76+
function automatic_grad_J_a(pulsevals, tlist)
7777
func = pulsevals -> J_a(pulsevals, tlist)
7878
∇J_a_zygote = Zygote.gradient(func, pulsevals)[1]
79-
copyto!(∇J_a, ∇J_a_zygote)
79+
return ∇J_a_zygote
8080
end
81-
return automatic_grad_J_a!
81+
return automatic_grad_J_a
8282
end
8383

8484

src/functionals.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -317,18 +317,18 @@ end
317317
Return a function to evaluate ``∂J_a/∂ϵ_{ln}`` for a pulse value running cost.
318318
319319
```julia
320-
grad_J_a! = make_grad_J_a(
320+
grad_J_a = make_grad_J_a(
321321
J_a,
322322
tlist;
323323
mode=:any,
324324
automatic=:default,
325325
)
326326
```
327327
328-
returns a function so that `grad_J_a!(∇J_a, pulsevals, tlist)` sets
329-
``∂J_a/∂ϵ_{ln}`` as the elements of the (vectorized) `∇J_a`. The function `J_a`
330-
must have the interface `J_a(pulsevals, tlist)`, see, e.g.,
331-
`J_a_fluence`.
328+
returns a function so that `∇J_a = grad_J_a(pulsevals, tlist)` sets
329+
that retrurns a vector `∇J_a` containing the vectorized elements
330+
``∂J_a/∂ϵ_{ln}``. The function `J_a` must have the interface `J_a(pulsevals,
331+
tlist)`, see, e.g., [`J_a_fluence`](@ref).
332332
333333
The parameters `mode` and `automatic` are handled as in [`make_chi`](@ref),
334334
where `mode` is one of `:any`, `:analytic`, `:automatic`, and `automatic` is
@@ -341,10 +341,11 @@ refers to the framework set with `QuantumControl.set_default_ad_framework`.
341341
new `J_a` function, define a new method `make_analytic_grad_J_a` like so:
342342
343343
```julia
344-
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence!
344+
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence
345345
```
346346
347-
which links `make_grad_J_a` for `J_a_fluence` to `grad_J_a_fluence!`.
347+
which links `make_grad_J_a` for [`J_a_fluence`](@ref) to
348+
[`grad_J_a_fluence`](@ref).
348349
"""
349350
function make_grad_J_a(J_a, tlist; mode=:any, automatic=:default)
350351
if mode == :any
@@ -890,19 +891,20 @@ end
890891
"""Analytic derivative for [`J_a_fluence`](@ref).
891892
892893
```julia
893-
grad_J_a_fluence!(∇J_a, pulsevals, tlist)
894+
∇J_a = grad_J_a_fluence(pulsevals, tlist)
894895
```
895896
896-
sets the (vectorized) elements of `∇J_a` to ``2 ϵ_{nl} dt``, where
897-
``ϵ_{nl}`` are the (vectorized) elements of `pulsevals` and ``dt`` is the time
898-
step, taken from the first time interval of `tlist` and assumed to be uniform.
897+
returns the `∇J_a`, which contains the (vectorized) elements ``2 ϵ_{nl} dt``,
898+
where ``ϵ_{nl}`` are the (vectorized) elements of `pulsevals` and ``dt`` is the
899+
time step, taken from the first time interval of `tlist` and assumed to be
900+
uniform.
899901
"""
900-
function grad_J_a_fluence!(∇J_a, pulsevals, tlist)
902+
function grad_J_a_fluence(pulsevals, tlist)
901903
dt = tlist[begin+1] - tlist[begin]
902-
axpy!(2 * dt, pulsevals, ∇J_a)
904+
return (2 * dt) * pulsevals
903905
end
904906

905907

906-
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence!
908+
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence
907909

908910
end

test/test_functionals.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using QuantumControl.Functionals:
66
J_T_re,
77
J_T_ss,
88
J_a_fluence,
9-
grad_J_a_fluence!,
9+
grad_J_a_fluence,
1010
make_grad_J_a,
1111
make_chi,
1212
chi_re,
@@ -101,20 +101,17 @@ end
101101
J_a_val = J_a_fluence(pulsevals, tlist)
102102
@test J_a_val > 0.0
103103

104-
G1 = copy(wrk.grad_J_a)
105-
grad_J_a_fluence!(G1, pulsevals, tlist)
104+
G1 = grad_J_a_fluence(pulsevals, tlist)
106105

107-
grad_J_a_zygote! = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=Zygote)
108-
@test grad_J_a_zygote! grad_J_a_fluence!
109-
G2 = copy(wrk.grad_J_a)
110-
grad_J_a_zygote!(G2, pulsevals, tlist)
106+
grad_J_a_zygote = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=Zygote)
107+
@test grad_J_a_zygote grad_J_a_fluence
108+
G2 = grad_J_a_zygote(pulsevals, tlist)
111109

112-
grad_J_a_fdm! =
110+
grad_J_a_fdm =
113111
make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=FiniteDifferences)
114-
@test grad_J_a_fdm! grad_J_a_fluence!
115-
@test grad_J_a_fdm! grad_J_a_zygote!
116-
G3 = copy(wrk.grad_J_a)
117-
grad_J_a_fdm!(G3, pulsevals, tlist)
112+
@test grad_J_a_fdm grad_J_a_fluence
113+
@test grad_J_a_fdm grad_J_a_zygote
114+
G3 = grad_J_a_fdm(pulsevals, tlist)
118115

119116
@test 0.0 norm(G2 - G1) < 1e-12 # zygote can be exact
120117
@test 0.0 < norm(G3 - G1) < 1e-12 # fdm should not be exact
@@ -324,7 +321,7 @@ end
324321
end
325322
grad_J_a = capture.value
326323
@test_throws DomainError begin
327-
grad_J_a(1, 1, tlist)
324+
grad_J_a(1, tlist)
328325
end
329326

330327
QuantumControl.set_default_ad_framework(nothing; quiet=true)

0 commit comments

Comments
 (0)