Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ include("mapreduce.jl")
include("arraymath.jl")
include("linalg.jl")
include("matrix_multiply.jl")
include("matrix_multiply_add.jl")
include("det.jl")
include("inv.jl")
include("solve.jl")
Expand Down
31 changes: 27 additions & 4 deletions src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ import LinearAlgebra: BlasFloat, matprod, mul!
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
#@inline *{TA<:LinearAlgebra.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
@inline mul!(dest::StaticMatrix, A::Transpose{<:Any,<:StaticMatrix}, B::StaticMatrix) =
_mul!(Size(dest), dest, Size(A), Size(B), A, B, true, false)
@inline mul!(dest::StaticVector, A::Transpose{<:Any,<:StaticMatrix}, B::StaticVector) =
_tmul!(Size(dest), dest, Size(A.parent), Size(B), A.parent, B, true, false)



Expand Down Expand Up @@ -210,6 +214,7 @@ end
end
end


# TODO aliasing problems if c === b?
@generated function _mul!(::Size{sc}, c::StaticVector, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticVector) where {sa, sb, sc}
if sb[1] != sa[2] || sc[1] != sa[1]
Expand Down Expand Up @@ -284,8 +289,28 @@ end
end
end


@generated function mul_blas!(::Size{s}, c::StaticMatrix{<:Any, <:Any, T}, ::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, T}, b::StaticMatrix{<:Any, <:Any, T}) where {s,sa,sb, T <: BlasFloat}
@inline mul_blas!(Sc::Size{s}, c::SizedArray{<:Any,T},
Sa::Size{sa}, Sb::Size{sb},
a::SizedArray{<:Any,T}, b::SizedArray{<:Any,T},
α::Real=one(T), β::Real=zero(T)) where {s,sa,sb, T <: BlasFloat} =
BLAS.gemm!('N','N', α, a.data, b.data, β , c.data)

@inline mul_blas!(Sc::Size{s}, c::SizedArray{<:Any,T},
Sa::Size{sa}, Sb::Size{sb},
a::Transpose{<:Any, <:SizedArray}, b::SizedArray{<:Any,T},
α::Real=one(T), β::Real=zero(T)) where {s,sa,sb, T <: BlasFloat} =
BLAS.gemm!('T','N', α, a.parent.data, b.data, β , c.data)

@inline mul_blas!(Sc::Size{s}, c::StaticMatrix{<:Any, <:Any, T},
Sa::Size{sa}, Sb::Size{sb},
a::Transpose{<:Any, <:StaticMatrix}, b::StaticMatrix{<:Any, <:Any, T},
α::Real=one(T), β::Real=zero(T)) where {s,sa,sb, T <: BlasFloat} =
BLAS.gemm!('T','N', α, a.parent, b, β , c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rather a lot of duplication which will be hard to maintain.

I think all you need is a utility function to "get the data pointer and transposed-ness" out of a, b, c? Then you can collapse all these down again into one function.


@generated function mul_blas!(::Size{s}, c::StaticMatrix{<:Any, <:Any, T},
::Size{sa}, ::Size{sb},
a::StaticMatrix{<:Any, <:Any, T}, b::StaticMatrix{<:Any, <:Any, T},
alpha::Real=one(T), beta::Real=zero(T)) where {s,sa,sb, T <: BlasFloat}
if sb[1] != sa[2] || sa[1] != s[1] || sb[2] != s[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $s"))
end
Expand Down Expand Up @@ -316,8 +341,6 @@ end
end

return quote
alpha = one(T)
beta = zero(T)
transA = 'N'
transB = 'N'
m = $(sa[1])
Expand Down
263 changes: 263 additions & 0 deletions src/matrix_multiply_add.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@

# 5-argument matrix multiplication
@inline LinearAlgebra.mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVecOrMat, α::Real, β::Real) =
_mul!(Size(dest), dest, Size(A), Size(B), A, B, α, β)
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Transpose{<:Any, <:StaticVector}, α::Real, β::Real) =
_mul!(Size(dest), dest, Size(A), Size(B), A, B, α, β)
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Adjoint{<:Any, <:StaticVector}, α::Real, β::Real) =
_mul!(Size(dest), dest, Size(A), Size(B), A, B, α, β)

@inline mul!(dest::StaticVector, A::Transpose{<:Any, <:StaticMatrix}, B::StaticVector, α::Real, β::Real) =
_tmul!(Size(dest), dest, Size(A.parent), Size(B), A.parent, B, α, β)
@inline mul!(dest::StaticMatrix, A::Transpose{<:Any, <:StaticMatrix}, B::StaticMatrix, α::Real, β::Real) =
_mul!(Size(dest), dest, Size(A), Size(B), A, B, α, β)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to arrange this dispatch a bit more cleanly rather than having quite so much code duplication here?

For example, you can define a type alias similar to StaticMatrixLike which captures both StaticMatrix, Transpose (and possibly Adjoint).

Then a single overload of mul! may be sufficient.


@inline multiplied_dimension(::Type{A}, ::Type{B}) where {
A<:Union{StaticMatrix,Transpose{<:Any,<:StaticMatrix}}, B<:StaticMatrix} =
prod(size(A)) * size(B,2)

@inline multiplied_dimension(::Type{A}, ::Type{B}) where {
A<:Union{StaticMatrix,Transpose{<:Any,<:StaticMatrix}},
B<:Transpose{<:Any, <:StaticMatrix}} =
prod(size(A)) * size(B,2)

# Matrix-matrix multiplication
@generated function _mul!(Sc::Size{sc}, c::AbstractMatrix,
Sa::Size{sa}, Sb::Size{sb},
a::AbstractMatrix, b::AbstractMatrix,
α::Real, β::Real) where {sa, sb, sc}
Ta,Tb,Tc = eltype(a), eltype(b), eltype(c)
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat

mult_dim = multiplied_dimension(a,b)
if sa[1] * sa[2] * sb[2] < 4*4*4
return quote
@_inline_meta
muladd_unrolled!(Sc, c, Sa, Sb, a, b, α, β)
return c
end
elseif sa[1] * sa[2] * sb[2] < 14*14*14 # Something seems broken for this one with large matrices (becomes allocating)
return quote
@_inline_meta
muladd_unrolled_chunks!(Sc, c, Sa, Sb, a, b, α, β)
return c
end
else
if can_blas
return quote
@_inline_meta
mul_blas!(Sc, c, Sa, Sb, a, b, α, β)
return c
end
else
return quote
@_inline_meta
muladd_unrolled_chunks!(Sc, c, Sa, Sb, a, b, α, β)
return c
end
end
end
end

# Matrix-vector multiplication
@generated function _mul!(::Size{sc}, c::StaticVecOrMat, ::Size{sa}, ::Size{sb},
a::StaticMatrix, b::StaticVector, α::Real, β::Real, ::Val{col}=Val(1)) where {sa, sb, sc,col}
if sb[1] != sa[2] || sc[1] != sa[1]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

if sa[2] != 0
exprs = [:(c[$(LinearIndices(sc)[k,col])] = β * c[$(LinearIndices(sc)[k,col])]
+ α * $(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
[:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]]
else
exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sa[1]]
end

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
return c
end
end

@generated function _tmul!(::Size{sc}, c::StaticVecOrMat, ::Size{sa}, ::Size{sb},
a::StaticMatrix, b::StaticVector, α::Real, β::Real, ::Val{col}=Val(1)) where {sa, sb, sc,col}
if sb[1] != sa[1] || sc[1] != sa[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

if sa[2] != 0
exprs = [:(c[$(LinearIndices(sc)[k,col])] = β * c[$(LinearIndices(sc)[k,col])]
+ α * $(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
[:(a[$(LinearIndices(sa)[j, k])]*b[$j]) for j = 1:sa[1]]))) for k = 1:sa[2]]
else
exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sa[1]]
end

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
return c
end
end

# Outer product
@generated function _mul!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticVector,
b::Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}},
α::Real, β::Real) where {sa, sb, sc}
if sa[1] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

exprs = [:(c[$(LinearIndices(sc)[i, j])] = β * c[$(LinearIndices(sc)[i, j])] + α *
a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]]

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
return c
end
end


@inline muladd_unrolled!(Sc::Size{sc}, c::StaticMatrix, Sa::Size{sa}, Sb::Size{sb},
a::StaticMatrix, b::StaticMatrix, α::Real, β::Real) where {sa, sb, sc} =
_muladd_unrolled!(Sc, c, Sa, Sb, a, b, α, β)

@inline muladd_unrolled!(Sc::Size{sc}, c::StaticMatrix, Sa::Size{sa}, Sb::Size{sb},
a::Transpose{<:Any, <:StaticMatrix}, b::StaticMatrix, α::Real, β::Real) where {sa, sb, sc} =
_tmuladd_unrolled!(Sc, c, Size(a.parent), Sb, a.parent, b, α, β)

@generated function _muladd_unrolled!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb},
a::StaticMatrix, b::StaticMatrix, α::Real, β::Real) where {sa, sb, sc}
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

if sa[2] != 0
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = β*c[$(LinearIndices(sc)[k1, k2])] + α *
$(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
[:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]
))) for k1 = 1:sa[1], k2 = 1:sb[2]]
else
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = zero(eltype(c))) for k1 = 1:sa[1], k2 = 1:sb[2]]
end

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
end
end

@generated function _tmuladd_unrolled!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb},
a::StaticMatrix, b::StaticMatrix, α::Real, β::Real) where {sa, sb, sc}
if sb[1] != sa[1] || sa[2] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $(reverse(sa)) and $sb and assign to array of size $sc"))
end

if sa[2] != 0
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = β*c[$(LinearIndices(sc)[k1, k2])] + α *
$(reduce((ex1,ex2) -> :(+($ex1,$ex2)),
[:(a[$(LinearIndices(sa)[j, k1])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[1]]
))) for k1 = 1:sa[2], k2 = 1:sb[2]]
else
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = zero(eltype(c))) for k1 = 1:sa[1], k2 = 1:sb[2]]
end

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
end
end


@inline muladd_unrolled_chunks!(Sc::Size{sc}, c::StaticMatrix, Sa::Size{sa}, Sb::Size{sb},
a::StaticMatrix, b::StaticMatrix, α::Real, β::Real) where {sa, sb, sc} =
_muladd_unrolled_chunks!(Sc, c, Sa, Sb, a, b, α, β)

@inline muladd_unrolled_chunks!(Sc::Size{sc}, c::StaticMatrix, Sa::Size{sa}, Sb::Size{sb},
a::Transpose{<:Any, <:StaticMatrix}, b::StaticMatrix, α::Real, β::Real) where {sa, sb, sc} =
_tmuladd_unrolled_chunks!(Sc, c, Size(a.parent), Sb, a.parent, b, α, β)

@generated function _muladd_unrolled_chunks!(::Size{sc}, c::StaticMatrix,
::Size{sa}, ::Size{sb},
a::Union{StaticMatrix, Transpose{<:Any, <:StaticMatrix}}, b::StaticMatrix,
α::Real, β::Real) where {sa, sb, sc}
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

#vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]

# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
tmp_type = SVector{sb[1], eltype(c)}

col_mult = [:($(Symbol("tmp_$k2")) =
_mul!($(Size(sc)), c, $(Size(sa)), $(Size(sb[1])), a,
$(Expr(:call, tmp_type,
[Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)),α,β,Val($k2))) for k2 = 1:sb[2]]

return quote
@_inline_meta
return $(Expr(:block, col_mult...))
end
end

@generated function _tmuladd_unrolled_chunks!(::Size{sc}, c::StaticMatrix,
::Size{sa}, ::Size{sb},
a::Union{StaticMatrix, Transpose{<:Any, <:StaticMatrix}}, b::StaticMatrix,
α::Real, β::Real) where {sa, sb, sc}
if sb[1] != sa[1] || sa[2] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $(reverse(sa)) and $sb and assign to array of size $sc"))
end

#vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]

# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
tmp_type = SVector{sb[1], eltype(c)}

col_mult = [:($(Symbol("tmp_$k2")) =
_tmul!($(Size(sc)), c, $(Size(sa)), $(Size(sb[1])), a,
$(Expr(:call, tmp_type,
[Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)),α,β,Val($k2))) for k2 = 1:sb[2]]

return quote
@_inline_meta
return $(Expr(:block, col_mult...))
end
end

# Special-case SizedMatrix
# @inline mul!(dest::SizedMatrix{<:Any, <:Any, Tc}, A::SizedMatrix{<:Any, <:Any, Ta},
# B::SizedMatrix{<:Any, <:Any, Tb}) where {Ta,Tb,Tc} =
# _mul!(Size(dest), dest, Size(A), Size(B), A, B, one(Ta), zero(Ta))
#
# @generated function _mul!(Sc::Size{sc}, c::SizedMatrix{<:Any, <:Any, Tc},
# Sa::Size{sa}, Sb::Size{sb},
# a::SizedMatrix{<:Any, <:Any, Ta}, b::SizedMatrix{<:Any, <:Any, Tb},
# α::Real, β::Real) where {sa, sb, sc, Ta, Tb, Tc}
# can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
#
# if can_blas
# if sa[1] * sa[2] * sb[2] < 4*4*4
# return quote
# @_inline_meta
# muladd_unrolled!(Sc, c, Sa, Sb, a, b, α, β)
# return c
# end
# elseif sa[1] * sa[2] * sb[2] < 14*14*14 # Something seems broken for this one with large matrices (becomes allocating)
# return quote
# @_inline_meta
# muladd_unrolled_chunks!(Sc, c, Sa, Sb, a, b, α, β)
# return c
# end
# else
# return quote
# # @_inline_meta
# BLAS.gemm!('N','N', α, a.data, b.data, β, c.data)
# return c
# end
# end
# end
# end
Loading