Skip to content

Commit 168dce1

Browse files
committed
cases OneHotMatrix * AbstractMatrix
1 parent f563f48 commit 168dce1

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/linalg.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike)
33
size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))"))
44
return A[:, onecold(B)]
55
end
6-
6+
77
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
88
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
99
size(A, 2) == size(B, 1) || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(size(B, 1))"))
1010
return NNlib.gather(A, _indices(B))
1111
end
1212

13+
function Base.:(*)(A::OneHotMatrix, B::AbstractMatrix{<:Number})
14+
size(A, 2) == size(B, 1) || throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) # probably caught later anyway
15+
T = promote_type(eltype(B), Int)
16+
Y = similar(B, T, size(A, 1), size(B, 2))
17+
LinearAlgebra.mul!(transpose(Y), transpose(B), transpose(A)) # uses matrix * wrapper(OneHotMatrix) method below
18+
return Y
19+
end
20+
1321
for wrapper in [:Adjoint, :Transpose]
1422
@eval begin
23+
# Adjoint * OneHotVector cases
1524
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
1625
size(A, 2) == length(b) ||
1726
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $(length(b))"))
@@ -26,17 +35,25 @@ for wrapper in [:Adjoint, :Transpose]
2635
return A[onecold(b)]
2736
end
2837

29-
# note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
38+
# Matrix * Adjoint(OneHotVector)
3039
function LinearAlgebra.mul!(Y::AbstractMatrix, A::AbstractMatrix, B::$wrapper{Bool,<:OneHotMatrix})
3140
if size(A,2) size(B,1)
3241
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))"))
3342
end
3443
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
3544
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
3645
end
46+
# note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
3747
fill!(Y, zero(eltype(Y)))
3848
return NNlib.scatter!(+, Y, A, _indices(parent(B)))
3949
end
50+
51+
# Adjoint(OneHotVector) * Matrix
52+
function LinearAlgebra.mul!(Y::AbstractVecOrMat, A::$wrapper{Bool,<:OneHotMatrix}, B::AbstractMatrix)
53+
op = LinearAlgebra.wrapperop(A)
54+
LinearAlgebra.mul!(op(Y), op(B), A.parent) # uses OneHotMatrix method below
55+
return Y
56+
end
4057
end
4158
end
4259

@@ -56,3 +73,7 @@ function LinearAlgebra.mul!(Y::AbstractVecOrMat, A::AbstractMatrix, B::OneHotLik
5673
end
5774
end
5875

76+
function LinearAlgebra.mul!(Y::AbstractVecOrMat{<:Real}, A::OneHotMatrix, B::AbstractVecOrMat)
77+
LinearAlgebra.mul!(transpose(Y), transpose(B), transpose(A)) # uses matrix * wrapper(OneHotMatrix) method above
78+
return Y
79+
end

0 commit comments

Comments
 (0)