Skip to content

Commit d2e7730

Browse files
committed
feat: support wrapper types as well
1 parent babf94f commit d2e7730

File tree

8 files changed

+96
-110
lines changed

8 files changed

+96
-110
lines changed

Project.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616

1717
[weakdeps]
18-
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1918
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2019

2120
[extensions]
22-
ForwardDiffGPUArraysCoreExt = "GPUArraysCore"
2321
ForwardDiffStaticArraysExt = "StaticArrays"
2422

2523
[compat]
@@ -28,7 +26,6 @@ CommonSubexpressions = "0.3"
2826
DiffResults = "1.1"
2927
DiffRules = "1.4"
3028
DiffTests = "0.1"
31-
GPUArraysCore = "0.2"
3229
IrrationalConstants = "0.1, 0.2"
3330
LogExpFunctions = "0.3"
3431
NaNMath = "1"

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/ForwardDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include("derivative.jl")
2121
include("gradient.jl")
2222
include("jacobian.jl")
2323
include("hessian.jl")
24+
include("utils.jl")
2425

2526
export DiffResults
2627

src/apiutils.jl

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,36 +72,46 @@ end
7272

7373
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
7474
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
75-
if isbitstype(V)
76-
for idx in structural_eachindex(duals, x)
77-
duals[idx] = Dual{T,V,N}(x[idx], seed)
78-
end
79-
else
80-
for idx in structural_eachindex(duals, x)
81-
if isassigned(x, idx)
75+
if supports_fast_scalar_indexing(duals)
76+
if isbitstype(V)
77+
for idx in structural_eachindex(duals, x)
8278
duals[idx] = Dual{T,V,N}(x[idx], seed)
83-
else
84-
Base._unsetindex!(duals, idx)
79+
end
80+
else
81+
for idx in structural_eachindex(duals, x)
82+
if isassigned(x, idx)
83+
duals[idx] = Dual{T,V,N}(x[idx], seed)
84+
else
85+
Base._unsetindex!(duals, idx)
86+
end
8587
end
8688
end
89+
else
90+
idxs = collect(structural_eachindex(duals, x))
91+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
8792
end
8893
return duals
8994
end
9095

9196
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
9297
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
93-
if isbitstype(V)
94-
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
95-
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
96-
end
97-
else
98-
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
99-
if isassigned(x, idx)
98+
if supports_fast_scalar_indexing(duals)
99+
if isbitstype(V)
100+
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
100101
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
101-
else
102-
Base._unsetindex!(duals, idx)
102+
end
103+
else
104+
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
105+
if isassigned(x, idx)
106+
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
107+
else
108+
Base._unsetindex!(duals, idx)
109+
end
103110
end
104111
end
112+
else
113+
idxs = collect(Iterators.take(structural_eachindex(duals, x), N))
114+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
105115
end
106116
return duals
107117
end
@@ -110,18 +120,23 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
110120
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
111121
offset = index - 1
112122
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
113-
if isbitstype(V)
114-
for idx in idxs
115-
duals[idx] = Dual{T,V,N}(x[idx], seed)
116-
end
117-
else
118-
for idx in idxs
119-
if isassigned(x, idx)
123+
if supports_fast_scalar_indexing(duals)
124+
if isbitstype(V)
125+
for idx in idxs
120126
duals[idx] = Dual{T,V,N}(x[idx], seed)
121-
else
122-
Base._unsetindex!(duals, idx)
127+
end
128+
else
129+
for idx in idxs
130+
if isassigned(x, idx)
131+
duals[idx] = Dual{T,V,N}(x[idx], seed)
132+
else
133+
Base._unsetindex!(duals, idx)
134+
end
123135
end
124136
end
137+
else
138+
idxs = collect(idxs)
139+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
125140
end
126141
return duals
127142
end
@@ -130,18 +145,23 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
130145
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
131146
offset = index - 1
132147
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
133-
if isbitstype(V)
134-
for (i, idx) in zip(1:chunksize, idxs)
135-
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
136-
end
137-
else
138-
for (i, idx) in zip(1:chunksize, idxs)
139-
if isassigned(x, idx)
148+
if supports_fast_scalar_indexing(duals)
149+
if isbitstype(V)
150+
for (i, idx) in zip(1:chunksize, idxs)
140151
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
141-
else
142-
Base._unsetindex!(duals, idx)
152+
end
153+
else
154+
for (i, idx) in zip(1:chunksize, idxs)
155+
if isassigned(x, idx)
156+
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
157+
else
158+
Base._unsetindex!(duals, idx)
159+
end
143160
end
144161
end
162+
else
163+
idxs = collect(Iterators.take(idxs, chunksize))
164+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
145165
end
146166
return duals
147167
end

src/gradient.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,29 @@ end
6565
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
6666
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
6767
idxs = structural_eachindex(result)
68-
for (i, idx) in zip(1:npartials(dual), idxs)
69-
result[idx] = partials(T, dual, i)
68+
if supports_fast_scalar_indexing(result)
69+
for (i, idx) in zip(1:npartials(dual), idxs)
70+
result[idx] = partials(T, dual, i)
71+
end
72+
else
73+
fn = PartialsFn{T}(dual)
74+
idxs = collect(Iterators.take(idxs, npartials(dual)))
75+
result[idxs] .= fn.(1:length(idxs))
76+
return result
7077
end
7178
return result
7279
end
7380

7481
function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
75-
offset = index - 1
76-
idxs = Iterators.drop(structural_eachindex(result), offset)
77-
for (i, idx) in zip(1:chunksize, idxs)
78-
result[idx] = partials(T, dual, i)
82+
idxs = Iterators.drop(structural_eachindex(result), index - 1)
83+
if supports_fast_scalar_indexing(result)
84+
for (i, idx) in zip(1:chunksize, idxs)
85+
result[idx] = partials(T, dual, i)
86+
end
87+
else
88+
fn = PartialsFn{T}(dual)
89+
idxs = collect(Iterators.take(idxs, chunksize))
90+
result[idxs] .= fn.(1:length(idxs))
7991
end
8092
return result
8193
end

src/utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# overload for array types that
2+
@inline supports_fast_scalar_indexing(::Array) = true
3+
4+
@inline function supports_fast_scalar_indexing(x::AbstractArray)
5+
parent(x) === x && return false
6+
return supports_fast_scalar_indexing(parent(x))
7+
end
8+
9+
# Helper function for broadcasting
10+
struct PartialsFn{T,D<:Dual}
11+
dual::D
12+
end
13+
PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
14+
15+
(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)

test/DualTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
118118
@test precision(typeof(FDNUM)) === precision(V)
119119
@test precision(NESTED_FDNUM) === precision(PRIMAL)
120120
@test precision(typeof(NESTED_FDNUM)) === precision(V)
121+
121122
@test precision(FDNUM; base=10) === precision(PRIMAL; base=10)
122123
@test precision(typeof(FDNUM); base=10) === precision(V; base=10)
123124
@test precision(NESTED_FDNUM; base=10) === precision(PRIMAL; base=10)

test/GradientTest.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,4 +277,9 @@ end
277277
@test Array(grad_jl) grad
278278
end
279279

280+
@testset "Scalar Indexing Checks" begin
281+
@test ForwardDiff.supports_fast_scalar_indexing(UnitLowerTriangular(view(rand(6, 6), 1:3, 1:3)))
282+
@test !ForwardDiff.supports_fast_scalar_indexing(UnitLowerTriangular(view(JLArray(rand(6, 6)), 1:3, 1:3)))
283+
end
284+
280285
end # module

0 commit comments

Comments
 (0)