-
-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
onecold
doesn't work with Reactant arrays, and errors out. Julia version 1.11.4.
In a fresh environment:
julia> using Pkg
julia> Pkg.add(["MLDataDevices", "Reactant", "OneHotArrays"])
julia> using Reactant, MLDataDevices, OneHotArrays
julia> onecold([true false false; false true true])
3-element Vector{Int64}:
1
2
2
julia> const dev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)
julia> onecold([true false false; false true true]|>dev)
ERROR: MethodError: no method matching vec(::Tuple{Int64})
The function `vec` exists, but no method is defined for this combination of argument types.
Closest candidates are:
vec(::StaticArraysCore.SizedArray{S, T, N, M} where {T, N, M}) where S
@ StaticArrays ~/.julia/packages/StaticArrays/LSPcF/src/SizedArray.jl:171
vec(::SparseArrays.AbstractSparseVector)
@ SparseArrays ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:1128
vec(::LinearAlgebra.Adjoint{<:Real, <:AbstractVector})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:374
...
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Reactant/AebXg/src/utils.jl:0 [inlined]
[2] call_with_reactant(::typeof(vec), ::Tuple{Int64})
@ Reactant ~/.julia/packages/Reactant/AebXg/src/utils.jl:790
[3] getindex
@ ~/.julia/packages/Reactant/AebXg/src/TracedRArray.jl:176 [inlined]
[4] getindex(none::Reactant.TracedRArray{Bool, 2}, none::Tuple{Int64, CartesianIndex{1}})
@ Reactant ./<missing>:0
[5] getindex
@ ~/.julia/packages/Reactant/AebXg/src/TracedRArray.jl:167 [inlined]
[6] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Bool, 2}, ::Int64, ::CartesianIndex{1})
@ Reactant ~/.julia/packages/Reactant/AebXg/src/utils.jl:0
[7] make_mlir_fn(f::typeof(getindex), args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/AebXg/src/TracedUtils.jl:303
[8] make_mlir_fn
@ ~/.julia/packages/Reactant/AebXg/src/TracedUtils.jl:178 [inlined]
[9] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(getindex), args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::IdDict{Reactant.Sharding.Mesh, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation}}; optimize::Bool, shardy_passes::Symbol, no_nan::Bool, backend::String, fn_kwargs::Tuple{}, raise::Bool, input_shardings::Nothing, output_shardings::Nothing, do_transpose::Bool, runtime::Val{:PJRT})
@ Reactant.Compiler ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:840
[10] compile_mlir!
@ ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:784 [inlined]
[11] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}; client::Nothing, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:1992
[12] compile_xla
@ ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:1974 [inlined]
[13] compile(f::Function, args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}; sync::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:2040
[14] compile
@ ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:2039 [inlined]
[15] getindex(::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ::Int64, ::CartesianIndex{1})
@ Reactant ~/.julia/packages/Reactant/AebXg/src/ConcreteRArray.jl:271
[16] findminmax!(f::typeof(identity), op::typeof(isless), Rval::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Rind::Matrix{CartesianIndex{2}}, A::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}})
@ Base ./reducedim.jl:1039
[17] _findmax(f::typeof(identity), A::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, region::Int64)
@ Base ./reducedim.jl:1209
[18] _findmax
@ ./reducedim.jl:1176 [inlined]
[19] findmax
@ ./reducedim.jl:1175 [inlined]
[20] argmax
@ ./reducedim.jl:1274 [inlined]
[21] _fast_argmax
@ ~/.julia/packages/OneHotArrays/rXTnu/src/onehot.jl:167 [inlined]
[22] onecold(y::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, labels::UnitRange{Int64})
@ OneHotArrays ~/.julia/packages/OneHotArrays/rXTnu/src/onehot.jl:161
[23] onecold(y::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}})
@ OneHotArrays ~/.julia/packages/OneHotArrays/rXTnu/src/onehot.jl:158
[24] top-level scope
@ REPL[9]:1
This might be an issue with reactant, but I'm new to playing w/ this part of the ecosystem so let me know if I need to file something there.
Metadata
Metadata
Assignees
Labels
No labels