Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/polars-python/src/c_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ pub fn _polars_runtime_64(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
.unwrap();
m.add_wrapped(wrap_pyfunction!(testing::assert_dataframe_equal_py))
.unwrap();
m.add_wrapped(wrap_pyfunction!(testing::assert_schema_equal_py))
.unwrap();

// Exceptions - Errors
m.add("PolarsError", py.get_type::<exceptions::PolarsError>())
Expand Down
27 changes: 25 additions & 2 deletions crates/polars-python/src/testing/frame.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use polars_testing::asserts::{DataFrameEqualOptions, assert_dataframe_equal};
use std::sync::Arc;

use polars_core::schema::SchemaRef;
use polars_testing::asserts::{DataFrameEqualOptions, assert_dataframe_equal, assert_schema_equal};
use pyo3::prelude::*;

use crate::PyDataFrame;
use crate::error::PyPolarsErr;
use crate::{PyDataFrame, PySchema};

#[pyfunction]
#[pyo3(signature = (left, right, *, check_row_order, check_column_order, check_dtypes, check_exact, rel_tol, abs_tol, categorical_as_str))]
Expand Down Expand Up @@ -32,3 +35,23 @@ pub fn assert_dataframe_equal_py(

assert_dataframe_equal(left_df, right_df, options).map_err(|e| PyPolarsErr::from(e).into())
}

#[pyfunction]
#[pyo3(signature = (left_schema, right_schema, check_dtypes, check_column_order))]
pub fn assert_schema_equal_py(
left_schema: PySchema,
right_schema: PySchema,
check_dtypes: bool,
check_column_order: bool,
) -> PyResult<()> {
let left_schema_ref: SchemaRef = Arc::new(left_schema.0);
let right_schema_ref: SchemaRef = Arc::new(right_schema.0);

assert_schema_equal(
&left_schema_ref,
&right_schema_ref,
check_dtypes,
check_column_order,
)
.map_err(|e| PyPolarsErr::from(e).into())
}
3 changes: 2 additions & 1 deletion crates/polars-testing/src/asserts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ pub mod series;
mod utils;

pub use utils::{
DataFrameEqualOptions, SeriesEqualOptions, assert_dataframe_equal, assert_series_equal,
DataFrameEqualOptions, SeriesEqualOptions, assert_dataframe_equal, assert_schema_equal,
assert_series_equal,
};
37 changes: 20 additions & 17 deletions crates/polars-testing/src/asserts/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ops::Not;

use polars_core::datatypes::unpack_dtypes;
use polars_core::prelude::*;
use polars_core::schema::SchemaRef;
use polars_ops::series::is_close;

/// Configuration options for comparing Series equality.
Expand Down Expand Up @@ -695,17 +696,14 @@ impl DataFrameEqualOptions {
/// - When `check_column_order` is false, compares data type sets for equality
/// - When `check_column_order` is true, performs more precise type checking
///
fn assert_dataframe_schema_equal(
left: &DataFrame,
right: &DataFrame,
pub fn assert_schema_equal(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
check_dtypes: bool,
check_column_order: bool,
) -> PolarsResult<()> {
let left_schema = left.schema();
let right_schema = right.schema();

let ordered_left_cols = left.get_column_names();
let ordered_right_cols = right.get_column_names();
let ordered_left_cols: Vec<&PlSmallStr> = left_schema.iter_names().collect();
let ordered_right_cols: Vec<&PlSmallStr> = right_schema.iter_names().collect();

let left_set: PlHashSet<&PlSmallStr> = ordered_left_cols.iter().copied().collect();
let right_set: PlHashSet<&PlSmallStr> = ordered_right_cols.iter().copied().collect();
Expand All @@ -718,7 +716,8 @@ fn assert_dataframe_schema_equal(
if left_set != right_set {
let left_not_right: Vec<_> = left_set
.iter()
.filter(|col| !right_set.contains(*col))
.copied()
.filter(|col| !right_set.contains(col))
.collect();

if !left_not_right.is_empty() {
Expand All @@ -734,7 +733,8 @@ fn assert_dataframe_schema_equal(
} else {
let right_not_left: Vec<_> = right_set
.iter()
.filter(|col| !left_set.contains(*col))
.copied()
.filter(|col| !left_set.contains(col))
.collect();

return Err(polars_err!(
Expand All @@ -760,8 +760,8 @@ fn assert_dataframe_schema_equal(

if check_dtypes {
if check_column_order {
let left_dtypes_ordered = left.dtypes();
let right_dtypes_ordered = right.dtypes();
let left_dtypes_ordered: Vec<&DataType> = left_schema.iter_values().collect();
let right_dtypes_ordered: Vec<&DataType> = right_schema.iter_values().collect();
if left_dtypes_ordered != right_dtypes_ordered {
return Err(polars_err!(
assertion_error = "DataFrames",
Expand All @@ -771,8 +771,8 @@ fn assert_dataframe_schema_equal(
));
}
} else {
let left_dtypes: PlHashSet<DataType> = left.dtypes().into_iter().collect();
let right_dtypes: PlHashSet<DataType> = right.dtypes().into_iter().collect();
let left_dtypes: PlHashSet<&DataType> = left_schema.iter_values().collect();
let right_dtypes: PlHashSet<&DataType> = right_schema.iter_values().collect();
if left_dtypes != right_dtypes {
return Err(polars_err!(
assertion_error = "DataFrames",
Expand Down Expand Up @@ -836,9 +836,12 @@ pub fn assert_dataframe_equal(
return Ok(());
}

assert_dataframe_schema_equal(
left,
right,
let left_schema = left.schema();
let right_schema = right.schema();

assert_schema_equal(
left_schema,
right_schema,
options.check_dtypes,
options.check_column_order,
)?;
Expand Down
7 changes: 7 additions & 0 deletions py-polars/src/polars/_plr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2485,6 +2485,13 @@ def assert_dataframe_equal_py(
abs_tol: float,
categorical_as_str: bool,
) -> None: ...
def assert_schema_equal_py(
left: Schema,
right: Schema,
*,
check_column_order: bool,
check_dtypes: bool,
) -> None: ...

# datatypes
def _get_dtype_max(dt: DataType) -> PyExpr: ...
Expand Down
2 changes: 2 additions & 0 deletions py-polars/src/polars/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from polars.testing.asserts import (
assert_frame_equal,
assert_frame_not_equal,
assert_schema_equal,
assert_series_equal,
assert_series_not_equal,
)
Expand All @@ -10,4 +11,5 @@
"assert_frame_not_equal",
"assert_series_equal",
"assert_series_not_equal",
"assert_schema_equal",
]
7 changes: 6 additions & 1 deletion py-polars/src/polars/testing/asserts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from polars.testing.asserts.frame import assert_frame_equal, assert_frame_not_equal
from polars.testing.asserts.frame import (
assert_frame_equal,
assert_frame_not_equal,
assert_schema_equal,
)
from polars.testing.asserts.series import assert_series_equal, assert_series_not_equal

__all__ = [
"assert_frame_equal",
"assert_frame_not_equal",
"assert_series_equal",
"assert_series_not_equal",
"assert_schema_equal",
]
54 changes: 52 additions & 2 deletions py-polars/src/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import contextlib
from typing import cast
from typing import TYPE_CHECKING, cast

from polars._utils.deprecation import deprecate_renamed_parameter
from polars.dataframe import DataFrame
from polars.lazyframe import LazyFrame
from polars.testing.asserts.utils import raise_assertion_error

if TYPE_CHECKING:
from polars import Schema

with contextlib.suppress(ImportError): # Module not available when building docs
from polars._plr import assert_dataframe_equal_py
from polars._plr import assert_dataframe_equal_py, assert_schema_equal_py


def _assert_correct_input_type(
Expand Down Expand Up @@ -229,3 +232,50 @@ def assert_frame_not_equal(
objects = "LazyFrames" if lazy else "DataFrames"
msg = f"{objects} are equal (but are expected not to be)"
raise AssertionError(msg)


def assert_schema_equal(
left_schema: Schema,
right_schema: Schema,
*,
check_column_order: bool = True,
check_dtypes: bool = True,
) -> None:
"""
Assert that the schema of the left and right frame are equal.

Raises a detailed `AssertionError` if the schemas of the frames differ.
This function is intended for use in unit tests.

Parameters
----------
left_schema
The first DataFrame or LazyFrame to compare.
right_schema
The second DataFrame or LazyFrame to compare.
check_column_order
Requires column order to match.
check_dtypes
Requires data types to match.

Examples
--------
>>> import polars as pl
>>> from polars.testing import assert_schema_equal
>>> df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]})
>>> df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
>>> assert_schema_equal(df1.schema, df2.schema)
Traceback (most recent call last):
...
AssertionError: DataFrames are different (columns are not in the same order)
[left]: ["b", "a"]
[right]: ["a", "b"]
"""
# Tell type checker these are now DataFrames to prevent type errors

assert_schema_equal_py(
left_schema,
right_schema,
check_column_order=check_column_order,
check_dtypes=check_dtypes,
)
30 changes: 29 additions & 1 deletion py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

import polars as pl
from polars.exceptions import InvalidOperationError
from polars.testing import assert_frame_equal, assert_frame_not_equal
from polars.testing import (
assert_frame_equal,
assert_frame_not_equal,
assert_schema_equal,
)
from polars.testing.parametric import dataframes

nan = float("nan")
Expand Down Expand Up @@ -423,6 +427,30 @@ def test_assert_dataframe_equal_all_nulls_fails_when_checking_dtypes() -> None:
assert_frame_equal(x, y, check_dtypes=True)


def test_assert_schema_equal_column_mismatch_order() -> None:
df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})

df1_schema = df1.schema
df2_schema = df2.schema
with pytest.raises(AssertionError, match="columns are not in the same order"):
assert_schema_equal(df1_schema, df2_schema)

assert_schema_equal(df1_schema, df2_schema, check_column_order=False)


def test_assert_schema_equal_dtypes_mismatch() -> None:
data = {"a": [1, 2], "b": [3, 4]}
df1 = pl.DataFrame(data, schema={"a": pl.Int8, "b": pl.Int16})
df2 = pl.DataFrame(data, schema={"b": pl.Int16, "a": pl.Int16})

df1_schema = df1.schema
df2_schema = df2.schema

with pytest.raises(AssertionError, match="dtypes do not match"):
assert_schema_equal(df1_schema, df2_schema, check_column_order=False)


def test_tracebackhide(testdir: pytest.Testdir) -> None:
testdir.makefile(
".py",
Expand Down
Loading