@@ -11,7 +11,7 @@ using LinearAlgebra
1111
1212MatrixAlgebraKit. iszerotangent (:: AbstractZero ) = true
1313
14- function ChainRulesCore. rrule (:: typeof (copy_input), f, A:: AbstractMatrix )
14+ function ChainRulesCore. rrule (:: typeof (copy_input), f, A)
1515 project = ProjectTo (A)
1616 copy_input_pullback (ΔA) = (NoTangent (), NoTangent (), project (unthunk (ΔA)))
1717 return copy_input (f, A), copy_input_pullback
2020for qr_f in (:qr_compact , :qr_full )
2121 qr_f! = Symbol (qr_f, ' !' )
2222 @eval begin
23- function ChainRulesCore. rrule (:: typeof ($ qr_f!), A:: AbstractMatrix , QR, alg)
23+ function ChainRulesCore. rrule (:: typeof ($ qr_f!), A, QR, alg)
2424 Ac = copy_input ($ qr_f, A)
2525 QR = $ (qr_f!)(Ac, QR, alg)
2626 function qr_pullback (ΔQR)
5858for lq_f in (:lq_compact , :lq_full )
5959 lq_f! = Symbol (lq_f, ' !' )
6060 @eval begin
61- function ChainRulesCore. rrule (:: typeof ($ lq_f!), A:: AbstractMatrix , LQ, alg)
61+ function ChainRulesCore. rrule (:: typeof ($ lq_f!), A, LQ, alg)
6262 Ac = copy_input ($ lq_f, A)
6363 LQ = $ (lq_f!)(Ac, LQ, alg)
6464 function lq_pullback (ΔLQ)
@@ -102,7 +102,7 @@ for eig in (:eig, :eigh)
102102 eig_t_pb = Symbol (eig, " _trunc_pullback" )
103103 _make_eig_t_pb = Symbol (" _make_" , eig_t_pb)
104104 @eval begin
105- function ChainRulesCore. rrule (:: typeof ($ eig_f!), A:: AbstractMatrix , DV, alg)
105+ function ChainRulesCore. rrule (:: typeof ($ eig_f!), A, DV, alg)
106106 Ac = copy_input ($ eig_f, A)
107107 DV = $ (eig_f!)(Ac, DV, alg)
108108 function $eig_pb (ΔDV)
@@ -115,10 +115,7 @@ for eig in (:eig, :eigh)
115115 end
116116 return DV, $ eig_pb
117117 end
118- function ChainRulesCore. rrule (
119- :: typeof ($ eig_t!), A:: AbstractMatrix , DV,
120- alg:: TruncatedAlgorithm
121- )
118+ function ChainRulesCore. rrule (:: typeof ($ eig_t!), A, DV, alg:: TruncatedAlgorithm )
122119 Ac = copy_input ($ eig_f, A)
123120 DV = $ (eig_f!)(Ac, DV, alg. alg)
124121 DV′, ind = MatrixAlgebraKit. truncate ($ eig_t!, DV, alg. trunc)
141138for svd_f in (:svd_compact , :svd_full )
142139 svd_f! = Symbol (svd_f, " !" )
143140 @eval begin
144- function ChainRulesCore. rrule (:: typeof ($ svd_f!), A:: AbstractMatrix , USVᴴ, alg)
141+ function ChainRulesCore. rrule (:: typeof ($ svd_f!), A, USVᴴ, alg)
145142 Ac = copy_input ($ svd_f, A)
146143 USVᴴ = $ (svd_f!)(Ac, USVᴴ, alg)
147144 function svd_pullback (ΔUSVᴴ)
@@ -157,10 +154,7 @@ for svd_f in (:svd_compact, :svd_full)
157154 end
158155end
159156
160- function ChainRulesCore. rrule (
161- :: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ,
162- alg:: TruncatedAlgorithm
163- )
157+ function ChainRulesCore. rrule (:: typeof (svd_trunc!), A, USVᴴ, alg:: TruncatedAlgorithm )
164158 Ac = copy_input (svd_compact, A)
165159 USVᴴ = svd_compact! (Ac, USVᴴ, alg. alg)
166160 USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
@@ -178,7 +172,7 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
178172 return svd_trunc_pullback
179173end
180174
181- function ChainRulesCore. rrule (:: typeof (left_polar!), A:: AbstractMatrix , WP, alg)
175+ function ChainRulesCore. rrule (:: typeof (left_polar!), A, WP, alg)
182176 Ac = copy_input (left_polar, A)
183177 WP = left_polar! (Ac, WP, alg)
184178 function left_polar_pullback (ΔWP)
@@ -192,7 +186,7 @@ function ChainRulesCore.rrule(::typeof(left_polar!), A::AbstractMatrix, WP, alg)
192186 return WP, left_polar_pullback
193187end
194188
195- function ChainRulesCore. rrule (:: typeof (right_polar!), A:: AbstractMatrix , PWᴴ, alg)
189+ function ChainRulesCore. rrule (:: typeof (right_polar!), A, PWᴴ, alg)
196190 Ac = copy_input (left_polar, A)
197191 PWᴴ = right_polar! (Ac, PWᴴ, alg)
198192 function right_polar_pullback (ΔPWᴴ)
0 commit comments