diff --git a/docs/notebooks/demos/demo_conservative_2d_curvilinear.ipynb b/docs/notebooks/demos/demo_conservative_2d_curvilinear.ipynb new file mode 100644 index 0000000..42a2679 --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_curvilinear.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — curvilinear target\n", + "\n", + "**Conservative regridding** resamples a gridded field while preserving its\n", + "area-weighted integral: each output cell is the area-weighted average of\n", + "the source cells it overlaps. It's the right tool for fluxes and intensive\n", + "quantities — precipitation, sea-surface temperature, mass-balance budgets —\n", + "where bilinear or nearest-neighbor interpolation would bias the total.\n", + "\n", + "The fast `.regrid.conservative` accessor only works on **1D-separable**\n", + "grids: plain rectilinear lat/lon, where lat depends only on `y` and lon\n", + "only on `x`. **Curvilinear** grids store coordinates as 2D arrays\n", + "`lat(y, x)` / `lon(y, x)` and are common in ocean models (ORCA, tripolar)\n", + "or any rotated/projected setup. Their cells aren't axis-aligned, so we\n", + "drop to `.regrid.conservative_2d`, which builds the full 2D polygon\n", + "intersection.\n", + "\n", + "**In this notebook.** We regrid a smooth analytic two-bump field from a\n", + "regular lat/lon grid onto a 30°-rotated curvilinear target — a minimal\n", + "stand-in for any curvilinear ocean grid — and verify that the\n", + "area-weighted integral is preserved to machine precision." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Source — regular 1° lat/lon, analytic two-bump field\n", + "\n", + "A simple synthetic field — one positive Gaussian bump in the northern\n", + "hemisphere, one negative in the southern — gives us something we can\n", + "visually track from source to target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(-60, 60, 121)\n", + "lon = np.linspace(-120, 120, 241)\n", + "Lo, La = np.meshgrid(lon, lat)\n", + "field = (\n", + " np.exp(-((Lo - 40) ** 2 + (La - 20) ** 2) / 500)\n", + " - np.exp(-((Lo + 60) ** 2 + (La + 15) ** 2) / 400)\n", + ")\n", + "src = xr.DataArray(\n", + " field,\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat, \"longitude\": lon},\n", + ")\n", + "src.plot(figsize=(8, 3.5), cmap=\"RdBu_r\", center=0)\n", + "plt.title(\"source: analytic two-bump field\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Target — rotated curvilinear grid\n", + "\n", + "To break 1D-separability we take a regular `(ny, nx)` mesh and rotate it\n", + "30° in the lat/lon plane. The result has the same kind of 2D coordinate\n", + "variables you'd find in an ORCA or rotated-pole grid: `longitude(ny, nx)`\n", + "and `latitude(ny, nx)` riding on a non-axis-aligned mesh." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "ny, nx = 30, 50\n", + "xi, yi = np.meshgrid(\n", + " np.linspace(-110, 110, nx),\n", + " np.linspace(-45, 45, ny),\n", + " indexing=\"xy\",\n", + ")\n", + "th = np.deg2rad(30)\n", + "lon2d = xi * np.cos(th) - yi * np.sin(th)\n", + "lat2d = xi * np.sin(th) + yi * np.cos(th)\n", + "target = xr.Dataset(coords={\n", + " \"longitude\": ((\"ny\", \"nx\"), lon2d),\n", + " \"latitude\": ((\"ny\", \"nx\"), lat2d),\n", + "})\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.plot(lon2d, lat2d, color=\"0.3\", lw=0.4)\n", + "ax.plot(lon2d.T, lat2d.T, color=\"0.3\", lw=0.4)\n", + "ax.set_title(\"curvilinear target (30° rotation)\")\n", + "ax.set_xlabel(\"longitude\"); ax.set_ylabel(\"latitude\")\n", + "ax.set_aspect(\"equal\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Regrid and plot\n", + "\n", + "`ConservativeRegridder` builds the weight matrix once — that's the\n", + "expensive step, since it computes the area of every source/target polygon\n", + "intersection — and lets us reuse it for the apply step and for the\n", + "diagnostic in the next cell. The one-shot equivalent is\n", + "`src.regrid.conservative_2d(target, ...)`. Either way, output is on the\n", + "curvilinear `(ny, nx)` mesh, with NaN where target cells fall outside the\n", + "source domain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "rgr = ConservativeRegridder(\n", + " src, target, x_coord=\"longitude\", y_coord=\"latitude\",\n", + ")\n", + "regridded = rgr.regrid(src)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "pc = ax.pcolormesh(lon2d, lat2d, regridded.values, cmap=\"RdBu_r\",\n", + " shading=\"auto\", vmin=-1, vmax=1)\n", + "fig.colorbar(pc, ax=ax, shrink=0.8)\n", + "ax.set_title(\"regridded onto rotated curvilinear grid\")\n", + "ax.set_xlabel(\"longitude\"); ax.set_ylabel(\"latitude\")\n", + "ax.set_aspect(\"equal\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## Conservation check\n", + "\n", + "The regridder stores its area-intersection matrix\n", + "`A[i, j] = area(target_i ∩ source_j)`. For target cells fully inside the\n", + "source domain, the area-weighted sum of outputs equals the direct A·s\n", + "integral to machine precision — the defining property of *conservative*\n", + "regridding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "A = rgr.areas\n", + "src_cover = np.ravel(A.sum(axis=0).todense())\n", + "tgt_cover = A.sum(axis=1).todense().reshape(regridded.shape)\n", + "valid = np.isfinite(regridded.values)\n", + "\n", + "direct = float((src.values.ravel() * src_cover).sum())\n", + "via_regrid = float((regridded.values[valid] * tgt_cover[valid]).sum())\n", + "print(f\"direct : {direct:.6f}\")\n", + "print(f\"via regrid : {via_regrid:.6f}\")\n", + "print(f\"relative err : {abs(direct - via_regrid) / max(abs(direct), 1e-12):.2e}\")\n", + "print(f\"coverage : {valid.mean():.2%} of target cells inside source\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_2d_regions.ipynb b/docs/notebooks/demos/demo_conservative_2d_regions.ipynb new file mode 100644 index 0000000..a66f949 --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_regions.ipynb @@ -0,0 +1,315 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e5c70992", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — regions (grid → country/state polygons)\n", + "\n", + "A common analysis question is: *given a gridded variable, what's its\n", + "area-weighted mean over each of these regions?* — countries, states,\n", + "watersheds, ocean basins, exclusive economic zones. This is the canonical\n", + "[xagg](https://github.com/ks905383/xagg)-style workflow and a natural fit\n", + "for **conservative regridding**: each region's output value is the\n", + "area-weighted average of the source cells it overlaps, which is exactly\n", + "what you want for fluxes and intensive quantities (precipitation,\n", + "temperature, mass-balance budgets) where bilinear or nearest-neighbor\n", + "interpolation would bias the total.\n", + "\n", + "Because the targets are arbitrary polygons rather than a grid, the fast\n", + "`.regrid.conservative` accessor doesn't apply. We use\n", + "`ConservativeRegridder.from_polygons`, which takes a 1D array of shapely\n", + "polygons as source and target.\n", + "\n", + "**In this notebook.**\n", + "\n", + "1. Aggregate xarray's air-temperature tutorial dataset (NMC reanalysis,\n", + " ~daily 2.5° lat/lon over North America) onto US states built by\n", + " dissolving county boundaries from `geodatasets`.\n", + "2. Visualize the per-state means against the source grid.\n", + "3. Verify conservation directly from the internal weight matrix.\n", + "4. Save the regridder, reload it, and reuse it on JJA vs. DJF subsets to\n", + " get per-state seasonal swings — showing the weight matrix is reusable\n", + " across any source field on the same grid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b28b1f2e", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "import geopandas as gpd\n", + "import geodatasets\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder, polygons_from_coords" + ] + }, + { + "cell_type": "markdown", + "id": "537c0ed3", + "metadata": {}, + "source": [ + "## Source — NCEP reanalysis surface air temperature\n", + "\n", + "`xr.tutorial.open_dataset(\"air_temperature\")` is a small (~3 MB) NMC\n", + "reanalysis subset: 4×daily surface air temperature over North America on\n", + "a 2.5° lat/lon grid. We reduce it to a long-term mean, convert from\n", + "Kelvin to °C, and align with the state polygons' CRS:\n", + "\n", + "- Native longitudes are on `[0, 360)`; states are on `[-180, 180)`, so\n", + " we wrap.\n", + "- Native latitudes are descending; we sort to ascending so xarray's\n", + " `pcolormesh` doesn't flip the image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "721e935c", + "metadata": {}, + "outputs": [], + "source": [ + "ds = xr.tutorial.open_dataset(\"air_temperature\")\n", + "air = (ds[\"air\"].mean(\"time\") - 273.15).sortby(\"lat\")\n", + "air = air.assign_coords(lon=(((air.lon + 180) % 360) - 180)).sortby(\"lon\")\n", + "air.attrs[\"units\"] = \"degC\"\n", + "air.name = \"mean_air_temperature\"\n", + "air" + ] + }, + { + "cell_type": "markdown", + "id": "434441b2", + "metadata": {}, + "source": [ + "## Regions — US states from `geodatasets`\n", + "\n", + "`geodatasets.get_path(\"geoda.ncovr\")` returns a GeoPackage of the 49\n", + "contiguous-US counties (+ DC). We **dissolve** on `STATE_NAME` —\n", + "geopandas' equivalent of a `groupby` for geometries — to combine each\n", + "state's counties into one (Multi)Polygon. The result is exactly what\n", + "`ConservativeRegridder.from_polygons` wants: one shapely (Multi)Polygon\n", + "per region. The air-temperature grid doesn't cover Alaska or Hawaii,\n", + "which is why ncovr (contiguous-US only) is a natural fit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c6437b2", + "metadata": {}, + "outputs": [], + "source": [ + "counties = gpd.read_file(geodatasets.get_path(\"geoda.ncovr\"))\n", + "states = (\n", + " counties.dissolve(by=\"STATE_NAME\", aggfunc=\"first\")\n", + " .reset_index()[[\"STATE_NAME\", \"geometry\"]]\n", + " .sort_values(\"STATE_NAME\")\n", + " .reset_index(drop=True)\n", + ")\n", + "print(f\"{len(states)} states/DC, CRS={states.crs}, \"\n", + " f\"bounds={states.total_bounds.round(1).tolist()}\")\n", + "states.head(3)" + ] + }, + { + "cell_type": "markdown", + "id": "3a3b8ac5", + "metadata": {}, + "source": [ + "## Build the regridder and apply\n", + "\n", + "`from_polygons` takes flat 1D arrays of source and target polygons:\n", + "\n", + "- **Source polygons** come from the 1D grid coords via\n", + " `polygons_from_coords`, which builds a rectangle per cell from the\n", + " coordinate midpoints (so a 25×53 grid → 1325 source polygons).\n", + "- **Target polygons** are just the states' `geometry` column.\n", + "\n", + "Source data has to be flattened to a single `src_cell` dimension to match\n", + "the flat polygon array. Constructing the regridder is the expensive step\n", + "— it computes the area of every source/target polygon intersection.\n", + "Once built, `rgr.regrid(...)` is a sparse matrix-vector product against\n", + "the precomputed weights and is essentially instant." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da3e70c8", + "metadata": {}, + "outputs": [], + "source": [ + "src_polys = polygons_from_coords(air.lon.values, air.lat.values)\n", + "tgt_polys = states.geometry.to_numpy()\n", + "\n", + "rgr = ConservativeRegridder.from_polygons(\n", + " source_polygons=src_polys,\n", + " target_polygons=tgt_polys,\n", + " source_dim=\"src_cell\",\n", + " target_dim=\"state\",\n", + " target_coords=xr.Dataset(coords={\"state\": states.STATE_NAME.values}),\n", + ")\n", + "\n", + "src_flat = xr.DataArray(air.values.ravel(), dims=(\"src_cell\",))\n", + "state_mean = rgr.regrid(src_flat)\n", + "state_mean.attrs[\"units\"] = \"degC\"\n", + "state_mean.to_series().sort_values().round(2)" + ] + }, + { + "cell_type": "markdown", + "id": "3542f836", + "metadata": {}, + "source": [ + "## Map + ranked bar chart\n", + "\n", + "States filled by area-weighted mean temperature, with the source grid\n", + "shown underneath for reference. Both panels share the same `vmin`/`vmax`\n", + "so a state's color on the bar chart matches its color on the map. A\n", + "sanity check: the warm/cool gradient should track latitude (Florida and\n", + "the Gulf states warmest, the Upper Midwest and Northeast coolest)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb8d0dda", + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax_map, ax_bar) = plt.subplots(\n", + " 1, 2, figsize=(14, 6), gridspec_kw={\"width_ratios\": [1.6, 1]},\n", + ")\n", + "\n", + "air.plot(ax=ax_map, cmap=\"coolwarm\", alpha=0.5,\n", + " cbar_kwargs={\"shrink\": 0.65, \"label\": \"grid mean T [°C]\"})\n", + "\n", + "vmin, vmax = float(air.min()), float(air.max())\n", + "states_plot = states.assign(mean_T=state_mean.values)\n", + "states_plot.plot(\n", + " column=\"mean_T\", cmap=\"coolwarm\", ax=ax_map,\n", + " edgecolor=\"black\", linewidth=0.4,\n", + " vmin=vmin, vmax=vmax,\n", + ")\n", + "ax_map.set_xlim(-128, -65); ax_map.set_ylim(22, 52)\n", + "ax_map.set_title(\"Annual mean surface T — states vs. source grid\")\n", + "ax_map.set_xlabel(\"longitude\"); ax_map.set_ylabel(\"latitude\")\n", + "\n", + "ordered = state_mean.to_series().sort_values()\n", + "colors = plt.cm.coolwarm((ordered.values - vmin) / (vmax - vmin))\n", + "ax_bar.barh(ordered.index, ordered.values, color=colors)\n", + "ax_bar.set_xlabel(\"area-weighted mean T [°C]\")\n", + "ax_bar.tick_params(axis=\"y\", labelsize=7)\n", + "ax_bar.grid(axis=\"x\", alpha=0.3)\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "803247d2", + "metadata": {}, + "source": [ + "## Conservation check\n", + "\n", + "The internal area matrix `A[i, j] = area(state_i ∩ src_cell_j)` lets us\n", + "verify conservation directly: integrating the source field weighted by\n", + "source-cell coverage should equal integrating the regridded field\n", + "weighted by target-cell area. Equality to machine precision is the\n", + "defining property of *conservative* regridding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "041a51f3", + "metadata": {}, + "outputs": [], + "source": [ + "A = rgr.areas # (n_states, n_src)\n", + "tgt_area = rgr.target_areas\n", + "src_cover = rgr.source_coverage_areas\n", + "\n", + "direct = float((air.values.ravel() * src_cover).sum())\n", + "via_regrid = float((state_mean.values * tgt_area).sum())\n", + "print(f\"direct A·s : {direct:.6f}\")\n", + "print(f\"Σ state_mean · a_state : {via_regrid:.6f}\")\n", + "print(f\"relative error : {abs(direct - via_regrid) / abs(direct):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ff49d6b7", + "metadata": {}, + "source": [ + "## Reuse: persist the regridder, apply to summer vs. winter\n", + "\n", + "The weight matrix depends only on the source/target geometry, not on the\n", + "data. Save once with `to_netcdf`, reload with `from_netcdf`, and apply\n", + "to any field on the same source grid — no need to rebuild the (expensive)\n", + "polygon intersection. Below we use one saved regridder on JJA and DJF\n", + "subsets to compute per-state seasonal amplitude. Continental interior\n", + "states show the largest swing; Florida and California — moderated by\n", + "ocean and latitude — show the smallest." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fa88419", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "path = Path(tempfile.gettempdir()) / \"states_regridder.nc\"\n", + "rgr.to_netcdf(path)\n", + "print(f\"wrote {path.name} ({path.stat().st_size / 1024:.1f} KB)\")\n", + "\n", + "rgr2 = ConservativeRegridder.from_netcdf(path)\n", + "\n", + "def seasonal_mean(months):\n", + " sub = ds[\"air\"].sel(time=ds[\"time.month\"].isin(months)).mean(\"time\") - 273.15\n", + " sub = sub.sortby(\"lat\")\n", + " sub = sub.assign_coords(lon=(((sub.lon + 180) % 360) - 180)).sortby(\"lon\")\n", + " flat = xr.DataArray(sub.values.ravel(), dims=(\"src_cell\",))\n", + " return rgr2.regrid(flat)\n", + "\n", + "summer = seasonal_mean([6, 7, 8])\n", + "winter = seasonal_mean([12, 1, 2])\n", + "amplitude = (summer - winter).to_series().rename(\"JJA − DJF [°C]\").round(1)\n", + "print(\"largest seasonal swing:\")\n", + "print(amplitude.sort_values(ascending=False).head(5))\n", + "print(\"\\nsmallest seasonal swing (maritime / subtropical):\")\n", + "print(amplitude.sort_values().head(5))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_2d_unstructured.ipynb b/docs/notebooks/demos/demo_conservative_2d_unstructured.ipynb new file mode 100644 index 0000000..b1a939e --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_unstructured.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — unstructured mesh + save/load\n", + "\n", + "**Conservative regridding** resamples a gridded field while preserving its\n", + "area-weighted integral — the right tool for fluxes and intensive\n", + "quantities, where bilinear or nearest-neighbor interpolation would bias\n", + "the total.\n", + "\n", + "**Unstructured meshes** — ICON triangles, MPAS hexagons, finite-element\n", + "models, generic Voronoi tessellations — are the natural geometry for\n", + "simulations that need adaptive resolution (denser cells over land, coarser\n", + "over open ocean). They have no `(y, x)` index structure: every cell is\n", + "just an arbitrary polygon. So they're never 1D-separable, and the fast\n", + "`.regrid.conservative` accessor doesn't apply.\n", + "\n", + "`ConservativeRegridder.from_polygons` takes a flat 1D array of shapely\n", + "polygons as source and/or target. The same machinery handles\n", + "structured→unstructured, unstructured→structured, and unstructured→\n", + "unstructured.\n", + "\n", + "**In this notebook.**\n", + "\n", + "1. Build a synthetic Voronoi mesh as a stand-in for a real ICON/MPAS dataset.\n", + "2. Regrid a smooth analytic field from a structured lat/lon source onto\n", + " the mesh.\n", + "3. **Persist the regridder to disk** — for a fixed source/target pair the\n", + " weight matrix is the same forever, so saving it lets a long-running\n", + " pipeline (or a follow-up notebook) skip the polygon-intersection build\n", + " on restart." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.collections import PolyCollection\n", + "from scipy.spatial import Voronoi\n", + "import shapely\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder, polygons_from_coords" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Build a Voronoi mesh\n", + "\n", + "Real pipelines load a pre-built mesh (UGRID, ICON, MPAS); for a\n", + "self-contained demo we synthesize one — jitter a regular grid of generator\n", + "points, take the Voronoi tessellation, and clip to the bounding box. The\n", + "construction details aren't the point: `from_polygons` only needs a 1D\n", + "array of shapely polygons however we get them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "def voronoi_mesh(n_points, bbox, seed=0):\n", + " rng = np.random.default_rng(seed)\n", + " x0, y0, x1, y1 = bbox\n", + " side = int(np.sqrt(n_points))\n", + " xs, ys = np.linspace(x0, x1, side), np.linspace(y0, y1, side)\n", + " pts = np.column_stack([np.repeat(xs, side), np.tile(ys, side)])\n", + " pts += rng.normal(scale=(x1 - x0) / side * 0.25, size=pts.shape)\n", + " halo = np.array([\n", + " [2*x0 - x1, 2*y0 - y1], [2*x1 - x0, 2*y0 - y1],\n", + " [2*x0 - x1, 2*y1 - y0], [2*x1 - x0, 2*y1 - y0],\n", + " ])\n", + " vor = Voronoi(np.concatenate([pts, halo]))\n", + " clip = shapely.box(x0, y0, x1, y1)\n", + " polys = []\n", + " for i in range(len(pts)):\n", + " r = vor.regions[vor.point_region[i]]\n", + " if not r or -1 in r:\n", + " continue\n", + " p = shapely.intersection(shapely.Polygon(vor.vertices[r]), clip)\n", + " if p.is_empty or p.geom_type != \"Polygon\":\n", + " continue\n", + " polys.append(p)\n", + " return np.array(polys, dtype=object)\n", + "\n", + "bbox = (-120, -50, 120, 50)\n", + "mesh_polys = voronoi_mesh(n_points=400, bbox=bbox)\n", + "print(f\"{len(mesh_polys)} mesh cells\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Structured lat/lon source\n", + "\n", + "A smooth `sin(2λ)·cos(3φ)` field on a 1° rectilinear grid — wavy enough\n", + "that the regridded mesh values are visually distinct, smooth enough that\n", + "no individual mesh cell aliases the pattern." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "lat_s = np.linspace(-50, 50, 100, endpoint=False) + 0.5\n", + "lon_s = np.linspace(-120, 120, 240, endpoint=False) + 0.5\n", + "Lo, La = np.meshgrid(lon_s, lat_s)\n", + "src = xr.DataArray(\n", + " np.sin(np.deg2rad(Lo) * 2) * np.cos(np.deg2rad(La) * 3),\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat_s, \"longitude\": lon_s},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Regrid onto the mesh\n", + "\n", + "`from_polygons` takes flat 1D arrays of source and target polygons. Source\n", + "polygons come from the 1D grid coords via `polygons_from_coords` (one\n", + "rectangle per cell, built from coordinate midpoints); target polygons are\n", + "the Voronoi mesh cells. Source data has to be flattened to a single\n", + "`src_cell` dimension to match the flat polygon array." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "rgr = ConservativeRegridder.from_polygons(\n", + " source_polygons=polygons_from_coords(lon_s, lat_s),\n", + " target_polygons=mesh_polys,\n", + " source_dim=\"src_cell\",\n", + " target_dim=\"cell\",\n", + ")\n", + "print(rgr)\n", + "\n", + "src_flat = xr.DataArray(src.values.ravel(), dims=(\"src_cell\",))\n", + "mesh_vals = rgr.regrid(src_flat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "patches = [np.asarray(p.exterior.coords) for p in mesh_polys]\n", + "pc = PolyCollection(patches, array=mesh_vals.values, cmap=\"RdBu_r\",\n", + " edgecolor=\"0.4\", lw=0.3, clim=(-1, 1))\n", + "ax.add_collection(pc)\n", + "ax.set_xlim(bbox[0], bbox[2]); ax.set_ylim(bbox[1], bbox[3])\n", + "ax.set_aspect(\"equal\")\n", + "fig.colorbar(pc, ax=ax, shrink=0.8)\n", + "ax.set_title(f\"regridded onto {len(mesh_polys)}-cell Voronoi mesh\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Persist the regridder\n", + "\n", + "The weight matrix is purely a function of source and target geometry —\n", + "not of the data — so for a fixed source/target pair it never needs to\n", + "change. `to_netcdf` writes it (along with shape and version metadata) to\n", + "a small NetCDF; `from_netcdf` rebuilds the regridder. A reload-then-apply\n", + "gives bit-identical output to the original, so a long-running pipeline\n", + "can skip the (expensive) polygon-intersection build on restart." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "path = Path(tempfile.gettempdir()) / \"mesh_regridder.nc\"\n", + "rgr.to_netcdf(path)\n", + "\n", + "with xr.open_dataset(path) as weights:\n", + " for k in (\"xarray_regrid_version\", \"created\", \"src_shape\", \"dst_shape\"):\n", + " print(f\" {k}: {weights.attrs[k]}\")\n", + "\n", + "rgr2 = ConservativeRegridder.from_netcdf(path)\n", + "same = np.array_equal(rgr.regrid(src_flat).values, rgr2.regrid(src_flat).values)\n", + "print(f\"\\nreload bit-identical: {same}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 50ecd81..fe93d6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,13 @@ accel = [ "opt-einsum", "dask[distributed]", ] +conservative-2d = [ + # For conservative regridding on grids that aren't 1D-separable + # (curvilinear, unstructured). Uses shapely for 2D polygon intersection. + "shapely>=2.0", + # Needed by ConservativeRegridder.to_netcdf / from_netcdf (group support). + "h5netcdf", +] benchmarking = [ "matplotlib", "zarr", @@ -53,6 +60,15 @@ benchmarking = [ "pooch", "cftime", # required for decode time of test netCDF files ] +notebooks = [ + # Extra dependencies needed to run docs/notebooks/demos/*.ipynb. + # Combine with [conservative-2d] and [accel] for the full demo set, + # e.g.: pip install "xarray-regrid[notebooks,conservative-2d,accel]" + "matplotlib", + "pooch", # xr.tutorial.open_dataset backend + "geopandas", # polygon regions demo + "geodatasets", # ships the ncovr US county boundaries +] dev = [ "hatch", "ruff", @@ -74,7 +90,7 @@ docs = [ # Required for ReadTheDocs path = "src/xarray_regrid/__init__.py" [tool.hatch.envs.default] -features = ["accel", "dev", "benchmarking"] +features = ["accel", "dev", "benchmarking", "conservative-2d"] [tool.hatch.envs.default.scripts] lint = [ @@ -185,3 +201,14 @@ warn_return_any = true warn_unused_ignores = true show_error_codes = true exclude = ["tests/*", "docs"] + +# shapely and sparse are untyped; the conservative_2d module bridges to them, +# so treating every signature that mentions sparse.COO as an error is noise. +[[tool.mypy.overrides]] +module = "xarray_regrid.methods.conservative_2d" +disallow_any_unimported = false +warn_return_any = false + +[[tool.mypy.overrides]] +module = ["shapely", "shapely.*", "sparse"] +ignore_missing_imports = true diff --git a/src/xarray_regrid/__init__.py b/src/xarray_regrid/__init__.py index 0dcaaec..8696e1a 100644 --- a/src/xarray_regrid/__init__.py +++ b/src/xarray_regrid/__init__.py @@ -1,12 +1,20 @@ from xarray_regrid import methods +from xarray_regrid.methods.conservative_2d import ( + ConservativeRegridder, + RegridSpec, + polygons_from_coords, +) from xarray_regrid.regrid import Regridder from xarray_regrid.utils import Grid, create_regridding_dataset __all__ = [ + "ConservativeRegridder", "Grid", + "RegridSpec", "Regridder", "create_regridding_dataset", "methods", + "polygons_from_coords", ] __version__ = "0.4.2" diff --git a/src/xarray_regrid/methods/_conservative_2d_serialization.py b/src/xarray_regrid/methods/_conservative_2d_serialization.py new file mode 100644 index 0000000..989d0e4 --- /dev/null +++ b/src/xarray_regrid/methods/_conservative_2d_serialization.py @@ -0,0 +1,115 @@ +from datetime import datetime, timezone +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any + +import numpy as np +import xarray as xr + +from xarray_regrid.methods._conservative_2d_spec import RegridSpec + +try: + import sparse + + _HAS_SPARSE = True +except ImportError: # pragma: no cover + sparse = None + _HAS_SPARSE = False + +# Bump on breaking change to the on-disk format in ConservativeRegridder.to_netcdf. +_SCHEMA_VERSION = 1 + + +def _package_version() -> str: + try: + return version("xarray-regrid") + except PackageNotFoundError: + return "unknown" + + +def _coo_components( + weights: "sparse.COO | np.ndarray", +) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple[int, int]]: + if _HAS_SPARSE and isinstance(weights, sparse.COO): + coords = np.asarray(weights.coords) + return ( + coords[0].astype(np.int64, copy=False), + coords[1].astype(np.int64, copy=False), + np.asarray(weights.data), + weights.shape, + ) + arr = np.asarray(weights) + rows, cols = np.nonzero(arr) + return ( + rows.astype(np.int64, copy=False), + cols.astype(np.int64, copy=False), + arr[rows, cols], + arr.shape, + ) + + +def _coo_from_components( + row: np.ndarray, + col: np.ndarray, + data: np.ndarray, + shape: tuple[int, int], +) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([row, col]), + data=data, + shape=shape, + has_duplicates=False, + sorted=False, + ) + dense = np.zeros(shape, dtype=data.dtype if data.size else np.float64) + dense[row, col] = data + return dense + + +def _metadata_attrs( + spec: RegridSpec, source_coords: xr.Dataset, target_coords: xr.Dataset +) -> dict[str, Any]: + attrs: dict[str, Any] = { + "x_coord": spec.x_coord, + "y_coord": spec.y_coord, + "spherical": int(spec.spherical), + "src_dims": [str(d) for d in spec.src_dims], + "dst_dims": [str(d) for d in spec.dst_dims], + "src_shape": list(spec.src_shape), + "dst_shape": list(spec.dst_shape), + "xarray_regrid_version": _package_version(), + "created": datetime.now(tz=timezone.utc).isoformat(), + "schema_version": _SCHEMA_VERSION, + } + for prefix, ds, coord in [ + ("source_x", source_coords, spec.x_coord), + ("source_y", source_coords, spec.y_coord), + ("target_x", target_coords, spec.x_coord), + ("target_y", target_coords, spec.y_coord), + ]: + if coord and coord in ds.coords and ds[coord].size: + attrs[f"{prefix}_range"] = [float(ds[coord].min()), float(ds[coord].max())] + return attrs + + +def _metadata_from_attrs(attrs: dict[str, Any], path: Path) -> RegridSpec: + """Parse and validate regridding spec metadata from netCDF attrs.""" + schema_version = int(attrs.get("schema_version", 0)) + if schema_version != _SCHEMA_VERSION: + msg = ( + f"regridder file at {path} uses schema version {schema_version}; " + f"this xarray-regrid understands {_SCHEMA_VERSION}. " + "Upgrade xarray-regrid or re-save." + ) + raise ValueError(msg) + + return RegridSpec( + x_coord=str(attrs["x_coord"]), + y_coord=str(attrs["y_coord"]), + spherical=bool(int(attrs["spherical"])), + src_dims=tuple(str(d) for d in np.atleast_1d(attrs["src_dims"])), + dst_dims=tuple(str(d) for d in np.atleast_1d(attrs["dst_dims"])), + src_shape=tuple(int(s) for s in np.atleast_1d(attrs["src_shape"])), + dst_shape=tuple(int(s) for s in np.atleast_1d(attrs["dst_shape"])), + ) diff --git a/src/xarray_regrid/methods/_conservative_2d_spec.py b/src/xarray_regrid/methods/_conservative_2d_spec.py new file mode 100644 index 0000000..7f17854 --- /dev/null +++ b/src/xarray_regrid/methods/_conservative_2d_spec.py @@ -0,0 +1,15 @@ +from collections.abc import Hashable +from dataclasses import dataclass + + +@dataclass(frozen=True) +class RegridSpec: + """Canonical metadata describing a source->target regridding layout.""" + + src_dims: tuple[Hashable, ...] + dst_dims: tuple[Hashable, ...] + src_shape: tuple[int, ...] + dst_shape: tuple[int, ...] + x_coord: str + y_coord: str + spherical: bool diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 2ab67d9..e8bf6f3 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -7,7 +7,7 @@ import xarray as xr try: - import sparse # type: ignore + import sparse except ImportError: sparse = None diff --git a/src/xarray_regrid/methods/conservative_2d.py b/src/xarray_regrid/methods/conservative_2d.py new file mode 100644 index 0000000..9be0763 --- /dev/null +++ b/src/xarray_regrid/methods/conservative_2d.py @@ -0,0 +1,1161 @@ +"""Conservative regridding for grids that aren't 1D-separable. + +The existing ``conservative`` method uses axis-factored 1D overlap — fast and +elegant but strictly rectilinear. This module computes the full 2D cell +intersection via shapely, so it handles: + +- curvilinear grids (2D ``lat[i, j]`` / ``lon[i, j]`` coordinate variables) +- unstructured meshes (arbitrary polygon cells, via + :meth:`ConservativeRegridder.from_polygons`) +- grid-to-polygon aggregation (e.g. gridded data → country shapes) + +For rectilinear grids a cheap analytic fast-path is used, but this module is +still slower and more memory-intensive than ``conservative``; prefer the +axis-factored path when your grid is 1D-separable. + +Requires ``shapely >= 2.0``. If ``sparse`` is available, the weight matrix is +stored as ``sparse.COO``; otherwise a dense numpy matrix is used. +""" + +import os +import warnings +from collections.abc import Hashable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from typing import Any, Literal, cast + +import numpy as np +import xarray as xr + +from xarray_regrid import utils +from xarray_regrid.methods._conservative_2d_serialization import ( + _coo_components, + _coo_from_components, + _metadata_attrs, + _metadata_from_attrs, +) +from xarray_regrid.methods._conservative_2d_spec import RegridSpec +from xarray_regrid.methods.conservative import get_valid_threshold + +NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] | None + + +try: + import shapely + from shapely import affinity + from shapely.strtree import STRtree + + _HAS_SHAPELY = True +except ImportError: # pragma: no cover + shapely = None + affinity = None + STRtree = None + _HAS_SHAPELY = False + +try: + import sparse + + _HAS_SPARSE = True +except ImportError: # pragma: no cover + sparse = None + _HAS_SPARSE = False + + +# We fill NaNs with 0 ourselves before matmul (see `_apply_core`), so sparse's +# "NaN will not be propagated" warning is spurious. A module-level filter +# avoids `warnings.catch_warnings` inside the hot matmul path, which is not +# thread-safe — dask threads would race on the global `warnings.filters` list. +warnings.filterwarnings( + "ignore", + message="Nan will not be propagated in matrix multiplication", + category=RuntimeWarning, +) + + +SHAPELY_IMPORT_ERROR = ( + "polygon conservative regridding requires shapely >= 2.0; " + "install with `pip install shapely`." +) + + +def _check_shapely() -> None: + if not _HAS_SHAPELY: + raise ImportError(SHAPELY_IMPORT_ERROR) + + +class _Direction: + """Lazy weights, apply matrix, and coverage mask for one regrid direction. + + Holds the raw ``(n_dst, n_src)`` area matrix in this direction's orientation + and derives, on first access: + + - ``weights``: row-normalized weight matrix + - ``apply_matrix``: pre-transposed and index-sorted weights, so + ``_apply_core``'s matmul is ``(..., n_src) @ (n_src, n_dst)`` with no + per-call sort + - ``coverage`` / ``coverage_all``: which output cells have any source overlap + + A regridder holds two of these (forward, backward); transposing the + regridder swaps them with no recomputation. + """ + + def __init__(self, areas: "sparse.COO | np.ndarray") -> None: + self.areas = areas + + @cached_property + def weights(self) -> "sparse.COO | np.ndarray": + return _row_normalize(self.areas) + + @cached_property + def apply_matrix(self) -> "sparse.COO | np.ndarray": + return _transpose_weights(self.weights, sort=True) + + @cached_property + def coverage(self) -> np.ndarray: + return _coverage_mask(self.areas) + + @cached_property + def coverage_all(self) -> bool: + return bool(self.coverage.all()) + + +class ConservativeRegridder: + """Reusable conservative regridder for grids that aren't 1D-separable: + curvilinear (2D ``lat``/``lon``), unstructured (via :meth:`from_polygons`), + or arbitrary polygon-to-polygon aggregation. For purely 1D-separable + rectilinear grids, the ``.regrid.conservative`` accessor is faster. + + Build once, apply to many fields via :meth:`regrid` (or by calling the + regridder); ``.T`` gives the backward regridder. Forward and backward + weight matrices are cached lazily. Requires ``shapely >= 2.0``. + + The unnormalized cell-intersection area matrix + ``A[i, j] = area(target_i ∩ source_j)`` is exposed as ``self.areas`` + (sparse ``(n_dst, n_src)`` if the ``sparse`` package is available, dense + otherwise) — useful for conservation diagnostics and per-cell coverage + analysis. + """ + + def __init__( + self, + source: xr.DataArray | xr.Dataset, + target: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + spherical: bool = False, + n_threads: int | None = None, + ) -> None: + _check_shapely() + source_grid, target_grid, src_x_sort_idx = _normalize_longitude_coords( + source, target, x_coord + ) + src_dims = _spatial_dims(source_grid, x_coord, y_coord) + dst_dims = _spatial_dims(target_grid, x_coord, y_coord) + if not src_dims: + msg = f"source has no dims for coords {x_coord!r}, {y_coord!r}" + raise ValueError(msg) + if not dst_dims: + msg = f"target has no dims for coords {x_coord!r}, {y_coord!r}" + raise ValueError(msg) + src_grid = _grid_from_coords( + source_grid, x_coord, y_coord, src_dims, spherical=spherical + ) + dst_grid = _grid_from_coords( + target_grid, x_coord, y_coord, dst_dims, spherical=spherical + ) + self.spherical = spherical + + self.x_coord = x_coord + self.y_coord = y_coord + self._src_dims = src_dims + self._dst_dims = dst_dims + self._src_shape = tuple(int(source.sizes[d]) for d in src_dims) + self._dst_shape = tuple(int(target.sizes[d]) for d in dst_dims) + areas = _build_intersection_areas(src_grid, dst_grid, n_threads=n_threads) + if src_x_sort_idx is not None: + # Source x was sorted by _normalize_longitude_coords so polygon + # construction stayed monotone. Relabel matrix columns so column + # order matches the user's original (unsorted) data layout. + x_dim_index = src_dims.index(source[x_coord].dims[0]) + areas = _remap_columns_for_axis_sort( + areas, src_x_sort_idx, self._src_shape, x_dim_index + ) + self.areas = areas + self._source_coords = source.coords.to_dataset() + self._target_coords = target.coords.to_dataset() + + @property + def spec(self) -> RegridSpec: + """Canonical metadata describing this regridder's layout.""" + return RegridSpec( + src_dims=self._src_dims, + dst_dims=self._dst_dims, + src_shape=self._src_shape, + dst_shape=self._dst_shape, + x_coord=self.x_coord, + y_coord=self.y_coord, + spherical=self.spherical, + ) + + @cached_property + def _forward(self) -> _Direction: + return _Direction(self.areas) + + @cached_property + def _backward(self) -> _Direction: + return _Direction(_transpose_weights(self.areas)) + + @property + def forward_weights(self) -> "sparse.COO | np.ndarray": + """The row-normalized forward weight matrix (source → target).""" + return self._forward.weights + + @property + def backward_weights(self) -> "sparse.COO | np.ndarray": + """The row-normalized backward weight matrix (target → source).""" + return self._backward.weights + + @property + def target_areas(self) -> np.ndarray: + """Area of each target cell overlapped by the source domain.""" + return _sum_matrix_axis_1d(self.areas, axis=1) + + @property + def source_coverage_areas(self) -> np.ndarray: + """Area of each source cell covered by target cells.""" + return _sum_matrix_axis_1d(self.areas, axis=0) + + def regrid( + self, + data: xr.DataArray | xr.Dataset, + skipna: bool = True, + nan_threshold: float = 1.0, + ) -> xr.DataArray | xr.Dataset: + """Regrid ``data`` forward (source → target).""" + return _apply_stored_weights( + data, + direction=self._forward, + spec=self.spec, + target_coords=self._target_coords, + skipna=skipna, + nan_threshold=nan_threshold, + ) + + def __call__( + self, + data: xr.DataArray | xr.Dataset, + skipna: bool = True, + nan_threshold: float = 1.0, + ) -> xr.DataArray | xr.Dataset: + return self.regrid(data, skipna=skipna, nan_threshold=nan_threshold) + + def transpose(self) -> "ConservativeRegridder": + """Return the backward regridder (target → source). The original's + forward Direction becomes the new's backward (and vice versa), so any + already-computed weight matrices are reused, not recomputed.""" + new = type(self)._from_state( + areas=_transpose_weights(self.areas), + source_coords=self._target_coords, + target_coords=self._source_coords, + spec=RegridSpec( + src_dims=self._dst_dims, + dst_dims=self._src_dims, + src_shape=self._dst_shape, + dst_shape=self._src_shape, + x_coord=self.x_coord, + y_coord=self.y_coord, + spherical=self.spherical, + ), + ) + if "_forward" in self.__dict__: + new.__dict__["_backward"] = self.__dict__["_forward"] + if "_backward" in self.__dict__: + new.__dict__["_forward"] = self.__dict__["_backward"] + return new + + @property + def T(self) -> "ConservativeRegridder": # noqa: N802 + """Alias for :meth:`transpose` (numpy-style transpose).""" + return self.transpose() + + def __repr__(self) -> str: + nnz = getattr(self.areas, "nnz", None) + shape = getattr(self.areas, "shape", (None, None)) + nnz_str = f"nnz={nnz}" if nnz is not None else "dense" + return ( + f"ConservativeRegridder(src_dims={self._src_dims}, " + f"dst_dims={self._dst_dims}, {shape[0]}x{shape[1]}, {nnz_str})" + ) + + def to_netcdf(self, path: str | Path, engine: NetcdfEngine = None) -> None: + """Save the weight matrix and reproducibility metadata to a netCDF file. + Requires a group-aware engine (``netcdf4`` or ``h5netcdf``); ``engine`` + is forwarded to :func:`xarray.Dataset.to_netcdf`.""" + path = Path(path) + row, col, data, shape = _coo_components(self.areas) + ds_weights = xr.Dataset( + { + "_coo_row": (("nnz",), row), + "_coo_col": (("nnz",), col), + "_coo_data": (("nnz",), data), + }, + attrs={ + **_metadata_attrs(self.spec, self._source_coords, self._target_coords), + "n_dst": int(shape[0]), + "n_src": int(shape[1]), + }, + ) + ds_weights.to_netcdf(path, mode="w", engine=engine) + self._source_coords.to_netcdf( + path, mode="a", group="source_coords", engine=engine + ) + self._target_coords.to_netcdf( + path, mode="a", group="target_coords", engine=engine + ) + + @classmethod + def from_netcdf( + cls, path: str | Path, engine: NetcdfEngine = None + ) -> "ConservativeRegridder": + """Reload a regridder previously written with :meth:`to_netcdf`. + + Validates ``schema_version``; raises :class:`ValueError` if the file + was written by an incompatible version. + """ + path = Path(path) + with xr.open_dataset(path, engine=engine) as ds_weights: + attrs = dict(ds_weights.attrs) + n_dst = int(attrs.pop("n_dst")) + n_src = int(attrs.pop("n_src")) + row = np.asarray(ds_weights["_coo_row"].values) + col = np.asarray(ds_weights["_coo_col"].values) + data = np.asarray(ds_weights["_coo_data"].values) + meta = _metadata_from_attrs(attrs, path) + + with xr.open_dataset(path, group="source_coords", engine=engine) as g: + source_coords = g.load() + with xr.open_dataset(path, group="target_coords", engine=engine) as g: + target_coords = g.load() + + return cls._from_state( + areas=_coo_from_components(row, col, data, (n_dst, n_src)), + source_coords=source_coords, + target_coords=target_coords, + spec=meta, + ) + + @classmethod + def _from_state( + cls, + *, + areas: "sparse.COO | np.ndarray", + source_coords: xr.Dataset, + target_coords: xr.Dataset, + spec: RegridSpec, + ) -> "ConservativeRegridder": + """Construct a regridder directly from its canonical state. Shared + bypass of ``__init__`` used by :meth:`from_netcdf` and + :meth:`from_polygons`; keeps the list of private attrs in one place.""" + instance = object.__new__(cls) + instance.x_coord = spec.x_coord + instance.y_coord = spec.y_coord + instance.spherical = spec.spherical + instance._src_dims = spec.src_dims + instance._dst_dims = spec.dst_dims + instance._src_shape = spec.src_shape + instance._dst_shape = spec.dst_shape + instance.areas = areas + instance._source_coords = source_coords + instance._target_coords = target_coords + return instance + + @classmethod + def from_polygons( + cls, + source_polygons: np.ndarray, + target_polygons: np.ndarray, + source_dim: str = "cell", + target_dim: str = "cell", + target_coords: xr.Dataset | None = None, + periodic: bool = False, + n_threads: int | None = None, + predicate_filter: bool = True, + ) -> "ConservativeRegridder": + """Build a regridder from explicit shapely polygon arrays — for + unstructured meshes (MPAS, ICON), arbitrary polygon targets (countries, + watersheds), or any non-rectilinear combination. + + Args: + source_polygons, target_polygons: 1D arrays of shapely Polygons. + source_dim, target_dim: Dim names for source/target cells on the + input and output arrays. + target_coords: Optional Dataset of coord variables along + ``target_dim`` to reattach on the output (else: integer index). + periodic: Unwrap polygons that cross the antimeridian (treats x + as longitude on a 360-degree periodic axis). + n_threads: Thread count for parallel GEOS intersection. + predicate_filter: If True, filter STRtree candidates with GEOS + ``intersects``. Set False for tight-bbox grid cells to skip + the predicate (faster on that case, pathological otherwise). + + Geometry is planar in the polygons' own coordinate space. For lat/lon + cells, project into an equal-area CRS first or use the structured + path with ``spherical=True``. + """ + _check_shapely() + src_polys = np.asarray(source_polygons) + dst_polys = np.asarray(target_polygons) + if src_polys.ndim != 1 or dst_polys.ndim != 1: + msg = "source_polygons and target_polygons must be 1D arrays" + raise ValueError(msg) + if periodic: + src_polys = _normalize_periodic_polygons(src_polys) + dst_polys = _normalize_periodic_polygons( + dst_polys, reference=_polygon_reference_x(src_polys) + ) + + src_grid = _Grid( + polys=src_polys, + bounds=shapely.bounds(src_polys), + rectilinear=False, + ) + dst_grid = _Grid( + polys=dst_polys, + bounds=shapely.bounds(dst_polys), + rectilinear=False, + ) + n_src = int(src_polys.size) + n_dst = int(dst_polys.size) + tgt_ds = ( + target_coords + if target_coords is not None + else xr.Dataset(coords={target_dim: np.arange(n_dst)}) + ) + return cls._from_state( + areas=_build_intersection_areas( + src_grid, + dst_grid, + n_threads=n_threads, + predicate_filter=predicate_filter, + ), + source_coords=xr.Dataset(coords={source_dim: np.arange(n_src)}), + target_coords=tgt_ds, + spec=RegridSpec( + src_dims=(source_dim,), + dst_dims=(target_dim,), + src_shape=(n_src,), + dst_shape=(n_dst,), + x_coord="", + y_coord="", + spherical=False, + ), + ) + + +def polygons_from_coords( + x: np.ndarray, + y: np.ndarray, + spherical: bool = False, + periodic: bool = False, +) -> np.ndarray: + """Build a 1D row-major (y, x) array of shapely cell polygons from 1D or + 2D center coords. Convenience for mixing structured and unstructured paths + via :meth:`ConservativeRegridder.from_polygons`. ``spherical=True`` + projects 1D lat/lon (degrees) into Lambert cylindrical equal-area space; + ``periodic=True`` unwraps antimeridian-crossing cells.""" + _check_shapely() + x = np.asarray(x) + y = np.asarray(y) + if periodic: + x = _unwrap_longitude(x) + if spherical: + if x.ndim != 1 or y.ndim != 1: + msg = "spherical=True requires 1D lat/lon arrays" + raise ValueError(msg) + return _build_cea_grid(x, y).polys + return _build_grid(x, y).polys + + +def _apply_stored_weights( + data: xr.DataArray | xr.Dataset, + direction: _Direction, + spec: RegridSpec, + target_coords: xr.Dataset, + skipna: bool, + nan_threshold: float, +) -> xr.DataArray | xr.Dataset: + """Apply ``direction``'s cached, pre-transposed weight matrix to ``data`` + via ``xr.apply_ufunc``. + + The apply matrix has shape ``(n_src, n_dst)`` so the matmul is + ``(..., n_src) @ (n_src, n_dst) → (..., n_dst)`` with no per-call transpose. + """ + actual_src_shape = tuple( + int(data.sizes[d]) for d in spec.src_dims if d in data.sizes + ) + if actual_src_shape != spec.src_shape: + msg = ( + f"source spatial shape {actual_src_shape} on dims {spec.src_dims} does " + f"not match the regridder's expected shape {spec.src_shape}" + ) + raise ValueError(msg) + + src_tokens = tuple(f"__src_{d}" for d in spec.src_dims) + data_renamed = data.rename(dict(zip(spec.src_dims, src_tokens, strict=True))) + + output_dtype = _result_dtype(data) + result = xr.apply_ufunc( + _apply_core, + data_renamed, + kwargs={ + "apply_weights": direction.apply_matrix, + "coverage": direction.coverage, + "coverage_all": direction.coverage_all, + "src_shape": spec.src_shape, + "dst_shape": spec.dst_shape, + "skipna": skipna, + "nan_threshold": nan_threshold, + "output_dtype": output_dtype, + }, + input_core_dims=[list(src_tokens)], + output_core_dims=[list(spec.dst_dims)], + exclude_dims=set(src_tokens), + dask="parallelized", + output_dtypes=[output_dtype], + dask_gufunc_kwargs={ + "output_sizes": {d: int(target_coords.sizes[d]) for d in spec.dst_dims}, + "allow_rechunk": True, + }, + keep_attrs=True, + ) + + return _assign_target_coords( + result, + target_coords, + spec.dst_dims, + spec.x_coord, + spec.y_coord, + ) + + +def _coverage_mask(areas: "sparse.COO | np.ndarray") -> np.ndarray: + """Return a boolean ``(n_dst,)`` mask of target cells with any source + overlap, derived from the raw area matrix.""" + if _HAS_SPARSE and isinstance(areas, sparse.COO): + # COO sparse: row_sum is 0 iff row has no nonzero entries. + n_dst = int(areas.shape[0]) + mask = np.zeros(n_dst, dtype=bool) + mask[areas.coords[0]] = True + return mask + arr = np.asarray(areas) + return np.asarray((arr > 0).any(axis=1)) + + +def _sum_matrix_axis_1d(areas: "sparse.COO | np.ndarray", axis: int) -> np.ndarray: + summed = areas.sum(axis=axis) + if hasattr(summed, "todense"): + summed = summed.todense() + return np.asarray(summed, dtype=np.float64).reshape(-1) + + +def _normalize_longitude_coords( + source: xr.DataArray | xr.Dataset, + target: xr.Dataset, + x_coord: str, +) -> tuple[xr.DataArray | xr.Dataset, xr.Dataset, np.ndarray | None]: + """Unwrap x coordinates across the antimeridian so source and target share + a contiguous longitude frame. No-op when the coord isn't present on both + objects or doesn't look like a longitude. + + For 1D rectilinear longitudes on both sides this mirrors the per-value + wrap done by :func:`xarray_regrid.utils.format_lon` for the axis-factored + path, which is what makes a source on ``[0, 360]`` align with a target on + ``[-180, 180]`` (and vice versa). If wrapping breaks source monotonicity + the source coord is sorted in place; the caller is expected to remap the + area-matrix columns by the returned ``src_x_sort_idx`` so the final matrix + columns line up with the user's original data layout. For 2D / curvilinear + coords the existing uniform-shift fallback is kept. + """ + if x_coord not in source.coords or x_coord not in target.coords: + return source, target, None + + source_x = np.asarray(source[x_coord].values) + target_x = np.asarray(target[x_coord].values) + if not _looks_like_longitude(source_x) and not _looks_like_longitude(target_x): + return source, target, None + + source_x = _unwrap_longitude(source_x) + target_x = _unwrap_longitude(target_x) + src_finite = source_x[np.isfinite(source_x)] + tgt_finite = target_x[np.isfinite(target_x)] + + src_x_sort_idx: np.ndarray | None = None + if ( + source_x.ndim == 1 + and target_x.ndim == 1 + and src_finite.size + and tgt_finite.size + ): + # Per-value wrap source into target's 360° window — mirrors format_lon + # so source [0, 360] vs target [-180, 180] (and either reversed) + # aligns. A uniform offset can't reconcile cross-convention grids: + # mean diff is exactly 180° and round() is banker's-rounded to 0. + wrap_point = float((tgt_finite[0] + tgt_finite[-1] + 360.0) / 2.0) + source_x = np.where(source_x < wrap_point - 360.0, source_x + 360.0, source_x) + source_x = np.where(source_x > wrap_point, source_x - 360.0, source_x) + diffs = np.diff(source_x) + if not (np.all(diffs > 0) or np.all(diffs < 0)): + src_x_sort_idx = np.argsort(source_x, kind="stable") + source_x = source_x[src_x_sort_idx] + elif src_finite.size and tgt_finite.size: + # 2D / curvilinear: fall back to uniform shift of target into source's + # window. Doesn't handle cross-convention but preserves the existing + # antimeridian-crossing behavior for 2D coords. + target_x = target_x + _periodic_offset( + float(src_finite.mean()), float(tgt_finite.mean()) + ) + return ( + utils.update_coord(source, x_coord, source_x), + cast(xr.Dataset, utils.update_coord(target, x_coord, target_x)), + src_x_sort_idx, + ) + + +def _looks_like_longitude(values: np.ndarray) -> bool: + finite = values[np.isfinite(values)] + return bool(finite.size and finite.min() >= -360.0 and finite.max() <= 360.0) + + +def _unwrap_longitude(values: np.ndarray) -> np.ndarray: + """Unwrap longitudes along the trailing axis (CF convention). For 2D + coords, latitude doesn't wrap mod 360°, so unwrapping it is incorrect.""" + radians = np.deg2rad(np.asarray(values, dtype=float)) + return np.rad2deg(np.unwrap(radians, axis=-1)) + + +def _normalize_periodic_polygons( + polygons: np.ndarray, reference: float | None = None +) -> np.ndarray: + """Unwrap each polygon across the antimeridian, then shift each into the + same 360-degree window as ``reference`` (or as the first finite center if + ``reference`` is None).""" + unwrapped = [_unwrap_polygon(p) for p in polygons] + if reference is None: + for poly in unwrapped: + center = _polygon_center_x(poly) + if np.isfinite(center): + reference = center + break + if reference is None: + return np.array(unwrapped, dtype=object) + + out = [] + for poly in unwrapped: + offset = _periodic_offset(reference, _polygon_center_x(poly)) + out.append(affinity.translate(poly, xoff=offset) if offset != 0.0 else poly) + return np.array(out, dtype=object) + + +def _polygon_reference_x(polygons: np.ndarray) -> float | None: + """Mean polygon-center x across an array of polygons, or None if all + polygons have non-finite centers.""" + bounds = shapely.bounds(polygons) + centers = 0.5 * (bounds[:, 0] + bounds[:, 2]) + finite = centers[np.isfinite(centers)] + return float(finite.mean()) if finite.size else None + + +def _polygon_center_x(polygon: Any) -> float: + minx, _, maxx, _ = polygon.bounds + return 0.5 * (float(minx) + float(maxx)) + + +def _periodic_offset(reference: float, value: float) -> float: + if not np.isfinite(reference) or not np.isfinite(value): + return 0.0 + return 360.0 * round((reference - value) / 360.0) + + +def _unwrap_polygon(polygon: Any) -> Any: + if polygon.is_empty: + return polygon + if polygon.geom_type == "Polygon": + exterior = _unwrap_ring(np.asarray(polygon.exterior.coords)) + holes = [_unwrap_ring(np.asarray(ring.coords)) for ring in polygon.interiors] + return shapely.Polygon(exterior, holes) + if polygon.geom_type == "MultiPolygon": + return shapely.MultiPolygon([_unwrap_polygon(part) for part in polygon.geoms]) + return polygon + + +def _unwrap_ring(ring: np.ndarray) -> np.ndarray: + new_ring = np.asarray(ring, dtype=float).copy() + offset = 0.0 + for i in range(1, new_ring.shape[0]): + x = new_ring[i, 0] + offset + step = x - new_ring[i - 1, 0] + if step > 180.0: + offset -= 360.0 + elif step < -180.0: + offset += 360.0 + new_ring[i, 0] += offset + return new_ring + + +def _remap_columns_for_axis_sort( + areas: "sparse.COO | np.ndarray", + sort_idx: np.ndarray, + src_shape: tuple[int, ...], + axis_index: int, +) -> "sparse.COO | np.ndarray": + """Relabel the column indices of an ``(n_dst, prod(src_shape))`` area + matrix so that columns appear in the user's original source-data order + after ``sort_idx`` was applied along ``axis_index`` of ``src_shape``. + + A row-major source flat index ``c = unravel(c, src_shape)`` has its + ``axis_index`` component ``i_sorted`` permuted via + ``i_orig = sort_idx[i_sorted]``; all other components are untouched. So + the new flat index is + ``c_new = c // stride * stride + (i_orig - i_sorted) * inner + ...``. + For separable 1D-rect grids (the only case that triggers this today) the + sorted axis sits between an outer block of size ``outer`` and an inner + block of size ``inner`` with ``stride = nx * inner``. + """ + nx = int(src_shape[axis_index]) + inner = int(np.prod(src_shape[axis_index + 1 :])) + stride = nx * inner + sort_idx = np.asarray(sort_idx, dtype=np.int64) + + if _HAS_SPARSE and isinstance(areas, sparse.COO): + old_col = np.asarray(areas.coords[1], dtype=np.int64) + outer_block = (old_col // stride) * stride + within = old_col % stride + i_sorted = within // inner + rest = within % inner + new_col = outer_block + sort_idx[i_sorted] * inner + rest + coords = np.stack([np.asarray(areas.coords[0], dtype=np.int64), new_col]) + return sparse.COO( + coords=coords, + data=np.asarray(areas.data), + shape=areas.shape, + has_duplicates=False, + sorted=False, + ) + + arr = np.asarray(areas) + n_cells = arr.shape[1] + inv_sort_idx = np.empty_like(sort_idx) + inv_sort_idx[sort_idx] = np.arange(sort_idx.size, dtype=sort_idx.dtype) + cells = np.arange(n_cells, dtype=np.int64) + outer_block = (cells // stride) * stride + within = cells % stride + i_orig = within // inner + rest = within % inner + inv_perm = outer_block + inv_sort_idx[i_orig] * inner + rest + return arr[:, inv_perm] + + +def _transpose_weights( + w: "sparse.COO | np.ndarray", *, sort: bool = False +) -> "sparse.COO | np.ndarray": + """Materialize a transposed weight matrix. + + `sparse.COO.T` is a lazy view that re-sorts indices on each downstream + matmul, so callers relying on ``.coords[0]`` being row indices or wanting + a hot matmul path should materialize once here. Pass ``sort=True`` to + additionally trigger the sort ahead of time (used for the apply matrix). + """ + if _HAS_SPARSE and isinstance(w, sparse.COO): + t = w.T + out = sparse.COO( + coords=np.asarray(t.coords), + data=np.asarray(t.data), + shape=t.shape, + has_duplicates=False, + sorted=False, + ) + if sort: + out._sort_indices() + return out + return np.asarray(w).T.copy() + + +def _spatial_dims( + obj: xr.DataArray | xr.Dataset, x_coord: str, y_coord: str +) -> tuple[Hashable, ...]: + """Return the spatial dim order to feed to apply_ufunc. + + For 1D rectilinear coords (x and y ride on separate dims), canonicalize to + ``(y_dim, x_dim)`` so the flattened cell order matches the polygon order + emitted by the fast-path in ``_build_grid``. For curvilinear (2D) coords, + preserve the dim order already present on ``obj``. + """ + if x_coord not in obj.coords or y_coord not in obj.coords: + return () + xd = obj[x_coord].dims + yd = obj[y_coord].dims + if len(xd) == 1 and len(yd) == 1 and xd[0] != yd[0]: + return (yd[0], xd[0]) + dims = set(xd) | set(yd) + return tuple(d for d in obj.dims if d in dims) + + +def _grid_from_coords( + obj: xr.DataArray | xr.Dataset, + x_coord: str, + y_coord: str, + dims: tuple[Hashable, ...], + spherical: bool = False, +) -> "_Grid": + """Build a :class:`_Grid` from the object's x/y coordinates. + + Rectilinear (both coords 1D on separate dims) takes the fast path. + Curvilinear coords are broadcast to a common N-D array in ``dims`` order. + + If ``spherical`` is True, coordinates are assumed to be longitude (x) and + latitude (y) in degrees, and cells are projected into a Lambert cylindrical + equal-area space (x' = lon_rad, y' = sin(lat_rad)) before constructing the + cell polygons. This gives mass-conservative weights on the sphere at the + same cost as the planar fast path. Rectilinear-only. + """ + xd = obj[x_coord] + yd = obj[y_coord] + is_rectilinear = xd.ndim == 1 and yd.ndim == 1 and xd.dims[0] != yd.dims[0] + + if spherical and not is_rectilinear: + msg = "spherical=True is only supported for rectilinear (1D lat/lon) coords" + raise NotImplementedError(msg) + + if is_rectilinear: + x = np.asarray(xd.values) + y = np.asarray(yd.values) + return _build_cea_grid(x, y) if spherical else _build_grid(x, y) + + xc, yc = xr.broadcast(xd, yd) + return _build_grid( + np.asarray(xc.transpose(*dims).values), + np.asarray(yc.transpose(*dims).values), + ) + + +def _build_cea_grid(lon_centers: np.ndarray, lat_centers: np.ndarray) -> "_Grid": + """Build a rectilinear :class:`_Grid` whose cell polygons are in Lambert + cylindrical equal-area coordinates (x' = lon_rad, y' = sin(lat_rad)). + + Projecting *edges* analytically — rather than projecting centers and then + re-midpointing — is required because ``sin()`` is nonlinear: the projected + midpoint of two lat centers is not the same as the midpoint of two + projected lat edges. + """ + _check_shapely() + if lon_centers.size < 2 or lat_centers.size < 2: + msg = "spherical mode requires at least two cells per dimension" + raise ValueError(msg) + lat_edges_deg = np.clip(utils.infer_1d_edges(lat_centers), -90.0, 90.0) + lon_edges_deg = utils.infer_1d_edges(lon_centers) + return _rect_grid_from_edges( + np.deg2rad(lon_edges_deg), + np.sin(np.deg2rad(lat_edges_deg)), + ) + + +def _infer_2d_corners(a: np.ndarray) -> np.ndarray: + """Infer (ny+1, nx+1) cell corners from a 2D cell-center array. Interior + corners are the mean of the 4 surrounding centers; boundary corners are + reflected from the adjacent interior row/column.""" + a = np.asarray(a, dtype=float) + ny, nx = a.shape + pad = np.empty((ny + 2, nx + 2), dtype=a.dtype) + pad[1:-1, 1:-1] = a + pad[0, 1:-1] = 2 * a[0, :] - a[1, :] + pad[-1, 1:-1] = 2 * a[-1, :] - a[-2, :] + pad[1:-1, 0] = 2 * a[:, 0] - a[:, 1] + pad[1:-1, -1] = 2 * a[:, -1] - a[:, -2] + pad[0, 0] = 2 * pad[0, 1] - pad[0, 2] + pad[0, -1] = 2 * pad[0, -2] - pad[0, -3] + pad[-1, 0] = 2 * pad[-1, 1] - pad[-1, 2] + pad[-1, -1] = 2 * pad[-1, -2] - pad[-1, -3] + return 0.25 * (pad[:-1, :-1] + pad[1:, :-1] + pad[:-1, 1:] + pad[1:, 1:]) + + +@dataclass +class _Grid: + """Cached cell geometry for a structured grid. + + ``polys`` is a flat (n_cells,) object array of shapely Polygons. + ``bounds`` is a (n_cells, 4) ``(minx, miny, maxx, maxy)`` array cached for + the STRtree / candidate-search path. ``rectilinear`` is True when both the + source x and y were 1D coordinate arrays (axis-aligned rectangles) — the + weight builder uses this to skip GEOS polygon clipping and compute + intersection areas analytically from the bounds. + """ + + polys: np.ndarray + bounds: np.ndarray + rectilinear: bool + + +def _rect_grid_from_edges(xe: np.ndarray, ye: np.ndarray) -> _Grid: + """Build a rectilinear :class:`_Grid` from already-prepared edge arrays. + + Shared by the raw-planar (:func:`_build_grid` 1D branch) and the analytic + equal-area (:func:`_build_cea_grid`) paths. + """ + x0, y0 = np.meshgrid(xe[:-1], ye[:-1], indexing="xy") + x1, y1 = np.meshgrid(xe[1:], ye[1:], indexing="xy") + x0f, y0f, x1f, y1f = x0.ravel(), y0.ravel(), x1.ravel(), y1.ravel() + polys = shapely.box(x0f, y0f, x1f, y1f) + bounds = np.stack([x0f, y0f, x1f, y1f], axis=1) + return _Grid(polys=polys, bounds=bounds, rectilinear=True) + + +def _build_grid(xc: np.ndarray, yc: np.ndarray) -> _Grid: + """Return a _Grid of cell geometry for a structured grid. + + Accepts 1D (rectilinear, separate x and y vectors) or 2D (curvilinear, + co-shaped center arrays) inputs. Output order is row-major in the input dim + order: for 2D inputs of shape (ny, nx) the polygons correspond to cells + reshaped as ``(ny, nx)``. + """ + _check_shapely() + if xc.ndim == 1 and yc.ndim == 1: + xe = utils.infer_1d_edges(xc.astype(float)) + ye = utils.infer_1d_edges(yc.astype(float)) + return _rect_grid_from_edges(xe, ye) + if xc.ndim == 2 and yc.ndim == 2 and xc.shape == yc.shape: + xcorn = _infer_2d_corners(xc) + ycorn = _infer_2d_corners(yc) + ny, nx = xc.shape + c00 = np.stack([xcorn[:-1, :-1], ycorn[:-1, :-1]], axis=-1) + c10 = np.stack([xcorn[:-1, 1:], ycorn[:-1, 1:]], axis=-1) + c11 = np.stack([xcorn[1:, 1:], ycorn[1:, 1:]], axis=-1) + c01 = np.stack([xcorn[1:, :-1], ycorn[1:, :-1]], axis=-1) + rings = np.stack([c00, c10, c11, c01, c00], axis=2).reshape(ny * nx, 5, 2) + polys = shapely.polygons(rings) + return _Grid(polys=polys, bounds=shapely.bounds(polys), rectilinear=False) + msg = "x and y coordinate arrays must both be 1D or both 2D" + raise ValueError(msg) + + +def _build_intersection_areas( + src: _Grid, + dst: _Grid, + n_threads: int | None = None, + *, + predicate_filter: bool = False, +) -> "sparse.COO | np.ndarray": + """Build the (n_dst, n_src) raw area-intersection matrix ``A[i, j] = + area(dst_i ∩ src_j)``. + + This is the unnormalized matrix. Row-normalize via :func:`_row_normalize` + to get forward weights; transpose first for backward (target → source). + + When both grids are rectilinear (axis-aligned rectangles) intersection + areas are computed analytically from the bounds, skipping GEOS clipping. + + ``predicate_filter=False`` (default) uses a bbox-only STRtree query and + relies on the ``area > 0`` filter below to drop bbox-false-positives. + For structured cells whose bboxes are tight (quadrilaterals) this is a + large win — the GEOS ``intersects`` predicate inside STRtree is much + more expensive than the extra no-op intersections it avoids. Set + ``predicate_filter=True`` for user-supplied polygons with loose bboxes + (long, thin, diagonal shapes) where the predicate pays for itself. + """ + _check_shapely() + n_dst = len(dst.polys) + n_src = len(src.polys) + + tree = STRtree(src.polys) + if predicate_filter: + pairs = tree.query(dst.polys, predicate="intersects") + else: + pairs = tree.query(dst.polys) + dst_idx = np.asarray(pairs[0]) + src_idx = np.asarray(pairs[1]) + + if dst_idx.size == 0: + return _empty_weights(n_dst, n_src) + + if src.rectilinear and dst.rectilinear: + sb = src.bounds[src_idx] + db = dst.bounds[dst_idx] + dx = np.minimum(sb[:, 2], db[:, 2]) - np.maximum(sb[:, 0], db[:, 0]) + dy = np.minimum(sb[:, 3], db[:, 3]) - np.maximum(sb[:, 1], db[:, 1]) + areas = np.maximum(dx, 0.0) * np.maximum(dy, 0.0) + else: + areas = _intersection_areas_threaded( + dst.polys[dst_idx], src.polys[src_idx], n_threads=n_threads + ) + + keep = areas > 0 + dst_idx = dst_idx[keep] + src_idx = src_idx[keep] + areas = areas[keep] + + if dst_idx.size == 0: + return _empty_weights(n_dst, n_src) + + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([dst_idx, src_idx]), + data=areas.astype(np.float64), + shape=(n_dst, n_src), + has_duplicates=False, + sorted=False, + ) + a_dense = np.zeros((n_dst, n_src), dtype=np.float64) + a_dense[dst_idx, src_idx] = areas + return a_dense + + +def _row_normalize( + areas: "sparse.COO | np.ndarray", +) -> "sparse.COO | np.ndarray": + """Normalize rows of an area matrix so each row sums to 1 (rows with no + overlap stay all-zero, which produces NaN output under the apply path).""" + if _HAS_SPARSE and isinstance(areas, sparse.COO): + n_dst = areas.shape[0] + dst_idx = areas.coords[0] + src_idx = areas.coords[1] + data = areas.data + if data.size == 0: + return areas + row_sum = np.bincount(dst_idx, weights=data, minlength=n_dst) + new_data = data / row_sum[dst_idx] + return sparse.COO( + coords=np.stack([dst_idx, src_idx]), + data=new_data, + shape=areas.shape, + has_duplicates=False, + sorted=False, + ) + row_sum = areas.sum(axis=1, keepdims=True) + row_sum = np.where(row_sum == 0, 1.0, row_sum) + return areas / row_sum + + +def _intersection_areas_threaded( + a: np.ndarray, b: np.ndarray, n_threads: int | None +) -> np.ndarray: + """Compute per-pair intersection areas ``area(a[i] & b[i])`` over numpy + arrays of shapely geometries, optionally parallelized via threads. + + Shapely 2.x releases the GIL for GEOS ops, so a ``ThreadPoolExecutor`` + gives near-linear speedup on multi-core machines without pickling data. + """ + _check_shapely() + n = len(a) + if n_threads is None: + # Below ~1k pairs the pool spin-up (~0.3 ms) dominates sub-ms work. + # Above that, scaling is near-linear with logical cores — shapely + # releases the GIL inside its GEOS ufuncs. Cap at 16 to avoid + # oversubscription on unusually wide machines. + n_threads = 1 if n < 1_000 else min(os.cpu_count() or 1, 16) + if n_threads <= 1 or n == 0: + return shapely.area(shapely.intersection(a, b)) + + splits = np.array_split(np.arange(n), n_threads) + + def _work(idx: np.ndarray) -> np.ndarray: + return shapely.area(shapely.intersection(a[idx], b[idx])) + + with ThreadPoolExecutor(max_workers=n_threads) as pool: + parts = list(pool.map(_work, splits)) + return np.concatenate(parts) + + +def _empty_weights(n_dst: int, n_src: int) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.zeros((2, 0), dtype=np.int64), + data=np.zeros(0, dtype=np.float64), + shape=(n_dst, n_src), + ) + return np.zeros((n_dst, n_src), dtype=np.float64) + + +def _apply_core( + arr: np.ndarray, + apply_weights: Any, + coverage: np.ndarray, + coverage_all: bool, + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + skipna: bool, + nan_threshold: float, + output_dtype: np.dtype, +) -> np.ndarray: + """Apply a pre-transposed weight matrix along the trailing spatial dims. + + ``arr`` has shape ``(..., *src_shape)``; ``apply_weights`` has shape + ``(n_src, n_dst)`` — returns ``(..., *dst_shape)``. + + ``coverage`` is a boolean ``(n_dst,)`` mask: target cells with any source + overlap. ``coverage_all`` is precomputed so every block skips the + ``coverage.all()`` scan. Uncovered cells are always masked to NaN in the + output, regardless of ``skipna`` — domain boundaries and polygon holes + produce NaN (matches the axis-factored ``conservative`` method). + """ + n_spatial = len(src_shape) + leading_shape = arr.shape[:-n_spatial] if n_spatial > 0 else arr.shape + n_src = int(np.prod(src_shape)) + flat = arr.reshape(-1, n_src) if leading_shape else arr.reshape(1, n_src) + + if skipna and np.issubdtype(flat.dtype, np.floating): + nan_mask = np.isnan(flat) + has_nan = nan_mask.any() + else: + has_nan = False + + if has_nan: + mask = (~nan_mask).astype(flat.dtype) + filled = np.where(nan_mask, flat.dtype.type(0.0), flat) + numerator = np.asarray(filled @ apply_weights) + fraction = np.asarray(mask @ apply_weights) + threshold = get_valid_threshold(nan_threshold) + with np.errstate(invalid="ignore", divide="ignore"): + result = numerator / fraction + result = np.where(fraction >= threshold, result, np.nan) + else: + result = np.asarray(flat @ apply_weights) + if not coverage_all: + result = np.where(coverage[np.newaxis, :], result, np.nan) + + # sparse.matmul promotes to float64 regardless of the input dtype — cast + # back to the requested output dtype so float32-in really produces + # float32-out (halves memory for float32 pipelines). + if result.dtype != output_dtype: + result = result.astype(output_dtype, copy=False) + + out_shape = (*leading_shape, *dst_shape) if leading_shape else dst_shape + return result.reshape(out_shape) + + +def _result_dtype(obj: xr.DataArray | xr.Dataset) -> np.dtype: + if isinstance(obj, xr.DataArray): + return np.result_type(np.float32, obj.dtype) + dtypes = [v.dtype for v in obj.data_vars.values()] + if not dtypes: + return np.dtype(np.float64) + return np.result_type(np.float32, *dtypes) + + +def _assign_target_coords( + obj: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + dst_dims: tuple[Hashable, ...], + x_coord: str, + y_coord: str, +) -> xr.DataArray | xr.Dataset: + """Attach target coordinates that name a spatial axis or live on the + output spatial dims. Scalar coords (``dims == ()``) ride along too, so + pinned target metadata (e.g. a fixed timestamp) is preserved.""" + dst_dim_set = set(dst_dims) + new_coords = { + name: coord + for name, coord in target_ds.coords.items() + if name in (x_coord, y_coord) or set(coord.dims).issubset(dst_dim_set) + } + return obj.assign_coords(new_coords) if new_coords else obj diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index b2ed389..fda6963 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -4,7 +4,12 @@ import numpy as np import xarray as xr -from xarray_regrid.methods import conservative, flox_reduce, interp +from xarray_regrid.methods import ( + conservative, + conservative_2d, + flox_reduce, + interp, +) from xarray_regrid.utils import format_for_regrid @@ -17,7 +22,11 @@ class Regridder: linear: linear, bilinear, or higher dimensional linear interpolation nearest: nearest-neighbor regridding cubic: cubic spline regridding - conservative: conservative regridding + conservative: axis-factored conservative regridding (rectilinear, + 1D-separable grids only) + conservative_2d: conservative regridding for grids that aren't + 1D-separable — curvilinear 2D coords, unstructured meshes, or + arbitrary polygon-to-polygon aggregation (requires shapely) most_common: most common value regridder stat: area statistics regridder """ @@ -82,6 +91,62 @@ def cubic( ds_formatted = format_for_regrid(self._obj, ds_target_grid) return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic") + def conservative_2d( + self, + ds_target_grid: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + spherical: bool = False, + time_dim: str | None = "time", + skipna: bool = True, + nan_threshold: float = 1.0, + n_threads: int | None = None, + ) -> xr.DataArray | xr.Dataset: + """Conservative regrid for grids that aren't 1D-separable. + + Use this when ``.conservative`` can't express your grid: curvilinear + coordinates (2D ``lat``/``lon`` arrays), unstructured meshes, or any + arbitrary polygon target. Computes 2D cell-polygon intersections via + shapely. Defaults to planar geometry; set ``spherical=True`` for + lat/lon grids in degrees to get proper spherical area weights via an + analytic cylindrical equal-area projection. Requires ``shapely >= 2.0``. + + Args: + ds_target_grid: Dataset defining the target grid; must expose + ``x_coord`` and ``y_coord`` as coordinate variables. + x_coord: Name of the x (longitude-like) coordinate variable. + y_coord: Name of the y (latitude-like) coordinate variable. + spherical: If True, assume coords are longitude/latitude in + degrees and apply a Lambert cylindrical equal-area projection + before intersecting. Rectilinear (1D coord) grids only. + time_dim: Name of the time dimension. Defaults to ``"time"``. Use + ``None`` to force regridding over the time dimension. + skipna: If True, propagate NaNs into the weighted mean via a + two-pass sum. + nan_threshold: Keep output cells whose valid source fraction is at + least ``nan_threshold``. + n_threads: Thread count for parallel GEOS intersection. ``None`` + auto-selects; set to ``1`` to disable threading. + + Returns: + Data regridded to the target grid. + """ + if not 0.0 <= nan_threshold <= 1.0: + msg = "nan_threshold must be between [0, 1]" + raise ValueError(msg) + ds_target_grid = validate_input( + self._obj, ds_target_grid, time_dim, require_shared_dims=False + ) + regridder = conservative_2d.ConservativeRegridder( + self._obj, + ds_target_grid, + x_coord=x_coord, + y_coord=y_coord, + spherical=spherical, + n_threads=n_threads, + ) + return regridder.regrid(self._obj, skipna=skipna, nan_threshold=nan_threshold) + def conservative( self, ds_target_grid: xr.Dataset, @@ -116,7 +181,7 @@ def conservative( Data regridded to the target dataset coordinates. """ if not 0.0 <= nan_threshold <= 1.0: - msg = "nan_threshold must be between [0, 1]]" + msg = "nan_threshold must be between [0, 1]" raise ValueError(msg) ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) @@ -275,6 +340,7 @@ def validate_input( data: xr.Dataset, ds_target_grid: xr.Dataset, time_dim: str | None, + require_shared_dims: bool = ..., ) -> xr.Dataset: ... @@ -283,6 +349,7 @@ def validate_input( data: xr.DataArray, ds_target_grid: xr.Dataset, time_dim: str | None, + require_shared_dims: bool = ..., ) -> xr.Dataset: ... @@ -290,11 +357,14 @@ def validate_input( data: xr.DataArray | xr.Dataset, ds_target_grid: xr.Dataset, time_dim: str | None, + require_shared_dims: bool = True, ) -> xr.Dataset: if time_dim is not None and time_dim in ds_target_grid.coords: ds_target_grid = ds_target_grid.isel({time_dim: 0}).reset_coords() - if len(set(data.dims).intersection(set(ds_target_grid.dims))) == 0: + # Curvilinear regridders match source and target by coord values, not by + # dim name, so they opt out of the shared-dim requirement. + if require_shared_dims and not set(data.dims) & set(ds_target_grid.dims): msg = ( "None of the target dims are in the data:\n" " regridding is not possible.\n" @@ -303,7 +373,7 @@ def validate_input( ) raise ValueError(msg) - if len(set(data.coords).intersection(set(ds_target_grid.coords))) == 0: + if not set(data.coords) & set(ds_target_grid.coords): msg = ( "None of the target coords are in the data:\n" " regridding is not possible.\n" diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index 6979f84..7253cca 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -80,7 +80,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]: grid.south, grid.north + grid.resolution_lat, grid.resolution_lat ) - if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0: + if np.remainder((grid.east - grid.west), grid.resolution_lon) > 0: lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon) else: lon_coords = np.arange( @@ -114,6 +114,22 @@ def create_regridding_dataset( ) +def infer_1d_edges(centers: np.ndarray) -> np.ndarray: + """Return cell edges from 1D centers: midpoints between consecutive + centers, with symmetric reflection for the two outer bounds. + + Requires at least two centers. + """ + c = np.asarray(centers, dtype=float) + if c.size < 2: + msg = "need at least two centers to infer cell edges" + raise ValueError(msg) + mids = 0.5 * (c[:-1] + c[1:]) + left = 2 * c[0] - mids[0] + right = 2 * c[-1] - mids[-1] + return np.concatenate([[left], mids, [right]]) + + def to_intervalindex(coords: np.ndarray) -> pd.IntervalIndex: """Convert a 1-d coordinate array to a pandas IntervalIndex. Take the midpoints between the coordinates as the interval boundaries. @@ -126,20 +142,9 @@ def to_intervalindex(coords: np.ndarray) -> pd.IntervalIndex: coordinates. """ if len(coords) > 1: - midpoints = (coords[:-1] + coords[1:]) / 2 - - # Extrapolate outer bounds beyond the first and last coordinates - left_bound = 2 * coords[0] - midpoints[0] - right_bound = 2 * coords[-1] - midpoints[-1] - - breaks = np.concatenate([[left_bound], midpoints, [right_bound]]) - intervals = pd.IntervalIndex.from_breaks(breaks) - - else: - # If the target grid has a single point, set search interval to span all space - intervals = pd.IntervalIndex.from_breaks([-np.inf, np.inf]) - - return intervals + return pd.IntervalIndex.from_breaks(infer_1d_edges(coords)) + # If the target grid has a single point, set search interval to span all space + return pd.IntervalIndex.from_breaks([-np.inf, np.inf]) def overlap(a: pd.IntervalIndex, b: pd.IntervalIndex) -> np.ndarray: @@ -436,8 +441,8 @@ def update_coord( def update_coord( obj: xr.DataArray | xr.Dataset, coord: Hashable, coord_vals: np.ndarray ) -> xr.DataArray | xr.Dataset: - """Update the values of a coordinate, ensuring indexes stay in sync.""" - attrs = obj.coords[coord].attrs - obj = obj.assign_coords({coord: coord_vals}) - obj.coords[coord].attrs = attrs - return obj + """Update the values of a coordinate, ensuring indexes stay in sync. + Preserves the coord's existing dims and attrs (so multi-dim coords work).""" + original = obj.coords[coord] + new_coord = xr.DataArray(coord_vals, dims=original.dims, attrs=original.attrs) + return obj.assign_coords({coord: new_coord}) diff --git a/tests/test_conservative_2d.py b/tests/test_conservative_2d.py new file mode 100644 index 0000000..b81fab3 --- /dev/null +++ b/tests/test_conservative_2d.py @@ -0,0 +1,693 @@ +"""Tests for conservative_2d.""" + +import numpy as np +import pytest +import xarray as xr + +import xarray_regrid # noqa: F401 (registers the accessor) +from xarray_regrid import ConservativeRegridder, polygons_from_coords + +shapely = pytest.importorskip("shapely") + + +def _rect_da(ny=60, nx=120, nt=2, seed=0): + x = np.linspace(-180, 180, nx, endpoint=False) + 180 / nx + y = np.linspace(-90, 90, ny, endpoint=False) + 90 / ny + rng = np.random.default_rng(seed) + return xr.DataArray( + rng.normal(size=(nt, ny, nx)).astype(np.float64), + dims=("time", "y", "x"), + coords={"time": np.arange(nt), "y": y, "x": x}, + name="var", + ) + + +def _rect_target(ny=24, nx=47): + x = np.linspace(-180, 180, nx, endpoint=False) + 180 / nx + y = np.linspace(-90, 90, ny, endpoint=False) + 90 / ny + return xr.Dataset(coords={"y": y, "x": x}) + + +def test_polygon_matches_factored_planar(): + """On a rectilinear grid with no spherical correction, the polygon path + should reproduce the axis-factored path to machine precision.""" + da = _rect_da() + target = _rect_target() + ref = da.regrid.conservative(target, latitude_coord=None) + got = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + got = got.transpose(*ref.dims) + np.testing.assert_allclose(got.values, ref.values, atol=1e-12) + + +def test_polygon_dask_time_chunks(): + da = _rect_da(nt=4).chunk({"time": 2}) + target = _rect_target() + got = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + assert got.chunks is not None + got = got.compute() + ref = _rect_da(nt=4).regrid.conservative(target, latitude_coord=None) + np.testing.assert_allclose(got.transpose(*ref.dims).values, ref.values, atol=1e-12) + + +def test_polygon_rechunks_spatial(): + """Spatially-chunked input should be accepted (rechunked internally).""" + da = _rect_da().chunk({"time": 1, "y": 30, "x": 40}) + target = _rect_target() + out = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + out.compute() + + +def test_polygon_nan_threshold(): + """Stricter nan_threshold produces more NaN output cells when partial + overlaps exist.""" + da = _rect_da() + da.values[:, 21:29, :] = np.nan + target = _rect_target(ny=23, nx=47) + out1 = da.regrid.conservative_2d( + target, x_coord="x", y_coord="y", skipna=True, nan_threshold=1.0 + ) + out0 = da.regrid.conservative_2d( + target, x_coord="x", y_coord="y", skipna=True, nan_threshold=0.0 + ) + assert int(np.isnan(out0.values).sum()) > int(np.isnan(out1.values).sum()) + + +def test_polygon_curvilinear_target(): + """Curvilinear target (2D lat/lon corners) returns finite values.""" + da = _rect_da() + ny_t, nx_t = 20, 30 + xi, yi = np.meshgrid( + np.linspace(-120, 120, nx_t), + np.linspace(-60, 60, ny_t), + indexing="xy", + ) + th = np.deg2rad(30) + x2d = xi * np.cos(th) - yi * np.sin(th) + y2d = xi * np.sin(th) + yi * np.cos(th) + target = xr.Dataset(coords={"x": (("ny", "nx"), x2d), "y": (("ny", "nx"), y2d)}) + out = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + assert out.shape == (2, 20, 30) + assert np.isfinite(out.values).mean() > 0.9 + + +def test_antimeridian_rectilinear_constant(): + da = xr.DataArray( + np.full((2, 4), 2.5), + dims=("latitude", "longitude"), + coords={ + "latitude": np.array([-2.5, 2.5]), + "longitude": np.array([167.5, 172.5, -177.5, -172.5]), + }, + ) + target = xr.Dataset( + coords={ + "latitude": np.array([-2.5, 2.5]), + "longitude": np.array([170.0, -170.0]), + } + ) + + out = da.regrid.conservative_2d(target, x_coord="longitude", y_coord="latitude") + np.testing.assert_allclose(out.values, 2.5, atol=1e-12) + + +def test_cross_convention_longitude_alignment(): + """Source on [0, 360] with target on [-180, 180] (and vice versa) must + align — a uniform shift can't reconcile the two conventions, so per-value + wrap of source longitudes is required. Regression: previously yielded + NaN on half the target cells because banker's rounding on the exact-180° + mean diff produced a zero offset.""" + src_vals = np.array([1.0, 2.0, 3.0, 4.0]) + tgt_vals_neg = np.array([-135.0, -45.0, 45.0, 135.0]) + src_vals_neg_x = np.array([45.0, 135.0, 225.0, 315.0]) + da = xr.DataArray( + np.broadcast_to(src_vals, (2, 4)).copy(), + dims=("latitude", "longitude"), + coords={"latitude": [-30.0, 30.0], "longitude": src_vals_neg_x}, + ) + target = xr.Dataset(coords={"latitude": [-30.0, 30.0], "longitude": tgt_vals_neg}) + expected = da.regrid.conservative(target).transpose("latitude", "longitude") + out_planar = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude" + ).transpose("latitude", "longitude") + out_spherical = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", spherical=True + ).transpose("latitude", "longitude") + np.testing.assert_allclose(out_planar.values, expected.values, atol=1e-12) + np.testing.assert_allclose(out_spherical.values, expected.values, atol=1e-12) + + # Reverse: source on [-180, 180], target on [0, 360]. + da_rev = xr.DataArray( + np.broadcast_to(src_vals, (2, 4)).copy(), + dims=("latitude", "longitude"), + coords={"latitude": [-30.0, 30.0], "longitude": tgt_vals_neg}, + ) + target_rev = xr.Dataset( + coords={"latitude": [-30.0, 30.0], "longitude": src_vals_neg_x} + ) + expected_rev = da_rev.regrid.conservative(target_rev).transpose( + "latitude", "longitude" + ) + out_rev = da_rev.regrid.conservative_2d( + target_rev, x_coord="longitude", y_coord="latitude" + ).transpose("latitude", "longitude") + np.testing.assert_allclose(out_rev.values, expected_rev.values, atol=1e-12) + + +def test_polygon_nan_threshold_invalid(): + da = _rect_da() + with pytest.raises(ValueError): + da.regrid.conservative_2d( + _rect_target(), x_coord="x", y_coord="y", nan_threshold=1.5 + ) + + +def test_polygon_dataset_input(): + """A Dataset input with multiple variables should regrid all of them.""" + da = _rect_da() + ds = xr.Dataset({"a": da, "b": da * 2.0}) + target = _rect_target() + out = ds.regrid.conservative_2d(target, x_coord="x", y_coord="y") + assert set(out.data_vars) == {"a", "b"} + np.testing.assert_allclose( + out["b"].transpose(*out["a"].dims).values, + out["a"].transpose(*out["a"].dims).values * 2.0, + atol=1e-12, + ) + + +# --- ConservativeRegridder (reusable) ------------------------------------------ + + +def test_regridder_reusable_matches_oneshot(): + """Reusing a single ConservativeRegridder on multiple fields matches the + one-shot `conservative_2d_regrid` call.""" + da1 = _rect_da(seed=1) + da2 = _rect_da(seed=2) + target = _rect_target() + regridder = ConservativeRegridder(da1, target, x_coord="x", y_coord="y") + ref1 = da1.regrid.conservative_2d(target, x_coord="x", y_coord="y") + ref2 = da2.regrid.conservative_2d(target, x_coord="x", y_coord="y") + out1 = regridder.regrid(da1) + out2 = regridder(da2) # __call__ alias + np.testing.assert_allclose(out1.values, ref1.values, atol=1e-12) + np.testing.assert_allclose(out2.values, ref2.values, atol=1e-12) + + +def test_regridder_weight_cache(): + """Forward weight matrix is built lazily and then reused across calls.""" + da = _rect_da() + target = _rect_target() + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + assert "forward_weights" not in regridder.__dict__ + regridder.regrid(da) + w1 = regridder.forward_weights + regridder.regrid(da) + assert regridder.forward_weights is w1 # same object, not rebuilt + + +def test_regridder_transpose_roundtrip_rectilinear_aligned(): + """If target cells are aligned unions of source cells, forward followed by + backward reproduces a constant field exactly.""" + # 120 source cells along each axis, target is a 4x coarsening (exact union + # of source cells). A constant source field survives the roundtrip to + # itself because every source cell is fully covered. + ns = 120 + nt = 30 # exactly ns / 4 + x_s = np.linspace(-180, 180, ns, endpoint=False) + 180 / ns + y_s = np.linspace(-90, 90, ns, endpoint=False) + 90 / ns + x_t = np.linspace(-180, 180, nt, endpoint=False) + 180 / nt + y_t = np.linspace(-90, 90, nt, endpoint=False) + 90 / nt + da = xr.DataArray( + np.full((ns, ns), 3.5, dtype=np.float64), + dims=("y", "x"), + coords={"y": y_s, "x": x_s}, + ) + target = xr.Dataset(coords={"y": y_t, "x": x_t}) + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + coarse = regridder.regrid(da) + back = regridder.T.regrid(coarse) + # Every value should match the original constant. + np.testing.assert_allclose(back.values, 3.5, atol=1e-12) + + +def test_regridder_T_preserves_weights(): # noqa: N802 + """regridder.T.T should share the raw area matrix with the original.""" + da = _rect_da() + target = _rect_target() + r = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + rr = r.T.T + # Same shape, same coords, same data. + assert r.areas.shape == rr.areas.shape + if hasattr(r.areas, "data"): + np.testing.assert_array_equal(r.areas.data, rr.areas.data) + + +def test_regridder_shape_mismatch_raises(): + """Applying the regridder to data whose spatial shape differs from the + source it was built for should raise.""" + da = _rect_da() + target = _rect_target() + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + smaller = _rect_da(ny=30, nx=60) + with pytest.raises(ValueError, match="spatial shape"): + regridder.regrid(smaller) + + +def test_spherical_mode_matches_factored(): + """Polygon path with ``spherical=True`` should match the axis-factored + sin-weighted conservative path to within a tight tolerance on lat/lon + grids (both are analytically equivalent for cylindrical equal-area).""" + lon_s = np.linspace(-180, 180, 180, endpoint=False) + 1.0 + lat_s = np.linspace(-90, 90, 90, endpoint=False) + 1.0 + lon_t = np.linspace(-180, 180, 60, endpoint=False) + 3.0 + lat_t = np.linspace(-90, 90, 30, endpoint=False) + 3.0 + vals = np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.sin(np.deg2rad(lon_s))[None, :] + da = xr.DataArray( + vals, + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + factored = da.regrid.conservative(target, latitude_coord="latitude") + polygon = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", spherical=True + ) + # Both methods should agree to the grid's own quadrature accuracy. Near the + # poles the factored path's median-dlat approximation introduces a small + # discrepancy; 1e-3 absolute is well below the planar raw error floor + # (~1e-2 at this resolution, tested separately). + np.testing.assert_allclose( + polygon.transpose(*factored.dims).values, + factored.values, + atol=1e-3, + ) + + +def test_spherical_conserves_integral(): + """Mass conservation check on the sphere. For cos^2(lat), true integral is + 8*pi/3; the regridder on a 2-to-6-degree grid should keep the + spherical-area-weighted sum within the grid quadrature floor when + spherical=True, and miss it by ~17x more when spherical=False.""" + lon_s = np.linspace(-180, 180, 180, endpoint=False) + 1.0 + lat_s = np.linspace(-90, 90, 90, endpoint=False) + 1.0 + lon_t = np.linspace(-180, 180, 60, endpoint=False) + 3.0 + lat_t = np.linspace(-90, 90, 30, endpoint=False) + 3.0 + da = xr.DataArray( + np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.ones(lon_s.size)[None, :], + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + out_sph = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", spherical=True + ) + out_raw = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", spherical=False + ) + + # True target spherical cell areas + dlon_arr = np.full(lon_t.size, np.deg2rad(np.mean(np.diff(lon_t)))) + lat_r = np.deg2rad(lat_t) + dlat_r = np.gradient(lat_r) + dlat_bands = np.sin(lat_r + dlat_r / 2) - np.sin(lat_r - dlat_r / 2) + a_tgt = dlat_bands[:, None] * dlon_arr[None, :] + + true_val = 8 * np.pi / 3 + sph_vals = out_sph.transpose("latitude", "longitude").values + raw_vals = out_raw.transpose("latitude", "longitude").values + err_sph = abs(float((sph_vals * a_tgt).sum()) - true_val) + err_raw = abs(float((raw_vals * a_tgt).sum()) - true_val) + # Spherical should be at least 10x more accurate than raw planar here. + assert err_sph < 0.1 * err_raw, f"err_sph={err_sph:.2e} err_raw={err_raw:.2e}" + + +# --- from_polygons (unstructured mesh) ---------------------------------------- + + +def _box_polygons(): + rng = np.random.default_rng(1) + n = 50 + cx = rng.uniform(-170, 170, n) + cy = rng.uniform(-80, 80, n) + return shapely.box(cx - 5, cy - 5, cx + 5, cy + 5) + + +def test_polygons_from_coords_periodic(): + polys = polygons_from_coords( + np.array([167.5, 172.5, -177.5, -172.5]), + np.array([-2.5, 2.5]), + periodic=True, + ) + bounds = shapely.bounds(polys) + widths = bounds[:, 2] - bounds[:, 0] + assert np.all(widths < 10.1) + + +def test_from_polygons_basic(): + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 30, endpoint=False) + 6, + np.linspace(-90, 90, 15, endpoint=False) + 6, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + da = xr.DataArray( + np.arange(src_polys.size, dtype=np.float64), + dims=("face",), + ) + out = rgr.regrid(da) + assert out.dims == ("cell",) + assert out.sizes["cell"] == tgt_polys.size + + +def test_from_polygons_attaches_target_aux_coords(): + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + region_id = np.arange(tgt_polys.size) + 100 + target_coords = xr.Dataset( + coords={ + "cell": np.arange(tgt_polys.size), + "region_id": ("cell", region_id), + } + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, + tgt_polys, + source_dim="face", + target_dim="cell", + target_coords=target_coords, + ) + da = xr.DataArray(np.arange(src_polys.size, dtype=np.float64), dims=("face",)) + out = rgr.regrid(da) + assert "region_id" in out.coords + np.testing.assert_array_equal(out["region_id"].values, region_id) + + +def test_from_polygons_periodic_antimeridian(): + src_polys = np.array( + [shapely.Polygon([(175, -5), (-175, -5), (-175, 5), (175, 5)])], + dtype=object, + ) + tgt_polys = np.array( + [ + shapely.box(160, -5, 170, 5), + shapely.box(175, -5, 185, 5), + shapely.box(-170, -5, -160, 5), + ], + dtype=object, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, + tgt_polys, + source_dim="src", + target_dim="tgt", + periodic=True, + ) + out = rgr.regrid(xr.DataArray([7.0], dims=("src",))) + + assert np.isnan(out.values[0]) + assert out.values[1] == pytest.approx(7.0) + assert np.isnan(out.values[2]) + + +def test_from_polygons_mass_conservation(): + """Sum of intersected mass should match the direct A·s calculation to + machine precision for any source field.""" + + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 36, endpoint=False) + 5, + np.linspace(-90, 90, 18, endpoint=False) + 5, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + rng = np.random.default_rng(3) + s = rng.normal(size=src_polys.size) + da = xr.DataArray(s, dims=("face",)) + out = rgr.regrid(da).values + # Direct mass = sum_i s_i * source_coverage_i. Matches output if we + # multiply output by target-covered area. + tgt_covered = rgr.target_areas + valid = tgt_covered > 0 + direct = float((s * rgr.source_coverage_areas).sum()) + via_regrid = float((out[valid] * tgt_covered[valid]).sum()) + rel = abs(direct - via_regrid) / max(abs(direct), 1e-12) + assert rel < 1e-12, f"rel err {rel:.2e}" + + +def test_from_polygons_transpose_roundtrip(): + """Roundtrip mesh ↔ mesh of a constant field returns the constant.""" + src_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="src", target_dim="tgt" + ) + da = xr.DataArray(np.full(src_polys.size, 3.5), dims=("src",)) + out = rgr.regrid(da) + back = rgr.T.regrid(out) + # Inner source cells (fully covered by target cells they map to) must be 3.5. + # Edge cells may be NaN if target domain doesn't cover them. + finite = np.isfinite(back.values) + np.testing.assert_allclose(back.values[finite], 3.5, atol=1e-12) + + +def test_from_polygons_nan_propagation(): + """NaN source cells propagate through skipna=True correctly.""" + src_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="src", target_dim="tgt" + ) + vals = np.full(src_polys.size, 1.0) + vals[:5] = np.nan + da = xr.DataArray(vals, dims=("src",)) + out_keep = rgr.regrid(da, nan_threshold=1.0) + out_strict = rgr.regrid(da, nan_threshold=0.0) + # Strict should have at least as many NaNs. + strict_nans = int(np.isnan(out_strict.values).sum()) + keep_nans = int(np.isnan(out_keep.values).sum()) + assert strict_nans >= keep_nans + + +def test_from_polygons_target_outside_source_is_nan(): + """Target cells entirely outside the source domain must be NaN regardless + of skipna (regression test for a bug where the "skip mask matmul when no + NaNs" optimization returned zeros for uncovered cells).""" + + src = np.array([shapely.box(0, 0, 1, 1)], dtype=object) + tgt = polygons_from_coords( + np.linspace(100, 110, 4, endpoint=False) + 1.25, + np.linspace(100, 110, 4, endpoint=False) + 1.25, + ) + rgr = ConservativeRegridder.from_polygons(src, tgt, source_dim="src") + # Source has no NaNs → used to wrongly return 0 here. + da = xr.DataArray(np.array([5.0]), dims=("src",)) + out = rgr.regrid(da, skipna=True) + assert np.isnan(out.values).all() + out_no = rgr.regrid(da, skipna=False) + assert np.isnan(out_no.values).all() + + +def test_from_polygons_hole_is_nan(): + """A target cell fully inside a source-polygon hole should be NaN, not 0.""" + + ring_with_hole = shapely.Polygon( + [(0, 0), (10, 0), (10, 10), (0, 10)], + [[(3, 3), (7, 3), (7, 7), (3, 7)]], + ) + src = np.array([ring_with_hole], dtype=object) + tgt = polygons_from_coords( + np.linspace(0, 10, 5, endpoint=False) + 1, + np.linspace(0, 10, 5, endpoint=False) + 1, + ) + rgr = ConservativeRegridder.from_polygons(src, tgt, source_dim="src") + out = rgr.regrid(xr.DataArray([7.0], dims=("src",))) + # The cell centered at (5,5), edges [4,6]x[4,6], lies entirely in the hole. + assert int(np.isnan(out.values).sum()) >= 1 + + +def test_from_polygons_input_validation(): + # 2D polygons array rejected + p = shapely.box(0, 0, 1, 1) + arr_2d = np.array([[p, p], [p, p]], dtype=object) + with pytest.raises(ValueError, match="1D"): + ConservativeRegridder.from_polygons(arr_2d, np.array([p])) + + +# --- netCDF save / load ------------------------------------------------------- + + +def test_regrid_preserves_input_dtype(): + """Float32 in → float32 out; float64 in → float64 out. Sparse promotes + to float64 internally, so we rely on an explicit cast at the end of + ``_apply_core``.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + assert rgr.regrid(da.astype(np.float32)).dtype == np.float32 + assert rgr.regrid(da.astype(np.float64)).dtype == np.float64 + # Integer inputs promote (float32 can't hold int32 without precision loss). + assert np.issubdtype(rgr.regrid((da * 10).astype(np.int32)).dtype, np.floating) + + +def test_to_netcdf_roundtrip_structured(tmp_path): + """Save, reload, regrid → identical output to the original regridder.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + out_before = rgr.regrid(da).values + + path = tmp_path / "regridder.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + + np.testing.assert_array_equal(rgr2.regrid(da).values, out_before) + assert rgr2.x_coord == "x" + assert rgr2.y_coord == "y" + assert rgr2.spherical is False + assert rgr2._src_dims == rgr._src_dims + assert rgr2._dst_dims == rgr._dst_dims + + +def test_to_netcdf_preserves_spherical_flag(tmp_path): + lat_s = np.linspace(-90, 90, 30, endpoint=False) + 3 + lon_s = np.linspace(-180, 180, 60, endpoint=False) + 3 + lat_t = np.linspace(-90, 90, 15, endpoint=False) + 6 + lon_t = np.linspace(-180, 180, 30, endpoint=False) + 6 + da = xr.DataArray( + np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.ones(lon_s.size)[None, :], + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + rgr = ConservativeRegridder( + da, target, x_coord="longitude", y_coord="latitude", spherical=True + ) + before = rgr.regrid(da).values + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + assert rgr2.spherical is True + np.testing.assert_allclose(rgr2.regrid(da).values, before, atol=1e-15) + + +def test_to_netcdf_transpose_works_after_reload(tmp_path): + """A reloaded regridder can still take .T for backward regridding.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + + fwd = rgr.regrid(da) + back_original = rgr.T.regrid(fwd).values + back_reloaded = rgr2.T.regrid(fwd).values + np.testing.assert_array_equal(back_original, back_reloaded) + + +def test_to_netcdf_unstructured_roundtrip(tmp_path): + """from_polygons regridder roundtrips too (single spatial dim each side).""" + rng = np.random.default_rng(1) + cx = rng.uniform(-170, 170, 30) + cy = rng.uniform(-80, 80, 30) + src_polys = shapely.box(cx - 5, cy - 5, cx + 5, cy + 5) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + da = xr.DataArray(rng.normal(size=30), dims=("face",)) + out_before = rgr.regrid(da).values + + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + np.testing.assert_array_equal(rgr2.regrid(da).values, out_before) + + +def test_to_netcdf_metadata_fields(tmp_path): + """Metadata captures grid ranges, version, created timestamp.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + + with xr.open_dataset(path) as ds: + attrs = dict(ds.attrs) + + assert attrs["x_coord"] == "x" + assert attrs["y_coord"] == "y" + assert bool(int(attrs["spherical"])) is False + assert tuple(int(size) for size in attrs["src_shape"]) == rgr._src_shape + assert tuple(int(size) for size in attrs["dst_shape"]) == rgr._dst_shape + # Grid ranges captured when the coord is present in source/target. + assert "source_x_range" in attrs + assert "target_x_range" in attrs + assert attrs["source_x_range"][0] <= attrs["source_x_range"][1] + assert attrs["created"] + assert int(attrs["schema_version"]) == 1 + + +def test_from_netcdf_rejects_unknown_schema(tmp_path): + """Loading a file written with a future schema version raises cleanly.""" + h5py = pytest.importorskip("h5py") + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + + # Bump the on-disk schema_version so the loader should reject it. + with h5py.File(path, "a") as f: + f.attrs["schema_version"] = 999 + + with pytest.raises(ValueError, match="schema version"): + ConservativeRegridder.from_netcdf(path) + + +def test_regridder_transpose_curvilinear(): + """Transpose works when the target is a curvilinear grid with different + dim names from the source.""" + da = _rect_da(ny=40, nx=80) + ny_t, nx_t = 20, 30 + xi, yi = np.meshgrid( + np.linspace(-120, 120, nx_t), + np.linspace(-60, 60, ny_t), + indexing="xy", + ) + th = np.deg2rad(15) + x2 = xi * np.cos(th) - yi * np.sin(th) + y2 = xi * np.sin(th) + yi * np.cos(th) + target = xr.Dataset(coords={"x": (("ny", "nx"), x2), "y": (("ny", "nx"), y2)}) + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + fwd = regridder.regrid(da) + assert fwd.dims == ("time", "ny", "nx") + # Going backward we should land back on (time, y, x) + back = regridder.T.regrid(fwd) + assert "y" in back.dims and "x" in back.dims + assert back.sizes["y"] == da.sizes["y"] + assert back.sizes["x"] == da.sizes["x"]