@@ -109,25 +109,32 @@ _onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:AbstractGPUArray}) wher
109109_notall_onehot (x:: OneHotArray , xs:: OneHotArray... ) = false
110110_notall_onehot (x:: OneHotLike , xs:: OneHotLike... ) = any (x -> ! _isonehot (x), (x, xs... ))
111111
112- function Base. cat ( x:: OneHotLike{<:Any, <:Any, N} , xs:: OneHotLike... ; dims :: Int ) where N
112+ function Base. _cat (dims :: Int , x:: OneHotLike{<:Any, <:Any, N} , xs:: OneHotLike... ) where N
113113 if isone (dims) || _notall_onehot (x, xs... )
114- return cat (map (x -> convert (_onehot_bool_type (x), x), (x, xs... ))... ; dims = dims)
114+ # return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
115+ return invoke (Base. _cat, Tuple{Int, Vararg{AbstractArray{Bool}}}, dims, x, xs... )
115116 else
116117 L = _nlabels (x, xs... )
117-
118118 return OneHotArray (cat (_indices (x), _indices .(xs)... ; dims = dims - 1 ), L)
119119 end
120120end
121+ function Base. _cat (:: Val{dims} , x:: OneHotLike{<:Any, <:Any, N} , xs:: OneHotLike... ) where {N,dims}
122+ if ! (dims isa Integer) || isone (dims) || _notall_onehot (x, xs... )
123+ # return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
124+ return invoke (Base. _cat, Tuple{Val{dims}, Vararg{AbstractArray{Bool}}}, Val (dims), x, xs... )
125+ else
126+ L = _nlabels (x, xs... )
127+ return OneHotArray (cat (_indices (x), _indices .(xs)... ; dims = Val (dims - 1 )), L)
128+ end
129+ end
121130
122- Base. hcat (x:: OneHotLike , xs:: OneHotLike... ) = cat (x, xs... ; dims = 2 )
123- Base. vcat (x:: OneHotLike , xs:: OneHotLike... ) =
124- vcat (map (x -> convert (_onehot_bool_type (x), x), (x, xs... ))... )
131+ Base. hcat (x:: OneHotLike , xs:: OneHotLike... ) = cat (x, xs... ; dims = Val (2 ))
125132
126133# optimized concatenation for matrices and vectors of same parameters
127134Base. hcat (x:: OneHotMatrix , xs:: OneHotMatrix... ) =
128- OneHotMatrix (reduce ( vcat, _indices .(xs); init = _indices (x) ), _nlabels (x, xs... ))
135+ OneHotMatrix (vcat ( _indices (x) , _indices .(xs)... ), _nlabels (x, xs... ))
129136Base. hcat (x:: OneHotVector , xs:: OneHotVector... ) =
130- OneHotMatrix (reduce (vcat , _indices .(xs); init = _indices (x)) , _nlabels (x, xs... ))
137+ OneHotMatrix (UInt32[ _indices (x) , _indices .(xs)... ] , _nlabels (x, xs... ))
131138
132139if isdefined (Base, :stack )
133140 import Base: _stack
@@ -140,6 +147,13 @@ function _stack(::Colon, xs::AbstractArray{<:OneHotArray})
140147 OneHotArray (Compat. stack (_indices, xs), n)
141148end
142149
150+ Base. reduce (:: typeof (hcat), xs:: AbstractVector{<:OneHotArray{<:Any, 0, 1}} ) = Compat. stack (xs)
151+ function Base. reduce (:: typeof (hcat), xs:: AbstractVector{<:OneHotMatrix} )
152+ n = _nlabels (first (xs))
153+ all (x -> _nlabels (x)== n, xs) || throw (DimensionMismatch (" The number of labels are not the same for all one-hot arrays." ))
154+ OneHotArray (reduce (vcat, _indices .(xs)), n)
155+ end
156+
143157Adapt. adapt_structure (T, x:: OneHotArray ) = OneHotArray (adapt (T, _indices (x)), x. nlabels)
144158
145159function Base. BroadcastStyle (:: Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}} ) where {var"N+1" , T <: AbstractGPUArray }
0 commit comments