diff --git a/docs/api/dt/nth.rst b/docs/api/dt/nth.rst new file mode 100644 index 0000000000..51ce105d7c --- /dev/null +++ b/docs/api/dt/nth.rst @@ -0,0 +1,79 @@ + +.. xfunction:: datatable.nth + :src: src/core/expr/fexpr_nth.cc pyfn_nth + :cvar: doc_dt_nth + :tests: tests/test-reduce.py + :signature: nth(cols, n=0, skipna=None) + + Return the ``nth`` row for an ``Expr``. + + Parameters + ---------- + cols: FExpr | iterable + Input columns or an iterable. + + n: int + The number of the row to be returned. + + skipna: None | "any" | "all" + Drop the nulls before counting which row is the nth row. + Needs to be ``None``, ``any``, or ``all``. + + return: Expr | ... + One-row f-expression that has the same names, stypes and + number of columns as `cols`. + + Examples + -------- + .. code-block:: python + + >>> from datatable import dt, f, by + >>> + >>> df = dt.Frame({'A': [1, 1, 2, 1, 2], + ... 'B': [None, 2, 3, 4, 5], + ... 'C': [1, 2, 1, 1, 2]}) + >>> df + | A B C + | int32 int32 int32 + -- + ----- ----- ----- + 0 | 1 NA 1 + 1 | 1 2 2 + 2 | 2 3 1 + 3 | 1 4 1 + 4 | 2 5 2 + [5 rows x 3 columns] + + Get the third row of column A:: + + >>> df[:, dt.nth(f.A, n=2)] + | A + | int32 + -- + ----- + 0 | 2 + [1 row x 1 column] + + Get the third row for multiple columns:: + + >>> df[:, dt.nth(f[:], n=2)] + | A B C + | int32 int32 int32 + -- + ----- ----- ----- + 0 | 2 3 1 + + +In the presence of :func:`by()`, it returns the nth row of the specified columns per group:: + + >>> df[:, [dt.nth(f.A, n = 2), dt.nth(f.B, n = -1)], by(f.C)] + | C A B + | int32 int32 int32 + -- + ----- ----- ----- + 0 | 1 1 4 + 1 | 2 NA 5 + [2 rows x 3 columns] + + + + See Also + -------- + - :func:`first()` -- function that returns the first row. + - :func:`last()` -- function that returns the last row. diff --git a/docs/api/fexpr.rst b/docs/api/fexpr.rst index da6fe0448a..e1063a90d1 100644 --- a/docs/api/fexpr.rst +++ b/docs/api/fexpr.rst @@ -327,6 +327,7 @@ .median() .min() .nunique() + .nth() .prod() .re_match() .remove() diff --git a/docs/api/fexpr/nth.rst b/docs/api/fexpr/nth.rst new file mode 100644 index 0000000000..6b824f54f3 --- /dev/null +++ b/docs/api/fexpr/nth.rst @@ -0,0 +1,7 @@ + +.. xmethod:: datatable.FExpr.nth + :src: src/core/expr/fexpr.cc PyFExpr::nth + :cvar: doc_FExpr_nth + :signature: nth(n, skipna) + + Equivalent to :func:`dt.nth(cols, n, skipna)`. diff --git a/docs/api/index-api.rst b/docs/api/index-api.rst index b9cffbfa03..3fab27d32c 100644 --- a/docs/api/index-api.rst +++ b/docs/api/index-api.rst @@ -189,6 +189,8 @@ Functions - Find the smallest element per column * - :func:`ngroup()` - Number each group + * - :func:`nth()` + - Return the nth row. * - :func:`nunique()` - Count the number of unique values per column * - :func:`prod()` @@ -273,6 +275,7 @@ Other median()
min()
ngroup()
+ nth()
nunique()
prod()
qcut()
diff --git a/docs/releases/v1.1.0.rst b/docs/releases/v1.1.0.rst index 68c47ca3be..5528904df1 100644 --- a/docs/releases/v1.1.0.rst +++ b/docs/releases/v1.1.0.rst @@ -114,6 +114,8 @@ -[new] Added reducer functions :func:`dt.countna()` and :func:`dt.nunique()`. [#2999] + -[new] Added function :func:`dt.nth()` to retrieve the n-th row. [#3128] + -[new] Class :class:`dt.FExpr` now has method :meth:`.nunique()`, which behaves exactly as the equivalent base level function :func:`dt.nunique()`. diff --git a/src/core/column/nth.h b/src/core/column/nth.h new file mode 100644 index 0000000000..a2c13868ed --- /dev/null +++ b/src/core/column/nth.h @@ -0,0 +1,86 @@ +//------------------------------------------------------------------------------ +// Copyright 2022 H2O.ai +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. +//------------------------------------------------------------------------------ +#ifndef dt_NTH_h +#define dt_NTH_h +#include "column/virtual.h" +#include "stype.h" +namespace dt { + +template +class Nth_ColumnImpl : public Virtual_ColumnImpl { + private: + Column col_; + Groupby gby_; + bool is_grouped_; + int32_t n_; + size_t : 32; + + public: + Nth_ColumnImpl(Column&& col, const Groupby& gby, bool is_grouped, int32_t n) + : Virtual_ColumnImpl(gby.size(), col.stype()), + col_(std::move(col)), + gby_(gby), + is_grouped_(is_grouped), + n_(n) + { + xassert(col_.can_be_read_as()); + } + + + ColumnImpl* clone() const override { + return new Nth_ColumnImpl(Column(col_), gby_, is_grouped_, n_); + } + + + size_t n_children() const noexcept override { + return 1; + } + + + const Column& child(size_t i) const override { + xassert(i == 0); (void)i; + return col_; + } + + bool get_element(size_t i, T* out) const override { + xassert(i < gby_.size()); + size_t i0, i1; + gby_.get_group(i, &i0, &i1); + + // Note, when `n_` is negative it is cast to `size_t`, that is + // an unsigned type. Then, when adding `i1`, we rely on `size_t` + // wrap-around. + size_t ni = (n_ >= 0)? static_cast(n_) + i0 + : static_cast(n_) + i1; + bool isvalid = false; + if (ni >= i0 && ni < i1){ + ni = is_grouped_?i:ni; + isvalid = col_.get_element(ni, out); + } + return isvalid; + } +}; + +} // namespace dt +#endif + + diff --git a/src/core/documentation.h b/src/core/documentation.h index 2cc4f743ad..3477932556 100644 --- a/src/core/documentation.h +++ b/src/core/documentation.h @@ -52,6 +52,7 @@ extern const char* doc_dt_mean; extern const char* doc_dt_median; extern const char* doc_dt_min; extern const char* doc_dt_ngroup; +extern const char* doc_dt_nth; extern const char* doc_dt_nunique; extern const char* doc_dt_qcut; extern const char* doc_dt_rbind; @@ -302,6 +303,7 @@ extern const char* doc_FExpr_max; extern const char* doc_FExpr_mean; extern const char* doc_FExpr_median; extern const char* doc_FExpr_min; +extern const char* doc_FExpr_nth; extern const char* doc_FExpr_nunique; extern const char* doc_FExpr_prod; extern const char* doc_FExpr_remove; diff --git a/src/core/expr/fexpr.cc b/src/core/expr/fexpr.cc index dfd59391cc..1c23ea139e 100644 --- a/src/core/expr/fexpr.cc +++ b/src/core/expr/fexpr.cc @@ -531,6 +531,20 @@ DECLARE_METHOD(&PyFExpr::min) ->name("min") ->docs(dt::doc_FExpr_min); +oobj PyFExpr::nth(const XArgs& args) { + auto nthFn = oobj::import("datatable", "nth"); + oobj n = args[0].to_oobj() ? args[0].to_oobj() + : py::oint(0); + oobj skipna = args[1].to_oobj_or_none(); + return nthFn.call({this, n, skipna}); +} + +DECLARE_METHOD(&PyFExpr::nth) + ->name("nth") + ->arg_names({"n", "skipna"}) + ->n_positional_or_keyword_args(2) + ->docs(dt::doc_FExpr_nth); + oobj PyFExpr::nunique(const XArgs&) { auto nuniqueFn = oobj::import("datatable", "nunique"); diff --git a/src/core/expr/fexpr.h b/src/core/expr/fexpr.h index b292bbb1d8..8a64df836c 100644 --- a/src/core/expr/fexpr.h +++ b/src/core/expr/fexpr.h @@ -197,6 +197,7 @@ class PyFExpr : public py::XObject { py::oobj mean(const py::XArgs&); py::oobj median(const py::XArgs&); py::oobj min(const py::XArgs&); + py::oobj nth(const py::XArgs&); py::oobj nunique(const py::XArgs&); py::oobj prod(const py::XArgs&); py::oobj remove(const py::XArgs&); diff --git a/src/core/expr/fexpr_nth.cc b/src/core/expr/fexpr_nth.cc new file mode 100644 index 0000000000..177c1b744c --- /dev/null +++ b/src/core/expr/fexpr_nth.cc @@ -0,0 +1,302 @@ +//------------------------------------------------------------------------------ +// Copyright 2022 H2O.ai +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. +//------------------------------------------------------------------------------ +#include "column/const.h" +#include "column/func_nary.h" +#include "column/latent.h" +#include "column/isna.h" +#include "column/nth.h" +#include "documentation.h" +#include "expr/fnary/fnary.h" +#include "expr/fexpr_func.h" +#include "expr/eval_context.h" +#include "parallel/api.h" +#include "python/xargs.h" +namespace dt { +namespace expr { + + +template +class FExpr_Nth : public FExpr_Func { + private: + ptrExpr arg_; + int32_t n_; + + public: + FExpr_Nth(ptrExpr&& arg, py::oobj n) + : arg_(std::move(arg)) + {n_ = n.to_int32_strict();} + + std::string repr() const override { + std::string out = "nth"; + out += '('; + out += arg_->repr(); + out += ", n="; + out += std::to_string(n_); + if (SKIPNA == 0) { + out += ", skipna=None"; + } else if (SKIPNA == 1) { + out += ", skipna=any"; + } else if (SKIPNA == 2) { + out += ", skipna=all"; + } + out += ')'; + return out; + } + + static Column make_isna_col(Column&& col) { + switch (col.stype()) { + case SType::VOID: return Const_ColumnImpl::make_bool_column(col.nrows(), true); + case SType::BOOL: + case SType::INT8: return Column(new Isna_ColumnImpl(std::move(col))); + case SType::INT16: return Column(new Isna_ColumnImpl(std::move(col))); + case SType::DATE32: + case SType::INT32: return Column(new Isna_ColumnImpl(std::move(col))); + case SType::TIME64: + case SType::INT64: return Column(new Isna_ColumnImpl(std::move(col))); + case SType::FLOAT32: return Column(new Isna_ColumnImpl(std::move(col))); + case SType::FLOAT64: return Column(new Isna_ColumnImpl(std::move(col))); + case SType::STR32: + case SType::STR64: return Column(new Isna_ColumnImpl(std::move(col))); + default: throw RuntimeError(); + } + } + + static bool op_rowany(size_t i, int8_t* out, const colvec& columns) { + for (const auto& col : columns) { + int8_t x; + bool xvalid = col.get_element(i, &x); + if (xvalid && x) { + *out = 1; + return true; + } + } + *out = 0; + return true; + } + + static bool op_rowall(size_t i, int8_t* out, const colvec& columns) { + for (const auto& col : columns) { + int8_t x; + bool xvalid = col.get_element(i, &x); + if (!xvalid || x == 0) { + *out = 0; + return true; + } + } + *out = 1; + return true; + } + + static Column make_boolean_column(colvec&& columns, const size_t nrows, const size_t ncols) { + if (SKIPNA == 1) { + return Column(new FuncNary_ColumnImpl( + std::move(columns), op_rowany, nrows, SType::BOOL)); + } + return Column(new FuncNary_ColumnImpl( + std::move(columns), op_rowall, nrows, SType::BOOL)); + + } + + template + static RowIndex rowindex_nth(Column& col, const Groupby& gby) { + Buffer buf = Buffer::mem(col.nrows() * sizeof(int32_t)); + auto indices = static_cast(buf.xptr()); + Latent_ColumnImpl::vivify(col); + + dt::parallel_for_dynamic( + gby.size(), + [&](size_t gi) { + size_t i1, i2; + gby.get_group(gi, &i1, &i2); + size_t n = POSITIVE? i1: i2 - 1; + int8_t value; + bool is_valid; + + if (POSITIVE) { + for (size_t i = i1; i < i2; ++i) { + is_valid = col.get_element(i, &value); + if (value==0 && is_valid) { + indices[n] = static_cast(i); + n += 1; + } + } + for (size_t j = n; j < i2; ++j){ + indices[j] = RowIndex::NA; + } + } else { + for (size_t i = i2; i-- > i1;) { + is_valid = col.get_element(i, &value); + if (value==0 && is_valid) { + indices[n] = static_cast(i); + n -= 1; + } + } + for (size_t j = n+1; j-- > i1;){ + indices[j] = RowIndex::NA; + } + } + } + ); + + return RowIndex(std::move(buf), RowIndex::ARR32|RowIndex::SORTED); + } + + Workframe evaluate_n(EvalContext &ctx) const override { + Workframe wf = arg_->evaluate_n(ctx); + Workframe outputs(ctx); + Groupby gby = ctx.get_groupby(); + + // Check if the input frame is grouped as `GtoONE` + bool is_wf_grouped = (wf.get_grouping_mode() == Grouping::GtoONE); + + if (is_wf_grouped) { + // Check if the input frame columns are grouped + bool are_cols_grouped = ctx.has_group_column( + wf.get_frame_id(0), + wf.get_column_id(0) + ); + + if (!are_cols_grouped) { + // When the input frame is `GtoONE`, but columns are not grouped, + // it means we are dealing with the output of another reducer. + // In such a case we create a new groupby, that has one element + // per a group. This may not be optimal performance-wise, + // but chained reducers is a very rare scenario. + xassert(gby.size() == wf.nrows()); + gby = Groupby::nrows_groups(gby.size()); + } + } + + if (wf.nrows() == 0) { + for (size_t i = 0; i < wf.ncols(); ++i) { + Column coli = Column::new_na_column(1, wf.retrieve_column(i).stype()); + outputs.add_column(std::move(coli), wf.retrieve_name(i), Grouping::GtoONE); + } + return outputs; + } + + RowIndex ri; + if (SKIPNA > 0) { + Workframe wf_skipna = arg_->evaluate_n(ctx); + colvec columns; + size_t ncols = wf_skipna.ncols(); + size_t nrows = wf_skipna.nrows(); + columns.reserve(ncols); + for (size_t i = 0; i < ncols; ++i) { + Column coli = make_isna_col(wf_skipna.retrieve_column(i)); + columns.push_back(std::move(coli)); + } + Column bool_column = make_boolean_column(std::move(columns), nrows, ncols); + ri = n_ < 0 ? rowindex_nth(bool_column, gby) + : rowindex_nth(bool_column, gby); + } + for (size_t i = 0; i < wf.ncols(); ++i) { + bool is_grouped = ctx.has_group_column( + wf.get_frame_id(i), + wf.get_column_id(i) + ); + Column coli = wf.retrieve_column(i); + if (SKIPNA > 0) coli.apply_rowindex(ri); + coli = evaluate1(std::move(coli), gby, is_grouped, n_); + outputs.add_column(std::move(coli), wf.retrieve_name(i), Grouping::GtoONE); + } + return outputs; + } + + + Column evaluate1(Column&& col, const Groupby& gby, bool is_grouped, const int32_t n) const { + SType stype = col.stype(); + switch (stype) { + case SType::VOID: return Column(new ConstNa_ColumnImpl(gby.size())); + case SType::BOOL: + case SType::INT8: return make(std::move(col), gby, is_grouped, n); + case SType::INT16: return make(std::move(col), gby, is_grouped, n); + case SType::DATE32: + case SType::INT32: return make(std::move(col), gby, is_grouped, n); + case SType::TIME64: + case SType::INT64: return make(std::move(col), gby, is_grouped, n); + case SType::FLOAT32: return make(std::move(col), gby, is_grouped, n); + case SType::FLOAT64: return make(std::move(col), gby, is_grouped, n); + case SType::STR32: return make(std::move(col), gby, is_grouped, n); + case SType::STR64: return make(std::move(col), gby, is_grouped,n); + default: + throw TypeError() + << "Invalid column of type `" << stype << "` in " << repr(); + } + } + + + template + Column make(Column&& col, const Groupby& gby, bool is_grouped, int32_t n) const { + return Column(new Nth_ColumnImpl(std::move(col), gby, is_grouped, n)); + } + +}; + + +static py::oobj pyfn_nth(const py::XArgs& args) { + auto arg = args[0].to_oobj(); + auto n = args[1].to(py::oint(0)); + auto skipna = args[2].to_oobj_or_none(); + if (!skipna.is_none()) { + if (!skipna.is_string()) { + throw TypeError() << "The argument for the `skipna` parameter " + <<"in function datatable.nth() should either be None, " + <<"or a string, instead got "<(as_fexpr(arg), n)); + } + if (skip_na == "all") { + return PyFExpr::make(new FExpr_Nth<2>(as_fexpr(arg), n)); + } + + } + return PyFExpr::make(new FExpr_Nth<0>(as_fexpr(arg), n)); +} + + +DECLARE_PYFN(&pyfn_nth) + ->name("nth") + ->docs(doc_dt_nth) + ->arg_names({"cols", "n", "skipna"}) + ->n_positional_args(1) + ->n_positional_or_keyword_args(2) + ->n_required_args(1); + + +}} // dt::expr diff --git a/src/datatable/__init__.py b/src/datatable/__init__.py index 8010f4674a..987c99f119 100644 --- a/src/datatable/__init__.py +++ b/src/datatable/__init__.py @@ -46,6 +46,7 @@ mean, Namespace, ngroup, + nth, prod, qcut, rbind, @@ -131,6 +132,7 @@ "mean", "median", "ngroup", + "nth", "obj64", "options", "prod", diff --git a/tests/dt/test-nth.py b/tests/dt/test-nth.py new file mode 100644 index 0000000000..ddb5124ab0 --- /dev/null +++ b/tests/dt/test-nth.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------- +# Copyright 2022 H2O.ai +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ------------------------------------------------------------------------------- +import pytest +import re +from datatable import dt, f, nth, FExpr, by +from tests import assert_equals + +# ------------------------------------------------------------------------------- +# Errors +# ------------------------------------------------------------------------------- + + +def test_nth_parameter_not_int(): + msg = ( + "The argument for the nth parameter in function datatable.nth() " + "should be an integer, instead got " + ) + DT = dt.Frame([1, 2, None, 4, 5]) + with pytest.raises(TypeError, match=re.escape(msg)): + DT[:, nth(f[0], "1")] + + +def test_nth_no_argument(): + msg = ( + r"Function datatable.nth\(\) requires at least 1 positional " + "argument, but none were given" + ) + with pytest.raises(TypeError, match=msg): + nth() + + +# ------------------------------------------------------------------------------- +# Normal +# ------------------------------------------------------------------------------- + + +def test_nth_str(): + assert str(nth(f.A, n=1)) == "FExpr<" + nth.__name__ + "(f.A, n=1, skipna=None)>" + assert ( + str(nth(f.A, n=1, skipna="all") + 1) + == "FExpr<" + nth.__name__ + "(f.A, n=1, skipna=all) + 1>" + ) + assert ( + str(nth(f.A + f.B, n=1)) + == "FExpr<" + nth.__name__ + "(f.A + f.B, n=1, skipna=None)>" + ) + assert ( + str(nth(f.B, 1, "any")) == "FExpr<" + nth.__name__ + "(f.B, n=1, skipna=any)>" + ) + assert str(nth(f[:2], 1)) == "FExpr<" + nth.__name__ + "(f[:2], n=1, skipna=None)>" + + +def test_nth_empty_frame(): + DT = dt.Frame() + expr_nth = nth(DT, 1) + assert isinstance(expr_nth, FExpr) + assert_equals(DT[:, nth(f[:], 1)], DT) + + +def test_nth_empty_frame_skipna(): + DT = dt.Frame() + expr_nth = nth(DT, 1) + assert isinstance(expr_nth, FExpr) + assert_equals(DT[:, nth(f[:], 1)], DT) + + +def test_nth_void(): + DT = dt.Frame([None, None, None]) + DT_nth = DT[:, nth(f[:], 0)] + assert_equals(DT_nth, DT[0, :]) + + +def test_nth_void_skipna(): + DT = dt.Frame([None, None, None]) + DT_nth = DT[:, nth(f[:], 0, None)] + assert_equals(DT_nth, DT[0, :]) + + +def test_nth_trivial(): + DT = dt.Frame([0] / dt.int64) + nth_fexpr = nth(f[:], n=-1) + DT_nth = DT[:, nth_fexpr] + assert isinstance(nth_fexpr, FExpr) + assert_equals(DT, DT_nth) + + +def test_nth_trivial_skipna(): + DT = dt.Frame([0] / dt.int64) + nth_fexpr = nth(f[:], n=-1, skipna=None) + DT_nth = DT[:, nth_fexpr] + assert isinstance(nth_fexpr, FExpr) + assert_equals(DT, DT_nth) + + +def test_nth_bool(): + DT = dt.Frame([None, False, None, True, False, True]) + DT_nth = DT[:, [nth(f[:], n=1), nth(f[:], n=-1), nth(f[:], n=24)]] + DT_ref = dt.Frame([[False], [True], [None] / dt.bool8]) + assert_equals(DT_nth, DT_ref) + + +def test_nth_bool_skipna(): + DT = dt.Frame([None, False, None, True, False, True]) + DT_nth = DT[ + :, + [ + nth(f[:], n=0, skipna="all"), + nth(f[:], n=-1, skipna="any"), + nth(f[:], n=2, skipna="any"), + ], + ] + DT_ref = dt.Frame([[False], [True], [False]]) + + assert_equals(DT_nth, DT_ref) + + +def test_nth_small(): + DT = dt.Frame([None, 3, None, 4]) + DT_nth = DT[:, [nth(f[:], n=1), nth(f[:], n=-5)]] + DT_ref = dt.Frame([[3] / dt.int32, [None] / dt.int32]) + assert_equals(DT_nth, DT_ref) + + +def test_nth_string(): + DT = dt.Frame(["d", "a", "z", "b"]) + DT_nth = DT[:, [nth(f[:], 0), nth(f[:], n=-1)]] + DT_ref = dt.Frame([["d"], ["b"]]) + assert_equals(DT_nth, DT_ref) + + +def test_nth_grouped(): + DT = dt.Frame( + [ + [15, None, 136, 93, 743, None, None, 91], + ["a", "a", "a", "b", "b", "c", "c", "c"], + ] + ) + DT_nth = DT[:, [nth(f[:], n=0), nth(f[:], n=2)], by(f[-1])] + DT_ref = dt.Frame( + { + "C1": [ + "a", + "b", + "c", + ], + "C0": [15, 93, None], + "C2": [136, None, 91], + } + ) + assert_equals(DT_nth, DT_ref) + + +def test_positive_nth_grouped_skipna(): + DT = dt.Frame( + [ + [15, None, 136, 93, 743, None, None, 91], + ["a", "a", "a", "b", "b", "c", "c", "c"], + ] + ) + DT_nth = DT[ + :, [nth(f[:], n=0, skipna="all"), nth(f[:], n=1, skipna="any")], by(f[-1]) + ] + DT_ref = dt.Frame( + { + "C1": [ + "a", + "b", + "c", + ], + "C0": [15, 93, 91], + "C2": [136, 743, None], + } + ) + assert_equals(DT_nth, DT_ref) + + +def test_negative_nth_grouped_skipna(): + DT = dt.Frame( + [ + [15, None, 136, 93, 743, None, None, 91], + ["a", "a", "a", "b", "b", "c", "c", "c"], + ] + ) + DT_nth = DT[ + :, [nth(f[:], n=-1, skipna="all"), nth(f[:], n=-2, skipna="any")], by(f[-1]) + ] + DT_ref = dt.Frame( + { + "C1": [ + "a", + "b", + "c", + ], + "C0": [136, 743, 91], + "C2": [15, 93, None], + } + ) + assert_equals(DT_nth, DT_ref) + + +def test_nth_grouped_column(): + DT = dt.Frame([0, 1, 0]) + DT_nth = DT[:, dt.nth(f.C0, 0), by(f.C0)] + DT_ref = dt.Frame({"C0": [0, 1], "C1": [0, 1]}) + assert_equals(DT_nth, DT_ref) + + +def test_nth_multiple_columns_skipna_any(): + DT = dt.Frame( + { + "building": ["a", "a", "b", "b", "a", "a", "b", "b"], + "var1": [1.5, None, 2.1, 2.2, 1.2, 1.3, 2.4, None], + "var2": [100, 110, 105, None, 102, None, 103, 107], + "var3": [10, 11, None, None, None, None, None, None], + "var4": [1, 2, 3, 4, 5, 6, 7, 8], + } + ) + DT_nth = DT[:, dt.nth(f[:], skipna="any"), by(f.building)] + DT_ref = dt.Frame( + { + "building": ["a", "b"], + "var1": [1.5, None]/dt.float64, + "var2": [100.0, None]/dt.int32, + "var3": [10.0, None]/dt.int32, + "var4": [1.0, None]/dt.int32 + } + ) + assert_equals(DT_nth, DT_ref) + + +def test_nth_multiple_columns_skipna_all(): + DT = dt.Frame( + { + "building": ["a", "a", "b", "b", "a", "a", "b", "b"], + "var1": [1.5, None, 2.1, 2.2, 1.2, 1.3, 2.4, None], + "var2": [100, 110, 105, None, 102, None, 103, 107], + "var3": [10, 11, None, None, None, None, None, None], + "var4": [1, 2, 3, 4, 5, 6, 7, 8], + } + ) + DT_nth = DT[:, dt.nth(f[:], skipna="all"), by(f.building)] + DT_ref = dt.Frame( + { + "building": ["a", "b"], + "var1": [1.5, 2.1], + "var2": [100, 105]/dt.int32, + "var3": [10.0, None]/dt.int32, + "var4": [1, 3], + } + ) + assert_equals(DT_nth, DT_ref) + +def test_nth_zero_rows(): + DT = dt.Frame() + assert_equals(DT[:, nth(f[:])], DT) \ No newline at end of file diff --git a/tests/test-f.py b/tests/test-f.py index 492e108fac..5a70ef0540 100644 --- a/tests/test-f.py +++ b/tests/test-f.py @@ -499,3 +499,13 @@ def test_codes(): type = dt.Type.cat8(dt.Type.str32)) assert_equals(DT[:, f.A.codes()], DT[:, dt.codes(f.A)]) + +def test_nth(): + assert str(dt.nth(f.A, n=0)) == str(f.A.nth(n=0, skipna=None)) + assert str(dt.nth(f.A, n=1, skipna="any")) == str(f.A.nth(n=1, skipna="any")) + assert str(dt.nth(f[:], -1, skipna=None)) == str(f[:].nth(-1, None)) + DT = dt.Frame(A = [9, 8, 2, 3, None, None, 3, 0, 5, 5, 8, None, 1]) + assert_equals(DT[:, f.A.nth(n=1, skipna=None)], DT[:, dt.nth(f.A, 1, None)]) + assert_equals(DT[:, f.A.nth(n=0, skipna="any")], DT[:, dt.nth(f.A, 0, "any")]) + assert_equals(DT[:, f.A.nth(n=0, skipna="all")], DT[:, dt.nth(f.A, 0, skipna="all")]) +