Skip to content

Commit 5d9aa6e

Browse files
authored
propagate attrs on coords in Dataset.map (#10602)
* check that weighted ops propagate attrs on coords * propagate attrs on coords in `map` if keep_attrs * directly check that `map` propagates attrs on coords * whats-new
1 parent 660b56b commit 5d9aa6e

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Bug fixes
3434
- Fix error when encoding an empty :py:class:`numpy.datetime64` array
3535
(:issue:`10722`, :pull:`10723`). By `Spencer Clark
3636
<https://github.com/spencerkclark>`_.
37+
- Propagation coordinate attrs in :py:meth:`xarray.Dataset.map` (:issue:`9317`, :pull:`10602`).
38+
By `Justus Magin <https://github.com/keewis>`_.
3739

3840
Documentation
3941
~~~~~~~~~~~~~

xarray/core/dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6929,11 +6929,22 @@ def map(
69296929
k: maybe_wrap_array(v, func(v, *args, **kwargs))
69306930
for k, v in self.data_vars.items()
69316931
}
6932+
coord_vars, indexes = merge_coordinates_without_align(
6933+
[v.coords for v in variables.values()]
6934+
)
6935+
coords = Coordinates._construct_direct(coords=coord_vars, indexes=indexes)
6936+
69326937
if keep_attrs:
69336938
for k, v in variables.items():
69346939
v._copy_attrs_from(self.data_vars[k])
6940+
6941+
for k, v in coords.items():
6942+
if k not in self.coords:
6943+
continue
6944+
v._copy_attrs_from(self.coords[k])
6945+
69356946
attrs = self.attrs if keep_attrs else None
6936-
return type(self)(variables, attrs=attrs)
6947+
return type(self)(variables, coords=coords, attrs=attrs)
69376948

69386949
def apply(
69396950
self,

xarray/tests/test_dataset.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6236,6 +6236,38 @@ def scale(x, multiple=1):
62366236
expected = data.drop_vars("time") # time is not used on a data var
62376237
assert_equal(expected, actual)
62386238

6239+
def test_map_coords_attrs(self) -> None:
6240+
ds = xr.Dataset(
6241+
{
6242+
"a": (
6243+
["x", "y", "z"],
6244+
np.arange(24).reshape(3, 4, 2),
6245+
{"attr1": "value1"},
6246+
),
6247+
"b": ("y", np.arange(4), {"attr2": "value2"}),
6248+
},
6249+
coords={
6250+
"x": ("x", np.array([-1, 0, 1]), {"attr3": "value3"}),
6251+
"z": ("z", list("ab"), {"attr4": "value4"}),
6252+
},
6253+
)
6254+
6255+
def func(arr):
6256+
if "y" not in arr.dims:
6257+
return arr
6258+
6259+
# drop attrs from coords
6260+
return arr.mean(dim="y").drop_attrs()
6261+
6262+
expected = ds.mean(dim="y", keep_attrs=True)
6263+
actual = ds.map(func, keep_attrs=True)
6264+
6265+
assert_identical(actual, expected)
6266+
assert actual["x"].attrs
6267+
6268+
ds["x"].attrs["y"] = "x"
6269+
assert ds["x"].attrs != actual["x"].attrs
6270+
62396271
def test_apply_pending_deprecated_map(self) -> None:
62406272
data = create_test_data()
62416273
data.attrs["foo"] = "bar"

xarray/tests/test_weighted.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,17 @@ def test_weighted_operations_keep_attr_da_in_ds(operation):
770770
assert data.a.attrs == result.a.attrs
771771

772772

773+
def test_weighted_mean_keep_attrs_ds():
774+
weights = DataArray(np.random.randn(2))
775+
data = Dataset(
776+
{"a": (["dim_0", "dim_1"], np.random.randn(2, 2), dict(attr="data"))},
777+
coords={"dim_1": ("dim_1", ["a", "b"], {"attr1": "value1"})},
778+
)
779+
780+
result = data.weighted(weights).mean(dim="dim_0", keep_attrs=True)
781+
assert data.coords["dim_1"].attrs == result.coords["dim_1"].attrs
782+
783+
773784
@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile"))
774785
@pytest.mark.parametrize("as_dataset", (True, False))
775786
def test_weighted_bad_dim(operation, as_dataset):

0 commit comments

Comments
 (0)