Skip to content

Commit 240e7e8

Browse files
authored
fix: disallow complex basis vectors for now (#669)
* fix: disallow complex basis vectors for now * Skip complex tests * Fix fromprimitive * Fix basis
1 parent 9b17e3e commit 240e7e8

File tree

6 files changed

+26
-22
lines changed

6 files changed

+26
-22
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.28"
4+
version = "0.6.29"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using ReverseDiff:
2626

2727
DI.check_available(::AutoReverseDiff) = true
2828

29-
function DI.basis(::AutoReverseDiff, a::AbstractArray{T}, i) where {T}
29+
function DI.basis(::AutoReverseDiff, a::AbstractArray{T}, i) where {T<:Real}
3030
return DI.OneElement(i, one(T), a)
3131
end
3232

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ end
7474
function prepare_derivative(
7575
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
7676
) where {F,C}
77-
pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...)
77+
pushforward_prep = prepare_pushforward(f, backend, x, (realone(x),), contexts...)
7878
return PushforwardDerivativePrep(pushforward_prep)
7979
end
8080

8181
function prepare_derivative(
8282
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}
8383
) where {F,C}
84-
pushforward_prep = prepare_pushforward(f!, y, backend, x, (one(x),), contexts...)
84+
pushforward_prep = prepare_pushforward(f!, y, backend, x, (realone(x),), contexts...)
8585
return PushforwardDerivativePrep(pushforward_prep)
8686
end
8787

@@ -95,7 +95,7 @@ function value_and_derivative(
9595
contexts::Vararg{Context,C},
9696
) where {F,C}
9797
y, ty = value_and_pushforward(
98-
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
98+
f, prep.pushforward_prep, backend, x, (realone(x),), contexts...
9999
)
100100
return y, only(ty)
101101
end
@@ -109,7 +109,7 @@ function value_and_derivative!(
109109
contexts::Vararg{Context,C},
110110
) where {F,C}
111111
y, _ = value_and_pushforward!(
112-
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
112+
f, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...
113113
)
114114
return y, der
115115
end
@@ -121,7 +121,7 @@ function derivative(
121121
x,
122122
contexts::Vararg{Context,C},
123123
) where {F,C}
124-
ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
124+
ty = pushforward(f, prep.pushforward_prep, backend, x, (realone(x),), contexts...)
125125
return only(ty)
126126
end
127127

@@ -133,7 +133,7 @@ function derivative!(
133133
x,
134134
contexts::Vararg{Context,C},
135135
) where {F,C}
136-
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
136+
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...)
137137
return der
138138
end
139139

@@ -148,7 +148,7 @@ function value_and_derivative(
148148
contexts::Vararg{Context,C},
149149
) where {F,C}
150150
y, ty = value_and_pushforward(
151-
f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...
151+
f!, y, prep.pushforward_prep, backend, x, (realone(x),), contexts...
152152
)
153153
return y, only(ty)
154154
end
@@ -163,7 +163,7 @@ function value_and_derivative!(
163163
contexts::Vararg{Context,C},
164164
) where {F,C}
165165
y, _ = value_and_pushforward!(
166-
f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
166+
f!, y, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...
167167
)
168168
return y, der
169169
end
@@ -176,7 +176,7 @@ function derivative(
176176
x,
177177
contexts::Vararg{Context,C},
178178
) where {F,C}
179-
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...)
179+
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (realone(x),), contexts...)
180180
return only(ty)
181181
end
182182

@@ -189,7 +189,9 @@ function derivative!(
189189
x,
190190
contexts::Vararg{Context,C},
191191
) where {F,C}
192-
pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
192+
pushforward!(
193+
f!, y, (der,), prep.pushforward_prep, backend, x, (realone(x),), contexts...
194+
)
193195
return der
194196
end
195197

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
abstract type FromPrimitive <: AbstractADType end
22

3-
function basis(fromprim::FromPrimitive, x::AbstractArray, i)
3+
function basis(fromprim::FromPrimitive, x::AbstractArray{<:Real}, i)
44
return basis(fromprim.backend, x, i)
55
end
66

7-
function multibasis(fromprim::FromPrimitive, x::AbstractArray, inds)
7+
function multibasis(fromprim::FromPrimitive, x::AbstractArray{<:Real}, inds)
88
return multibasis(fromprim.backend, x, inds)
99
end
1010

1111
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
1212
inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend)
1313

14-
function BatchSizeSettings(fromprim::FromPrimitive, x::AbstractArray)
14+
function BatchSizeSettings(fromprim::FromPrimitive, x::AbstractArray{<:Real})
1515
return BatchSizeSettings(fromprim.backend, x)
1616
end
1717

DifferentiationInterface/src/utils/basis.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Construct the `i`-th standard basis array in the vector space of `a` with elemen
4646
If an AD backend benefits from a more specialized basis array implementation,
4747
this function can be extended on the backend type.
4848
"""
49-
basis(::AbstractADType, a::AbstractArray, i) = basis(a, i)
49+
basis(::AbstractADType, a::AbstractArray{<:Real}, i) = basis(a, i)
5050

5151
"""
5252
multibasis(backend, a::AbstractArray, inds::AbstractVector)
@@ -58,16 +58,18 @@ Construct the sum of the `i`-th standard basis arrays in the vector space of `a`
5858
If an AD backend benefits from a more specialized basis array implementation,
5959
this function can be extended on the backend type.
6060
"""
61-
multibasis(::AbstractADType, a::AbstractArray, inds) = multibasis(a, inds)
61+
multibasis(::AbstractADType, a::AbstractArray{<:Real}, inds) = multibasis(a, inds)
6262

63-
function basis(a::AbstractArray{T,N}, i) where {T,N}
63+
function basis(a::AbstractArray{T,N}, i) where {T<:Real,N}
6464
return zero(a) + OneElement(i, one(T), a)
6565
end
6666

67-
function multibasis(a::AbstractArray{T,N}, inds::AbstractVector) where {T,N}
67+
function multibasis(a::AbstractArray{T,N}, inds::AbstractVector) where {T<:Real,N}
6868
seed = zero(a)
6969
for i in inds
7070
seed += OneElement(i, one(T), a)
7171
end
7272
return seed
7373
end
74+
75+
realone(x::Real) = one(x)

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ test_differentiation(
2626
@testset verbose = true "Complex number support" begin
2727
backend = AutoSparse(AutoFiniteDiff(); coloring_algorithm=GreedyColoringAlgorithm())
2828
x = float.(1:3) .+ im
29-
@test_nowarn jacobian(identity, backend, x)
30-
@test_nowarn jacobian(copyto!, similar(x), backend, x)
31-
@test_nowarn hessian(sum, backend, x)
29+
@test_skip @test_nowarn jacobian(identity, backend, x)
30+
@test_skip @test_nowarn jacobian(copyto!, similar(x), backend, x)
31+
@test_skip @test_nowarn hessian(sum, backend, x)
3232
end

0 commit comments

Comments
 (0)