Skip to content

Commit df704df

Browse files
authored
Merge pull request #787 from JuliaAI/dev
For a 0.20.5 release
2 parents 1a89215 + adb341f commit df704df

File tree

4 files changed

+29
-13
lines changed

4 files changed

+29
-13
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.20.4"
4+
version = "0.20.5"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -49,6 +49,7 @@ Tables = "0.2, 1.0"
4949
julia = "1.6"
5050

5151
[extras]
52+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5253
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
5354
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
5455
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -59,4 +60,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5960
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
6061

6162
[targets]
62-
test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
63+
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]

src/interface/data_utils.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,18 @@ function MMI.selectrows(::FI, ::Val{:table}, X, r)
9797
end
9898

9999
function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer})
100-
cols = Tables.columntable(X) # named tuple of vectors
101-
return cols[c]
100+
cols = Tables.columns(X)
101+
return Tables.getcolumn(cols, c)
102102
end
103103

104-
function MMI.selectcols(::FI, ::Val{:table}, X, c::AbstractArray)
105-
cols = Tables.columntable(X) # named tuple of vectors
106-
newcols = project(cols, c)
107-
return Tables.materializer(X)(newcols)
104+
function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray})
105+
if isdataframe(X)
106+
return X[!, c]
107+
else
108+
cols = Tables.columntable(X) # named tuple of vectors
109+
newcols = project(cols, c)
110+
return Tables.materializer(X)(newcols)
111+
end
108112
end
109113

110114
# -------------------------------
@@ -124,7 +128,7 @@ function project(t::NamedTuple, indices::AbstractArray{<:Integer})
124128
end
125129

126130
# utils for selectrows
127-
typename(X) = split(string(supertype(typeof(X)).name), '.')[end]
131+
typename(X) = split(string(supertype(typeof(X))), '.')[end]
128132
isdataframe(X) = typename(X) == "AbstractDataFrame"
129133

130134
# ----------------------------------------------------------------

src/machines.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
## SCITYPE CHECK LEVEL
22

33
"""
4-
default_scitype_check_level()
4+
default_scitype_check_level()
55
66
Return the current global default value for scientific type checking
77
when constructing machines.
88
9-
default_scitype_check_level(i::Integer)
9+
default_scitype_check_level(i::Integer)
1010
1111
Set the global default value for scientific type checking to `i`.
1212

test/interface/data_utils.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import DataFrames
2+
13
rng = StableRNGs.StableRNG(123)
24

35
@testset "categorical" begin
@@ -23,7 +25,7 @@ end
2325
b = categorical(["a", "b", "c"])
2426
c = categorical(["a", "b", "c"]; ordered=true)
2527
X = (x1=x, x2=z, x3=b, x4=c)
26-
@test MLJModelInterface.scitype(x) == ST.scitype(x)
28+
@test MLJModelInterface.scitype(x) == ST.scitype(x)
2729
@test MLJModelInterface.scitype(y) == ST.scitype(y)
2830
@test MLJModelInterface.scitype(z) == ST.scitype(z)
2931
@test MLJModelInterface.scitype(a) == ST.scitype(a)
@@ -39,7 +41,7 @@ end
3941
b = categorical(["a", "b", "c"])
4042
c = categorical(["a", "b", "c"]; ordered=true)
4143
X = (x1=x, x2=z, x3=b, x4=c)
42-
@test_throws ArgumentError MLJModelInterface.schema(x)
44+
@test_throws ArgumentError MLJModelInterface.schema(x)
4345
@test MLJModelInterface.schema(X) == ST.schema(X)
4446
end
4547

@@ -197,4 +199,13 @@ end
197199
@test selectcols(tt, :w) == v
198200
end
199201

202+
# https://github.com/JuliaAI/MLJBase.jl/issues/784
203+
@testset "typename and dataframes" begin
204+
df = DataFrames.DataFrame(x=[1,2,3], y=[2,3,4], z=[4,5,6])
205+
@test MLJBase.typename(df) == "AbstractDataFrame"
206+
@test MLJBase.isdataframe(df)
207+
@test selectrows(df, 2:3) == df[2:3, :]
208+
@test selectcols(df, [:x, :z]) == df[!, [:x, :z]]
209+
end
210+
200211
true

0 commit comments

Comments
 (0)