diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 77f168d4474b..ff2f0c704ec6 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -1861,3 +1861,32 @@ def test_analytic_dereference(): assert expr.op().predicates == ( ops.Equals(ops.WindowFunction(ops.RowNumber()), ops.Literal(5, dtype="int8")), ) + + +def test_drop_null_schema_change(): + orig_schema = ibis.schema({"a": "int64", "b": "string", "c": "!float64"}) + t = ibis.table(orig_schema) + + expr = t.drop_null() + expected_schema = ibis.schema({"a": "!int64", "b": "!string", "c": "!float64"}) + assert expr.schema() == expected_schema + + expr = t.drop_null("a") + expected_schema = ibis.schema({"a": "!int64", "b": "string", "c": "!float64"}) + assert expr.schema() == expected_schema + + expr = t.drop_null(["a", "c"]) + expected_schema = ibis.schema({"a": "!int64", "b": "string", "c": "!float64"}) + assert expr.schema() == expected_schema + + expr = t.drop_null(how="all") + expected_schema = orig_schema + assert expr.schema() == expected_schema + + expr = t.drop_null("a", how="all") + expected_schema = orig_schema + assert expr.schema() == expected_schema + + expr = t.drop_null(["a", "c"], how="all") + expected_schema = orig_schema + assert expr.schema() == expected_schema diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index ef8f4a4f5908..b5c97c4f79d0 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -3061,9 +3061,17 @@ def drop_null( │ 344 │ └─────┘ """ - if subset is not None: - subset = tuple(self.bind(subset)) - return ops.DropNull(self, how, subset).to_expr() + subset_columns = None if subset is not None else tuple(self.bind(subset)) + result = ops.DropNull(self, how, subset_columns).to_expr() + if how == "any": + # We now know that all columns in `subset` are non-nullable + schema = self.schema() + subset_names = ( + schema.names if subset is None else util.promote_tuple(subset) + ) + new_types = {col: schema[col].copy(nullable=False) for col in subset_names} + result = result.cast(new_types) + return result def fill_null(self, replacements: ir.Scalar | Mapping[str, ir.Scalar], /) -> Table: """Fill null values in a table expression.