Skip to content

Commit ebed6ab

Browse files
authored
Some cat improvements (#52)
* improve cat functions * handle reduce hcat fast cases
1 parent 0b49d1b commit ebed6ab

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

src/array.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,32 @@ _onehot_bool_type(::OneHotLike{<:Any, <:Any, var"N+1", <:AbstractGPUArray}) wher
109109
_notall_onehot(x::OneHotArray, xs::OneHotArray...) = false
110110
_notall_onehot(x::OneHotLike, xs::OneHotLike...) = any(x -> !_isonehot(x), (x, xs...))
111111

112-
function Base.cat(x::OneHotLike{<:Any, <:Any, N}, xs::OneHotLike...; dims::Int) where N
112+
function Base._cat(dims::Int, x::OneHotLike{<:Any, <:Any, N}, xs::OneHotLike...) where N
113113
if isone(dims) || _notall_onehot(x, xs...)
114-
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
114+
# return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
115+
return invoke(Base._cat, Tuple{Int, Vararg{AbstractArray{Bool}}}, dims, x, xs...)
115116
else
116117
L = _nlabels(x, xs...)
117-
118118
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
119119
end
120120
end
121+
function Base._cat(::Val{dims}, x::OneHotLike{<:Any, <:Any, N}, xs::OneHotLike...) where {N,dims}
122+
if !(dims isa Integer) || isone(dims) || _notall_onehot(x, xs...)
123+
# return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
124+
return invoke(Base._cat, Tuple{Val{dims}, Vararg{AbstractArray{Bool}}}, Val(dims), x, xs...)
125+
else
126+
L = _nlabels(x, xs...)
127+
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = Val(dims - 1)), L)
128+
end
129+
end
121130

122-
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
123-
Base.vcat(x::OneHotLike, xs::OneHotLike...) =
124-
vcat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...)
131+
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = Val(2))
125132

126133
# optimized concatenation for matrices and vectors of same parameters
127134
Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) =
128-
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))
135+
OneHotMatrix(vcat(_indices(x), _indices.(xs)...), _nlabels(x, xs...))
129136
Base.hcat(x::OneHotVector, xs::OneHotVector...) =
130-
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))
137+
OneHotMatrix(UInt32[_indices(x), _indices.(xs)...], _nlabels(x, xs...))
131138

132139
if isdefined(Base, :stack)
133140
import Base: _stack
@@ -140,6 +147,13 @@ function _stack(::Colon, xs::AbstractArray{<:OneHotArray})
140147
OneHotArray(Compat.stack(_indices, xs), n)
141148
end
142149

150+
Base.reduce(::typeof(hcat), xs::AbstractVector{<:OneHotArray{<:Any, 0, 1}}) = Compat.stack(xs)
151+
function Base.reduce(::typeof(hcat), xs::AbstractVector{<:OneHotMatrix})
152+
n = _nlabels(first(xs))
153+
all(x -> _nlabels(x)==n, xs) || throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays."))
154+
OneHotArray(reduce(vcat, _indices.(xs)), n)
155+
end
156+
143157
Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)
144158

145159
function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}

test/array.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,22 @@ end
5252
@test hcat(ov, ov) isa OneHotMatrix
5353
@test vcat(ov, ov) == vcat(collect(ov), collect(ov))
5454
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)
55+
@test cat(ov, ov; dims = 3) isa OneHotArray
56+
@test cat(ov, ov; dims = Val(3)) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)
57+
@test cat(ov, ov; dims = Val(3)) isa OneHotArray
58+
@test cat(ov, ov; dims = (1, 2)) == cat(collect(ov), collect(ov); dims = (1, 2))
5559

5660
# matrix cat
5761
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
5862
@test hcat(om, om) isa OneHotMatrix
59-
@test vcat(om, om) == vcat(collect(om), collect(om))
63+
@test vcat(om, om) == vcat(collect(om), collect(om)) # not one-hot!
64+
@test cat(om, om; dims = 1) == vcat(collect(om), collect(om))
65+
@test cat(om, om; dims = Val(1)) == vcat(collect(om), collect(om))
6066
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)
67+
@test cat(om, om; dims = 3) isa OneHotArray
68+
@test cat(om, om; dims = Val(3)) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)
69+
@test cat(om, om; dims = Val(3)) isa OneHotArray
70+
@test cat(om, om; dims = (1, 2)) == cat(collect(om), collect(om); dims = (1, 2))
6171

6272
# array cat
6373
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
@@ -71,9 +81,19 @@ end
7181
@test stack([om, om]) isa OneHotArray
7282
@test stack([oa, oa, oa, oa]) isa OneHotArray
7383

84+
# reduce(hcat)
85+
@test reduce(hcat, [ov, ov]) == hcat(ov, ov)
86+
@test reduce(hcat, [ov, ov]) isa OneHotMatrix
87+
@test reduce(hcat, [onehotbatch(1, 1:3), onehotbatch(1, 1:3)]) == [1 1; 0 0; 0 0]
88+
@test reduce(hcat, [onehotbatch(1, 1:3), onehotbatch(1, 1:3)]) isa OneHotMatrix
89+
@test reduce(hcat, [om, om]) == hcat(om, om)
90+
@test reduce(hcat, [om, om]) isa OneHotMatrix
91+
7492
# proper error handling of inconsistent sizes
7593
@test_throws DimensionMismatch hcat(ov, ov2)
7694
@test_throws DimensionMismatch hcat(om, om2)
95+
@test_throws DimensionMismatch stack([om, om2])
96+
@test_throws DimensionMismatch reduce(hcat, [om, om2])
7797
end
7898

7999
@testset "Base.reshape" begin
@@ -88,6 +108,8 @@ end
88108
@testset "w/ cat" begin
89109
r = reshape(oa, 10, :)
90110
@test hcat(r, r) isa OneHotArray
111+
@test cat(r, r; dims = 2) isa OneHotArray
112+
@test cat(r, r; dims = Val(2)) isa OneHotArray
91113
@test vcat(r, r) isa Array{Bool}
92114
end
93115

0 commit comments

Comments
 (0)