Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ julia = "1"
[extras]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

[targets]
test = ["InteractiveUtils", "Test"]
test = ["InteractiveUtils", "Test", "BenchmarkTools"]
1 change: 1 addition & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ include("broadcast.jl")
include("mapreduce.jl")
include("arraymath.jl")
include("linalg.jl")
include("matrix_multiply_add.jl")
include("matrix_multiply.jl")
include("det.jl")
include("inv.jl")
Expand Down
213 changes: 6 additions & 207 deletions src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B

@inline mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVector) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
@inline mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticMatrix) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::StaticMatrix) = mul!(dest, reshape(A, Size(Size(A)[1], 1)), B)
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
#@inline *{TA<:LinearAlgebra.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})



# Implementations

Expand Down Expand Up @@ -97,10 +89,14 @@ end

# Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
if sa[1]*sa[2]*sb[2] >= 14*14*14
Sa = TSize{size(S),false}()
Sb = TSize{sa,false}()
Sc = TSize{sb,false}()
_add = MulAddMul(true,false)
return quote
@_inline_meta
C = similar(a, T, $S)
mul_blas!($S, C, Sa, Sb, a, b)
mul_blas!($Sa, C, $Sa, $Sb, a, b, $_add)
return C
end
elseif sa[1]*sa[2]*sb[2] < 8*8*8
Expand Down Expand Up @@ -177,7 +173,7 @@ end
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster
tmp_type_in = :(SVector{$(sb[1]), T})
tmp_type_out = :(SVector{$(sa[1]), T})
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(Size(a), Size($(sb[1])), a,
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(TSize(a), TSize($(sb[1])), a,
$(Expr(:call, tmp_type_in, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))::$tmp_type_out)
for k2 = 1:sb[2]]

Expand All @@ -193,201 +189,4 @@ end
end
end

@generated function partly_unrolled_multiply(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}) where {sa, sb, Ta, Tb}
if sa[2] != sb[1]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
end

if sa[2] != 0
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
else
exprs = [:(zero(promote_op(matprod,Ta,Tb))) for k = 1:sa[1]]
end

return quote
$(Expr(:meta,:noinline))
@inbounds return SVector(tuple($(exprs...)))
end
end

# TODO aliasing problems if c === b?
@generated function _mul!(::Size{sc}, c::StaticVector, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticVector) where {sa, sb, sc}
if sb[1] != sa[2] || sc[1] != sa[1]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

if sa[2] != 0
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]]
else
exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sa[1]]
end

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
return c
end
end

@generated function _mul!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticVector,
b::Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}}) where {sa, sb, sc}
if sa[1] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

exprs = [:(c[$(LinearIndices(sc)[i, j])] = a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]]

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
return c
end
end

@generated function _mul!(Sc::Size{sc}, c::StaticMatrix{<:Any, <:Any, Tc}, Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, sc, Ta, Tb, Tc}
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat

if can_blas
if sa[1] * sa[2] * sb[2] < 4*4*4
return quote
@_inline_meta
mul_unrolled!(Sc, c, Sa, Sb, a, b)
return c
end
elseif sa[1] * sa[2] * sb[2] < 14*14*14 # Something seems broken for this one with large matrices (becomes allocating)
return quote
@_inline_meta
mul_unrolled_chunks!(Sc, c, Sa, Sb, a, b)
return c
end
else
return quote
@_inline_meta
mul_blas!(Sc, c, Sa, Sb, a, b)
return c
end
end
else
if sa[1] * sa[2] * sb[2] < 4*4*4
return quote
@_inline_meta
mul_unrolled!(Sc, c, Sa, Sb, a, b)
return c
end
else
return quote
@_inline_meta
mul_unrolled_chunks!(Sc, c, Sa, Sb, a, b)
return c
end
end
end
end


@generated function mul_blas!(::Size{s}, c::StaticMatrix{<:Any, <:Any, T}, ::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, T}, b::StaticMatrix{<:Any, <:Any, T}) where {s,sa,sb, T <: BlasFloat}
if sb[1] != sa[2] || sa[1] != s[1] || sb[2] != s[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $s"))
end

if sa[1] > 0 && sa[2] > 0 && sb[2] > 0
# This code adapted from `gemm!()` in base/linalg/blas.jl

if T == Float64
gemm = :dgemm_
elseif T == Float32
gemm = :sgemm_
elseif T == Complex{Float64}
gemm = :zgemm_
else # T == Complex{Float32}
gemm = :cgemm_
end

blascall = quote
ccall((LinearAlgebra.BLAS.@blasfunc($gemm), LinearAlgebra.BLAS.libblas), Nothing,
(Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt},
Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$T}, Ptr{$T}, Ref{LinearAlgebra.BLAS.BlasInt},
Ptr{$T}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$T}, Ptr{$T},
Ref{LinearAlgebra.BLAS.BlasInt}),
transA, transB, m, n,
ka, alpha, a, strideA,
b, strideB, beta, c,
strideC)
end

return quote
alpha = one(T)
beta = zero(T)
transA = 'N'
transB = 'N'
m = $(sa[1])
ka = $(sa[2])
kb = $(sb[1])
n = $(sb[2])
strideA = $(sa[1])
strideB = $(sb[1])
strideC = $(s[1])

$blascall

return c
end
else
throw(DimensionMismatch("Cannot call BLAS gemm with zero-dimension arrays, attempted $sa * $sb -> $s."))
end
end


@generated function mul_unrolled!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticMatrix) where {sa, sb, sc}
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

if sa[2] != 0
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]))) for k1 = 1:sa[1], k2 = 1:sb[2]]
else
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = zero(eltype(c))) for k1 = 1:sa[1], k2 = 1:sb[2]]
end

return quote
@_inline_meta
@inbounds $(Expr(:block, exprs...))
end
end

@generated function mul_unrolled_chunks!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticMatrix) where {sa, sb, sc}
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
end

#vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]

# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
tmp_type = SVector{sb[1], eltype(c)}
vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size(sa)), $(Size(sb[1])), a, $(Expr(:call, tmp_type, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))) for k2 = 1:sb[2]]

exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]

return quote
@_inline_meta
@inbounds $(Expr(:block, vect_exprs...))
@inbounds $(Expr(:block, exprs...))
end
end

#function mul_blas(a, b, c, A, B)
#q
#end

# The idea here is to get pointers to stack variables and call BLAS.
# This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
# and using BLAS should be fastest for (very) large SArrays

# Here is an LLVM function that gets the pointer to its input, %x
# After this we would make the ccall above.
#
# define i8* @f(i32 %x) #0 {
# %1 = alloca i32, align 4
# store i32 %x, i32* %1, align 4
# ret i32* %1
# }
Loading