diff --git a/test/core/test_vector_calculus.py b/test/core/test_vector_calculus.py index a991a258b..e6009dafc 100644 --- a/test/core/test_vector_calculus.py +++ b/test/core/test_vector_calculus.py @@ -193,6 +193,79 @@ def test_divergence_basic(self, gridpath, datasetpath): assert np.isfinite(div_field.values).any() +class TestScalarDotGradientMPASOcean: + + def test_scalardotgradient_uses_known_gradient_components( + self, gridpath, datasetpath, monkeypatch + ): + """Test scalar dot gradient against independently supplied gradients.""" + uxds = ux.open_dataset( + gridpath("mpas", "QU", "480", "grid.nc"), + datasetpath("mpas", "QU", "480", "data.nc"), + ) + + n_face = uxds.uxgrid.n_face + dims = ["n_face"] + scalar = ux.UxDataArray( + np.zeros(n_face), dims=dims, uxgrid=uxds.uxgrid, name="scalar" + ) + u_component = ux.UxDataArray( + np.full(n_face, 2.0), dims=dims, uxgrid=uxds.uxgrid, name="u" + ) + v_component = ux.UxDataArray( + np.full(n_face, -0.5), dims=dims, uxgrid=uxds.uxgrid, name="v" + ) + + def mock_gradient(self): + return ux.UxDataset( + { + "zonal_gradient": ux.UxDataArray( + np.full(n_face, 3.0), dims=dims, uxgrid=self.uxgrid + ), + "meridional_gradient": ux.UxDataArray( + np.full(n_face, -4.0), dims=dims, uxgrid=self.uxgrid + ), + }, + uxgrid=self.uxgrid, + ) + + monkeypatch.setattr(ux.UxDataArray, "gradient", mock_gradient) + + result = u_component.scalardotgradient(v_component, scalar) + + expected = np.full(n_face, 8.0) + nt.assert_allclose(result.values, expected, rtol=0.0, atol=0.0) + + assert isinstance(result, ux.UxDataArray) + assert result.name == "scalar_dot_gradient" + assert result.attrs["long_name"] == "scalar dot gradient" + assert result.sizes == u_component.sizes + + def test_scalardotgradient_rejects_misaligned_indexes(self, gridpath, datasetpath): + """Test scalar dot gradient fails instead of silently realigning faces.""" + uxds = ux.open_dataset( + gridpath("mpas", "QU", "480", "grid.nc"), + datasetpath("mpas", "QU", "480", "data.nc"), + ) + + scalar = uxds["bottomDepth"] + u_component = ux.UxDataArray( + np.ones(scalar.size), + dims=["n_face"], + coords={"n_face": np.arange(scalar.size)}, + uxgrid=uxds.uxgrid, + ) + v_component = ux.UxDataArray( + np.ones(scalar.size), + dims=["n_face"], + coords={"n_face": np.arange(scalar.size) + 1}, + uxgrid=uxds.uxgrid, + ) + + with pytest.raises(ValueError): + u_component.scalardotgradient(v_component, scalar) + + class TestDivergenceDyamondSubset: def test_divergence_constant_field(self, gridpath, datasetpath): diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index b75199cc7..5ce27cfec 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1614,6 +1614,66 @@ def divergence(self, other: "UxDataArray", **kwargs) -> "UxDataArray": return divergence_da + def scalardotgradient(self, v: "UxDataArray", q: "UxDataArray") -> "UxDataArray": + """ + Compute the dot product between a vector field and the gradient of a scalar field. + + Parameters + ---------- + v : UxDataArray + The meridional component of the vector field. ``self`` is treated as + the zonal component. + q : UxDataArray + Scalar field whose gradient is dotted with the vector field. + + Returns + ------- + scalar_dot_gradient : UxDataArray + Dot product ``self * dq/dx + v * dq/dy``. + """ + if not isinstance(v, UxDataArray): + raise TypeError("v must be a UxDataArray") + + if not isinstance(q, UxDataArray): + raise TypeError("q must be a UxDataArray") + + if self.uxgrid != v.uxgrid or self.uxgrid != q.uxgrid: + raise ValueError("All UxDataArrays must have the same grid") + + if self.dims != v.dims or self.dims != q.dims: + raise ValueError("All UxDataArrays must have the same dimensions") + + if self.ndim > 1: + raise ValueError( + "Scalar dot gradient currently requires 1D face-centered data. " + "Consider selecting a single slice before computing." + ) + + if not (self._face_centered() and v._face_centered() and q._face_centered()): + raise ValueError( + "Computing the scalar dot gradient is only supported for face-centered data variables." + ) + + u = self + + q_gradient = q.gradient() + q_zonal = q_gradient["zonal_gradient"] + q_meridional = q_gradient["meridional_gradient"] + + u_aligned, v_aligned, q_zonal, q_meridional = xr.align( + u, v, q_zonal, q_meridional, join="exact", copy=False + ) + scalar_dot_gradient = (u_aligned * q_zonal) + (v_aligned * q_meridional) + scalar_dot_gradient.name = "scalar_dot_gradient" + scalar_dot_gradient.attrs.update( + { + "long_name": "scalar dot gradient", + "description": "Dot product u * (dq/dx) + v * (dq/dy).", + } + ) + + return UxDataArray(scalar_dot_gradient, uxgrid=self.uxgrid) + def difference(self, destination: str | None = "edge"): """Computes the absolute difference of a data variable.