Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Deprecations
Bug Fixes
~~~~~~~~~

- :py:meth:`Dataset.map` now merges attrs from the function result and the original
using the ``drop_conflicts`` strategy when ``keep_attrs=True``, preserving attrs
set by the function (:issue:`11019`, :pull:`11020`).
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is
only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`).
By `Julia Signell <https://github.com/jsignell>`_.
Expand Down
6 changes: 6 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ sparse = "0.15.*"
toolz = "0.12.*"
zarr = "2.18.*"

# TODO: Remove `platforms` restriction once pandas nightly has win-64 wheels again.
# Without this, `pixi lock` fails because it can't solve the nightly feature for win-64,
# which breaks RTD builds (RTD has no lock file cache, unlike GitHub Actions CI).
[feature.nightly]
platforms = ["linux-64", "osx-arm64"]

[feature.nightly.dependencies]
python = "*"

Expand Down
20 changes: 18 additions & 2 deletions xarray/computation/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,29 @@ def _implementation(self, func, dim, **kwargs) -> DataArray:

dataset = self.obj._to_temp_dataset()
dataset = dataset.map(func, dim=dim, **kwargs)
return self.obj._from_temp_dataset(dataset)
result = self.obj._from_temp_dataset(dataset)

# Clear attrs when keep_attrs is explicitly False
# (weighted operations can propagate attrs from weights through internal computations)
if kwargs.get("keep_attrs") is False:
result.attrs = {}

return result


class DatasetWeighted(Weighted["Dataset"]):
def _implementation(self, func, dim, **kwargs) -> Dataset:
self._check_dim(dim)
return self.obj.map(func, dim=dim, **kwargs)
result = self.obj.map(func, dim=dim, **kwargs)

# Clear attrs when keep_attrs is explicitly False
# (weighted operations can propagate attrs from weights through internal computations)
if kwargs.get("keep_attrs") is False:
result.attrs = {}
for var in result.data_vars.values():
var.attrs = {}

return result


def _inject_docstring(cls, cls_name):
Expand Down
23 changes: 14 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6910,8 +6910,10 @@ def map(
DataArray.
keep_attrs : bool or None, optional
If True, both the dataset's and variables' attributes (`attrs`) will be
copied from the original objects to the new ones. If False, the new dataset
and variables will be returned without copying the attributes.
combined from the original objects and the function results using the
``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs
are dropped. If False, the new dataset and variables will have only
the attributes set by the function.
args : iterable, optional
Positional arguments passed on to `func`.
**kwargs : Any
Expand Down Expand Up @@ -6960,16 +6962,19 @@ def map(
coords = Coordinates._construct_direct(coords=coord_vars, indexes=indexes)

if keep_attrs:
# Merge attrs from function result and original, dropping conflicts
from xarray.structure.merge import merge_attrs

for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
v.attrs = merge_attrs(
[v.attrs, self.data_vars[k].attrs], "drop_conflicts"
)
for k, v in coords.items():
if k in self.coords:
v._copy_attrs_from(self.coords[k])
else:
for v in variables.values():
v.attrs = {}
for v in coords.values():
v.attrs = {}
v.attrs = merge_attrs(
[v.attrs, self.coords[k].attrs], "drop_conflicts"
)
# When keep_attrs=False, leave attrs as the function returned them

attrs = self.attrs if keep_attrs else None
return type(self)(variables, coords=coords, attrs=attrs)
Expand Down
13 changes: 10 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ def map( # type: ignore[override]
DataArray.
keep_attrs : bool | None, optional
If True, both the dataset's and variables' attributes (`attrs`) will be
copied from the original objects to the new ones. If False, the new dataset
and variables will be returned without copying the attributes.
combined from the original objects and the function results using the
``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs
are dropped. If False, the new dataset and variables will have only
the attributes set by the function.
args : iterable, optional
Positional arguments passed on to `func`.
**kwargs : Any
Expand Down Expand Up @@ -438,8 +440,13 @@ def map( # type: ignore[override]
for k, v in self.data_vars.items()
}
if keep_attrs:
# Merge attrs from function result and original, dropping conflicts
from xarray.structure.merge import merge_attrs

for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
v.attrs = merge_attrs(
[v.attrs, self.data_vars[k].attrs], "drop_conflicts"
)
attrs = self.attrs if keep_attrs else None
# return type(self)(variables, attrs=attrs)
return Dataset(variables, attrs=attrs)
Expand Down
29 changes: 29 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6452,6 +6452,35 @@ def mixed_func(x):
expected = xr.Dataset({"foo": 42, "bar": ("y", [4, 5])})
assert_identical(result, expected)

def test_map_preserves_function_attrs(self) -> None:
# Regression test for GH11019
# Attrs added by function should be preserved in result
ds = xr.Dataset({"test": ("x", [1, 2, 3], {"original": "value"})})

def add_attr(da):
return da.assign_attrs(new_attr="foobar")

# With keep_attrs=True: merge using drop_conflicts (no conflict here)
result = ds.map(add_attr, keep_attrs=True)
assert result["test"].attrs == {"original": "value", "new_attr": "foobar"}

# With keep_attrs=False: function's attrs preserved
result = ds.map(add_attr, keep_attrs=False)
assert result["test"].attrs == {"original": "value", "new_attr": "foobar"}

# When function modifies existing attr with keep_attrs=True, conflict is dropped
def modify_attr(da):
return da.assign_attrs(original="modified", extra="added")

result = ds.map(modify_attr, keep_attrs=True)
assert result["test"].attrs == {
"extra": "added"
} # "original" dropped due to conflict

# When function modifies existing attr with keep_attrs=False, function wins
result = ds.map(modify_attr, keep_attrs=False)
assert result["test"].attrs == {"original": "modified", "extra": "added"}

def test_apply_pending_deprecated_map(self) -> None:
data = create_test_data()
data.attrs["foo"] = "bar"
Expand Down
Loading