Skip to content

Commit babf94f

Browse files
avik-paldevmotion
andcommitted
fix: regression in non-fast scalar indexing support
fix: project toml for julia pre 1.9 fix: support gradient + more test coverage chore: relax version chore: remove 1.6 support and bump min version to 1.10 fix: apply suggestions from code review Co-authored-by: David Widmann <[email protected]> fix: use a struct instead of closure fix: sizecheck chore: remove GPUArraysCore Co-authored-by: David Widmann <[email protected]> fix: revert _take chore: remove 1.8 checks chore: remove 0.1 Co-authored-by: David Widmann <[email protected]>
1 parent 7decc58 commit babf94f

File tree

6 files changed

+107
-5
lines changed

6 files changed

+107
-5
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "1.1.0"
3+
version = "1.1.1"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
@@ -15,9 +15,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616

1717
[weakdeps]
18+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1819
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1920

2021
[extensions]
22+
ForwardDiffGPUArraysCoreExt = "GPUArraysCore"
2123
ForwardDiffStaticArraysExt = "StaticArrays"
2224

2325
[compat]
@@ -26,6 +28,7 @@ CommonSubexpressions = "0.3"
2628
DiffResults = "1.1"
2729
DiffRules = "1.4"
2830
DiffTests = "0.1"
31+
GPUArraysCore = "0.2"
2932
IrrationalConstants = "0.1, 0.2"
3033
LogExpFunctions = "0.3"
3134
NaNMath = "1"
@@ -39,9 +42,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
3942
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
4043
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
4144
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
45+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
4246
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4347
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4448
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4549

4650
[targets]
47-
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]
51+
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"]

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
module ForwardDiffGPUArraysCoreExt
2+
3+
using GPUArraysCore: AbstractGPUArray
4+
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
5+
6+
struct PartialsFn{T,D<:Dual}
7+
dual::D
8+
end
9+
PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
10+
11+
(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)
12+
13+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
14+
seed::Partials{N,V}) where {T,V,N}
15+
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
16+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
17+
return duals
18+
end
19+
20+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
21+
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
22+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
23+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
24+
return duals
25+
end
26+
27+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
28+
seed::Partials{N,V}) where {T,V,N}
29+
offset = index - 1
30+
idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset))
31+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
32+
return duals
33+
end
34+
35+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
36+
seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N}
37+
offset = index - 1
38+
idxs = collect(
39+
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize)
40+
)
41+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
42+
return duals
43+
end
44+
45+
# gradient
46+
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
47+
dual::Dual) where {T}
48+
fn = PartialsFn{T}(dual)
49+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
50+
result[idxs] .= fn.(1:length(idxs))
51+
return result
52+
end
53+
54+
function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual,
55+
index, chunksize) where {T}
56+
fn = PartialsFn{T}(dual)
57+
offset = index - 1
58+
idxs = collect(
59+
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize)
60+
)
61+
result[idxs] .= fn.(1:length(idxs))
62+
return result
63+
end
64+
65+
end

src/dual.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,6 @@ Base.copy(d::Dual) = d
298298
Base.eps(d::Dual) = eps(value(d))
299299
Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D))
300300

301-
# The `base` keyword was added in Julia 1.8:
302-
# https://github.com/JuliaLang/julia/pull/42428
303301
Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base)
304302
function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual}
305303
precision(valtype(D); base=base)

test/DualTest.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
118118
@test precision(typeof(FDNUM)) === precision(V)
119119
@test precision(NESTED_FDNUM) === precision(PRIMAL)
120120
@test precision(typeof(NESTED_FDNUM)) === precision(V)
121-
122121
@test precision(FDNUM; base=10) === precision(PRIMAL; base=10)
123122
@test precision(typeof(FDNUM); base=10) === precision(V; base=10)
124123
@test precision(NESTED_FDNUM; base=10) === precision(PRIMAL; base=10)

test/GradientTest.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ForwardDiff
99
using ForwardDiff: Dual, Tag
1010
using StaticArrays
1111
using DiffTests
12+
using JLArrays
1213

1314
include(joinpath(dirname(@__FILE__), "utils.jl"))
1415

@@ -255,4 +256,25 @@ end
255256
end
256257
end
257258

259+
@testset "GPUArraysCore" begin
260+
fn(x) = sum(x .^ 2 ./ 2)
261+
262+
x = [1.0, 2.0, 3.0]
263+
x_jl = JLArray(x)
264+
265+
grad = ForwardDiff.gradient(fn, x)
266+
grad_jl = ForwardDiff.gradient(fn, x_jl)
267+
268+
@test grad_jl isa JLArray
269+
@test Array(grad_jl) grad
270+
271+
cfg = ForwardDiff.GradientConfig(
272+
fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x))
273+
)
274+
grad_jl = ForwardDiff.gradient(fn, x_jl, cfg)
275+
276+
@test grad_jl isa JLArray
277+
@test Array(grad_jl) grad
278+
end
279+
258280
end # module

test/JacobianTest.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig
88
using StaticArrays
99
using DiffTests
1010
using LinearAlgebra
11+
using JLArrays
1112

1213
include(joinpath(dirname(@__FILE__), "utils.jl"))
1314

@@ -308,4 +309,17 @@ end
308309
end
309310
end
310311

312+
@testset "GPUArraysCore" begin
313+
f(x) = x .^ 2 ./ 2
314+
315+
x = [1.0, 2.0, 3.0]
316+
x_jl = JLArray(x)
317+
318+
jac = ForwardDiff.jacobian(f, x)
319+
jac_jl = ForwardDiff.jacobian(f, x_jl)
320+
321+
@test jac_jl isa JLArray
322+
@test Array(jac_jl) jac
323+
end
324+
311325
end # module

0 commit comments

Comments
 (0)