@@ -3,15 +3,24 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike)
33 size (A, 2 ) == size (B, 1 ) || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $(size (B, 1 )) " ))
44 return A[:, onecold (B)]
55end
6-
6+
77function Base.:(* )(A:: AbstractMatrix , B:: OneHotLike{<:Any, 1} )
88 _isonehot (B) || return invoke (* , Tuple{AbstractMatrix, AbstractMatrix}, A, B)
99 size (A, 2 ) == size (B, 1 ) || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $(size (B, 1 )) " ))
1010 return NNlib. gather (A, _indices (B))
1111end
1212
13+ function Base.:(* )(A:: OneHotMatrix , B:: AbstractMatrix{<:Number} )
14+ size (A, 2 ) == size (B, 1 ) || throw (DimensionMismatch (" A has dimensions $(size (A)) but B has dimensions $(size (B)) " )) # probably caught later anyway
15+ T = promote_type (eltype (B), Int)
16+ Y = similar (B, T, size (A, 1 ), size (B, 2 ))
17+ LinearAlgebra. mul! (transpose (Y), transpose (B), transpose (A)) # uses matrix * wrapper(OneHotMatrix) method below
18+ return Y
19+ end
20+
1321for wrapper in [:Adjoint , :Transpose ]
1422 @eval begin
23+ # Adjoint * OneHotVector cases
1524 function Base.:* (A:: $wrapper{<:Any, <:AbstractMatrix{T}} , b:: OneHotVector ) where T
1625 size (A, 2 ) == length (b) ||
1726 throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $(length (b)) " ))
@@ -26,17 +35,25 @@ for wrapper in [:Adjoint, :Transpose]
2635 return A[onecold (b)]
2736 end
2837
29- # note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
38+ # Matrix * Adjoint(OneHotVector)
3039 function LinearAlgebra. mul! (Y:: AbstractMatrix , A:: AbstractMatrix , B:: $wrapper{Bool,<:OneHotMatrix} )
3140 if size (A,2 ) ≠ size (B,1 )
3241 throw (DimensionMismatch (" Matrix column must correspond with the OneHot Size $(size (A,2 )) ≠ $(size (B,1 )) " ))
3342 end
3443 if ! (size (Y,1 ) == size (A,1 ) && size (Y,2 ) == size (B,2 ))
3544 throw (DimensionMismatch (" Invalid output matrix size for multiplication of matrix sizes $(size (A)) and $(size (B)) " ))
3645 end
46+ # note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
3747 fill! (Y, zero (eltype (Y)))
3848 return NNlib. scatter! (+ , Y, A, _indices (parent (B)))
3949 end
50+
51+ # Adjoint(OneHotVector) * Matrix
52+ function LinearAlgebra. mul! (Y:: AbstractVecOrMat , A:: $wrapper{Bool,<:OneHotMatrix} , B:: AbstractMatrix )
53+ op = LinearAlgebra. wrapperop (A)
54+ LinearAlgebra. mul! (op (Y), op (B), A. parent) # uses OneHotMatrix method below
55+ return Y
56+ end
4057 end
4158end
4259
@@ -56,3 +73,7 @@ function LinearAlgebra.mul!(Y::AbstractVecOrMat, A::AbstractMatrix, B::OneHotLik
5673 end
5774end
5875
76+ function LinearAlgebra. mul! (Y:: AbstractVecOrMat{<:Real} , A:: OneHotMatrix , B:: AbstractVecOrMat )
77+ LinearAlgebra. mul! (transpose (Y), transpose (B), transpose (A)) # uses matrix * wrapper(OneHotMatrix) method above
78+ return Y
79+ end
0 commit comments