diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7e3badc7143..00c88f19239 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,10 @@ Deprecations Bug Fixes ~~~~~~~~~ +- :py:meth:`Dataset.map` now merges attrs from the function result and the original + using the ``drop_conflicts`` strategy when ``keep_attrs=True``, preserving attrs + set by the function (:issue:`11019`, :pull:`11020`). + By `Maximilian Roos `_. - Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`). By `Julia Signell `_. diff --git a/pixi.toml b/pixi.toml index bfa51c3fb54..443c8c8f211 100644 --- a/pixi.toml +++ b/pixi.toml @@ -143,6 +143,12 @@ sparse = "0.15.*" toolz = "0.12.*" zarr = "2.18.*" +# TODO: Remove `platforms` restriction once pandas nightly has win-64 wheels again. +# Without this, `pixi lock` fails because it can't solve the nightly feature for win-64, +# which breaks RTD builds (RTD has no lock file cache, unlike GitHub Actions CI). +[feature.nightly] +platforms = ["linux-64", "osx-arm64"] + [feature.nightly.dependencies] python = "*" diff --git a/xarray/computation/weighted.py b/xarray/computation/weighted.py index b311290aabf..12d61cedc7d 100644 --- a/xarray/computation/weighted.py +++ b/xarray/computation/weighted.py @@ -544,13 +544,29 @@ def _implementation(self, func, dim, **kwargs) -> DataArray: dataset = self.obj._to_temp_dataset() dataset = dataset.map(func, dim=dim, **kwargs) - return self.obj._from_temp_dataset(dataset) + result = self.obj._from_temp_dataset(dataset) + + # Clear attrs when keep_attrs is explicitly False + # (weighted operations can propagate attrs from weights through internal computations) + if kwargs.get("keep_attrs") is False: + result.attrs = {} + + return result class DatasetWeighted(Weighted["Dataset"]): def _implementation(self, func, dim, **kwargs) -> Dataset: self._check_dim(dim) - return self.obj.map(func, dim=dim, **kwargs) + result = self.obj.map(func, dim=dim, **kwargs) + + # Clear attrs when keep_attrs is explicitly False + # (weighted operations can propagate attrs from weights through internal computations) + if kwargs.get("keep_attrs") is False: + result.attrs = {} + for var in result.data_vars.values(): + var.attrs = {} + + return result def _inject_docstring(cls, cls_name): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bce048048da..9a7b88263ef 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6910,8 +6910,10 @@ def map( DataArray. keep_attrs : bool or None, optional If True, both the dataset's and variables' attributes (`attrs`) will be - copied from the original objects to the new ones. If False, the new dataset - and variables will be returned without copying the attributes. + combined from the original objects and the function results using the + ``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs + are dropped. If False, the new dataset and variables will have only + the attributes set by the function. args : iterable, optional Positional arguments passed on to `func`. **kwargs : Any @@ -6960,16 +6962,19 @@ def map( coords = Coordinates._construct_direct(coords=coord_vars, indexes=indexes) if keep_attrs: + # Merge attrs from function result and original, dropping conflicts + from xarray.structure.merge import merge_attrs + for k, v in variables.items(): - v._copy_attrs_from(self.data_vars[k]) + v.attrs = merge_attrs( + [v.attrs, self.data_vars[k].attrs], "drop_conflicts" + ) for k, v in coords.items(): if k in self.coords: - v._copy_attrs_from(self.coords[k]) - else: - for v in variables.values(): - v.attrs = {} - for v in coords.values(): - v.attrs = {} + v.attrs = merge_attrs( + [v.attrs, self.coords[k].attrs], "drop_conflicts" + ) + # When keep_attrs=False, leave attrs as the function returned them attrs = self.attrs if keep_attrs else None return type(self)(variables, coords=coords, attrs=attrs) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e079332780c..a64ceefb207 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -397,8 +397,10 @@ def map( # type: ignore[override] DataArray. keep_attrs : bool | None, optional If True, both the dataset's and variables' attributes (`attrs`) will be - copied from the original objects to the new ones. If False, the new dataset - and variables will be returned without copying the attributes. + combined from the original objects and the function results using the + ``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs + are dropped. If False, the new dataset and variables will have only + the attributes set by the function. args : iterable, optional Positional arguments passed on to `func`. **kwargs : Any @@ -438,8 +440,13 @@ def map( # type: ignore[override] for k, v in self.data_vars.items() } if keep_attrs: + # Merge attrs from function result and original, dropping conflicts + from xarray.structure.merge import merge_attrs + for k, v in variables.items(): - v._copy_attrs_from(self.data_vars[k]) + v.attrs = merge_attrs( + [v.attrs, self.data_vars[k].attrs], "drop_conflicts" + ) attrs = self.attrs if keep_attrs else None # return type(self)(variables, attrs=attrs) return Dataset(variables, attrs=attrs) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 6dce32aeb5c..83ce11269c5 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6452,6 +6452,35 @@ def mixed_func(x): expected = xr.Dataset({"foo": 42, "bar": ("y", [4, 5])}) assert_identical(result, expected) + def test_map_preserves_function_attrs(self) -> None: + # Regression test for GH11019 + # Attrs added by function should be preserved in result + ds = xr.Dataset({"test": ("x", [1, 2, 3], {"original": "value"})}) + + def add_attr(da): + return da.assign_attrs(new_attr="foobar") + + # With keep_attrs=True: merge using drop_conflicts (no conflict here) + result = ds.map(add_attr, keep_attrs=True) + assert result["test"].attrs == {"original": "value", "new_attr": "foobar"} + + # With keep_attrs=False: function's attrs preserved + result = ds.map(add_attr, keep_attrs=False) + assert result["test"].attrs == {"original": "value", "new_attr": "foobar"} + + # When function modifies existing attr with keep_attrs=True, conflict is dropped + def modify_attr(da): + return da.assign_attrs(original="modified", extra="added") + + result = ds.map(modify_attr, keep_attrs=True) + assert result["test"].attrs == { + "extra": "added" + } # "original" dropped due to conflict + + # When function modifies existing attr with keep_attrs=False, function wins + result = ds.map(modify_attr, keep_attrs=False) + assert result["test"].attrs == {"original": "modified", "extra": "added"} + def test_apply_pending_deprecated_map(self) -> None: data = create_test_data() data.attrs["foo"] = "bar"