Skip to content

Commit 948fbb4

Browse files
authored
Loosen AbstractArray restrictions where these are not necessary (#63)
* remove `AbstractArray` annotations for rrules * remove `AbstractArray` annotations for implementations
1 parent ab9f6f3 commit 948fbb4

File tree

9 files changed

+33
-51
lines changed

9 files changed

+33
-51
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using LinearAlgebra
1111

1212
MatrixAlgebraKit.iszerotangent(::AbstractZero) = true
1313

14-
function ChainRulesCore.rrule(::typeof(copy_input), f, A::AbstractMatrix)
14+
function ChainRulesCore.rrule(::typeof(copy_input), f, A)
1515
project = ProjectTo(A)
1616
copy_input_pullback(ΔA) = (NoTangent(), NoTangent(), project(unthunk(ΔA)))
1717
return copy_input(f, A), copy_input_pullback
@@ -20,7 +20,7 @@ end
2020
for qr_f in (:qr_compact, :qr_full)
2121
qr_f! = Symbol(qr_f, '!')
2222
@eval begin
23-
function ChainRulesCore.rrule(::typeof($qr_f!), A::AbstractMatrix, QR, alg)
23+
function ChainRulesCore.rrule(::typeof($qr_f!), A, QR, alg)
2424
Ac = copy_input($qr_f, A)
2525
QR = $(qr_f!)(Ac, QR, alg)
2626
function qr_pullback(ΔQR)
@@ -58,7 +58,7 @@ end
5858
for lq_f in (:lq_compact, :lq_full)
5959
lq_f! = Symbol(lq_f, '!')
6060
@eval begin
61-
function ChainRulesCore.rrule(::typeof($lq_f!), A::AbstractMatrix, LQ, alg)
61+
function ChainRulesCore.rrule(::typeof($lq_f!), A, LQ, alg)
6262
Ac = copy_input($lq_f, A)
6363
LQ = $(lq_f!)(Ac, LQ, alg)
6464
function lq_pullback(ΔLQ)
@@ -102,7 +102,7 @@ for eig in (:eig, :eigh)
102102
eig_t_pb = Symbol(eig, "_trunc_pullback")
103103
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
104104
@eval begin
105-
function ChainRulesCore.rrule(::typeof($eig_f!), A::AbstractMatrix, DV, alg)
105+
function ChainRulesCore.rrule(::typeof($eig_f!), A, DV, alg)
106106
Ac = copy_input($eig_f, A)
107107
DV = $(eig_f!)(Ac, DV, alg)
108108
function $eig_pb(ΔDV)
@@ -115,10 +115,7 @@ for eig in (:eig, :eigh)
115115
end
116116
return DV, $eig_pb
117117
end
118-
function ChainRulesCore.rrule(
119-
::typeof($eig_t!), A::AbstractMatrix, DV,
120-
alg::TruncatedAlgorithm
121-
)
118+
function ChainRulesCore.rrule(::typeof($eig_t!), A, DV, alg::TruncatedAlgorithm)
122119
Ac = copy_input($eig_f, A)
123120
DV = $(eig_f!)(Ac, DV, alg.alg)
124121
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
@@ -141,7 +138,7 @@ end
141138
for svd_f in (:svd_compact, :svd_full)
142139
svd_f! = Symbol(svd_f, "!")
143140
@eval begin
144-
function ChainRulesCore.rrule(::typeof($svd_f!), A::AbstractMatrix, USVᴴ, alg)
141+
function ChainRulesCore.rrule(::typeof($svd_f!), A, USVᴴ, alg)
145142
Ac = copy_input($svd_f, A)
146143
USVᴴ = $(svd_f!)(Ac, USVᴴ, alg)
147144
function svd_pullback(ΔUSVᴴ)
@@ -157,10 +154,7 @@ for svd_f in (:svd_compact, :svd_full)
157154
end
158155
end
159156

160-
function ChainRulesCore.rrule(
161-
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ,
162-
alg::TruncatedAlgorithm
163-
)
157+
function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
164158
Ac = copy_input(svd_compact, A)
165159
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
166160
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
@@ -178,7 +172,7 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
178172
return svd_trunc_pullback
179173
end
180174

181-
function ChainRulesCore.rrule(::typeof(left_polar!), A::AbstractMatrix, WP, alg)
175+
function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg)
182176
Ac = copy_input(left_polar, A)
183177
WP = left_polar!(Ac, WP, alg)
184178
function left_polar_pullback(ΔWP)
@@ -192,7 +186,7 @@ function ChainRulesCore.rrule(::typeof(left_polar!), A::AbstractMatrix, WP, alg)
192186
return WP, left_polar_pullback
193187
end
194188

195-
function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, alg)
189+
function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
196190
Ac = copy_input(left_polar, A)
197191
PWᴴ = right_polar!(Ac, PWᴴ, alg)
198192
function right_polar_pullback(ΔPWᴴ)

src/implementations/eig.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
function copy_input(::typeof(eig_full), A::AbstractMatrix)
44
return copy!(similar(A, float(eltype(A))), A)
55
end
6-
function copy_input(::typeof(eig_vals), A)
7-
return copy_input(eig_full, A)
8-
end
6+
copy_input(::typeof(eig_vals), A) = copy_input(eig_full, A)
97
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)
108

119
copy_input(::typeof(eig_full), A::Diagonal) = copy(A)
@@ -67,7 +65,7 @@ function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlg
6765
D = similar(A, Tc, n)
6866
return D
6967
end
70-
function initialize_output(::typeof(eig_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm)
68+
function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm)
7169
return initialize_output(eig_full!, A, alg.alg)
7270
end
7371

@@ -108,7 +106,7 @@ function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
108106
return D
109107
end
110108

111-
function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
109+
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
112110
D, V = eig_full!(A, DV, alg.alg)
113111
return first(truncate(eig_trunc!, (D, V), alg.trunc))
114112
end
@@ -131,7 +129,7 @@ end
131129

132130
# GPU logic
133131
# ---------
134-
_gpu_geev!(A::AbstractMatrix, D, V) = throw(MethodError(_gpu_geev!, (A, D, V)))
132+
_gpu_geev!(A, D, V) = throw(MethodError(_gpu_geev!, (A, D, V)))
135133

136134
function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
137135
check_input(eig_full!, A, DV, alg)

src/implementations/eigh.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
function copy_input(::typeof(eigh_full), A::AbstractMatrix)
44
return copy!(similar(A, float(eltype(A))), A)
55
end
6-
function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
7-
return copy_input(eigh_full, A)
8-
end
6+
copy_input(::typeof(eigh_vals), A) = copy_input(eigh_full, A)
97
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
108

119
copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
@@ -65,7 +63,7 @@ function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAl
6563
D = similar(A, real(eltype(A)), n)
6664
return D
6765
end
68-
function initialize_output(::typeof(eigh_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm)
66+
function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm)
6967
return initialize_output(eigh_full!, A, alg.alg)
7068
end
7169

@@ -111,7 +109,7 @@ function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
111109
return D
112110
end
113111

114-
function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
112+
function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
115113
D, V = eigh_full!(A, DV, alg.alg)
116114
return first(truncate(eigh_trunc!, (D, V), alg.trunc))
117115
end

src/implementations/gen_eig.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
function copy_input(::typeof(gen_eig_full), A::AbstractMatrix, B::AbstractMatrix)
44
return copy!(similar(A, float(eltype(A))), A), copy!(similar(B, float(eltype(B))), B)
55
end
6-
function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix)
7-
return copy_input(gen_eig_full, A, B)
8-
end
6+
copy_input(::typeof(gen_eig_vals), A, B) = copy_input(gen_eig_full, A, B)
97

108
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
119
ma, na = size(A)

src/implementations/lq.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Inputs
22
# ------
3-
for f in (:lq_full, :lq_compact, :lq_null)
4-
@eval function copy_input(::typeof($f), A)
5-
return copy!(similar(A, float(eltype(A))), A)
6-
end
7-
@eval copy_input(::typeof($f), A::Diagonal) = copy(A)
8-
end
3+
copy_input(::typeof(lq_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A)
4+
copy_input(::typeof(lq_compact), A) = copy_input(lq_full, A)
5+
copy_input(::typeof(lq_null), A) = copy_input(lq_full, A)
6+
7+
copy_input(::typeof(lq_full), A::Diagonal) = copy(A)
98

109
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
1110
m, n = size(A)

src/implementations/qr.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Inputs
22
# ------
3-
for f in (:qr_full, :qr_compact, :qr_null)
4-
@eval function copy_input(::typeof($f), A)
5-
return copy!(similar(A, float(eltype(A))), A)
6-
end
7-
@eval copy_input(::typeof($f), A::Diagonal) = copy(A)
8-
end
3+
copy_input(::typeof(qr_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A)
4+
copy_input(::typeof(qr_compact), A) = copy_input(qr_full, A)
5+
copy_input(::typeof(qr_null), A) = copy_input(qr_full, A)
6+
7+
copy_input(::typeof(qr_full), A::Diagonal) = copy(A)
98

109
function check_input(::typeof(qr_full!), A::AbstractMatrix, QR, ::AbstractAlgorithm)
1110
m, n = size(A)

src/implementations/schur.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Inputs
22
# ------
3-
copy_input(::typeof(schur_full), A::AbstractMatrix) = copy_input(eig_full, A)
4-
copy_input(::typeof(schur_vals), A::AbstractMatrix) = copy_input(eig_vals, A)
3+
copy_input(::typeof(schur_full), A) = copy_input(eig_full, A)
4+
copy_input(::typeof(schur_vals), A) = copy_input(eig_vals, A)
55

66
# check input
77
function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAlgorithm)

src/implementations/svd.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Inputs
22
# ------
3-
function copy_input(::typeof(svd_full), A)
4-
return copy!(similar(A, float(eltype(A))), A)
5-
end
3+
copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A)
64
copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
75
copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A)
86
copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A)
@@ -238,7 +236,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
238236
return S
239237
end
240238

241-
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
239+
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
242240
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
243241
return first(truncate(svd_trunc!, USVᴴ′, alg.trunc))
244242
end
@@ -270,7 +268,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
270268

271269
return U, S, Vᴴ
272270
end
273-
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
271+
function svd_compact!(A, USVᴴ, alg::DiagonalAlgorithm)
274272
return svd_full!(A, USVᴴ, alg)
275273
end
276274
function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm)

src/implementations/truncation.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ end
2929
# findtruncated
3030
# -------------
3131
# Generic fallback
32-
function findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy)
33-
return findtruncated(values, strategy)
34-
end
32+
findtruncated_svd(values, strategy::TruncationStrategy) = findtruncated(values, strategy)
3533

3634
# specific implementations for finding truncated values
3735
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()

0 commit comments

Comments
 (0)