Skip to content

Commit 4ab245c

Browse files
lkdvosJutho
andauthored
Add qr_null_pullback! and lq_null_pullback! (#62)
* add `qr_null_pullback!` and `lq_null_pullback!` functions * mark as public * implement chainrules * include some generic non-differentiable rules * fix type stability * update qr/lq_null pullbacks * remove AbstractArray annotation --------- Co-authored-by: Jutho Haegeman <[email protected]>
1 parent 948fbb4 commit 4ab245c

File tree

5 files changed

+67
-20
lines changed

5 files changed

+67
-20
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ using LinearAlgebra
1111

1212
MatrixAlgebraKit.iszerotangent(::AbstractZero) = true
1313

14+
@non_differentiable MatrixAlgebraKit.select_algorithm(args...)
15+
@non_differentiable MatrixAlgebraKit.initialize_output(args...)
16+
@non_differentiable MatrixAlgebraKit.check_input(args...)
17+
@non_differentiable MatrixAlgebraKit.isisometry(args...)
18+
@non_differentiable MatrixAlgebraKit.isunitary(args...)
19+
1420
function ChainRulesCore.rrule(::typeof(copy_input), f, A)
1521
project = ProjectTo(A)
1622
copy_input_pullback(ΔA) = (NoTangent(), NoTangent(), project(unthunk(ΔA)))
@@ -35,18 +41,12 @@ for qr_f in (:qr_compact, :qr_full)
3541
end
3642
end
3743
end
38-
function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
44+
function ChainRulesCore.rrule(::typeof(qr_null!), A, N, alg)
3945
Ac = copy_input(qr_full, A)
40-
QR = initialize_output(qr_full!, A, alg)
41-
Q, R = qr_full!(Ac, QR, alg)
42-
N = copy!(N, view(Q, 1:size(A, 1), (size(A, 2) + 1):size(A, 1)))
46+
N = qr_null!(Ac, N, alg)
4347
function qr_null_pullback(ΔN)
4448
ΔA = zero(A)
45-
(m, n) = size(A)
46-
minmn = min(m, n)
47-
ΔQ = zero!(similar(A, (m, m)))
48-
view(ΔQ, 1:m, (minmn + 1):m) .= unthunk.(ΔN)
49-
MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ZeroTangent()))
49+
MatrixAlgebraKit.qr_null_pullback!(ΔA, A, N, unthunk(ΔN))
5050
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
5151
end
5252
function qr_null_pullback(::ZeroTangent) # is this extra definition useful?
@@ -73,18 +73,12 @@ for lq_f in (:lq_compact, :lq_full)
7373
end
7474
end
7575
end
76-
function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
76+
function ChainRulesCore.rrule(::typeof(lq_null!), A, Nᴴ, alg)
7777
Ac = copy_input(lq_full, A)
78-
LQ = initialize_output(lq_full!, A, alg)
79-
L, Q = lq_full!(Ac, LQ, alg)
80-
Nᴴ = copy!(Nᴴ, view(Q, (size(A, 1) + 1):size(A, 2), 1:size(A, 2)))
78+
Nᴴ = lq_null!(Ac, Nᴴ, alg)
8179
function lq_null_pullback(ΔNᴴ)
8280
ΔA = zero(A)
83-
(m, n) = size(A)
84-
minmn = min(m, n)
85-
ΔQ = zero!(similar(A, (n, n)))
86-
view(ΔQ, (minmn + 1):n, 1:n) .= unthunk.(ΔNᴴ)
87-
MatrixAlgebraKit.lq_pullback!(ΔA, A, (L, Q), (ZeroTangent(), ΔQ))
81+
MatrixAlgebraKit.lq_null_pullback!(ΔA, A, Nᴴ, unthunk(ΔNᴴ))
8882
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
8983
end
9084
function lq_null_pullback(::ZeroTangent) # is this extra definition useful?

src/MatrixAlgebraKit.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
5555
)
5656
eval(
5757
Expr(
58-
:public, :qr_pullback!, :lq_pullback!, :svd_pullback!, :svd_trunc_pullback!,
58+
:public, :left_polar_pullback!, :right_polar_pullback!,
59+
:qr_pullback!, :qr_null_pullback!, :lq_pullback!, :lq_null_pullback!,
5960
:eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!, :eigh_trunc_pullback!,
60-
:left_polar_pullback!, :right_polar_pullback!
61+
:svd_pullback!, :svd_trunc_pullback!
6162
)
6263
)
6364
end

src/common/pullbacks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ ecosystems
99
function iszerotangent end
1010

1111
iszerotangent(::Any) = false
12+
iszerotangent(::Nothing) = true

src/pullbacks/lq.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,28 @@ function lq_pullback!(
103103
ΔA1 .+= ΔQ̃
104104
return ΔA
105105
end
106+
107+
"""
108+
lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ)
109+
110+
Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
111+
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.
112+
113+
See also [`lq_pullback!`](@ref).
114+
"""
115+
function lq_null_pullback!(
116+
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
117+
tol::Real = default_pullback_gaugetol(A),
118+
gauge_atol::Real = tol
119+
)
120+
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
121+
NᴴΔN = Nᴴ * ΔNᴴ'
122+
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
123+
Δgauge < tol ||
124+
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
125+
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
126+
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
127+
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
128+
end
129+
return ΔA
130+
end

src/pullbacks/qr.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,29 @@ function qr_pullback!(
102102
ΔA1 .+= ΔQ̃
103103
return ΔA
104104
end
105+
106+
"""
107+
qr_null_pullback(ΔA, A, N, ΔN)
108+
109+
Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
110+
`N` and its cotangent `ΔN` of `qr_null(A)`.
111+
112+
See also [`qr_pullback!`](@ref).
113+
"""
114+
function qr_null_pullback!(
115+
ΔA::AbstractMatrix, A, N, ΔN;
116+
tol::Real = default_pullback_gaugetol(A),
117+
gauge_atol::Real = tol
118+
)
119+
if !iszerotangent(ΔN) && size(N, 2) > 0
120+
NᴴΔN = N' * ΔN
121+
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
122+
Δgauge < tol ||
123+
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
124+
125+
Q, R = qr_compact(A; positive = true)
126+
X = rdiv!(ΔN' * Q, UpperTriangular(R)')
127+
ΔA = mul!(ΔA, N, X, -1, 1)
128+
end
129+
return ΔA
130+
end

0 commit comments

Comments
 (0)