Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,26 @@ function rmul!(T::Tridiagonal, D::Diagonal)
end
return T
end
for T in [:UpperTriangular, :UnitUpperTriangular,
:LowerTriangular, :UnitLowerTriangular]
@eval rmul!(A::$T{<:Any, <:StridedMatrix}, D::Diagonal) = _rmul!(A, D)
@eval lmul!(D::Diagonal, A::$T{<:Any, <:StridedMatrix}) = _lmul!(D, A)
end
function _rmul!(A::UpperOrLowerTriangular, D::Diagonal)
P = parent(A)
isunit = A isa UnitUpperOrUnitLowerTriangular
isupper = A isa UpperOrUnitUpperTriangular
for col in axes(A,2)
rowstart = isupper ? firstindex(A,1) : col+isunit
rowstop = isupper ? col-isunit : lastindex(A,1)
for row in rowstart:rowstop
P[row, col] *= D.diag[col]
end
end
isunit && _setdiag!(P, identity, D.diag)
TriWrapper = isupper ? UpperTriangular : LowerTriangular
return TriWrapper(P)
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
matmul_size_check(size(D), size(B))
Expand All @@ -388,6 +408,13 @@ function lmul!(D::Diagonal, B::AbstractVecOrMat)
end
return B
end
# A' = D * A' => A = A * D'
# This uses the fact that D' is a Diagonal
function lmul!(D::Diagonal, A::AdjOrTransAbsMat)
f = wrapperop(A)
rmul!(f(A), f(D))
A
end

# in-place multiplication with a diagonal
# T .= D * T
Expand All @@ -402,12 +429,20 @@ function lmul!(D::Diagonal, T::Tridiagonal)
end
return T
end
# A' = D * A' => A = A * D'
# This uses the fact that D' is a Diagonal
function lmul!(D::Diagonal, A::AdjOrTransAbsMat)
f = wrapperop(A)
rmul!(f(A), f(D))
A
function _lmul!(D::Diagonal, A::UpperOrLowerTriangular)
P = parent(A)
isunit = A isa UnitUpperOrUnitLowerTriangular
isupper = A isa UpperOrUnitUpperTriangular
for col in axes(A,2)
rowstart = isupper ? firstindex(A,1) : col+isunit
rowstop = isupper ? col-isunit : lastindex(A,1)
for row in rowstart:rowstop
P[row, col] = D.diag[row] * P[row, col]
end
end
isunit && _setdiag!(P, identity, D.diag)
TriWrapper = isupper ? UpperTriangular : LowerTriangular
return TriWrapper(P)
end

@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
Expand Down
4 changes: 2 additions & 2 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1197,11 +1197,11 @@ end
outTri = similar(TriA)
out = similar(A)
# 2 args
for fun in (*, rmul!, rdiv!, /)
@testset for fun in (*, rmul!, rdiv!, /)
@test fun(copy(TriA), D)::Tri == fun(Matrix(TriA), D)
@test fun(copy(UTriA), D)::Tri == fun(Matrix(UTriA), D)
end
for fun in (*, lmul!, ldiv!, \)
@testset for fun in (*, lmul!, ldiv!, \)
@test fun(D, copy(TriA))::Tri == fun(D, Matrix(TriA))
@test fun(D, copy(UTriA))::Tri == fun(D, Matrix(UTriA))
end
Expand Down