diff --git a/src/adjtrans.jl b/src/adjtrans.jl index f7caa9bc..d81aa3ae 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -319,8 +319,8 @@ const AdjointAbsVec{T} = Adjoint{T,<:AbstractVector} const AdjointAbsMat{T} = Adjoint{T,<:AbstractMatrix} const TransposeAbsVec{T} = Transpose{T,<:AbstractVector} const TransposeAbsMat{T} = Transpose{T,<:AbstractMatrix} -const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector} -const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix} +const AdjOrTransAbsVec{T,V<:AbstractVector} = AdjOrTrans{T,V} +const AdjOrTransAbsMat{T,M<:AbstractMatrix} = AdjOrTrans{T,M} # for internal use below wrapperop(_) = identity diff --git a/src/diagonal.jl b/src/diagonal.jl index 27ff3b1f..5618c313 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -330,6 +330,28 @@ function (*)(D::Diagonal, V::AbstractVector) return D.diag .* V end +function _diag_adj_mul(A::AdjOrTransAbsMat, D::Diagonal) + adj = wrapperop(A) + copy(adj(adj(D) * adj(A))) +end +function _diag_adj_mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number}) + @invoke *(A::AbstractMatrix, D::AbstractMatrix) +end +function _diag_adj_mul(D::Diagonal, A::AdjOrTransAbsMat) + adj = wrapperop(A) + copy(adj(adj(A) * adj(D))) +end +function _diag_adj_mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}) + @invoke *(D::AbstractMatrix, A::AbstractMatrix) +end + +function (*)(A::AdjOrTransAbsMat, D::Diagonal) + _diag_adj_mul(A, D) +end +function (*)(D::Diagonal, A::AdjOrTransAbsMat) + _diag_adj_mul(D, A) +end + function rmul!(A::AbstractMatrix, D::Diagonal) matmul_size_check(size(A), size(D)) for I in CartesianIndices(A)