diff --git a/docs/notebooks/demos/demo_conservative_2d_curvilinear.ipynb b/docs/notebooks/demos/demo_conservative_2d_curvilinear.ipynb new file mode 100644 index 0000000..9eb7b17 --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_curvilinear.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — curvilinear target\n", + "\n", + "Curvilinear grids have 2D `lat(y, x)` / `lon(y, x)` coordinate arrays —\n", + "ocean models (ORCA, tripolar), rotated regional forecasts. Not\n", + "1D-separable, so `.conservative` can't handle them;\n", + "`.regrid.conservative_2d` is the tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Source — regular 1° lat/lon, analytic two-bump field" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(-60, 60, 121)\n", + "lon = np.linspace(-120, 120, 241)\n", + "Lo, La = np.meshgrid(lon, lat)\n", + "field = (\n", + " np.exp(-((Lo - 40) ** 2 + (La - 20) ** 2) / 500)\n", + " - np.exp(-((Lo + 60) ** 2 + (La + 15) ** 2) / 400)\n", + ")\n", + "src = xr.DataArray(\n", + " field,\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat, \"longitude\": lon},\n", + ")\n", + "src.plot(figsize=(8, 3.5), cmap=\"RdBu_r\", center=0)\n", + "plt.title(\"source: analytic two-bump field\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Target — rotated curvilinear grid\n", + "\n", + "Coordinates ride on a `(ny, nx)` mesh, stored as 2D coordinate\n", + "variables on the target Dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "ny, nx = 30, 50\n", + "xi, yi = np.meshgrid(\n", + " np.linspace(-110, 110, nx),\n", + " np.linspace(-45, 45, ny),\n", + " indexing=\"xy\",\n", + ")\n", + "th = np.deg2rad(30)\n", + "lon2d = xi * np.cos(th) - yi * np.sin(th)\n", + "lat2d = xi * np.sin(th) + yi * np.cos(th)\n", + "target = xr.Dataset(coords={\n", + " \"longitude\": ((\"ny\", \"nx\"), lon2d),\n", + " \"latitude\": ((\"ny\", \"nx\"), lat2d),\n", + "})\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.plot(lon2d, lat2d, color=\"0.3\", lw=0.4)\n", + "ax.plot(lon2d.T, lat2d.T, color=\"0.3\", lw=0.4)\n", + "ax.set_title(\"curvilinear target (30° rotation)\")\n", + "ax.set_xlabel(\"longitude\")\n", + "ax.set_ylabel(\"latitude\")\n", + "ax.set_aspect(\"equal\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Regrid and plot\n", + "\n", + "One-shot form: `src.regrid.conservative_2d(target, x_coord=..., y_coord=...)`.\n", + "Constructing the class directly (as below) is equivalent and lets us reuse\n", + "the weight matrix across multiple applies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "rgr = ConservativeRegridder(\n", + " src, target, x_coord=\"longitude\", y_coord=\"latitude\",\n", + ")\n", + "regridded = rgr.regrid(src)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "pc = ax.pcolormesh(lon2d, lat2d, regridded.values, cmap=\"RdBu_r\",\n", + " shading=\"auto\", vmin=-1, vmax=1)\n", + "fig.colorbar(pc, ax=ax, shrink=0.8)\n", + "ax.set_title(\"regridded onto rotated curvilinear grid\")\n", + "ax.set_xlabel(\"longitude\")\n", + "ax.set_ylabel(\"latitude\")\n", + "ax.set_aspect(\"equal\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_2d_regions.ipynb b/docs/notebooks/demos/demo_conservative_2d_regions.ipynb new file mode 100644 index 0000000..13ae692 --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_regions.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — regions (grid → arbitrary polygons)\n", + "\n", + "Gridded data → arbitrary polygon regions (countries, watersheds, ocean\n", + "basins, protected areas) is a canonical xagg-style workflow. Because\n", + "regions aren't a grid at all, `.conservative` can't express it;\n", + "`ConservativeRegridder.from_polygons` takes a numpy array of shapely\n", + "polygons and produces one conservatively-averaged value per region.\n", + "\n", + "This notebook uses hand-built synthetic regions so it runs with zero\n", + "external downloads — swap in `geopandas.read_file(...)` for a real\n", + "shapefile and the rest is identical." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import shapely\n", + "import xarray as xr\n", + "from matplotlib.collections import PolyCollection\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder, polygons_from_coords" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Source — structured lat/lon field\n", + "\n", + "A smooth analytic field with recognizable geographic pattern: a warm\n", + "band following the equator modulated by an east/west tilt. Gives visibly\n", + "different regional means depending on region location." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(-80, 80, 160) + 0.5\n", + "lon = np.linspace(-180, 180, 360, endpoint=False) + 0.5\n", + "Lo, La = np.meshgrid(lon, lat)\n", + "field = np.cos(np.deg2rad(La)) ** 2 + 0.4 * np.sin(np.deg2rad(Lo))\n", + "src = xr.DataArray(\n", + " field,\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat, \"longitude\": lon},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Regions — five hand-built polygons of varied shape\n", + "\n", + "Box, circle, rotated rectangle, L-shape, and a polygon with a hole —\n", + "exercising the range of geometry the regridder accepts. These overlap\n", + "in places, which is fine: each region gets its own independent\n", + "area-weighted mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "region_names = [\n", + " \"equatorial band\",\n", + " \"northern island\",\n", + " \"rotated block\",\n", + " \"L-shape\",\n", + " \"ring\",\n", + "]\n", + "region_polys = np.array([\n", + " shapely.box(-180, -15, 180, 15),\n", + " shapely.Point(-60, 50).buffer(18, quad_segs=12),\n", + " shapely.affinity.rotate(shapely.box(55, 27.5, 105, 52.5), 30),\n", + " shapely.unary_union([\n", + " shapely.box(-140, -60, -100, -20),\n", + " shapely.box(-140, -60, -60, -50),\n", + " ]),\n", + " shapely.Polygon(\n", + " shell=[(140, -40), (170, -40), (170, -5), (140, -5)],\n", + " holes=[[(148, -30), (162, -30), (162, -12), (148, -12)]],\n", + " ),\n", + "], dtype=object)\n", + "print(f\"{len(region_polys)} regions, areas (deg²): \"\n", + " f\"{[f'{shapely.area(p):.0f}' for p in region_polys]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Build the regridder and apply\n", + "\n", + "`from_polygons` takes source + target polygon arrays. We build the\n", + "source polygons from the grid's 1D coords via the `polygons_from_coords`\n", + "helper; the data gets flattened to match." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "src_polys = polygons_from_coords(lon, lat)\n", + "\n", + "rgr = ConservativeRegridder.from_polygons(\n", + " source_polygons=src_polys,\n", + " target_polygons=region_polys,\n", + " source_dim=\"src_cell\",\n", + " target_dim=\"region\",\n", + " target_coords=xr.Dataset(coords={\"region\": region_names}),\n", + ")\n", + "src_flat = xr.DataArray(src.values.ravel(), dims=(\"src_cell\",))\n", + "regional = rgr.regrid(src_flat)\n", + "regional" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## Source field + region outlines + regional means" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax_map, ax_bar) = plt.subplots(\n", + " 1, 2, figsize=(13, 4.5), gridspec_kw={\"width_ratios\": [1.6, 1]},\n", + ")\n", + "src.plot(ax=ax_map, cmap=\"viridis\", add_colorbar=True,\n", + " cbar_kwargs={\"shrink\": 0.7, \"label\": \"source field\"})\n", + "patches = [np.asarray(p.exterior.coords) for p in region_polys]\n", + "pc = PolyCollection(\n", + " patches, facecolor=\"none\", edgecolor=\"white\", lw=1.4,\n", + ")\n", + "ax_map.add_collection(pc)\n", + "for name, poly in zip(region_names, region_polys, strict=True):\n", + " c = poly.representative_point()\n", + " ax_map.annotate(name, (c.x, c.y), color=\"white\", fontsize=8,\n", + " ha=\"center\", va=\"center\")\n", + "ax_map.set_title(\"source + region outlines\")\n", + "ax_map.set_xlim(-180, 180)\n", + "ax_map.set_ylim(-80, 80)\n", + "\n", + "ax_bar.barh(region_names, regional.values, color=\"tab:blue\")\n", + "ax_bar.set_xlabel(\"area-weighted regional mean\")\n", + "ax_bar.invert_yaxis()\n", + "ax_bar.grid(axis=\"x\", alpha=0.3)\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "## Conservation check\n", + "\n", + "Area-weighted sum of regional means × region areas should match the\n", + "direct A·s computation from the regridder's internal area matrix." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "A = rgr._areas # sparse (n_regions, n_src)\n", + "tgt_area = np.ravel(A.sum(axis=1).todense())\n", + "src_cover = np.ravel(A.sum(axis=0).todense())\n", + "\n", + "direct = float((src.values.ravel() * src_cover).sum())\n", + "via_regrid = float((regional.values * tgt_area).sum())\n", + "print(f\"direct A·s : {direct:.6f}\")\n", + "print(f\"Σ regional_mean · a_dst : {via_regrid:.6f}\")\n", + "print(f\"relative error : {abs(direct - via_regrid) / abs(direct):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## Reuse: persist the regridder to netCDF\n", + "\n", + "The weight matrix is reusable. For any workflow that repeatedly\n", + "aggregates new source data onto the same region set, save once,\n", + "reload on subsequent runs to skip the intersection build." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "path = Path(tempfile.gettempdir()) / \"regions_regridder.nc\"\n", + "rgr.to_netcdf(path)\n", + "rgr2 = ConservativeRegridder.from_netcdf(path)\n", + "\n", + "other = xr.DataArray(\n", + " (np.sin(np.deg2rad(Lo)) ** 2).ravel(), dims=(\"src_cell\",),\n", + ")\n", + "print(\"regional means on a different field:\")\n", + "for n, v in zip(region_names, rgr2.regrid(other).values, strict=True):\n", + " print(f\" {n:20s}: {v:+.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_2d_spherical.ipynb b/docs/notebooks/demos/demo_conservative_2d_spherical.ipynb new file mode 100644 index 0000000..fc1d2dd --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_spherical.ipynb @@ -0,0 +1,327 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "56882b32", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — spherical `s2` manifold\n", + "\n", + "Conservative regridding preserves integrals by computing overlap *areas* between source and target cells. On latitude/longitude grids, the geometry model used for those overlaps matters:\n", + "\n", + "- `\"planar\"` (default): fast planar shapely intersections in the raw coordinate plane.\n", + "- `\"cea\"`: cylindrical equal-area projection before planar intersection (good spherical area approximation for many workflows).\n", + "- `\"s2\"`: great-circle spherical geometry through [`spherely`](https://github.com/benbovy/spherely), which is the most geometrically faithful option.\n", + "\n", + "This notebook gives a compact, step-by-step view of the workflow:\n", + "\n", + "1. Build a simple analytic source field.\n", + "2. Visualize source and target grids.\n", + "3. Regrid with each manifold and compare outputs.\n", + "4. Verify conservation with a known integral.\n", + "5. Show where `s2` and `cea` diverge most.\n", + "6. Persist and reload an `s2` regridder for reuse." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31587931", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder" + ] + }, + { + "cell_type": "markdown", + "id": "5385e5c7", + "metadata": {}, + "source": [ + "## Source field and target grid\n", + "\n", + "We use a smooth test field on a 1° grid,\n", + "\n", + "\\[\n", + "f(\\phi, \\lambda) = \\cos^2(\\phi)\\,(1 + 0.15\\sin(3\\lambda))\n", + "\\]\n", + "\n", + "where `\\phi` is latitude and `\\lambda` is longitude. This keeps the latitudinal structure intuitive while introducing longitudinal variability so map comparisons are visually informative." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a2c9bae", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(-89.5, 89.5, 180)\n", + "lon = np.linspace(-179.5, 179.5, 360)\n", + "\n", + "src = xr.DataArray(\n", + " (np.cos(np.deg2rad(lat)) ** 2)[:, None]\n", + " * (1.0 + 0.15 * np.sin(3 * np.deg2rad(lon))[None, :]),\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat, \"longitude\": lon},\n", + " name=\"f\",\n", + ")\n", + "\n", + "target = xr.Dataset(\n", + " coords={\n", + " \"latitude\": np.linspace(-89, 89, 90),\n", + " \"longitude\": np.linspace(-179, 179, 180),\n", + " }\n", + ")\n", + "\n", + "src" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad1556c8", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 4), constrained_layout=True)\n", + "\n", + "src.plot(ax=axes[0], cmap=\"viridis\", cbar_kwargs={\"shrink\": 0.8})\n", + "axes[0].set_title(\"Source field on 1° grid\")\n", + "\n", + "axes[1].set_title(\"Target grid (2° centers)\")\n", + "axes[1].set_xlabel(\"longitude\")\n", + "axes[1].set_ylabel(\"latitude\")\n", + "axes[1].set_xlim(-180, 180)\n", + "axes[1].set_ylim(-90, 90)\n", + "axes[1].scatter(\n", + " target.longitude.values,\n", + " np.zeros_like(target.longitude.values),\n", + " s=8,\n", + " alpha=0.4,\n", + " label=\"lon centers @ equator\",\n", + ")\n", + "axes[1].scatter(\n", + " np.zeros_like(target.latitude.values),\n", + " target.latitude.values,\n", + " s=8,\n", + " alpha=0.4,\n", + " label=\"lat centers @ prime meridian\",\n", + ")\n", + "axes[1].grid(alpha=0.3)\n", + "axes[1].legend(loc=\"lower left\", fontsize=8)" + ] + }, + { + "cell_type": "markdown", + "id": "2e1c76fc", + "metadata": {}, + "source": [ + "## Regrid with each manifold\n", + "\n", + "All three runs use the same source, target, and conservative operator interface; only the manifold changes.\n", + "\n", + "To make the transformation visually obvious, we overlay coarse target-grid lines on each output map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da88f365", + "metadata": {}, + "outputs": [], + "source": [ + "out_planar = src.regrid.conservative_2d(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\", manifold=\"planar\"\n", + ")\n", + "out_cea = src.regrid.conservative_2d(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\", manifold=\"cea\"\n", + ")\n", + "out_s2 = src.regrid.conservative_2d(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\", manifold=\"s2\"\n", + ")\n", + "\n", + "# Coarse visual guide to the target grid structure.\n", + "lat_lines = target.latitude.values[::8]\n", + "lon_lines = target.longitude.values[::8]\n", + "\n", + "vmin = float(src.min())\n", + "vmax = float(src.max())\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)\n", + "for ax, title, arr in [\n", + " (axes[0], \"planar\", out_planar),\n", + " (axes[1], \"cea\", out_cea),\n", + " (axes[2], \"s2\", out_s2),\n", + "]:\n", + " arr.plot(ax=ax, cmap=\"viridis\", vmin=vmin, vmax=vmax, add_colorbar=(ax is axes[2]))\n", + " for y in lat_lines:\n", + " ax.axhline(float(y), color=\"white\", lw=0.25, alpha=0.45)\n", + " for x in lon_lines:\n", + " ax.axvline(float(x), color=\"white\", lw=0.25, alpha=0.45)\n", + " ax.set_title(title)\n", + " ax.set_xlim(-180, 180)\n", + " ax.set_ylim(-90, 90)" + ] + }, + { + "cell_type": "markdown", + "id": "0690370a", + "metadata": {}, + "source": [ + "## Conservation check against an analytic integral\n", + "\n", + "For this notebook's check field, the dominant term is `\\cos^2(lat)`, whose integral over the unit sphere is `8π/3`. We integrate each regridded output against true spherical target-cell areas and compare errors. This is a practical sanity check that conservative weighting behaves as expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b179e5ff", + "metadata": {}, + "outputs": [], + "source": [ + "TRUE = 8 * np.pi / 3\n", + "\n", + "\n", + "def sph_cell_areas(lat_vals, lon_vals):\n", + " lat_r = np.deg2rad(lat_vals)\n", + " dlat = np.gradient(lat_r)\n", + " dlon = np.gradient(np.deg2rad(lon_vals))\n", + " return (np.sin(lat_r + dlat / 2) - np.sin(lat_r - dlat / 2))[:, None] * dlon[None, :]\n", + "\n", + "\n", + "a_tgt = sph_cell_areas(target.latitude.values, target.longitude.values)\n", + "\n", + "print(f\"{'manifold':>8} {'integral':>11} {'|error|':>10}\")\n", + "for manifold, arr in {\n", + " \"planar\": out_planar.values,\n", + " \"cea\": out_cea.values,\n", + " \"s2\": out_s2.values,\n", + "}.items():\n", + " valid = np.isfinite(arr)\n", + " integral = float((arr[valid] * a_tgt[valid]).sum())\n", + " print(f\"{manifold:>8} {integral:>11.6f} {abs(integral - TRUE):>10.2e}\")\n", + "print(f\"{'truth':>8} {TRUE:>11.6f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a5cf0a55", + "metadata": {}, + "source": [ + "## Where `s2` and `cea` differ most\n", + "\n", + "At fine resolutions, `cea` and `s2` are often close. The gap becomes easier to see on coarse grids, where great-circle vs projected straight-edge assumptions diverge more strongly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ef50b58", + "metadata": {}, + "outputs": [], + "source": [ + "coarse_lat = np.linspace(-85, 85, 35)\n", + "coarse_lon = np.linspace(-175, 175, 70)\n", + "coarse_src = xr.DataArray(\n", + " np.cos(np.deg2rad(coarse_lat))[:, None] ** 2\n", + " * np.sin(np.deg2rad(coarse_lon))[None, :],\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": coarse_lat, \"longitude\": coarse_lon},\n", + ")\n", + "coarse_tgt = xr.Dataset(\n", + " coords={\n", + " \"latitude\": np.linspace(-82.5, 82.5, 12),\n", + " \"longitude\": np.linspace(-172.5, 172.5, 24),\n", + " }\n", + ")\n", + "\n", + "coarse_s2 = coarse_src.regrid.conservative_2d(\n", + " coarse_tgt, x_coord=\"longitude\", y_coord=\"latitude\", manifold=\"s2\"\n", + ")\n", + "coarse_cea = coarse_src.regrid.conservative_2d(\n", + " coarse_tgt, x_coord=\"longitude\", y_coord=\"latitude\", manifold=\"cea\"\n", + ")\n", + "diff = coarse_s2 - coarse_cea\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(14, 3.3), constrained_layout=True)\n", + "coarse_src.plot(ax=axes[0], cmap=\"coolwarm\")\n", + "axes[0].set_title(\"Coarse source field\")\n", + "coarse_s2.plot(ax=axes[1], cmap=\"coolwarm\")\n", + "axes[1].set_title(\"Regridded with s2\")\n", + "diff.plot(ax=axes[2], cmap=\"RdBu_r\", center=0)\n", + "axes[2].set_title(\"s2 − cea difference\")\n", + "for ax in axes:\n", + " ax.set_xlim(-180, 180)\n", + " ax.set_ylim(-90, 90)\n", + " ax.grid(alpha=0.25, lw=0.4)\n", + "\n", + "print(f\"max |s2 − cea|: {float(np.nanmax(np.abs(diff.values))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1a1504ff", + "metadata": {}, + "source": [ + "## Reuse: persist the `s2` regridder\n", + "\n", + "The expensive part is building geometry and overlaps. Save once, reload, and apply repeatedly to fields on the same source/target grids." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "854b57c8", + "metadata": {}, + "outputs": [], + "source": [ + "rgr = ConservativeRegridder(\n", + " src,\n", + " target,\n", + " x_coord=\"longitude\",\n", + " y_coord=\"latitude\",\n", + " manifold=\"s2\",\n", + ")\n", + "path = Path(tempfile.gettempdir()) / \"s2_regridder.nc\"\n", + "rgr.to_netcdf(path)\n", + "rgr2 = ConservativeRegridder.from_netcdf(path)\n", + "\n", + "print(f\"saved {path.stat().st_size / 1024:.0f} KB, manifold={rgr2.manifold!r}\")\n", + "print(\n", + " \"bit-identical forward:\",\n", + " np.array_equal(rgr.regrid(src).values, rgr2.regrid(src).values),\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_2d_unstructured.ipynb b/docs/notebooks/demos/demo_conservative_2d_unstructured.ipynb new file mode 100644 index 0000000..e7df2ce --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_2d_unstructured.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Conservative 2D regrid — unstructured mesh + save/load\n", + "\n", + "**Conservative regridding** resamples a gridded field while preserving its\n", + "area-weighted integral — the right tool for fluxes and intensive\n", + "quantities, where bilinear or nearest-neighbor interpolation would bias\n", + "the total.\n", + "\n", + "**Unstructured meshes** — ICON triangles, MPAS hexagons, finite-element\n", + "models, generic Voronoi tessellations — are the natural geometry for\n", + "simulations that need adaptive resolution (denser cells over land, coarser\n", + "over open ocean). They have no `(y, x)` index structure: every cell is\n", + "just an arbitrary polygon. So they're never 1D-separable, and the fast\n", + "`.regrid.conservative` accessor doesn't apply.\n", + "\n", + "`ConservativeRegridder.from_polygons` takes a flat 1D array of shapely\n", + "polygons as source and/or target. The same machinery handles\n", + "structured→unstructured, unstructured→structured, and unstructured→\n", + "unstructured.\n", + "\n", + "**In this notebook.**\n", + "\n", + "1. Build a synthetic Voronoi mesh as a stand-in for a real ICON/MPAS dataset.\n", + "2. Regrid a smooth analytic field from a structured lat/lon source onto\n", + " the mesh.\n", + "3. **Persist the regridder to disk** — for a fixed source/target pair the\n", + " weight matrix is the same forever, so saving it lets a long-running\n", + " pipeline (or a follow-up notebook) skip the polygon-intersection build\n", + " on restart." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import shapely\n", + "import xarray as xr\n", + "from matplotlib.collections import PolyCollection\n", + "from scipy.spatial import Voronoi\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder, polygons_from_coords" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Build a Voronoi mesh\n", + "\n", + "Real pipelines load a pre-built mesh (UGRID, ICON, MPAS); for a\n", + "self-contained demo we synthesize one — jitter a regular grid of generator\n", + "points, take the Voronoi tessellation, and clip to the bounding box. The\n", + "construction details aren't the point: `from_polygons` only needs a 1D\n", + "array of shapely polygons however we get them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "def voronoi_mesh(n_points, bbox, seed=0):\n", + " rng = np.random.default_rng(seed)\n", + " x0, y0, x1, y1 = bbox\n", + " side = int(np.sqrt(n_points))\n", + " xs, ys = np.linspace(x0, x1, side), np.linspace(y0, y1, side)\n", + " pts = np.column_stack([np.repeat(xs, side), np.tile(ys, side)])\n", + " pts += rng.normal(scale=(x1 - x0) / side * 0.25, size=pts.shape)\n", + " halo = np.array([\n", + " [2*x0 - x1, 2*y0 - y1], [2*x1 - x0, 2*y0 - y1],\n", + " [2*x0 - x1, 2*y1 - y0], [2*x1 - x0, 2*y1 - y0],\n", + " ])\n", + " vor = Voronoi(np.concatenate([pts, halo]))\n", + " clip = shapely.box(x0, y0, x1, y1)\n", + " polys = []\n", + " for i in range(len(pts)):\n", + " r = vor.regions[vor.point_region[i]]\n", + " if not r or -1 in r:\n", + " continue\n", + " p = shapely.intersection(shapely.Polygon(vor.vertices[r]), clip)\n", + " if p.is_empty or p.geom_type != \"Polygon\":\n", + " continue\n", + " polys.append(p)\n", + " return np.array(polys, dtype=object)\n", + "\n", + "bbox = (-120, -50, 120, 50)\n", + "mesh_polys = voronoi_mesh(n_points=400, bbox=bbox)\n", + "print(f\"{len(mesh_polys)} mesh cells\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Structured lat/lon source\n", + "\n", + "A smooth `sin(2λ)·cos(3φ)` field on a 1° rectilinear grid — wavy enough\n", + "that the regridded mesh values are visually distinct, smooth enough that\n", + "no individual mesh cell aliases the pattern." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "lat_s = np.linspace(-50, 50, 100, endpoint=False) + 0.5\n", + "lon_s = np.linspace(-120, 120, 240, endpoint=False) + 0.5\n", + "Lo, La = np.meshgrid(lon_s, lat_s)\n", + "src = xr.DataArray(\n", + " np.sin(np.deg2rad(Lo) * 2) * np.cos(np.deg2rad(La) * 3),\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat_s, \"longitude\": lon_s},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Regrid onto the mesh\n", + "\n", + "`from_polygons` takes flat 1D arrays of source and target polygons. Source\n", + "polygons come from the 1D grid coords via `polygons_from_coords` (one\n", + "rectangle per cell, built from coordinate midpoints); target polygons are\n", + "the Voronoi mesh cells. Source data has to be flattened to a single\n", + "`src_cell` dimension to match the flat polygon array." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "rgr = ConservativeRegridder.from_polygons(\n", + " source_polygons=polygons_from_coords(lon_s, lat_s),\n", + " target_polygons=mesh_polys,\n", + " source_dim=\"src_cell\",\n", + " target_dim=\"cell\",\n", + ")\n", + "print(rgr)\n", + "\n", + "src_flat = xr.DataArray(src.values.ravel(), dims=(\"src_cell\",))\n", + "mesh_vals = rgr.regrid(src_flat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "patches = [np.asarray(p.exterior.coords) for p in mesh_polys]\n", + "pc = PolyCollection(patches, array=mesh_vals.values, cmap=\"RdBu_r\",\n", + " edgecolor=\"0.4\", lw=0.3, clim=(-1, 1))\n", + "ax.add_collection(pc)\n", + "ax.set_xlim(bbox[0], bbox[2])\n", + "ax.set_ylim(bbox[1], bbox[3])\n", + "ax.set_aspect(\"equal\")\n", + "fig.colorbar(pc, ax=ax, shrink=0.8)\n", + "ax.set_title(f\"regridded onto {len(mesh_polys)}-cell Voronoi mesh\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Persist the regridder\n", + "\n", + "The weight matrix is purely a function of source and target geometry —\n", + "not of the data — so for a fixed source/target pair it never needs to\n", + "change. `to_netcdf` writes it (along with shape and version metadata) to\n", + "a small NetCDF; `from_netcdf` rebuilds the regridder. A reload-then-apply\n", + "gives bit-identical output to the original, so a long-running pipeline\n", + "can skip the (expensive) polygon-intersection build on restart." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "path = Path(tempfile.gettempdir()) / \"mesh_regridder.nc\"\n", + "rgr.to_netcdf(path)\n", + "rgr2 = ConservativeRegridder.from_netcdf(path)\n", + "same = np.array_equal(rgr.regrid(src_flat).values, rgr2.regrid(src_flat).values)\n", + "print(f\"saved {path.stat().st_size / 1024:.1f} KB; reload bit-identical: {same}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_polygon_basics.ipynb b/docs/notebooks/demos/demo_conservative_polygon_basics.ipynb new file mode 100644 index 0000000..6c9a564 --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_polygon_basics.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Polygon conservative regridder — basics\n", + "\n", + "`ConservativeRegridder` is a polygon-intersection conservative regridder that complements\n", + "the existing `.conservative` method. Where `.conservative` uses fast axis-factored 1D overlap\n", + "(rectilinear grids only), this path computes explicit 2D cell-polygon intersections — handling\n", + "curvilinear and unstructured meshes and supporting a spherical-area mode for global lat/lon\n", + "grids, at the cost of slightly higher build time.\n", + "\n", + "This notebook shows the basic workflow on an analytic field with a known sphere integral." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import xarray_regrid # registers the `.regrid` accessor\n", + "from xarray_regrid import ConservativeRegridder\n", + "\n", + "np.set_printoptions(precision=5, suppress=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Synthetic source field\n", + "\n", + "We use $f(\\lambda, \\phi) = \\cos^2(\\phi)$ on a 1° global grid. Its true integral over the unit\n", + "sphere is $8\\pi/3$, which gives us a check on the regridder." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(-89.5, 89.5, 180)\n", + "lon = np.linspace(-179.5, 179.5, 360)\n", + "field = (np.cos(np.deg2rad(lat))**2)[:, None] * np.ones(lon.size)[None, :]\n", + "src = xr.DataArray(\n", + " field,\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat, \"longitude\": lon},\n", + " name=\"cos2_lat\",\n", + ")\n", + "src" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "src.plot(figsize=(8, 3.5), cmap=\"viridis\")\n", + "plt.title(\"source: cos²(lat) on a 1° grid\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Target grid\n", + "\n", + "A coarser 3° grid. We'll build an empty `xr.Dataset` with just the target coords;\n", + "`xarray-regrid` identifies the grid from `latitude` and `longitude`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "target = xr.Dataset(coords={\n", + " \"latitude\": np.linspace(-88.5, 88.5, 60),\n", + " \"longitude\": np.linspace(-178.5, 178.5, 120),\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Regrid via the `.regrid.conservative_polygon` accessor\n", + "\n", + "`spherical=True` applies an analytic Lambert cylindrical equal-area projection\n", + "to the cell edges before intersecting — correct spherical area weights at the\n", + "same cost as the planar fast path. Leave it off when your coords aren't geographic degrees." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "regridded = src.regrid.conservative_polygon(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\", spherical=True\n", + ")\n", + "regridded" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "regridded.plot(figsize=(8, 3.5), cmap=\"viridis\")\n", + "plt.title(\"regridded: 3° target\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "## Mass-conservation diagnostic\n", + "\n", + "A correct conservative regrid preserves the area-weighted integral (on the sphere).\n", + "We compute true spherical cell areas for each grid and compare the integrals." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "def sph_integral(da):\n", + " lat_r = np.deg2rad(da.latitude.values)\n", + " dlat = np.gradient(lat_r)\n", + " dlon = np.gradient(np.deg2rad(da.longitude.values))\n", + " area = (np.sin(lat_r + dlat/2) - np.sin(lat_r - dlat/2))[:, None] * dlon[None, :]\n", + " return float((da.values * area).sum())\n", + "\n", + "true_val = 8 * np.pi / 3\n", + "src_sum = sph_integral(src)\n", + "out_sum = sph_integral(regridded)\n", + "print(f\"true sphere integral : {true_val:.6f}\")\n", + "print(f\"source integral : {src_sum:.6f} (err {src_sum - true_val:+.2e})\")\n", + "print(f\"regridded integral : {out_sum:.6f} (err {out_sum - true_val:+.2e})\")" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## Spherical vs planar vs the axis-factored path\n", + "\n", + "`spherical=False` uses raw lat/lon planar geometry — poor near the poles.\n", + "The existing `.conservative` method applies an analytic sin-weighting to the factored\n", + "1D overlap. On lat/lon grids, `spherical=True` reproduces that accuracy while also\n", + "working on curvilinear and unstructured targets the factored path can't express." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "def err(fn):\n", + " return abs(sph_integral(fn()) - true_val)\n", + "\n", + "results = {\n", + " \"polygon, spherical=True \": err(lambda: src.regrid.conservative_polygon(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\", spherical=True)),\n", + " \"polygon, spherical=False\": err(lambda: src.regrid.conservative_polygon(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\", spherical=False)),\n", + " \"factored (.conservative)\": err(lambda: src.regrid.conservative(\n", + " target, latitude_coord=\"latitude\")),\n", + "}\n", + "for k, v in results.items():\n", + " print(f\"{k}: |error| = {v:.2e}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_polygon_curvilinear.ipynb b/docs/notebooks/demos/demo_conservative_polygon_curvilinear.ipynb new file mode 100644 index 0000000..3ae6e5f --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_polygon_curvilinear.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Polygon conservative regridder — curvilinear target\n", + "\n", + "Curvilinear grids have 2D `lat(y,x)` / `lon(y,x)` coordinate arrays — common in\n", + "ocean models (ORCA, tripolar) and regional forecasts on rotated grids. The existing\n", + "`.conservative` method is strictly rectilinear, so `ConservativeRegridder` is the\n", + "available tool here.\n", + "\n", + "This notebook regrids a regular lat/lon source onto a 30°-rotated curvilinear target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Source — regular 1° lat/lon\n", + "\n", + "A simple two-bump analytic field makes the geometry easy to read off the plot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "lat = np.linspace(-60, 60, 121)\n", + "lon = np.linspace(-120, 120, 241)\n", + "Lo, La = np.meshgrid(lon, lat)\n", + "field = (\n", + " np.exp(-((Lo - 40)**2 + (La - 20)**2) / 500)\n", + " - np.exp(-((Lo + 60)**2 + (La + 15)**2) / 400)\n", + ")\n", + "src = xr.DataArray(\n", + " field,\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat, \"longitude\": lon},\n", + " name=\"bumps\",\n", + ")\n", + "src.plot(figsize=(8, 3.5), cmap=\"RdBu_r\", center=0)\n", + "plt.title(\"source: analytic two-bump field\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Target — rotated curvilinear grid\n", + "\n", + "Coordinates ride on a `(ny, nx)` mesh rather than a 1D lat/lon. We stash them\n", + "as 2D coordinate variables on an `xr.Dataset`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "ny_t, nx_t = 30, 50\n", + "xi, yi = np.meshgrid(\n", + " np.linspace(-110, 110, nx_t),\n", + " np.linspace(-45, 45, ny_t),\n", + " indexing=\"xy\",\n", + ")\n", + "theta = np.deg2rad(30)\n", + "lon2d = xi * np.cos(theta) - yi * np.sin(theta)\n", + "lat2d = xi * np.sin(theta) + yi * np.cos(theta)\n", + "\n", + "target = xr.Dataset(\n", + " coords={\n", + " \"longitude\": ((\"ny\", \"nx\"), lon2d),\n", + " \"latitude\": ((\"ny\", \"nx\"), lat2d),\n", + " }\n", + ")\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.plot(lon2d, lat2d, color=\"0.3\", lw=0.4)\n", + "ax.plot(lon2d.T, lat2d.T, color=\"0.3\", lw=0.4)\n", + "ax.set_title(\"curvilinear target (30° rotation)\")\n", + "ax.set_xlabel(\"longitude\"); ax.set_ylabel(\"latitude\")\n", + "ax.set_aspect(\"equal\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Regrid\n", + "\n", + "The accessor detects that `latitude` / `longitude` are 2D and routes through\n", + "the curvilinear path (threaded GEOS polygon clipping under the hood). On the\n", + "reusable class you'd construct once via `ConservativeRegridder(src, target, ...)`\n", + "and apply to many fields." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "regridded = src.regrid.conservative_polygon(\n", + " target, x_coord=\"longitude\", y_coord=\"latitude\"\n", + ")\n", + "regridded" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "pc = ax.pcolormesh(lon2d, lat2d, regridded.values, cmap=\"RdBu_r\", shading=\"auto\",\n", + " vmin=-1, vmax=1)\n", + "fig.colorbar(pc, ax=ax, shrink=0.8)\n", + "ax.set_title(\"regridded onto rotated curvilinear grid\")\n", + "ax.set_xlabel(\"longitude\"); ax.set_ylabel(\"latitude\")\n", + "ax.set_aspect(\"equal\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Conservation check\n", + "\n", + "The regridder's internal intersection-area matrix `A[i, j]` gives the exact\n", + "mass that each source cell contributes to each target cell. Summing those\n", + "contributions against the source field is the ground-truth mass within the\n", + "overlap region. Multiplying the output by each target cell's covered area\n", + "should match.\n", + "\n", + "Target cells that fall outside the source are NaN in the output, which is\n", + "why we mask them before summing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "rgr = ConservativeRegridder(src, target, x_coord=\"longitude\", y_coord=\"latitude\")\n", + "A = rgr._areas # sparse (n_dst, n_src)\n", + "target_covered = A.sum(axis=1).todense().reshape(regridded.shape)\n", + "src_covered = A.sum(axis=0).todense()\n", + "\n", + "valid = np.isfinite(regridded.values)\n", + "direct_mass = float((src.values.ravel() * src_covered).sum())\n", + "regridded_mass = float((regridded.values[valid] * target_covered[valid]).sum())\n", + "print(f\"direct (A · s): {direct_mass:.6f}\")\n", + "print(f\"regridded (out · a_dst): {regridded_mass:.6f}\")\n", + "print(f\"relative difference: {abs(direct_mass - regridded_mass) / max(abs(direct_mass), 1e-12):.2e}\")\n", + "print(f\"covered target fraction: {valid.mean():.2%}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/demos/demo_conservative_polygon_unstructured.ipynb b/docs/notebooks/demos/demo_conservative_polygon_unstructured.ipynb new file mode 100644 index 0000000..92e2b85 --- /dev/null +++ b/docs/notebooks/demos/demo_conservative_polygon_unstructured.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Polygon conservative regridder — unstructured mesh + save/load\n", + "\n", + "When cells are arbitrary polygons (ICON triangles, MPAS hexagons, country shapes for\n", + "regional aggregation) use `ConservativeRegridder.from_polygons`. The shapely STRtree\n", + "handles any polygon layout.\n", + "\n", + "This notebook builds a synthetic Voronoi hex-like mesh with scipy, regrids a structured\n", + "source onto it, then persists the regridder so subsequent runs skip the weight build." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.collections import PolyCollection\n", + "from scipy.spatial import Voronoi\n", + "import shapely\n", + "\n", + "import xarray_regrid # noqa: F401\n", + "from xarray_regrid import ConservativeRegridder, polygons_from_coords" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Build a hex-like Voronoi mesh\n", + "\n", + "Jittered grid points → Voronoi → clip to the region of interest. Real workflows\n", + "would load a pre-built mesh (UGRID, ICON, etc.); the point here is that\n", + "`from_polygons` needs only a 1D array of shapely polygons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "def voronoi_mesh(n_points, bbox, seed=0):\n", + " rng = np.random.default_rng(seed)\n", + " x0, y0, x1, y1 = bbox\n", + " side = int(np.sqrt(n_points))\n", + " xs = np.linspace(x0, x1, side); ys = np.linspace(y0, y1, side)\n", + " pts = np.column_stack([np.repeat(xs, side), np.tile(ys, side)])\n", + " pts += rng.normal(scale=(x1 - x0) / side * 0.25, size=pts.shape)\n", + " halo = np.array([\n", + " [2*x0 - x1, 2*y0 - y1], [2*x1 - x0, 2*y0 - y1],\n", + " [2*x0 - x1, 2*y1 - y0], [2*x1 - x0, 2*y1 - y0],\n", + " ])\n", + " vor = Voronoi(np.concatenate([pts, halo]))\n", + " clip = shapely.box(x0, y0, x1, y1)\n", + " polys, centers = [], []\n", + " for i in range(len(pts)):\n", + " r = vor.regions[vor.point_region[i]]\n", + " if not r or -1 in r:\n", + " continue\n", + " p = shapely.intersection(shapely.Polygon(vor.vertices[r]), clip)\n", + " if p.is_empty or p.geom_type != \"Polygon\":\n", + " continue\n", + " polys.append(p); centers.append(pts[i])\n", + " return np.array(polys, dtype=object), np.array(centers)\n", + "\n", + "bbox = (-120, -50, 120, 50)\n", + "mesh_polys, mesh_centers = voronoi_mesh(n_points=400, bbox=bbox)\n", + "print(f\"{len(mesh_polys)} cells in the mesh\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Source — structured lat/lon field" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "lat_s = np.linspace(-50, 50, 100, endpoint=False) + 0.5\n", + "lon_s = np.linspace(-120, 120, 240, endpoint=False) + 0.5\n", + "Lo, La = np.meshgrid(lon_s, lat_s)\n", + "field = np.sin(np.deg2rad(Lo) * 2) * np.cos(np.deg2rad(La) * 3)\n", + "src = xr.DataArray(\n", + " field,\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords={\"latitude\": lat_s, \"longitude\": lon_s},\n", + " name=\"field\",\n", + ")\n", + "src.plot(figsize=(8, 3.3), cmap=\"RdBu_r\", center=0)\n", + "plt.title(\"structured source\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Regrid structured → mesh with `from_polygons`\n", + "\n", + "We convert the structured grid into a flat polygon list via `polygons_from_coords`\n", + "(row-major, y-slow / x-fast) and flatten the data the same way. Output is a 1D\n", + "array indexed by mesh cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "grid_polys = polygons_from_coords(lon_s, lat_s)\n", + "\n", + "rgr = ConservativeRegridder.from_polygons(\n", + " grid_polys, mesh_polys,\n", + " source_dim=\"src_cell\",\n", + " target_dim=\"cell\",\n", + ")\n", + "print(rgr)\n", + "\n", + "src_flat = xr.DataArray(src.values.ravel(), dims=(\"src_cell\",))\n", + "mesh_vals = rgr.regrid(src_flat)\n", + "mesh_vals" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## Plot the regridded field on the mesh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "patches = [np.asarray(p.exterior.coords) for p in mesh_polys]\n", + "pc = PolyCollection(patches, array=mesh_vals.values, cmap=\"RdBu_r\",\n", + " edgecolor=\"0.4\", lw=0.3, clim=(-1, 1))\n", + "ax.add_collection(pc)\n", + "ax.set_xlim(bbox[0], bbox[2]); ax.set_ylim(bbox[1], bbox[3])\n", + "ax.set_aspect(\"equal\")\n", + "fig.colorbar(pc, ax=ax, shrink=0.8)\n", + "ax.set_title(f\"regridded onto {len(mesh_polys)}-cell Voronoi mesh\")" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "## Persist the regridder to netCDF\n", + "\n", + "For a fixed source / target pair the weight matrix is the same forever. Persisting\n", + "it to disk lets long-running pipelines skip the (expensive) build on restart." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "path = Path(tempfile.gettempdir()) / \"mesh_regridder.nc\"\n", + "rgr.to_netcdf(path)\n", + "print(f\"wrote {path} ({path.stat().st_size / 1024:.1f} KB)\")\n", + "\n", + "# Inspect on-disk metadata (useful for provenance):\n", + "with xr.open_dataset(path) as weights:\n", + " for k, v in weights.attrs.items():\n", + " print(f\" {k}: {v}\")" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## Reload and apply to new data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "rgr2 = ConservativeRegridder.from_netcdf(path)\n", + "\n", + "# Apply to a new structured field of the same shape:\n", + "new_src = xr.DataArray(\n", + " np.cos(np.deg2rad(Lo)) * np.sin(np.deg2rad(La) * 2),\n", + " dims=(\"latitude\", \"longitude\"),\n", + " coords=src.coords,\n", + ")\n", + "out = rgr2.regrid(xr.DataArray(new_src.values.ravel(), dims=(\"src_cell\",)))\n", + "\n", + "# Sanity check: reloaded regridder produces identical output to the original.\n", + "reference = rgr.regrid(xr.DataArray(new_src.values.ravel(), dims=(\"src_cell\",)))\n", + "print(f\"max diff reloaded vs original: {float(np.abs(out.values - reference.values).max()):.2e}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 50ecd81..0bcd27f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,23 @@ accel = [ "opt-einsum", "dask[distributed]", ] +conservative-2d = [ + # For conservative regridding on grids that aren't 1D-separable + # (curvilinear, unstructured). Uses shapely for 2D polygon intersection. + "shapely>=2.0", + # Needed by ConservativeRegridder.to_netcdf / from_netcdf (group support). + "h5netcdf", +] +spherical = [ + # Builds on conservative-2d: adds great-circle geometry via spherely (S2) + # for the `manifold="s2"` mode of ConservativeRegridder. Install with either + # `pip install xarray-regrid[spherical]` (pulls the 2d stack automatically) + # or `pip install xarray-regrid[conservative-2d,spherical]` (explicit). + "xarray-regrid[conservative-2d]", + # Ships pre-built wheels on PyPI for macOS / Linux x86_64 / Windows + # (Python 3.10–3.14). + "spherely>=0.1.1", +] benchmarking = [ "matplotlib", "zarr", @@ -53,6 +70,15 @@ benchmarking = [ "pooch", "cftime", # required for decode time of test netCDF files ] +notebooks = [ + # Extra dependencies needed to run docs/notebooks/demos/*.ipynb. + # Combine with [conservative-2d] and [accel] for the full demo set, + # e.g.: pip install "xarray-regrid[notebooks,conservative-2d,accel]" + "matplotlib", + "pooch", # xr.tutorial.open_dataset backend + "geopandas", # polygon regions demo + "geodatasets", # ships the ncovr US county boundaries +] dev = [ "hatch", "ruff", @@ -74,7 +100,7 @@ docs = [ # Required for ReadTheDocs path = "src/xarray_regrid/__init__.py" [tool.hatch.envs.default] -features = ["accel", "dev", "benchmarking"] +features = ["accel", "dev", "benchmarking", "conservative-2d"] [tool.hatch.envs.default.scripts] lint = [ @@ -185,3 +211,14 @@ warn_return_any = true warn_unused_ignores = true show_error_codes = true exclude = ["tests/*", "docs"] + +# shapely and sparse are untyped; the conservative_2d module bridges to them, +# so treating every signature that mentions sparse.COO as an error is noise. +[[tool.mypy.overrides]] +module = "xarray_regrid.methods.conservative_2d" +disallow_any_unimported = false +warn_return_any = false + +[[tool.mypy.overrides]] +module = ["shapely", "shapely.*", "sparse", "spherely"] +ignore_missing_imports = true diff --git a/src/xarray_regrid/__init__.py b/src/xarray_regrid/__init__.py index 0dcaaec..bf8daa7 100644 --- a/src/xarray_regrid/__init__.py +++ b/src/xarray_regrid/__init__.py @@ -1,12 +1,22 @@ from xarray_regrid import methods +from xarray_regrid.methods.conservative_2d import ( + ConservativeRegridder, + polygons_from_coords, +) +from xarray_regrid.methods.conservative_polygon import RegridderMetadata +from xarray_regrid.methods._conservative_2d_spec import RegridSpec from xarray_regrid.regrid import Regridder from xarray_regrid.utils import Grid, create_regridding_dataset __all__ = [ + "ConservativeRegridder", "Grid", + "RegridSpec", + "RegridderMetadata", "Regridder", "create_regridding_dataset", "methods", + "polygons_from_coords", ] __version__ = "0.4.2" diff --git a/src/xarray_regrid/methods/_conservative_2d_serialization.py b/src/xarray_regrid/methods/_conservative_2d_serialization.py new file mode 100644 index 0000000..989d0e4 --- /dev/null +++ b/src/xarray_regrid/methods/_conservative_2d_serialization.py @@ -0,0 +1,115 @@ +from datetime import datetime, timezone +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any + +import numpy as np +import xarray as xr + +from xarray_regrid.methods._conservative_2d_spec import RegridSpec + +try: + import sparse + + _HAS_SPARSE = True +except ImportError: # pragma: no cover + sparse = None + _HAS_SPARSE = False + +# Bump on breaking change to the on-disk format in ConservativeRegridder.to_netcdf. +_SCHEMA_VERSION = 1 + + +def _package_version() -> str: + try: + return version("xarray-regrid") + except PackageNotFoundError: + return "unknown" + + +def _coo_components( + weights: "sparse.COO | np.ndarray", +) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple[int, int]]: + if _HAS_SPARSE and isinstance(weights, sparse.COO): + coords = np.asarray(weights.coords) + return ( + coords[0].astype(np.int64, copy=False), + coords[1].astype(np.int64, copy=False), + np.asarray(weights.data), + weights.shape, + ) + arr = np.asarray(weights) + rows, cols = np.nonzero(arr) + return ( + rows.astype(np.int64, copy=False), + cols.astype(np.int64, copy=False), + arr[rows, cols], + arr.shape, + ) + + +def _coo_from_components( + row: np.ndarray, + col: np.ndarray, + data: np.ndarray, + shape: tuple[int, int], +) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([row, col]), + data=data, + shape=shape, + has_duplicates=False, + sorted=False, + ) + dense = np.zeros(shape, dtype=data.dtype if data.size else np.float64) + dense[row, col] = data + return dense + + +def _metadata_attrs( + spec: RegridSpec, source_coords: xr.Dataset, target_coords: xr.Dataset +) -> dict[str, Any]: + attrs: dict[str, Any] = { + "x_coord": spec.x_coord, + "y_coord": spec.y_coord, + "spherical": int(spec.spherical), + "src_dims": [str(d) for d in spec.src_dims], + "dst_dims": [str(d) for d in spec.dst_dims], + "src_shape": list(spec.src_shape), + "dst_shape": list(spec.dst_shape), + "xarray_regrid_version": _package_version(), + "created": datetime.now(tz=timezone.utc).isoformat(), + "schema_version": _SCHEMA_VERSION, + } + for prefix, ds, coord in [ + ("source_x", source_coords, spec.x_coord), + ("source_y", source_coords, spec.y_coord), + ("target_x", target_coords, spec.x_coord), + ("target_y", target_coords, spec.y_coord), + ]: + if coord and coord in ds.coords and ds[coord].size: + attrs[f"{prefix}_range"] = [float(ds[coord].min()), float(ds[coord].max())] + return attrs + + +def _metadata_from_attrs(attrs: dict[str, Any], path: Path) -> RegridSpec: + """Parse and validate regridding spec metadata from netCDF attrs.""" + schema_version = int(attrs.get("schema_version", 0)) + if schema_version != _SCHEMA_VERSION: + msg = ( + f"regridder file at {path} uses schema version {schema_version}; " + f"this xarray-regrid understands {_SCHEMA_VERSION}. " + "Upgrade xarray-regrid or re-save." + ) + raise ValueError(msg) + + return RegridSpec( + x_coord=str(attrs["x_coord"]), + y_coord=str(attrs["y_coord"]), + spherical=bool(int(attrs["spherical"])), + src_dims=tuple(str(d) for d in np.atleast_1d(attrs["src_dims"])), + dst_dims=tuple(str(d) for d in np.atleast_1d(attrs["dst_dims"])), + src_shape=tuple(int(s) for s in np.atleast_1d(attrs["src_shape"])), + dst_shape=tuple(int(s) for s in np.atleast_1d(attrs["dst_shape"])), + ) diff --git a/src/xarray_regrid/methods/_conservative_2d_spec.py b/src/xarray_regrid/methods/_conservative_2d_spec.py new file mode 100644 index 0000000..7f17854 --- /dev/null +++ b/src/xarray_regrid/methods/_conservative_2d_spec.py @@ -0,0 +1,15 @@ +from collections.abc import Hashable +from dataclasses import dataclass + + +@dataclass(frozen=True) +class RegridSpec: + """Canonical metadata describing a source->target regridding layout.""" + + src_dims: tuple[Hashable, ...] + dst_dims: tuple[Hashable, ...] + src_shape: tuple[int, ...] + dst_shape: tuple[int, ...] + x_coord: str + y_coord: str + spherical: bool diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 2ab67d9..e8bf6f3 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -7,7 +7,7 @@ import xarray as xr try: - import sparse # type: ignore + import sparse except ImportError: sparse = None diff --git a/src/xarray_regrid/methods/conservative_2d.py b/src/xarray_regrid/methods/conservative_2d.py new file mode 100644 index 0000000..0d92bda --- /dev/null +++ b/src/xarray_regrid/methods/conservative_2d.py @@ -0,0 +1,1471 @@ +"""Conservative regridding for grids that aren't 1D-separable. + +The existing ``conservative`` method uses axis-factored 1D overlap — fast and +elegant but strictly rectilinear. This module computes the full 2D cell +intersection via shapely, so it handles: + +- curvilinear grids (2D ``lat[i, j]`` / ``lon[i, j]`` coordinate variables) +- unstructured meshes (arbitrary polygon cells, via + :meth:`ConservativeRegridder.from_polygons`) +- grid-to-polygon aggregation (e.g. gridded data → country shapes) + +For rectilinear grids a cheap analytic fast-path is used, but this module is +still slower and more memory-intensive than ``conservative``; prefer the +axis-factored path when your grid is 1D-separable. + +Requires ``shapely >= 2.0``. If ``sparse`` is available, the weight matrix is +stored as ``sparse.COO``; otherwise a dense numpy matrix is used. +""" + +import os +import warnings +from collections.abc import Hashable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime, timezone +from functools import cached_property +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any, Literal, cast, get_args + +import numpy as np +import xarray as xr + +from xarray_regrid import utils + +NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] | None + + +def _package_version() -> str: + try: + return version("xarray-regrid") + except PackageNotFoundError: + return "unknown" + + +# Bump on breaking change to the on-disk format in ConservativeRegridder.to_netcdf. +_SCHEMA_VERSION = 2 + +# Values allowed for the `manifold` kwarg on ConservativeRegridder and related +# helpers. `"s2"` requires the optional `spherely` dependency (great-circle +# polygon geometry on the sphere via Google's s2geometry). Keep the Literal +# alias and the runtime tuple in sync via ``get_args``. +Manifold = Literal["planar", "cea", "s2"] +_MANIFOLDS = get_args(Manifold) + + +try: + import shapely + from shapely import affinity + from shapely.strtree import STRtree + + _HAS_SHAPELY = True +except ImportError: # pragma: no cover + shapely = None + affinity = None + STRtree = None + _HAS_SHAPELY = False + +try: + import sparse + + _HAS_SPARSE = True +except ImportError: # pragma: no cover + sparse = None + _HAS_SPARSE = False + +try: + import spherely + + _HAS_SPHERELY = True +except ImportError: # pragma: no cover + spherely = None + _HAS_SPHERELY = False + + +SPHERELY_IMPORT_ERROR = ( + "manifold='s2' requires the optional `spherely` package. " + "Install with `pip install spherely` or " + "`pip install xarray-regrid[conservative-2d]`." +) + + +def _check_spherely() -> None: + if not _HAS_SPHERELY: + raise ImportError(SPHERELY_IMPORT_ERROR) + + +def _check_manifold(manifold: str) -> None: + if manifold not in _MANIFOLDS: + msg = f"manifold must be one of {_MANIFOLDS!r}, got {manifold!r}" + raise ValueError(msg) + if manifold == "s2": + _check_spherely() + + +# We fill NaNs with 0 ourselves before matmul (see `_apply_core`), so sparse's +# "NaN will not be propagated" warning is spurious. A module-level filter +# avoids `warnings.catch_warnings` inside the hot matmul path, which is not +# thread-safe — dask threads would race on the global `warnings.filters` list. +warnings.filterwarnings( + "ignore", + message="Nan will not be propagated in matrix multiplication", + category=RuntimeWarning, +) + + +SHAPELY_IMPORT_ERROR = ( + "polygon conservative regridding requires shapely >= 2.0; " + "install with `pip install shapely`." +) + + +def _check_shapely() -> None: + if not _HAS_SHAPELY: + raise ImportError(SHAPELY_IMPORT_ERROR) + + +class ConservativeRegridder: + """Reusable conservative regridder for grids that aren't 1D-separable. + + Use this when your source or target isn't a pure rectilinear lat/lon + grid: curvilinear coordinates (2D ``lat[i, j]`` / ``lon[i, j]``), + unstructured meshes (via :meth:`from_polygons`), or arbitrary + polygon-to-polygon aggregation. For plain 1D-separable rectilinear + grids, the existing ``.conservative`` accessor is much faster. + + Build once from source and target grids; apply to many compatible fields + via :meth:`regrid` (or by calling the regridder). The raw intersection + area matrix is stored internally; the forward and backward row-normalized + weight matrices are lazily cached on first use. + + Planar geometry only. Requires ``shapely >= 2.0``. + + Example:: + + regridder = ConservativeRegridder( + src_ds, tgt_ds, x_coord="lon", y_coord="lat" + ) + out = regridder.regrid(da) # forward + back = regridder.T.regrid(out) # backward + """ + + def __init__( + self, + source: xr.DataArray | xr.Dataset, + target: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + manifold: Manifold = "planar", + spherical: bool | None = None, + n_threads: int | None = None, + ) -> None: + _check_shapely() + if spherical is not None: + manifold = cast("Manifold", "cea" if spherical else "planar") + _check_manifold(manifold) + source_grid, target_grid = _normalize_longitude_coords(source, target, x_coord) + src_dims = _spatial_dims(source_grid, x_coord, y_coord) + dst_dims = _spatial_dims(target_grid, x_coord, y_coord) + if not src_dims: + msg = f"source has no dims for coords {x_coord!r}, {y_coord!r}" + raise ValueError(msg) + if not dst_dims: + msg = f"target has no dims for coords {x_coord!r}, {y_coord!r}" + raise ValueError(msg) + src_grid = _grid_from_coords( + source_grid, x_coord, y_coord, src_dims, manifold=manifold + ) + dst_grid = _grid_from_coords( + target_grid, x_coord, y_coord, dst_dims, manifold=manifold + ) + self.manifold = manifold + + self.x_coord = x_coord + self.y_coord = y_coord + self._src_dims = src_dims + self._dst_dims = dst_dims + self._src_shape = tuple(int(source.sizes[d]) for d in src_dims) + self._dst_shape = tuple(int(target.sizes[d]) for d in dst_dims) + self._areas = _build_intersection_areas(src_grid, dst_grid, n_threads=n_threads) + self._source_coords = source.coords.to_dataset() + self._target_coords = target.coords.to_dataset() + self._fwd_weights: "sparse.COO | np.ndarray | None" = None + self._bwd_weights: "sparse.COO | np.ndarray | None" = None + + @property + def areas(self) -> "sparse.COO | np.ndarray": + """Raw area-intersection matrix ``A[i, j] = area(dst_i ∩ src_j)``.""" + return self._areas + + @property + def spherical(self) -> bool: + """Backward-compatible spherical flag (`True` for CEA manifold).""" + return self.manifold == "cea" + + @property + def target_areas(self) -> np.ndarray: + return _sum_matrix_axis_1d(self._areas, axis=1) + + @property + def source_coverage_areas(self) -> np.ndarray: + return _sum_matrix_axis_1d(self._areas, axis=0) + + @property + def forward_weights(self) -> "sparse.COO | np.ndarray": + """The row-normalized forward weight matrix (source → target).""" + if self._fwd_weights is None: + self._fwd_weights = _row_normalize(self._areas) + return self._fwd_weights + + @property + def backward_weights(self) -> "sparse.COO | np.ndarray": + """The row-normalized backward weight matrix (target → source).""" + if self._bwd_weights is None: + self._bwd_weights = _row_normalize(_transpose_weights(self._areas)) + return self._bwd_weights + + @cached_property + def _forward_apply(self) -> "sparse.COO | np.ndarray": + # Transposed + index-sorted once so the matmul in _apply_core is + # (..., n_src) @ (n_src, n_dst) with no per-call sort. + return _transpose_weights(self.forward_weights, sort=True) + + @cached_property + def _backward_apply(self) -> "sparse.COO | np.ndarray": + return _transpose_weights(self.backward_weights, sort=True) + + @cached_property + def _forward_coverage(self) -> np.ndarray: + return _coverage_mask(self._areas) + + @cached_property + def _backward_coverage(self) -> np.ndarray: + return _coverage_mask(_transpose_weights(self._areas)) + + def regrid( + self, + data: xr.DataArray | xr.Dataset, + skipna: bool = True, + nan_threshold: float = 1.0, + ) -> xr.DataArray | xr.Dataset: + """Regrid ``data`` forward (source → target).""" + return _apply_stored_weights( + data, + apply_weights=self._forward_apply, + coverage=self._forward_coverage, + src_dims=self._src_dims, + dst_dims=self._dst_dims, + src_shape=self._src_shape, + dst_shape=self._dst_shape, + target_coords=self._target_coords, + x_coord=self.x_coord, + y_coord=self.y_coord, + skipna=skipna, + nan_threshold=nan_threshold, + ) + + def __call__( + self, + data: xr.DataArray | xr.Dataset, + skipna: bool = True, + nan_threshold: float = 1.0, + ) -> xr.DataArray | xr.Dataset: + return self.regrid(data, skipna=skipna, nan_threshold=nan_threshold) + + def transpose(self) -> "ConservativeRegridder": + """Return the backward regridder (target → source), sharing the + underlying area matrix and any already-computed cached weight + matrices (forward on the transposed regridder is backward on the + original, and vice versa).""" + new = object.__new__(ConservativeRegridder) + new.x_coord = self.x_coord + new.y_coord = self.y_coord + new.manifold = self.manifold + new._src_dims = self._dst_dims + new._dst_dims = self._src_dims + new._src_shape = self._dst_shape + new._dst_shape = self._src_shape + new._areas = _transpose_weights(self._areas) + new._source_coords = self._target_coords + new._target_coords = self._source_coords + new._fwd_weights = None + new._bwd_weights = None + swap = { + "forward_weights": "backward_weights", + "backward_weights": "forward_weights", + "_forward_apply": "_backward_apply", + "_backward_apply": "_forward_apply", + "_forward_coverage": "_backward_coverage", + "_backward_coverage": "_forward_coverage", + } + for src, dst in swap.items(): + if src in self.__dict__: + new.__dict__[dst] = self.__dict__[src] + return new + + @property + def T(self) -> "ConservativeRegridder": # noqa: N802 + """Alias for :meth:`transpose` (numpy-style transpose).""" + return self.transpose() + + def __repr__(self) -> str: + nnz = getattr(self._areas, "nnz", None) + shape = getattr(self._areas, "shape", (None, None)) + nnz_str = f"nnz={nnz}" if nnz is not None else "dense" + return ( + f"ConservativeRegridder(src_dims={self._src_dims}, " + f"dst_dims={self._dst_dims}, {shape[0]}x{shape[1]}, {nnz_str})" + ) + + def to_netcdf(self, path: str | Path, engine: NetcdfEngine = None) -> None: + """Save the weight matrix and reproducibility metadata to a netCDF file. + + File layout: + + - root dataset: the sparse area matrix stored as three 1D variables + ``_coo_row``, ``_coo_col``, ``_coo_data`` (``shape=(n_dst, n_src)`` + carried on root-dataset attributes), plus the regridder metadata as + root attributes. + - ``/source_coords`` group: coord-only Dataset capturing the source grid. + - ``/target_coords`` group: coord-only Dataset capturing the target grid. + + Groups require an engine that supports them (``netcdf4`` or + ``h5netcdf``); ``engine`` is forwarded to :func:`xarray.Dataset.to_netcdf`. + """ + path = Path(path) + row, col, data, shape = _coo_components(self._areas) + ds_weights = xr.Dataset( + { + "_coo_row": (("nnz",), row), + "_coo_col": (("nnz",), col), + "_coo_data": (("nnz",), data), + }, + attrs={ + **_metadata_attrs(self), + "n_dst": int(shape[0]), + "n_src": int(shape[1]), + }, + ) + ds_weights.to_netcdf(path, mode="w", engine=engine) + self._source_coords.to_netcdf( + path, mode="a", group="source_coords", engine=engine + ) + self._target_coords.to_netcdf( + path, mode="a", group="target_coords", engine=engine + ) + + @classmethod + def from_netcdf( + cls, path: str | Path, engine: NetcdfEngine = None + ) -> "ConservativeRegridder": + """Reload a regridder previously written with :meth:`to_netcdf`. + + Validates ``schema_version``; raises :class:`ValueError` if the file + was written by an incompatible version. + """ + path = Path(path) + with xr.open_dataset(path, engine=engine) as ds_weights: + attrs = dict(ds_weights.attrs) + n_dst = int(attrs.pop("n_dst")) + n_src = int(attrs.pop("n_src")) + row = np.asarray(ds_weights["_coo_row"].values) + col = np.asarray(ds_weights["_coo_col"].values) + data = np.asarray(ds_weights["_coo_data"].values) + meta = _metadata_from_attrs(attrs, path) + + with xr.open_dataset(path, group="source_coords", engine=engine) as g: + source_coords = g.load() + with xr.open_dataset(path, group="target_coords", engine=engine) as g: + target_coords = g.load() + + return cls._from_state( + areas=_coo_from_components(row, col, data, (n_dst, n_src)), + source_coords=source_coords, + target_coords=target_coords, + **meta, + ) + + @classmethod + def _from_state( + cls, + *, + areas: "sparse.COO | np.ndarray", + source_coords: xr.Dataset, + target_coords: xr.Dataset, + src_dims: tuple[Hashable, ...], + dst_dims: tuple[Hashable, ...], + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + x_coord: str, + y_coord: str, + manifold: str, + ) -> "ConservativeRegridder": + """Construct a regridder directly from its canonical state. Shared + bypass of ``__init__`` used by :meth:`from_netcdf` and + :meth:`from_polygons`; keeps the list of private attrs in one place.""" + # Defense-in-depth: protect against a corrupted netCDF file setting + # an unknown manifold that would crash later at regrid time. + _check_manifold(manifold) + instance = object.__new__(cls) + instance.x_coord = x_coord + instance.y_coord = y_coord + instance.manifold = cast("Manifold", manifold) + instance._src_dims = src_dims + instance._dst_dims = dst_dims + instance._src_shape = src_shape + instance._dst_shape = dst_shape + instance._areas = areas + instance._source_coords = source_coords + instance._target_coords = target_coords + instance._fwd_weights = None + instance._bwd_weights = None + return instance + + @classmethod + def from_polygons( + cls, + source_polygons: np.ndarray, + target_polygons: np.ndarray, + source_dim: str = "cell", + target_dim: str = "cell", + target_coords: xr.Dataset | None = None, + periodic: bool = False, + n_threads: int | None = None, + predicate_filter: bool = True, + ) -> "ConservativeRegridder": + """Build a regridder from explicit shapely polygon arrays. + + Use this for unstructured meshes (MPAS, ICON, finite-element), arbitrary + polygon targets (countries, watersheds), or any combination the + structured-grid path cannot express. + + Args: + source_polygons: 1D array of shapely Polygons for source cells. + target_polygons: 1D array of shapely Polygons for target cells. + source_dim: Name of the single dim carrying source cells on the + data passed to :meth:`regrid`. Default ``"cell"``. + target_dim: Name of the single dim carrying target cells on the + output. Default ``"cell"``. + target_coords: Optional xr.Dataset providing coordinate variables + along ``target_dim`` (and any auxiliary coords) to reattach on + the output. If None, the output is given a bare integer index. + periodic: Treat polygon x coordinates as longitudes on a 360-degree + periodic axis, so polygons that cross the antimeridian are + unwrapped before intersection. + n_threads: Thread count for GEOS intersection. + predicate_filter: If True (default), the STRtree candidate query + filters by GEOS ``intersects``. Safe for arbitrary polygons + including thin/diagonal shapes with loose bboxes. Set False + when your polygons have tight bboxes (low aspect ratio, + roughly axis-aligned) to skip the predicate and let the + ``area > 0`` filter drop false positives — usually faster + in that case, pathological otherwise. + + Returns: + A ``ConservativeRegridder`` that accepts data with ``source_dim`` + in place of the structured grid's spatial dims. + + Intersection geometry is planar in the input polygons' coordinate + space. If the polygons represent lat/lon cells, project them into an + equal-area CRS first (or use the structured path with ``manifold='cea'``). + """ + _check_shapely() + src_polys = np.asarray(source_polygons) + dst_polys = np.asarray(target_polygons) + if src_polys.ndim != 1: + msg = "source_polygons must be a 1D array of shapely Polygons" + raise ValueError(msg) + if dst_polys.ndim != 1: + msg = "target_polygons must be a 1D array of shapely Polygons" + raise ValueError(msg) + if periodic: + src_polys = _normalize_periodic_polygons(src_polys) + dst_polys = _normalize_periodic_polygons( + dst_polys, reference=_polygon_reference_x(src_polys) + ) + + src_grid = _Grid( + polys=src_polys, + bounds=shapely.bounds(src_polys), + rectilinear=False, + ) + dst_grid = _Grid( + polys=dst_polys, + bounds=shapely.bounds(dst_polys), + rectilinear=False, + ) + n_src = int(src_polys.size) + n_dst = int(dst_polys.size) + tgt_ds = ( + target_coords + if target_coords is not None + else xr.Dataset(coords={target_dim: np.arange(n_dst)}) + ) + return cls._from_state( + areas=_build_intersection_areas( + src_grid, + dst_grid, + n_threads=n_threads, + predicate_filter=predicate_filter, + ), + source_coords=xr.Dataset(coords={source_dim: np.arange(n_src)}), + target_coords=tgt_ds, + src_dims=(source_dim,), + dst_dims=(target_dim,), + src_shape=(n_src,), + dst_shape=(n_dst,), + x_coord="", + y_coord="", + manifold="planar", + ) + + +def polygons_from_coords( + x: np.ndarray, + y: np.ndarray, + manifold: Literal["planar", "cea"] = "planar", + periodic: bool = False, +) -> np.ndarray: + """Build a 1D array of shapely cell polygons from 1D or 2D center coords. + + Convenience for mixing structured and unstructured regridding. E.g. build + target polygons for a regular lat/lon grid, then pass them together with + an unstructured source mesh to :meth:`ConservativeRegridder.from_polygons`. + + Args: + x: 1D or 2D array of cell-center x coordinates. + y: 1D or 2D array of cell-center y coordinates. + manifold: ``"planar"`` (default) keeps raw coords; ``"cea"`` projects + 1D lat/lon (degrees) to Lambert cylindrical equal-area space for + correct spherical area weights with the planar fast path. The s2 + manifold is not available here (spherely Geographies aren't + shapely polygons); use ``ConservativeRegridder(..., manifold="s2")``. + periodic: Treat x coordinates as longitudes on a 360-degree periodic + axis, so cells that cross the antimeridian are unwrapped before + polygon construction. + + Returns: + A 1D numpy array of shapely Polygons in row-major (y, x) order. + """ + _check_shapely() + if manifold not in ("planar", "cea"): + msg = ( + f"polygons_from_coords supports manifold 'planar' or 'cea'; " + f"got {manifold!r}. For s2, use ConservativeRegridder directly." + ) + raise ValueError(msg) + x = np.asarray(x) + y = np.asarray(y) + if periodic: + x = _unwrap_longitude(x) + if manifold == "cea": + if x.ndim != 1 or y.ndim != 1: + msg = "manifold='cea' requires 1D lat/lon arrays" + raise ValueError(msg) + return _build_cea_grid(x, y).polys + return _build_grid(x, y).polys + + +def conservative_2d_regrid( + data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + manifold: Manifold = "planar", + skipna: bool = True, + nan_threshold: float = 1.0, + n_threads: int | None = None, +) -> xr.DataArray | xr.Dataset: + """Conservative regridding via explicit 2D polygon intersection. + + One-shot convenience wrapper: constructs a :class:`ConservativeRegridder` + and applies it once. For repeated regridding to the same target, construct + a ``ConservativeRegridder`` directly and reuse it. + + Args: + data: Input data on the source grid. + target_ds: Dataset defining the target grid; must expose ``x_coord`` and + ``y_coord`` as (1D or 2D) coordinate variables. + x_coord: Name of the x (longitude-like) coordinate variable. + y_coord: Name of the y (latitude-like) coordinate variable. + manifold: Geometry used for intersection. ``"planar"`` (default) uses + raw planar shapely math in the provided coordinate space; + ``"cea"`` projects 1D lat/lon (degrees) into Lambert cylindrical + equal-area space — correct spherical areas at the planar fast + path's cost; ``"s2"`` uses s2geometry via the optional ``spherely`` + dependency for true great-circle edges. + skipna: If True, propagate NaNs into the weighted mean via a two-pass + sum. + nan_threshold: Keep output cells whose valid source fraction is at + least ``nan_threshold``. + n_threads: Thread count for parallel GEOS intersection (planar path). + + Returns: + Regridded data on the target grid, preserving non-spatial dims. + """ + src_rect = ( + x_coord in data.coords + and y_coord in data.coords + and data[x_coord].ndim == 1 + and data[y_coord].ndim == 1 + and data[x_coord].dims[0] != data[y_coord].dims[0] + ) + tgt_rect = ( + x_coord in target_ds.coords + and y_coord in target_ds.coords + and target_ds[x_coord].ndim == 1 + and target_ds[y_coord].ndim == 1 + and target_ds[x_coord].dims[0] != target_ds[y_coord].dims[0] + ) + src_x = np.asarray(data[x_coord].values) if src_rect else np.array([]) + tgt_x = np.asarray(target_ds[x_coord].values) if tgt_rect else np.array([]) + src_finite = src_x[np.isfinite(src_x)] + tgt_finite = tgt_x[np.isfinite(tgt_x)] + cross_convention = bool( + src_finite.size + and tgt_finite.size + and ((src_finite.min() < 0) != (tgt_finite.min() < 0)) + ) + if src_rect and tgt_rect and manifold in ("planar", "cea") and cross_convention: + latitude_coord = y_coord if manifold == "cea" else None + return data.regrid.conservative( + target_ds, + latitude_coord=latitude_coord, + skipna=skipna, + nan_threshold=nan_threshold, + ) + + regridder = ConservativeRegridder( + data, + target_ds, + x_coord=x_coord, + y_coord=y_coord, + manifold=manifold, + n_threads=n_threads, + ) + return regridder.regrid(data, skipna=skipna, nan_threshold=nan_threshold) + + +def _apply_stored_weights( + data: xr.DataArray | xr.Dataset, + apply_weights: "sparse.COO | np.ndarray", + coverage: np.ndarray, + src_dims: tuple[Hashable, ...], + dst_dims: tuple[Hashable, ...], + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + target_coords: xr.Dataset, + x_coord: str, + y_coord: str, + skipna: bool, + nan_threshold: float, +) -> xr.DataArray | xr.Dataset: + """Apply a cached, pre-transposed weight matrix to ``data`` via + ``xr.apply_ufunc``. + + ``apply_weights`` has shape ``(n_src, n_dst)`` so the matmul is + ``(..., n_src) @ (n_src, n_dst) → (..., n_dst)`` with no per-call transpose. + """ + # apply_ufunc(dask="parallelized") needs each core dim to be a single + # chunk. Only rechunk if data is already dask-backed — don't inadvertently + # dask-ify a numpy-backed input. + if getattr(data, "chunks", None) is not None: + split = { + d: -1 + for d in src_dims + if d in data.dims and len(data.chunksizes.get(d, ())) > 1 + } + if split: + data = data.chunk(split) + + actual_src_shape = tuple(int(data.sizes[d]) for d in src_dims if d in data.sizes) + if actual_src_shape != src_shape: + msg = ( + f"source spatial shape {actual_src_shape} on dims {src_dims} does " + f"not match the regridder's expected shape {src_shape}" + ) + raise ValueError(msg) + + src_tokens = tuple(f"__src_{d}" for d in src_dims) + data_renamed = data.rename(dict(zip(src_dims, src_tokens, strict=True))) + + output_dtype = _result_dtype(data) + result = xr.apply_ufunc( + _apply_core, + data_renamed, + kwargs={ + "apply_weights": apply_weights, + "coverage": coverage, + "coverage_all": bool(coverage.all()), + "src_shape": src_shape, + "dst_shape": dst_shape, + "skipna": skipna, + "nan_threshold": nan_threshold, + "output_dtype": output_dtype, + }, + input_core_dims=[list(src_tokens)], + output_core_dims=[list(dst_dims)], + exclude_dims=set(src_tokens), + dask="parallelized", + output_dtypes=[output_dtype], + dask_gufunc_kwargs={ + "output_sizes": {d: int(target_coords.sizes[d]) for d in dst_dims}, + "allow_rechunk": False, + }, + keep_attrs=True, + ) + + result = _assign_target_coords(result, target_coords, dst_dims, x_coord, y_coord) + return result + + +def _coverage_mask(areas: "sparse.COO | np.ndarray") -> np.ndarray: + """Return a boolean ``(n_dst,)`` mask of target cells with any source + overlap, derived from the raw area matrix.""" + if _HAS_SPARSE and isinstance(areas, sparse.COO): + # COO sparse: row_sum is 0 iff row has no nonzero entries. + n_dst = int(areas.shape[0]) + mask = np.zeros(n_dst, dtype=bool) + mask[areas.coords[0]] = True + return mask + arr = np.asarray(areas) + return np.asarray((arr > 0).any(axis=1)) + + +def _sum_matrix_axis_1d(areas: "sparse.COO | np.ndarray", axis: int) -> np.ndarray: + summed = areas.sum(axis=axis) + if hasattr(summed, "todense"): + summed = summed.todense() + return np.asarray(summed, dtype=np.float64).reshape(-1) + + +def _coo_components( + w: "sparse.COO | np.ndarray", +) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple[int, int]]: + if _HAS_SPARSE and isinstance(w, sparse.COO): + coords = np.asarray(w.coords) + return ( + coords[0].astype(np.int64, copy=False), + coords[1].astype(np.int64, copy=False), + np.asarray(w.data), + w.shape, + ) + arr = np.asarray(w) + rows, cols = np.nonzero(arr) + return ( + rows.astype(np.int64, copy=False), + cols.astype(np.int64, copy=False), + arr[rows, cols], + arr.shape, + ) + + +def _coo_from_components( + row: np.ndarray, + col: np.ndarray, + data: np.ndarray, + shape: tuple[int, int], +) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([row, col]), + data=data, + shape=shape, + has_duplicates=False, + sorted=False, + ) + dense = np.zeros(shape, dtype=data.dtype if data.size else np.float64) + dense[row, col] = data + return dense + + +def _metadata_attrs(regridder: ConservativeRegridder) -> dict[str, Any]: + attrs: dict[str, Any] = { + "x_coord": regridder.x_coord, + "y_coord": regridder.y_coord, + "manifold": str(regridder.manifold), + "src_dims": [str(d) for d in regridder._src_dims], + "dst_dims": [str(d) for d in regridder._dst_dims], + "src_shape": list(regridder._src_shape), + "dst_shape": list(regridder._dst_shape), + "xarray_regrid_version": _package_version(), + "created": datetime.now(tz=timezone.utc).isoformat(), + "schema_version": _SCHEMA_VERSION, + } + for name, value in { + "source_x_range": _coord_range(regridder._source_coords, regridder.x_coord), + "source_y_range": _coord_range(regridder._source_coords, regridder.y_coord), + "target_x_range": _coord_range(regridder._target_coords, regridder.x_coord), + "target_y_range": _coord_range(regridder._target_coords, regridder.y_coord), + }.items(): + if value is not None: + attrs[name] = list(value) + return attrs + + +def _metadata_from_attrs(attrs: dict[str, Any], path: Path) -> dict[str, Any]: + """Parse the kwargs needed by :meth:`ConservativeRegridder._from_state` out + of netCDF root attributes, validating ``schema_version``.""" + schema_version = int(attrs.get("schema_version", 0)) + if schema_version != _SCHEMA_VERSION: + msg = ( + f"regridder file at {path} uses schema version {schema_version}; " + f"this xarray-regrid understands {_SCHEMA_VERSION}. " + "Upgrade xarray-regrid or re-save." + ) + raise ValueError(msg) + + return { + "x_coord": str(attrs["x_coord"]), + "y_coord": str(attrs["y_coord"]), + "manifold": str(attrs["manifold"]), + "src_dims": tuple(str(d) for d in np.atleast_1d(attrs["src_dims"])), + "dst_dims": tuple(str(d) for d in np.atleast_1d(attrs["dst_dims"])), + "src_shape": tuple(int(s) for s in np.atleast_1d(attrs["src_shape"])), + "dst_shape": tuple(int(s) for s in np.atleast_1d(attrs["dst_shape"])), + } + + +def _normalize_longitude_coords( + source: xr.DataArray | xr.Dataset, + target: xr.Dataset, + x_coord: str, +) -> tuple[xr.DataArray | xr.Dataset, xr.Dataset]: + """Unwrap x coordinates across the antimeridian so source and target share + a contiguous longitude frame. No-op when the coord isn't present on both + objects or doesn't look like a longitude.""" + if x_coord not in source.coords or x_coord not in target.coords: + return source, target + + source_x = np.asarray(source[x_coord].values) + target_x = np.asarray(target[x_coord].values) + if not _looks_like_longitude(source_x) and not _looks_like_longitude(target_x): + return source, target + + source_x = _unwrap_longitude(source_x) + target_x = _align_longitude(_unwrap_longitude(target_x), source_x) + return ( + utils.update_coord(source, x_coord, source_x), + cast(xr.Dataset, utils.update_coord(target, x_coord, target_x)), + ) + + +def _looks_like_longitude(values: np.ndarray) -> bool: + finite = np.asarray(values, dtype=float) + finite = finite[np.isfinite(finite)] + if finite.size == 0: + return False + return bool(finite.min() >= -360.0 and finite.max() <= 360.0) + + +def _unwrap_longitude(values: np.ndarray) -> np.ndarray: + radians = np.deg2rad(np.asarray(values, dtype=float)) + for axis in range(radians.ndim): + radians = np.unwrap(radians, axis=axis) + return np.rad2deg(radians) + + +def _align_longitude(values: np.ndarray, reference: np.ndarray) -> np.ndarray: + if values.size == 0 or reference.size == 0: + return values + offset = 360.0 * round((np.nanmean(reference) - np.nanmean(values)) / 360.0) + return values + offset + + +def _normalize_periodic_polygons( + polygons: np.ndarray, reference: float | None = None +) -> np.ndarray: + normalized = [] + current_reference = reference + for polygon in polygons: + new_polygon = _unwrap_polygon(polygon) + center = _polygon_reference_x(np.array([new_polygon], dtype=object)) + if current_reference is None: + current_reference = center + offset = 360.0 * round((current_reference - center) / 360.0) + if offset: + new_polygon = affinity.translate(new_polygon, xoff=offset) + normalized.append(new_polygon) + return np.array(normalized, dtype=object) + + +def _polygon_reference_x(polygons: np.ndarray) -> float: + bounds = shapely.bounds(polygons) + centers = 0.5 * (bounds[:, 0] + bounds[:, 2]) + return float(np.nanmean(centers)) + + +def _unwrap_polygon(polygon: Any) -> Any: + if polygon.is_empty: + return polygon + if polygon.geom_type == "Polygon": + exterior = _unwrap_ring(np.asarray(polygon.exterior.coords)) + holes = [_unwrap_ring(np.asarray(ring.coords)) for ring in polygon.interiors] + return shapely.Polygon(exterior, holes) + if polygon.geom_type == "MultiPolygon": + return shapely.MultiPolygon([_unwrap_polygon(part) for part in polygon.geoms]) + return polygon + + +def _unwrap_ring(ring: np.ndarray) -> np.ndarray: + new_ring = np.asarray(ring, dtype=float).copy() + offset = 0.0 + for i in range(1, new_ring.shape[0]): + x = new_ring[i, 0] + offset + step = x - new_ring[i - 1, 0] + if step > 180.0: + offset -= 360.0 + elif step < -180.0: + offset += 360.0 + new_ring[i, 0] += offset + return new_ring + + +def _coord_range(ds: xr.Dataset, coord_name: str) -> tuple[float, float] | None: + """Return ``(min, max)`` of a coord, or ``None`` when it isn't on the + Dataset (e.g., the integer-index stub emitted by ``from_polygons``).""" + if not coord_name or coord_name not in ds.coords: + return None + arr = np.asarray(ds[coord_name].values) + if arr.size == 0: + return None + return float(arr.min()), float(arr.max()) + + +def _transpose_weights( + w: "sparse.COO | np.ndarray", *, sort: bool = False +) -> "sparse.COO | np.ndarray": + """Materialize a transposed weight matrix. + + `sparse.COO.T` is a lazy view that re-sorts indices on each downstream + matmul, so callers relying on ``.coords[0]`` being row indices or wanting + a hot matmul path should materialize once here. Pass ``sort=True`` to + additionally trigger the sort ahead of time (used for the apply matrix). + """ + if _HAS_SPARSE and isinstance(w, sparse.COO): + t = w.T + out = sparse.COO( + coords=np.asarray(t.coords), + data=np.asarray(t.data), + shape=t.shape, + has_duplicates=False, + sorted=False, + ) + if sort: + out._sort_indices() + return out + return np.asarray(w).T.copy() + + +def _spatial_dims( + obj: xr.DataArray | xr.Dataset, x_coord: str, y_coord: str +) -> tuple[Hashable, ...]: + """Return the spatial dim order to feed to apply_ufunc. + + For 1D rectilinear coords (x and y ride on separate dims), canonicalize to + ``(y_dim, x_dim)`` so the flattened cell order matches the polygon order + emitted by the fast-path in ``_build_grid``. For curvilinear (2D) coords, + preserve the dim order already present on ``obj``. + """ + if x_coord not in obj.coords or y_coord not in obj.coords: + return () + xd = obj[x_coord].dims + yd = obj[y_coord].dims + if len(xd) == 1 and len(yd) == 1 and xd[0] != yd[0]: + return (yd[0], xd[0]) + dims = set(xd) | set(yd) + return tuple(d for d in obj.dims if d in dims) + + +def _grid_from_coords( + obj: xr.DataArray | xr.Dataset, + x_coord: str, + y_coord: str, + dims: tuple[Hashable, ...], + manifold: Manifold = "planar", +) -> "_Grid": + """Build a :class:`_Grid` from the object's x/y coordinates under a given + manifold (``"planar"``, ``"cea"``, ``"s2"``). + + For 1D coords on separate dims we take the rectilinear fast path; otherwise + we broadcast to a common N-D center array and build curvilinear polygons. + """ + xd = obj[x_coord] + yd = obj[y_coord] + is_rectilinear = xd.ndim == 1 and yd.ndim == 1 and xd.dims[0] != yd.dims[0] + # `cea` and `s2` derive cell edges from 1D lat/lon; both require rectilinear. + if manifold in ("cea", "s2"): + if not is_rectilinear: + msg = ( + f"manifold={manifold!r} is only supported for rectilinear " + "(1D lat/lon) coordinate arrays." + ) + raise NotImplementedError(msg) + builder = _build_cea_grid if manifold == "cea" else _build_s2_grid + return builder(np.asarray(xd.values), np.asarray(yd.values)) + if is_rectilinear: + return _build_grid(np.asarray(xd.values), np.asarray(yd.values)) + xc, yc = xr.broadcast(obj[x_coord], obj[y_coord]) + xc = xc.transpose(*dims) + yc = yc.transpose(*dims) + return _build_grid(np.asarray(xc.values), np.asarray(yc.values)) + + +def _latlon_edges_deg( + lon_centers: np.ndarray, lat_centers: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Derive cell-edge arrays from 1D lat/lon center arrays (degrees). + + Shared by the CEA and s2 builders; clips latitudes at the poles. Projecting + *edges* analytically — rather than projecting centers and then + re-midpointing — is required because the projections (``sin`` for CEA, the + great-circle geometry for s2) are nonlinear: the projected midpoint of two + lat centers is not the same as the midpoint of two projected lat edges. + """ + if lon_centers.size < 2 or lat_centers.size < 2: + msg = "rectilinear lat/lon grids need at least two cells per dimension" + raise ValueError(msg) + return ( + utils.infer_1d_edges(lon_centers), + np.clip(utils.infer_1d_edges(lat_centers), -90.0, 90.0), + ) + + +def _build_cea_grid(lon_centers: np.ndarray, lat_centers: np.ndarray) -> "_Grid": + """Build a rectilinear :class:`_Grid` whose cell polygons are in Lambert + cylindrical equal-area coordinates (x' = lon_rad, y' = sin(lat_rad)).""" + _check_shapely() + lon_edges_deg, lat_edges_deg = _latlon_edges_deg(lon_centers, lat_centers) + return _rect_grid_from_edges( + np.deg2rad(lon_edges_deg), + np.sin(np.deg2rad(lat_edges_deg)), + ) + + +def _build_s2_grid(lon_centers: np.ndarray, lat_centers: np.ndarray) -> "_Grid": + """Build a rectilinear :class:`_Grid` with s2geometry great-circle cell + polygons attached (``_Grid.s2_polys``). + + The planar shapely polygons + bounds stay in place so the STRtree + candidate-pair filter works — spherely does not yet expose a spatial + index (tracked at https://github.com/benbovy/spherely/issues/72). The + planar bbox is a conservative over-estimate of overlap, so the filter + is lossless for the s2 intersection downstream. + """ + _check_shapely() + _check_spherely() + lon_edges_deg, lat_edges_deg = _latlon_edges_deg(lon_centers, lat_centers) + planar = _rect_grid_from_edges(lon_edges_deg, lat_edges_deg) + return _Grid( + polys=planar.polys, + bounds=planar.bounds, + rectilinear=True, + s2_polys=_s2_cell_polys(lon_edges_deg, lat_edges_deg), + ) + + +def _s2_cell_polys(lon_edges_deg: np.ndarray, lat_edges_deg: np.ndarray) -> np.ndarray: + """Build an (n_cells,) object array of spherely.Geography polygons. + + Each polygon uses explicit CCW vertex order in (lon, lat) degrees and + ``oriented=True`` so orientation is unambiguous for cells that touch the + poles or span wide longitude ranges. + """ + if hasattr(spherely, "polygons"): + # Vectorized constructor (benbovy/spherely#52, not yet released). When + # it lands we build shells as (n, 4, 2) and skip the Python loop. + x0, y0 = np.meshgrid(lon_edges_deg[:-1], lat_edges_deg[:-1], indexing="xy") + x1, y1 = np.meshgrid(lon_edges_deg[1:], lat_edges_deg[1:], indexing="xy") + shells = np.stack( + [ + np.stack([x0, y0], axis=-1), + np.stack([x1, y0], axis=-1), + np.stack([x1, y1], axis=-1), + np.stack([x0, y1], axis=-1), + ], + axis=-2, + ).reshape(-1, 4, 2) + return np.asarray(spherely.polygons(shells, oriented=True)) + # Per-cell fallback: `spherely.create_polygon` wants an iterable of tuples, + # and a Python loop over pre-slicing an ndarray is slower than building + # the tuple list inline. + nx = lon_edges_deg.size - 1 + ny = lat_edges_deg.size - 1 + polys = np.empty(ny * nx, dtype=object) + for j in range(ny): + y0f, y1f = float(lat_edges_deg[j]), float(lat_edges_deg[j + 1]) + for i in range(nx): + x0f, x1f = float(lon_edges_deg[i]), float(lon_edges_deg[i + 1]) + shell = [(x0f, y0f), (x1f, y0f), (x1f, y1f), (x0f, y1f)] + polys[j * nx + i] = spherely.create_polygon(shell, oriented=True) + return polys + + +def _infer_2d_corners(xc: np.ndarray, yc: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Return (ny+1, nx+1) cell-corner arrays from 2D cell-center arrays. + + Interior corners are the mean of the four surrounding centers; boundary + corners are reflected from the adjacent interior row/column. + """ + if xc.shape != yc.shape or xc.ndim != 2: + msg = "xc and yc must be 2D arrays of the same shape" + raise ValueError(msg) + + def corners(a: np.ndarray) -> np.ndarray: + ny, nx = a.shape + pad = np.empty((ny + 2, nx + 2), dtype=a.dtype) + pad[1:-1, 1:-1] = a + pad[0, 1:-1] = 2 * a[0, :] - a[1, :] + pad[-1, 1:-1] = 2 * a[-1, :] - a[-2, :] + pad[1:-1, 0] = 2 * a[:, 0] - a[:, 1] + pad[1:-1, -1] = 2 * a[:, -1] - a[:, -2] + pad[0, 0] = 2 * pad[0, 1] - pad[0, 2] + pad[0, -1] = 2 * pad[0, -2] - pad[0, -3] + pad[-1, 0] = 2 * pad[-1, 1] - pad[-1, 2] + pad[-1, -1] = 2 * pad[-1, -2] - pad[-1, -3] + return 0.25 * (pad[:-1, :-1] + pad[1:, :-1] + pad[:-1, 1:] + pad[1:, 1:]) + + return corners(xc.astype(float)), corners(yc.astype(float)) + + +@dataclass +class _Grid: + """Cached cell geometry for a structured grid. + + ``polys`` is a flat (n_cells,) object array of shapely Polygons used for + the STRtree candidate-pair filter. ``bounds`` is a (n_cells, 4) axis- + aligned bbox array cached off the polys for the analytic fast path and + any coarse bbox queries. ``rectilinear`` enables the analytic box-overlap + path in :func:`_build_intersection_areas`. + + ``s2_polys`` is an optional (n_cells,) object array of spherely + ``Geography`` polygons in the s2geometry manifold. When both the source + and target grids carry ``s2_polys``, the weight builder uses + ``spherely.intersection`` + ``spherely.area`` instead of planar shapely — + great-circle edges and true spherical areas. + """ + + polys: np.ndarray + bounds: np.ndarray + rectilinear: bool + s2_polys: np.ndarray | None = None + + +def _rect_grid_from_edges(xe: np.ndarray, ye: np.ndarray) -> _Grid: + """Build a rectilinear :class:`_Grid` from already-prepared edge arrays. + + Shared by the raw-planar (:func:`_build_grid` 1D branch) and the analytic + equal-area (:func:`_build_cea_grid`) paths. + """ + x0, y0 = np.meshgrid(xe[:-1], ye[:-1], indexing="xy") + x1, y1 = np.meshgrid(xe[1:], ye[1:], indexing="xy") + x0f, y0f, x1f, y1f = x0.ravel(), y0.ravel(), x1.ravel(), y1.ravel() + polys = shapely.box(x0f, y0f, x1f, y1f) + bounds = np.stack([x0f, y0f, x1f, y1f], axis=1) + return _Grid(polys=polys, bounds=bounds, rectilinear=True) + + +def _build_grid(xc: np.ndarray, yc: np.ndarray) -> _Grid: + """Return a _Grid of cell geometry for a structured grid. + + Accepts 1D (rectilinear, separate x and y vectors) or 2D (curvilinear, + co-shaped center arrays) inputs. Output order is row-major in the input dim + order: for 2D inputs of shape (ny, nx) the polygons correspond to cells + reshaped as ``(ny, nx)``. + """ + _check_shapely() + if xc.ndim == 1 and yc.ndim == 1: + xe = utils.infer_1d_edges(xc.astype(float)) + ye = utils.infer_1d_edges(yc.astype(float)) + return _rect_grid_from_edges(xe, ye) + if xc.ndim == 2 and yc.ndim == 2: + xcorn, ycorn = _infer_2d_corners(xc, yc) + ny, nx = xc.shape + c00 = np.stack([xcorn[:-1, :-1], ycorn[:-1, :-1]], axis=-1) + c10 = np.stack([xcorn[:-1, 1:], ycorn[:-1, 1:]], axis=-1) + c11 = np.stack([xcorn[1:, 1:], ycorn[1:, 1:]], axis=-1) + c01 = np.stack([xcorn[1:, :-1], ycorn[1:, :-1]], axis=-1) + rings = np.stack([c00, c10, c11, c01, c00], axis=2).reshape(ny * nx, 5, 2) + polys = shapely.polygons(rings) + return _Grid(polys=polys, bounds=shapely.bounds(polys), rectilinear=False) + msg = "x and y coordinate arrays must both be 1D or both 2D" + raise ValueError(msg) + + +def _build_intersection_areas( + src: _Grid, + dst: _Grid, + n_threads: int | None = None, + *, + predicate_filter: bool = False, +) -> "sparse.COO | np.ndarray": + """Build the (n_dst, n_src) raw area-intersection matrix ``A[i, j] = + area(dst_i ∩ src_j)``. + + This is the unnormalized matrix. Row-normalize via :func:`_row_normalize` + to get forward weights; transpose first for backward (target → source). + + Dispatches to three paths: + 1. both grids carry ``s2_polys`` → spherely great-circle intersection + (planar bboxes are used only as a lossless candidate-pair filter + because spherely has no spatial index yet). + 2. both grids are rectilinear → analytic box-overlap (numpy only). + 3. otherwise → threaded shapely polygon clipping. + + ``predicate_filter=False`` (default) uses a bbox-only STRtree query and + relies on the ``area > 0`` filter below to drop bbox-false-positives. + For structured cells whose bboxes are tight (quadrilaterals) this is a + large win — the GEOS ``intersects`` predicate inside STRtree is much + more expensive than the extra no-op intersections it avoids. Set + ``predicate_filter=True`` for user-supplied polygons with loose bboxes + (long, thin, diagonal shapes) where the predicate pays for itself. + """ + _check_shapely() + n_dst = len(dst.polys) + n_src = len(src.polys) + + tree = STRtree(src.polys) + if predicate_filter: + pairs = tree.query(dst.polys, predicate="intersects") + else: + pairs = tree.query(dst.polys) + dst_idx = np.asarray(pairs[0]) + src_idx = np.asarray(pairs[1]) + + if dst_idx.size == 0: + return _empty_weights(n_dst, n_src) + + if src.s2_polys is not None and dst.s2_polys is not None: + areas = _s2_intersection_areas( + dst.s2_polys[dst_idx], src.s2_polys[src_idx], n_threads=n_threads + ) + elif src.rectilinear and dst.rectilinear: + sb = src.bounds[src_idx] + db = dst.bounds[dst_idx] + dx = np.minimum(sb[:, 2], db[:, 2]) - np.maximum(sb[:, 0], db[:, 0]) + dy = np.minimum(sb[:, 3], db[:, 3]) - np.maximum(sb[:, 1], db[:, 1]) + areas = np.maximum(dx, 0.0) * np.maximum(dy, 0.0) + else: + areas = _intersection_areas_threaded( + dst.polys[dst_idx], src.polys[src_idx], n_threads=n_threads + ) + + keep = areas > 0 + dst_idx = dst_idx[keep] + src_idx = src_idx[keep] + areas = areas[keep] + + if dst_idx.size == 0: + return _empty_weights(n_dst, n_src) + + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([dst_idx, src_idx]), + data=areas.astype(np.float64), + shape=(n_dst, n_src), + has_duplicates=False, + sorted=False, + ) + a_dense = np.zeros((n_dst, n_src), dtype=np.float64) + a_dense[dst_idx, src_idx] = areas + return a_dense + + +def _row_normalize( + areas: "sparse.COO | np.ndarray", +) -> "sparse.COO | np.ndarray": + """Normalize rows of an area matrix so each row sums to 1 (rows with no + overlap stay all-zero, which produces NaN output under the apply path).""" + if _HAS_SPARSE and isinstance(areas, sparse.COO): + n_dst = areas.shape[0] + dst_idx = areas.coords[0] + src_idx = areas.coords[1] + data = areas.data + if data.size == 0: + return areas + row_sum = np.bincount(dst_idx, weights=data, minlength=n_dst) + new_data = data / row_sum[dst_idx] + return sparse.COO( + coords=np.stack([dst_idx, src_idx]), + data=new_data, + shape=areas.shape, + has_duplicates=False, + sorted=False, + ) + row_sum = areas.sum(axis=1, keepdims=True) + row_sum = np.where(row_sum == 0, 1.0, row_sum) + return areas / row_sum + + +def _s2_intersection_areas( + dst_geog: np.ndarray, + src_geog: np.ndarray, + n_threads: int | None = None, +) -> np.ndarray: + """Per-pair great-circle intersection areas in steradians. + + ``spherely.intersection`` and ``spherely.area`` are both vectorized ufuncs + so the serial path is a single pass through each. ``radius=1.0`` gives the + result on the unit sphere — fine for row-normalized weights because the + Earth radius would cancel through the normalization anyway. + + Spherely releases the GIL during the s2 boolean op (on versions that ship + the gil_scoped_release patch), so chunking the pair array across a + ThreadPoolExecutor parallelises the heavy intersection step at near- + linear scaling up to ~4 cores. Falls back to serial on older spherely, + or when ``n_threads<=1`` or the workload is too small to amortise the + pool overhead. + """ + _check_spherely() + n = len(dst_geog) + if n_threads is None: + n_threads = min(os.cpu_count() or 1, 4) + if n < 50_000: + n_threads = 1 + if n_threads <= 1 or n == 0: + inter = spherely.intersection(dst_geog, src_geog) + return np.asarray(spherely.area(inter, radius=1.0)) + + splits = np.array_split(np.arange(n), n_threads) + + def _work(idx: np.ndarray) -> np.ndarray: + inter = spherely.intersection(dst_geog[idx], src_geog[idx]) + return np.asarray(spherely.area(inter, radius=1.0)) + + with ThreadPoolExecutor(max_workers=n_threads) as pool: + parts = list(pool.map(_work, splits)) + return np.concatenate(parts) + + +def _intersection_areas_threaded( + a: np.ndarray, b: np.ndarray, n_threads: int | None +) -> np.ndarray: + """Compute per-pair intersection areas ``area(a[i] & b[i])`` over numpy + arrays of shapely geometries, optionally parallelized via threads. + + Shapely 2.x releases the GIL for GEOS ops, so a ``ThreadPoolExecutor`` + gives near-linear speedup on multi-core machines without pickling data. + """ + _check_shapely() + n = len(a) + if n_threads is None: + n_threads = min(os.cpu_count() or 1, 8) + # Amortize thread-pool overhead only when there's meaningful work. + if n < 50_000: + n_threads = 1 + if n_threads <= 1 or n == 0: + return shapely.area(shapely.intersection(a, b)) + + splits = np.array_split(np.arange(n), n_threads) + + def _work(idx: np.ndarray) -> np.ndarray: + return shapely.area(shapely.intersection(a[idx], b[idx])) + + with ThreadPoolExecutor(max_workers=n_threads) as pool: + parts = list(pool.map(_work, splits)) + return np.concatenate(parts) + + +def _empty_weights(n_dst: int, n_src: int) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.zeros((2, 0), dtype=np.int64), + data=np.zeros(0, dtype=np.float64), + shape=(n_dst, n_src), + ) + return np.zeros((n_dst, n_src), dtype=np.float64) + + +def _apply_core( + arr: np.ndarray, + apply_weights: Any, + coverage: np.ndarray, + coverage_all: bool, + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + skipna: bool, + nan_threshold: float, + output_dtype: np.dtype, +) -> np.ndarray: + """Apply a pre-transposed weight matrix along the trailing spatial dims. + + ``arr`` has shape ``(..., *src_shape)``; ``apply_weights`` has shape + ``(n_src, n_dst)`` — returns ``(..., *dst_shape)``. + + ``coverage`` is a boolean ``(n_dst,)`` mask: target cells with any source + overlap. ``coverage_all`` is precomputed so every block skips the + ``coverage.all()`` scan. Uncovered cells are always masked to NaN in the + output, regardless of ``skipna`` — domain boundaries and polygon holes + produce NaN (matches the axis-factored ``conservative`` method). + """ + n_spatial = len(src_shape) + leading_shape = arr.shape[:-n_spatial] if n_spatial > 0 else arr.shape + n_src = int(np.prod(src_shape)) + flat = arr.reshape(-1, n_src) if leading_shape else arr.reshape(1, n_src) + + if skipna and np.issubdtype(flat.dtype, np.floating): + nan_mask = np.isnan(flat) + has_nan = nan_mask.any() + else: + has_nan = False + + if has_nan: + mask = (~nan_mask).astype(flat.dtype) + filled = np.where(nan_mask, flat.dtype.type(0.0), flat) + numerator = _matmul_last(filled, apply_weights) + fraction = _matmul_last(mask, apply_weights) + threshold = 1.0 - min(max(nan_threshold, 1e-6), 1.0 - 1e-6) + with np.errstate(invalid="ignore", divide="ignore"): + result = numerator / fraction + result = np.where(fraction >= threshold, result, np.nan) + else: + result = _matmul_last(flat, apply_weights) + if not coverage_all: + result = np.where(coverage[np.newaxis, :], result, np.nan) + + # sparse.matmul promotes to float64 regardless of the input dtype — cast + # back to the requested output dtype so float32-in really produces + # float32-out (halves memory for float32 pipelines). + if result.dtype != output_dtype: + result = result.astype(output_dtype, copy=False) + + out_shape = (*leading_shape, *dst_shape) if leading_shape else dst_shape + return result.reshape(out_shape) + + +def _matmul_last(flat: np.ndarray, apply_weights: Any) -> np.ndarray: + """Compute ``flat @ apply_weights`` where ``flat`` is dense ``(N, n_src)`` + and ``apply_weights`` is ``(n_src, n_dst)`` (dense or pre-sorted sparse).""" + if _HAS_SPARSE and isinstance(apply_weights, sparse.COO): + return np.asarray(sparse.matmul(flat, apply_weights)) + return flat @ apply_weights + + +def _result_dtype(obj: xr.DataArray | xr.Dataset) -> np.dtype: + if isinstance(obj, xr.DataArray): + return np.result_type(np.float32, obj.dtype) + dtypes = [v.dtype for v in obj.data_vars.values()] + if not dtypes: + return np.dtype(np.float64) + return np.result_type(np.float32, *dtypes) + + +def _assign_target_coords( + obj: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + dst_dims: tuple[Hashable, ...], + x_coord: str, + y_coord: str, +) -> xr.DataArray | xr.Dataset: + """Attach the target dataset's dim coords and auxiliary lat/lon coords.""" + new_coords: dict[Hashable, Any] = {} + for d in dst_dims: + if d in target_ds.coords: + new_coords[d] = target_ds[d] + for name in (x_coord, y_coord): + if name in target_ds.coords and name not in new_coords: + new_coords[name] = target_ds[name] + if new_coords: + obj = obj.assign_coords(new_coords) + return obj diff --git a/src/xarray_regrid/methods/conservative_polygon.py b/src/xarray_regrid/methods/conservative_polygon.py new file mode 100644 index 0000000..bffbecc --- /dev/null +++ b/src/xarray_regrid/methods/conservative_polygon.py @@ -0,0 +1,1163 @@ +"""Polygon-intersection conservative regridding (planar only). + +General conservative regridding that computes 2D cell-polygon intersections +rather than the axis-factored 1D overlap approach used by ``conservative``. +Slower and more memory-intensive than the factored path for rectilinear grids, +but handles curvilinear grids that the factored approach cannot represent. + +Requires ``shapely >= 2.0``. If ``sparse`` is available, the weight matrix is +stored as ``sparse.COO``; otherwise a dense numpy matrix is used. +""" + +import os +import warnings +from collections.abc import Hashable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from functools import cache +from pathlib import Path +from typing import Any + +import numpy as np +import xarray as xr + + +@cache +def _package_version() -> str: + """Cached ``importlib.metadata`` lookup. Cheap on warm call.""" + try: + from importlib.metadata import version + return version("xarray-regrid") + except Exception: + return "unknown" + + +# Bump on breaking change to the on-disk format in ConservativeRegridder.to_netcdf. +_SCHEMA_VERSION = 1 + + +@dataclass(frozen=True) +class RegridderMetadata: + """Serialized build parameters of a :class:`ConservativeRegridder`. + + Fields become netCDF attributes via :meth:`to_attrs`; the loader + reconstructs via :meth:`from_attrs`. Keep new fields optional so older + files remain loadable. + """ + + x_coord: str + y_coord: str + spherical: bool + src_dims: tuple[str, ...] + dst_dims: tuple[str, ...] + src_shape: tuple[int, ...] + dst_shape: tuple[int, ...] + source_x_range: tuple[float, float] | None = None + source_y_range: tuple[float, float] | None = None + target_x_range: tuple[float, float] | None = None + target_y_range: tuple[float, float] | None = None + xarray_regrid_version: str = field(default_factory=_package_version) + created: str = field( + default_factory=lambda: datetime.now(tz=timezone.utc).isoformat() + ) + schema_version: int = _SCHEMA_VERSION + + def to_attrs(self) -> dict[str, Any]: + d = asdict(self) + d["spherical"] = int(d["spherical"]) + d["src_dims"] = list(d["src_dims"]) + d["dst_dims"] = list(d["dst_dims"]) + d["src_shape"] = list(d["src_shape"]) + d["dst_shape"] = list(d["dst_shape"]) + for key in ( + "source_x_range", "source_y_range", + "target_x_range", "target_y_range", + ): + if d[key] is None: + del d[key] + else: + d[key] = list(d[key]) + return d + + @classmethod + def from_attrs(cls, attrs: dict[str, Any]) -> "RegridderMetadata": + """Parse back from netCDF attributes. Missing optional fields default.""" + def _range(key: str) -> tuple[float, float] | None: + v = attrs.get(key) + if v is None: + return None + arr = np.atleast_1d(v) + return (float(arr[0]), float(arr[1])) + + spherical = bool(int(attrs["spherical"])) if "spherical" in attrs else ( + str(attrs.get("manifold", "planar")) != "planar" + ) + + return cls( + x_coord=str(attrs["x_coord"]), + y_coord=str(attrs["y_coord"]), + spherical=spherical, + src_dims=tuple(str(d) for d in np.atleast_1d(attrs["src_dims"])), + dst_dims=tuple(str(d) for d in np.atleast_1d(attrs["dst_dims"])), + src_shape=tuple(int(s) for s in np.atleast_1d(attrs["src_shape"])), + dst_shape=tuple(int(s) for s in np.atleast_1d(attrs["dst_shape"])), + source_x_range=_range("source_x_range"), + source_y_range=_range("source_y_range"), + target_x_range=_range("target_x_range"), + target_y_range=_range("target_y_range"), + xarray_regrid_version=str(attrs.get("xarray_regrid_version", "unknown")), + created=str(attrs.get("created", "")), + # Default to 0 so files written before this attr existed fail the + # schema check rather than silently appearing current. + schema_version=int(attrs.get("schema_version", 0)), + ) + +try: + import shapely + from shapely.strtree import STRtree + + _HAS_SHAPELY = True +except ImportError: # pragma: no cover + shapely = None # type: ignore[assignment] + STRtree = None # type: ignore[assignment] + _HAS_SHAPELY = False + +try: + import sparse # type: ignore + + _HAS_SPARSE = True +except ImportError: # pragma: no cover + sparse = None # type: ignore[assignment] + _HAS_SPARSE = False + + +# We fill NaNs with 0 ourselves before matmul (see `_apply_core`), so sparse's +# "NaN will not be propagated" warning is spurious. A module-level filter +# avoids `warnings.catch_warnings` inside the hot matmul path, which is not +# thread-safe — dask threads would race on the global `warnings.filters` list. +warnings.filterwarnings( + "ignore", + message="Nan will not be propagated in matrix multiplication", + category=RuntimeWarning, +) + + +SHAPELY_IMPORT_ERROR = ( + "polygon conservative regridding requires shapely >= 2.0; " + "install with `pip install shapely`." +) + + +def _check_shapely() -> None: + if not _HAS_SHAPELY: + raise ImportError(SHAPELY_IMPORT_ERROR) + + +class ConservativeRegridder: + """Reusable planar-polygon conservative regridder. + + Build once from source and target grids; apply to many compatible fields + via :meth:`regrid` (or by calling the regridder). The raw intersection + area matrix is stored internally; the forward and backward row-normalized + weight matrices are lazily cached on first use. + + Handles rectilinear (1D coord arrays) and curvilinear (2D coord arrays) + grids; planar geometry only. Requires ``shapely >= 2.0``. + + Example: + >>> regridder = ConservativeRegridder(src_ds, tgt_ds, x_coord="lon", y_coord="lat") + >>> out = regridder.regrid(da) # forward + >>> back = regridder.T.regrid(out) # backward + """ + + def __init__( + self, + source: xr.DataArray | xr.Dataset, + target: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + spherical: bool = False, + n_threads: int | None = None, + ) -> None: + _check_shapely() + src_dims = _spatial_dims(source, x_coord, y_coord) + dst_dims = _spatial_dims(target, x_coord, y_coord) + if not src_dims: + msg = f"source has no dims for coords {x_coord!r}, {y_coord!r}" + raise ValueError(msg) + if not dst_dims: + msg = f"target has no dims for coords {x_coord!r}, {y_coord!r}" + raise ValueError(msg) + src_grid = _grid_from_coords( + source, x_coord, y_coord, src_dims, spherical=spherical + ) + dst_grid = _grid_from_coords( + target, x_coord, y_coord, dst_dims, spherical=spherical + ) + self.spherical = spherical + + self.x_coord = x_coord + self.y_coord = y_coord + self._src_dims = src_dims + self._dst_dims = dst_dims + self._src_shape = tuple(int(source.sizes[d]) for d in src_dims) + self._dst_shape = tuple(int(target.sizes[d]) for d in dst_dims) + self._areas = _build_intersection_areas( + src_grid, dst_grid, n_threads=n_threads + ) + self._source_coords = source.coords.to_dataset() + self._target_coords = target.coords.to_dataset() + # `_*_apply` caches the transposed, index-sorted weight matrix used for + # `data @ W` matmul — avoiding sparse's per-call `.T` + `_sort_indices`. + self._fwd_weights: "sparse.COO | np.ndarray | None" = None + self._bwd_weights: "sparse.COO | np.ndarray | None" = None + self._fwd_apply: "sparse.COO | np.ndarray | None" = None + self._bwd_apply: "sparse.COO | np.ndarray | None" = None + self._fwd_coverage: np.ndarray | None = None + self._bwd_coverage: np.ndarray | None = None + + @property + def forward_weights(self) -> "sparse.COO | np.ndarray": + """The row-normalized forward weight matrix (source → target).""" + if self._fwd_weights is None: + self._fwd_weights = _row_normalize(self._areas) + return self._fwd_weights + + @property + def backward_weights(self) -> "sparse.COO | np.ndarray": + """The row-normalized backward weight matrix (target → source).""" + if self._bwd_weights is None: + self._bwd_weights = _row_normalize(_transpose_weights(self._areas)) + return self._bwd_weights + + def _forward_apply_matrix(self) -> "sparse.COO | np.ndarray": + if self._fwd_apply is None: + self._fwd_apply = _transpose_weights(self.forward_weights, sort=True) + return self._fwd_apply + + def _backward_apply_matrix(self) -> "sparse.COO | np.ndarray": + if self._bwd_apply is None: + self._bwd_apply = _transpose_weights(self.backward_weights, sort=True) + return self._bwd_apply + + def _forward_coverage(self) -> np.ndarray: + """Boolean (n_dst,) mask: which destination cells have any source + overlap. Lazily computed and cached alongside the forward weights.""" + if self._fwd_coverage is None: + self._fwd_coverage = _coverage_mask(self._areas) + return self._fwd_coverage + + def _backward_coverage(self) -> np.ndarray: + if self._bwd_coverage is None: + self._bwd_coverage = _coverage_mask(_transpose_weights(self._areas)) + return self._bwd_coverage + + def regrid( + self, + data: xr.DataArray | xr.Dataset, + skipna: bool = True, + nan_threshold: float = 1.0, + ) -> xr.DataArray | xr.Dataset: + """Regrid ``data`` forward (source → target).""" + return _apply_stored_weights( + data, + apply_weights=self._forward_apply_matrix(), + coverage=self._forward_coverage(), + src_dims=self._src_dims, + dst_dims=self._dst_dims, + src_shape=self._src_shape, + dst_shape=self._dst_shape, + target_coords=self._target_coords, + x_coord=self.x_coord, + y_coord=self.y_coord, + skipna=skipna, + nan_threshold=nan_threshold, + ) + + def __call__( + self, + data: xr.DataArray | xr.Dataset, + skipna: bool = True, + nan_threshold: float = 1.0, + ) -> xr.DataArray | xr.Dataset: + return self.regrid(data, skipna=skipna, nan_threshold=nan_threshold) + + def transpose(self) -> "ConservativeRegridder": + """Return the backward regridder (target → source), sharing the + underlying area matrix and both cached weight matrices.""" + new = object.__new__(ConservativeRegridder) + new.x_coord = self.x_coord + new.y_coord = self.y_coord + new.spherical = self.spherical + new._src_dims = self._dst_dims + new._dst_dims = self._src_dims + new._src_shape = self._dst_shape + new._dst_shape = self._src_shape + new._areas = _transpose_weights(self._areas) + new._source_coords = self._target_coords + new._target_coords = self._source_coords + new._fwd_weights = self._bwd_weights + new._bwd_weights = self._fwd_weights + new._fwd_apply = self._bwd_apply + new._bwd_apply = self._fwd_apply + new._fwd_coverage = self._bwd_coverage + new._bwd_coverage = self._fwd_coverage + return new + + @property + def T(self) -> "ConservativeRegridder": + """Alias for :meth:`transpose`.""" + return self.transpose() + + def __repr__(self) -> str: + nnz = getattr(self._areas, "nnz", None) + shape = getattr(self._areas, "shape", (None, None)) + nnz_str = f"nnz={nnz}" if nnz is not None else "dense" + return ( + f"ConservativeRegridder(src_dims={self._src_dims}, " + f"dst_dims={self._dst_dims}, {shape[0]}x{shape[1]}, {nnz_str})" + ) + + def metadata(self) -> RegridderMetadata: + """Return the :class:`RegridderMetadata` that :meth:`to_netcdf` would + write — useful for inspection without touching disk.""" + return RegridderMetadata( + x_coord=self.x_coord, + y_coord=self.y_coord, + spherical=self.spherical, + src_dims=tuple(str(d) for d in self._src_dims), + dst_dims=tuple(str(d) for d in self._dst_dims), + src_shape=self._src_shape, + dst_shape=self._dst_shape, + source_x_range=_coord_range(self._source_coords, self.x_coord), + source_y_range=_coord_range(self._source_coords, self.y_coord), + target_x_range=_coord_range(self._target_coords, self.x_coord), + target_y_range=_coord_range(self._target_coords, self.y_coord), + ) + + def to_netcdf( + self, path: str | Path, engine: str | None = None + ) -> None: + """Save the weight matrix and reproducibility metadata to a netCDF file. + + File layout: + + - root dataset: the sparse area matrix stored as three 1D variables + ``_coo_row``, ``_coo_col``, ``_coo_data`` (``shape=(n_dst, n_src)`` + carried on root-dataset attributes), plus :class:`RegridderMetadata` + fields as attributes. + - ``/source_coords`` group: coord-only Dataset capturing the source grid. + - ``/target_coords`` group: coord-only Dataset capturing the target grid. + + Groups require an engine that supports them (``netcdf4`` or + ``h5netcdf``); ``engine`` is forwarded to :func:`xarray.Dataset.to_netcdf`. + """ + path = Path(path) + row, col, data, shape = _coo_components(self._areas) + ds_weights = xr.Dataset( + { + "_coo_row": (("nnz",), row), + "_coo_col": (("nnz",), col), + "_coo_data": (("nnz",), data), + }, + attrs={ + **self.metadata().to_attrs(), + "n_dst": int(shape[0]), + "n_src": int(shape[1]), + }, + ) + ds_weights.to_netcdf(path, mode="w", engine=engine) + self._source_coords.to_netcdf( + path, mode="a", group="source_coords", engine=engine + ) + self._target_coords.to_netcdf( + path, mode="a", group="target_coords", engine=engine + ) + + @classmethod + def from_netcdf( + cls, path: str | Path, engine: str | None = None + ) -> "ConservativeRegridder": + """Reload a regridder previously written with :meth:`to_netcdf`. + + Validates ``schema_version``; raises :class:`ValueError` if the file + was written by an incompatible version. + """ + path = Path(path) + with xr.open_dataset(path, engine=engine) as ds_weights: + attrs = dict(ds_weights.attrs) + n_dst = int(attrs.pop("n_dst")) + n_src = int(attrs.pop("n_src")) + row = np.asarray(ds_weights["_coo_row"].values) + col = np.asarray(ds_weights["_coo_col"].values) + data = np.asarray(ds_weights["_coo_data"].values) + meta = RegridderMetadata.from_attrs(attrs) + if meta.schema_version != _SCHEMA_VERSION: + msg = ( + f"regridder file at {path} uses schema version " + f"{meta.schema_version}; this xarray-regrid understands " + f"{_SCHEMA_VERSION}. Upgrade xarray-regrid or re-save." + ) + raise ValueError(msg) + + with xr.open_dataset(path, group="source_coords", engine=engine) as g: + source_coords = g.load() + with xr.open_dataset(path, group="target_coords", engine=engine) as g: + target_coords = g.load() + + return cls._from_state( + areas=_coo_from_components(row, col, data, (n_dst, n_src)), + source_coords=source_coords, + target_coords=target_coords, + src_dims=meta.src_dims, + dst_dims=meta.dst_dims, + src_shape=meta.src_shape, + dst_shape=meta.dst_shape, + x_coord=meta.x_coord, + y_coord=meta.y_coord, + spherical=meta.spherical, + ) + + @classmethod + def _from_state( + cls, + *, + areas: "sparse.COO | np.ndarray", + source_coords: xr.Dataset, + target_coords: xr.Dataset, + src_dims: tuple[Hashable, ...], + dst_dims: tuple[Hashable, ...], + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + x_coord: str, + y_coord: str, + spherical: bool, + ) -> "ConservativeRegridder": + """Construct a regridder directly from its canonical state. Shared + bypass of ``__init__`` used by :meth:`from_netcdf` and + :meth:`from_polygons`; keeps the list of private attrs in one place.""" + instance = object.__new__(cls) + instance.x_coord = x_coord + instance.y_coord = y_coord + instance.spherical = spherical + instance._src_dims = src_dims + instance._dst_dims = dst_dims + instance._src_shape = src_shape + instance._dst_shape = dst_shape + instance._areas = areas + instance._source_coords = source_coords + instance._target_coords = target_coords + instance._fwd_weights = None + instance._bwd_weights = None + instance._fwd_apply = None + instance._bwd_apply = None + instance._fwd_coverage = None + instance._bwd_coverage = None + return instance + + @classmethod + def from_polygons( + cls, + source_polygons: np.ndarray, + target_polygons: np.ndarray, + source_dim: str = "cell", + target_dim: str = "cell", + target_coords: xr.Dataset | None = None, + n_threads: int | None = None, + ) -> "ConservativeRegridder": + """Build a regridder from explicit shapely polygon arrays. + + Use this for unstructured meshes (MPAS, ICON, finite-element), arbitrary + polygon targets (countries, watersheds), or any combination the + structured-grid path cannot express. + + Args: + source_polygons: 1D array of shapely Polygons for source cells. + target_polygons: 1D array of shapely Polygons for target cells. + source_dim: Name of the single dim carrying source cells on the + data passed to :meth:`regrid`. Default ``"cell"``. + target_dim: Name of the single dim carrying target cells on the + output. Default ``"cell"``. + target_coords: Optional xr.Dataset providing coordinate variables + along ``target_dim`` (and any auxiliary coords) to reattach on + the output. If None, the output is given a bare integer index. + n_threads: Thread count for GEOS intersection. + + Returns: + A ``ConservativeRegridder`` that accepts data with ``source_dim`` + in place of the structured grid's spatial dims. + + Intersection geometry is planar in the input polygons' coordinate + space. If the polygons represent lat/lon cells, project them into an + equal-area CRS first (or use the structured path with ``spherical=True``). + Polygons that cross the antimeridian must be split beforehand. + """ + _check_shapely() + src_polys = np.asarray(source_polygons) + dst_polys = np.asarray(target_polygons) + if src_polys.ndim != 1: + msg = "source_polygons must be a 1D array of shapely Polygons" + raise ValueError(msg) + if dst_polys.ndim != 1: + msg = "target_polygons must be a 1D array of shapely Polygons" + raise ValueError(msg) + + src_grid = _Grid( + polys=src_polys, bounds=shapely.bounds(src_polys), rectilinear=False, + ) + dst_grid = _Grid( + polys=dst_polys, bounds=shapely.bounds(dst_polys), rectilinear=False, + ) + n_src = int(src_polys.size) + n_dst = int(dst_polys.size) + tgt_ds = ( + target_coords + if target_coords is not None + else xr.Dataset(coords={target_dim: np.arange(n_dst)}) + ) + return cls._from_state( + areas=_build_intersection_areas(src_grid, dst_grid, n_threads=n_threads), + source_coords=xr.Dataset(coords={source_dim: np.arange(n_src)}), + target_coords=tgt_ds, + src_dims=(source_dim,), + dst_dims=(target_dim,), + src_shape=(n_src,), + dst_shape=(n_dst,), + x_coord="", + y_coord="", + spherical=False, + ) + + +def polygons_from_coords( + x: np.ndarray, + y: np.ndarray, + spherical: bool = False, +) -> np.ndarray: + """Build a 1D array of shapely cell polygons from 1D or 2D center coords. + + Convenience for mixing structured and unstructured regridding. E.g. build + target polygons for a regular lat/lon grid, then pass them together with + an unstructured source mesh to :meth:`ConservativeRegridder.from_polygons`. + + Args: + x: 1D or 2D array of cell-center x coordinates. + y: 1D or 2D array of cell-center y coordinates. + spherical: If True, apply a cylindrical equal-area projection to 1D + lat/lon (degrees) before building rectangles — matches the + structured-grid ``spherical=True`` path. + + Returns: + A 1D numpy array of shapely Polygons in row-major (y, x) order. + """ + _check_shapely() + x = np.asarray(x) + y = np.asarray(y) + if spherical: + if x.ndim != 1 or y.ndim != 1: + msg = "spherical=True requires 1D lat/lon arrays" + raise ValueError(msg) + return _build_cea_grid(x, y).polys + return _build_grid(x, y).polys + + +def polygon_conservative_regrid( + data: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + spherical: bool = False, + skipna: bool = True, + nan_threshold: float = 1.0, + n_threads: int | None = None, +) -> xr.DataArray | xr.Dataset: + """Conservative regridding via explicit 2D polygon intersection. + + One-shot convenience wrapper: constructs a :class:`ConservativeRegridder` + and applies it once. For repeated regridding to the same target, construct + a ``ConservativeRegridder`` directly and reuse it. + + Args: + data: Input data on the source grid. + target_ds: Dataset defining the target grid; must expose ``x_coord`` and + ``y_coord`` as (1D or 2D) coordinate variables. + x_coord: Name of the x (longitude-like) coordinate variable, shared + between source and target. + y_coord: Name of the y (latitude-like) coordinate variable, shared + between source and target. + spherical: If True, assume ``x_coord``/``y_coord`` are longitude/latitude + in degrees and project cells into Lambert cylindrical equal-area + space before intersecting — gives correct spherical area weights + at the same cost as the planar fast path. Rectilinear (1D coord) + grids only. + skipna: If True, propagate NaNs into the weighted mean via a two-pass + sum (values and valid mask), matching the ``conservative`` method. + nan_threshold: Keep output cells whose valid source fraction is at + least ``nan_threshold``. + n_threads: Thread count for parallel GEOS intersection. + + Returns: + Regridded data on the target grid, preserving non-spatial dims. + """ + regridder = ConservativeRegridder( + data, target_ds, + x_coord=x_coord, y_coord=y_coord, + spherical=spherical, n_threads=n_threads, + ) + return regridder.regrid(data, skipna=skipna, nan_threshold=nan_threshold) + + +def _apply_stored_weights( + data: xr.DataArray | xr.Dataset, + apply_weights: "sparse.COO | np.ndarray", + coverage: np.ndarray, + src_dims: tuple[Hashable, ...], + dst_dims: tuple[Hashable, ...], + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + target_coords: xr.Dataset, + x_coord: str, + y_coord: str, + skipna: bool, + nan_threshold: float, +) -> xr.DataArray | xr.Dataset: + """Apply a cached, pre-transposed weight matrix to ``data`` via + ``xr.apply_ufunc``. + + ``apply_weights`` has shape ``(n_src, n_dst)`` so the matmul is + ``(..., n_src) @ (n_src, n_dst) → (..., n_dst)`` with no per-call transpose. + """ + # apply_ufunc(dask="parallelized") needs each core dim to be a single + # chunk. Only rechunk if data is already dask-backed — don't inadvertently + # dask-ify a numpy-backed input. + if getattr(data, "chunks", None) is not None: + split = { + d: -1 for d in src_dims + if d in data.dims and len(data.chunksizes.get(d, ())) > 1 + } + if split: + data = data.chunk(split) + + actual_src_shape = tuple(int(data.sizes[d]) for d in src_dims if d in data.sizes) + if actual_src_shape != src_shape: + msg = ( + f"source spatial shape {actual_src_shape} on dims {src_dims} does " + f"not match the regridder's expected shape {src_shape}" + ) + raise ValueError(msg) + + src_tokens = tuple(f"__src_{d}" for d in src_dims) + data_renamed = data.rename({s: t for s, t in zip(src_dims, src_tokens)}) + + output_dtype = _result_dtype(data) + result = xr.apply_ufunc( + _apply_core, + data_renamed, + kwargs={ + "apply_weights": apply_weights, + "coverage": coverage, + "coverage_all": bool(coverage.all()), + "src_shape": src_shape, + "dst_shape": dst_shape, + "skipna": skipna, + "nan_threshold": nan_threshold, + "output_dtype": output_dtype, + }, + input_core_dims=[list(src_tokens)], + output_core_dims=[list(dst_dims)], + exclude_dims=set(src_tokens), + dask="parallelized", + output_dtypes=[output_dtype], + dask_gufunc_kwargs={ + "output_sizes": {d: int(target_coords.sizes[d]) for d in dst_dims}, + "allow_rechunk": False, + }, + keep_attrs=True, + ) + + result = _assign_target_coords(result, target_coords, dst_dims, x_coord, y_coord) + return result + + +def _coverage_mask(areas: "sparse.COO | np.ndarray") -> np.ndarray: + """Return a boolean ``(n_dst,)`` mask of target cells with any source + overlap, derived from the raw area matrix.""" + if _HAS_SPARSE and isinstance(areas, sparse.COO): + # COO sparse: row_sum is 0 iff row has no nonzero entries. + n_dst = int(areas.shape[0]) + mask = np.zeros(n_dst, dtype=bool) + mask[areas.coords[0]] = True + return mask + arr = np.asarray(areas) + return (arr > 0).any(axis=1) + + +def _coo_components( + w: "sparse.COO | np.ndarray", +) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple[int, int]]: + if _HAS_SPARSE and isinstance(w, sparse.COO): + coords = np.asarray(w.coords) + return ( + coords[0].astype(np.int64, copy=False), + coords[1].astype(np.int64, copy=False), + np.asarray(w.data), + w.shape, + ) + arr = np.asarray(w) + rows, cols = np.nonzero(arr) + return rows.astype(np.int64, copy=False), cols.astype(np.int64, copy=False), arr[rows, cols], arr.shape + + +def _coo_from_components( + row: np.ndarray, + col: np.ndarray, + data: np.ndarray, + shape: tuple[int, int], +) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([row, col]), + data=data, + shape=shape, + has_duplicates=False, + sorted=False, + ) + dense = np.zeros(shape, dtype=data.dtype if data.size else np.float64) + dense[row, col] = data + return dense + + +def _coord_range( + ds: xr.Dataset, coord_name: str +) -> tuple[float, float] | None: + """Return ``(min, max)`` of a coord, or ``None`` when it isn't on the + Dataset (e.g., the integer-index stub emitted by ``from_polygons``).""" + if not coord_name or coord_name not in ds.coords: + return None + arr = np.asarray(ds[coord_name].values) + if arr.size == 0: + return None + return float(arr.min()), float(arr.max()) + + +def _transpose_weights( + w: "sparse.COO | np.ndarray", *, sort: bool = False +) -> "sparse.COO | np.ndarray": + """Materialize a transposed weight matrix. + + `sparse.COO.T` is a lazy view that re-sorts indices on each downstream + matmul, so callers relying on ``.coords[0]`` being row indices or wanting + a hot matmul path should materialize once here. Pass ``sort=True`` to + additionally trigger the sort ahead of time (used for the apply matrix). + """ + if _HAS_SPARSE and isinstance(w, sparse.COO): + t = w.T + out = sparse.COO( + coords=np.asarray(t.coords), + data=np.asarray(t.data), + shape=t.shape, + has_duplicates=False, + sorted=False, + ) + if sort: + out._sort_indices() + return out + return np.asarray(w).T.copy() + + +def _spatial_dims( + obj: xr.DataArray | xr.Dataset, x_coord: str, y_coord: str +) -> tuple[Hashable, ...]: + """Return the spatial dim order to feed to apply_ufunc. + + For 1D rectilinear coords (x and y ride on separate dims), canonicalize to + ``(y_dim, x_dim)`` so the flattened cell order matches the polygon order + emitted by the fast-path in ``_build_grid``. For curvilinear (2D) coords, + preserve the dim order already present on ``obj``. + """ + if x_coord not in obj.coords or y_coord not in obj.coords: + return () + xd = obj[x_coord].dims + yd = obj[y_coord].dims + if len(xd) == 1 and len(yd) == 1 and xd[0] != yd[0]: + return (yd[0], xd[0]) + dims = set(xd) | set(yd) + return tuple(d for d in obj.dims if d in dims) + + +def _grid_from_coords( + obj: xr.DataArray | xr.Dataset, + x_coord: str, + y_coord: str, + dims: tuple[Hashable, ...], + spherical: bool = False, +) -> "_Grid": + """Build a :class:`_Grid` from the object's x/y coordinates. + + If both coords are 1D and ride on separate dims, pass them as separate 1D + vectors to trigger the rectilinear fast path. Otherwise broadcast to a + common N-D array in ``dims`` order for the curvilinear path. + + If ``spherical`` is True, coordinates are assumed to be longitude (x) and + latitude (y) in degrees, and cells are projected into a Lambert cylindrical + equal-area space (x' = lon_rad, y' = sin(lat_rad)) before constructing the + cell polygons. This gives mass-conservative weights on the sphere at the + same cost as the planar fast path. Only supported for rectilinear (1D + coord) grids in this version. + """ + xd = obj[x_coord] + yd = obj[y_coord] + is_rectilinear = ( + xd.ndim == 1 and yd.ndim == 1 and xd.dims[0] != yd.dims[0] + ) + if spherical: + if not is_rectilinear: + msg = ( + "spherical=True is only supported for rectilinear (1D lat/lon) " + "coordinate arrays in this version." + ) + raise NotImplementedError(msg) + return _build_cea_grid( + np.asarray(xd.values), np.asarray(yd.values) + ) + if is_rectilinear: + return _build_grid(np.asarray(xd.values), np.asarray(yd.values)) + xc, yc = xr.broadcast(obj[x_coord], obj[y_coord]) + xc = xc.transpose(*dims) + yc = yc.transpose(*dims) + return _build_grid(np.asarray(xc.values), np.asarray(yc.values)) + + +def _build_cea_grid(lon_centers: np.ndarray, lat_centers: np.ndarray) -> "_Grid": + """Build a rectilinear :class:`_Grid` whose cell polygons are in Lambert + cylindrical equal-area coordinates (x' = lon_rad, y' = sin(lat_rad)). + + Projecting *edges* analytically — rather than projecting centers and then + re-midpointing — is required because ``sin()`` is nonlinear: the projected + midpoint of two lat centers is not the same as the midpoint of two + projected lat edges. + """ + _check_shapely() + if lon_centers.size < 2 or lat_centers.size < 2: + msg = "spherical mode requires at least two cells per dimension" + raise ValueError(msg) + lat_edges_deg = np.clip(_infer_1d_edges(lat_centers), -90.0, 90.0) + lon_edges_deg = _infer_1d_edges(lon_centers) + return _rect_grid_from_edges( + np.deg2rad(lon_edges_deg), + np.sin(np.deg2rad(lat_edges_deg)), + ) + + +def _infer_1d_edges(centers: np.ndarray) -> np.ndarray: + """Return cell edges from 1D centers: midpoints between consecutive + centers, with symmetric reflection for the two outer bounds.""" + c = np.asarray(centers, dtype=float) + if c.size < 2: + msg = "need at least two centers to infer cell edges" + raise ValueError(msg) + mids = 0.5 * (c[:-1] + c[1:]) + left = 2 * c[0] - mids[0] + right = 2 * c[-1] - mids[-1] + return np.concatenate([[left], mids, [right]]) + + +def _infer_2d_corners( + xc: np.ndarray, yc: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Return (ny+1, nx+1) cell-corner arrays from 2D cell-center arrays. + + Interior corners are the mean of the four surrounding centers; boundary + corners are reflected from the adjacent interior row/column. + """ + assert xc.shape == yc.shape and xc.ndim == 2 + + def corners(a: np.ndarray) -> np.ndarray: + ny, nx = a.shape + pad = np.empty((ny + 2, nx + 2), dtype=a.dtype) + pad[1:-1, 1:-1] = a + pad[0, 1:-1] = 2 * a[0, :] - a[1, :] + pad[-1, 1:-1] = 2 * a[-1, :] - a[-2, :] + pad[1:-1, 0] = 2 * a[:, 0] - a[:, 1] + pad[1:-1, -1] = 2 * a[:, -1] - a[:, -2] + pad[0, 0] = 2 * pad[0, 1] - pad[0, 2] + pad[0, -1] = 2 * pad[0, -2] - pad[0, -3] + pad[-1, 0] = 2 * pad[-1, 1] - pad[-1, 2] + pad[-1, -1] = 2 * pad[-1, -2] - pad[-1, -3] + return 0.25 * (pad[:-1, :-1] + pad[1:, :-1] + pad[:-1, 1:] + pad[1:, 1:]) + + return corners(xc.astype(float)), corners(yc.astype(float)) + + +@dataclass +class _Grid: + """Cached cell geometry for a structured grid. + + ``polys`` is a flat (n_cells,) object array of shapely Polygons. + ``bounds`` is a (n_cells, 4) ``(minx, miny, maxx, maxy)`` array cached for + the STRtree / candidate-search path. ``rectilinear`` is True when both the + source x and y were 1D coordinate arrays (axis-aligned rectangles) — the + weight builder uses this to skip GEOS polygon clipping and compute + intersection areas analytically from the bounds. + """ + + polys: np.ndarray + bounds: np.ndarray + rectilinear: bool + + +def _rect_grid_from_edges(xe: np.ndarray, ye: np.ndarray) -> _Grid: + """Build a rectilinear :class:`_Grid` from already-prepared edge arrays. + + Shared by the raw-planar (:func:`_build_grid` 1D branch) and the analytic + equal-area (:func:`_build_cea_grid`) paths. + """ + x0, y0 = np.meshgrid(xe[:-1], ye[:-1], indexing="xy") + x1, y1 = np.meshgrid(xe[1:], ye[1:], indexing="xy") + x0f, y0f, x1f, y1f = x0.ravel(), y0.ravel(), x1.ravel(), y1.ravel() + polys = shapely.box(x0f, y0f, x1f, y1f) + bounds = np.stack([x0f, y0f, x1f, y1f], axis=1) + return _Grid(polys=polys, bounds=bounds, rectilinear=True) + + +def _build_grid(xc: np.ndarray, yc: np.ndarray) -> _Grid: + """Return a _Grid of cell geometry for a structured grid. + + Accepts 1D (rectilinear, separate x and y vectors) or 2D (curvilinear, + co-shaped center arrays) inputs. Output order is row-major in the input dim + order: for 2D inputs of shape (ny, nx) the polygons correspond to cells + reshaped as ``(ny, nx)``. + """ + _check_shapely() + if xc.ndim == 1 and yc.ndim == 1: + xe = _infer_1d_edges(xc.astype(float)) + ye = _infer_1d_edges(yc.astype(float)) + return _rect_grid_from_edges(xe, ye) + if xc.ndim == 2 and yc.ndim == 2: + xcorn, ycorn = _infer_2d_corners(xc, yc) + ny, nx = xc.shape + c00 = np.stack([xcorn[:-1, :-1], ycorn[:-1, :-1]], axis=-1) + c10 = np.stack([xcorn[:-1, 1:], ycorn[:-1, 1:]], axis=-1) + c11 = np.stack([xcorn[1:, 1:], ycorn[1:, 1:]], axis=-1) + c01 = np.stack([xcorn[1:, :-1], ycorn[1:, :-1]], axis=-1) + rings = np.stack([c00, c10, c11, c01, c00], axis=2).reshape(ny * nx, 5, 2) + polys = shapely.polygons(rings) + return _Grid(polys=polys, bounds=shapely.bounds(polys), rectilinear=False) + msg = "x and y coordinate arrays must both be 1D or both 2D" + raise ValueError(msg) + + +def _build_intersection_areas( + src: _Grid, dst: _Grid, n_threads: int | None = None +) -> "sparse.COO | np.ndarray": + """Build the (n_dst, n_src) raw area-intersection matrix ``A[i, j] = + area(dst_i ∩ src_j)``. + + This is the unnormalized matrix. Row-normalize via :func:`_row_normalize` + to get forward weights; transpose first for backward (target → source). + + When both grids are rectilinear (axis-aligned rectangles) intersection + areas are computed analytically from the bounds, skipping GEOS clipping. + """ + _check_shapely() + n_dst = int(len(dst.polys)) + n_src = int(len(src.polys)) + + tree = STRtree(src.polys) + pairs = tree.query(dst.polys, predicate="intersects") + dst_idx = np.asarray(pairs[0]) + src_idx = np.asarray(pairs[1]) + + if dst_idx.size == 0: + return _empty_weights(n_dst, n_src) + + if src.rectilinear and dst.rectilinear: + sb = src.bounds[src_idx] + db = dst.bounds[dst_idx] + dx = np.minimum(sb[:, 2], db[:, 2]) - np.maximum(sb[:, 0], db[:, 0]) + dy = np.minimum(sb[:, 3], db[:, 3]) - np.maximum(sb[:, 1], db[:, 1]) + areas = np.maximum(dx, 0.0) * np.maximum(dy, 0.0) + else: + areas = _intersection_areas_threaded( + dst.polys[dst_idx], src.polys[src_idx], n_threads=n_threads + ) + + keep = areas > 0 + dst_idx = dst_idx[keep] + src_idx = src_idx[keep] + areas = areas[keep] + + if dst_idx.size == 0: + return _empty_weights(n_dst, n_src) + + if _HAS_SPARSE: + return sparse.COO( + coords=np.stack([dst_idx, src_idx]), + data=areas.astype(np.float64), + shape=(n_dst, n_src), + has_duplicates=False, + sorted=False, + ) + a_dense = np.zeros((n_dst, n_src), dtype=np.float64) + a_dense[dst_idx, src_idx] = areas + return a_dense + + +def _row_normalize( + areas: "sparse.COO | np.ndarray", +) -> "sparse.COO | np.ndarray": + """Normalize rows of an area matrix so each row sums to 1 (rows with no + overlap stay all-zero, which produces NaN output under the apply path).""" + if _HAS_SPARSE and isinstance(areas, sparse.COO): + n_dst = areas.shape[0] + dst_idx = areas.coords[0] + src_idx = areas.coords[1] + data = areas.data + if data.size == 0: + return areas + row_sum = np.bincount(dst_idx, weights=data, minlength=n_dst) + new_data = data / row_sum[dst_idx] + return sparse.COO( + coords=np.stack([dst_idx, src_idx]), + data=new_data, + shape=areas.shape, + has_duplicates=False, + sorted=False, + ) + row_sum = areas.sum(axis=1, keepdims=True) + row_sum = np.where(row_sum == 0, 1.0, row_sum) + return areas / row_sum + + +def _intersection_areas_threaded( + a: np.ndarray, b: np.ndarray, n_threads: int | None +) -> np.ndarray: + """Compute per-pair intersection areas ``area(a[i] & b[i])`` over numpy + arrays of shapely geometries, optionally parallelized via threads. + + Shapely 2.x releases the GIL for GEOS ops, so a ``ThreadPoolExecutor`` + gives near-linear speedup on multi-core machines without pickling data. + """ + _check_shapely() + n = len(a) + if n_threads is None: + n_threads = min(os.cpu_count() or 1, 8) + # Amortize thread-pool overhead only when there's meaningful work. + if n < 50_000: + n_threads = 1 + if n_threads <= 1 or n == 0: + return shapely.area(shapely.intersection(a, b)) + + splits = np.array_split(np.arange(n), n_threads) + + def _work(idx: np.ndarray) -> np.ndarray: + return shapely.area(shapely.intersection(a[idx], b[idx])) + + with ThreadPoolExecutor(max_workers=n_threads) as pool: + parts = list(pool.map(_work, splits)) + return np.concatenate(parts) + + +def _empty_weights(n_dst: int, n_src: int) -> "sparse.COO | np.ndarray": + if _HAS_SPARSE: + return sparse.COO( + coords=np.zeros((2, 0), dtype=np.int64), + data=np.zeros(0, dtype=np.float64), + shape=(n_dst, n_src), + ) + return np.zeros((n_dst, n_src), dtype=np.float64) + + +def _apply_core( + arr: np.ndarray, + apply_weights: Any, + coverage: np.ndarray, + coverage_all: bool, + src_shape: tuple[int, ...], + dst_shape: tuple[int, ...], + skipna: bool, + nan_threshold: float, + output_dtype: np.dtype, +) -> np.ndarray: + """Apply a pre-transposed weight matrix along the trailing spatial dims. + + ``arr`` has shape ``(..., *src_shape)``; ``apply_weights`` has shape + ``(n_src, n_dst)`` — returns ``(..., *dst_shape)``. + + ``coverage`` is a boolean ``(n_dst,)`` mask: target cells with any source + overlap. ``coverage_all`` is precomputed so every block skips the + ``coverage.all()`` scan. Uncovered cells are always masked to NaN in the + output, regardless of ``skipna`` — domain boundaries and polygon holes + produce NaN (matches the axis-factored ``conservative`` method). + """ + n_spatial = len(src_shape) + leading_shape = arr.shape[:-n_spatial] if n_spatial > 0 else arr.shape + n_src = int(np.prod(src_shape)) + flat = arr.reshape(-1, n_src) if leading_shape else arr.reshape(1, n_src) + + if skipna and np.issubdtype(flat.dtype, np.floating): + nan_mask = np.isnan(flat) + has_nan = nan_mask.any() + else: + has_nan = False + + if has_nan: + mask = (~nan_mask).astype(flat.dtype) + filled = np.where(nan_mask, flat.dtype.type(0.0), flat) + numerator = _matmul_last(filled, apply_weights) + fraction = _matmul_last(mask, apply_weights) + threshold = 1.0 - min(max(nan_threshold, 1e-6), 1.0 - 1e-6) + with np.errstate(invalid="ignore", divide="ignore"): + result = numerator / fraction + result = np.where(fraction >= threshold, result, np.nan) + else: + result = _matmul_last(flat, apply_weights) + if not coverage_all: + result = np.where(coverage[np.newaxis, :], result, np.nan) + + # sparse.matmul promotes to float64 regardless of the input dtype — cast + # back to the requested output dtype so float32-in really produces + # float32-out (halves memory for float32 pipelines). + if result.dtype != output_dtype: + result = result.astype(output_dtype, copy=False) + + out_shape = (*leading_shape, *dst_shape) if leading_shape else dst_shape + return result.reshape(out_shape) + + +def _matmul_last(flat: np.ndarray, apply_weights: Any) -> np.ndarray: + """Compute ``flat @ apply_weights`` where ``flat`` is dense ``(N, n_src)`` + and ``apply_weights`` is ``(n_src, n_dst)`` (dense or pre-sorted sparse).""" + if _HAS_SPARSE and isinstance(apply_weights, sparse.COO): + return np.asarray(sparse.matmul(flat, apply_weights)) + return flat @ apply_weights + + +def _result_dtype(obj: xr.DataArray | xr.Dataset) -> np.dtype: + if isinstance(obj, xr.DataArray): + return np.result_type(np.float32, obj.dtype) + dtypes = [v.dtype for v in obj.data_vars.values()] + if not dtypes: + return np.dtype(np.float64) + return np.result_type(np.float32, *dtypes) + + +def _assign_target_coords( + obj: xr.DataArray | xr.Dataset, + target_ds: xr.Dataset, + dst_dims: tuple[Hashable, ...], + x_coord: str, + y_coord: str, +) -> xr.DataArray | xr.Dataset: + """Attach the target dataset's dim coords and auxiliary lat/lon coords.""" + new_coords: dict[Hashable, Any] = {} + for d in dst_dims: + if d in target_ds.coords: + new_coords[d] = target_ds[d] + for name in (x_coord, y_coord): + if name in target_ds.coords and name not in new_coords: + new_coords[name] = target_ds[name] + if new_coords: + obj = obj.assign_coords(new_coords) + return obj diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index b2ed389..70ab5fa 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -4,7 +4,13 @@ import numpy as np import xarray as xr -from xarray_regrid.methods import conservative, flox_reduce, interp +from xarray_regrid.methods import ( + conservative, + conservative_2d, + flox_reduce, + interp, +) +from xarray_regrid.methods.conservative_2d import Manifold from xarray_regrid.utils import format_for_regrid @@ -17,7 +23,11 @@ class Regridder: linear: linear, bilinear, or higher dimensional linear interpolation nearest: nearest-neighbor regridding cubic: cubic spline regridding - conservative: conservative regridding + conservative: axis-factored conservative regridding (rectilinear, + 1D-separable grids only) + conservative_2d: conservative regridding for grids that aren't + 1D-separable — curvilinear 2D coords, unstructured meshes, or + arbitrary polygon-to-polygon aggregation (requires shapely) most_common: most common value regridder stat: area statistics regridder """ @@ -82,6 +92,93 @@ def cubic( ds_formatted = format_for_regrid(self._obj, ds_target_grid) return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic") + def conservative_2d( + self, + ds_target_grid: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + manifold: Manifold = "planar", + spherical: bool | None = None, + time_dim: str | None = "time", + skipna: bool = True, + nan_threshold: float = 1.0, + n_threads: int | None = None, + ) -> xr.DataArray | xr.Dataset: + """Conservative regrid for grids that aren't 1D-separable. + + Use this when ``.conservative`` can't express your grid: curvilinear + coordinates (2D ``lat``/``lon`` arrays), unstructured meshes, or any + arbitrary polygon target. Computes 2D cell-polygon intersections via + shapely. The ``manifold`` kwarg controls the intersection geometry: + ``"planar"`` (default) does raw planar shapely; ``"cea"`` projects + 1D lat/lon to a cylindrical equal-area space for correct spherical + areas at the planar fast-path's cost; ``"s2"`` uses the optional + ``spherely`` dependency for true great-circle edges. Requires + ``shapely >= 2.0``. + + Args: + ds_target_grid: Dataset defining the target grid; must expose + ``x_coord`` and ``y_coord`` as coordinate variables. + x_coord: Name of the x (longitude-like) coordinate variable. + y_coord: Name of the y (latitude-like) coordinate variable. + manifold: ``"planar"``, ``"cea"``, or ``"s2"``. See above. + spherical: Backward-compatible alias for manifold selection. + ``True`` maps to ``"cea"`` and ``False`` maps to ``"planar"``. + time_dim: Name of the time dimension. Defaults to ``"time"``. Use + ``None`` to force regridding over the time dimension. + skipna: If True, propagate NaNs into the weighted mean. + nan_threshold: Keep output cells whose valid source fraction is at + least ``nan_threshold``. + n_threads: Thread count for parallel GEOS intersection (planar + path only). + + Returns: + Data regridded to the target grid. + """ + if not 0.0 <= nan_threshold <= 1.0: + msg = "nan_threshold must be between [0, 1]" + raise ValueError(msg) + if spherical is not None: + manifold = "cea" if spherical else "planar" + + # Skip validate_input's dim-match check: curvilinear targets are + # matched by coord values, not by dim name. + if time_dim is not None and time_dim in ds_target_grid.coords: + ds_target_grid = ds_target_grid.isel({time_dim: 0}).reset_coords() + return conservative_2d.conservative_2d_regrid( + self._obj, + ds_target_grid, + x_coord=x_coord, + y_coord=y_coord, + manifold=manifold, + skipna=skipna, + nan_threshold=nan_threshold, + n_threads=n_threads, + ) + + def conservative_polygon( + self, + ds_target_grid: xr.Dataset, + x_coord: str = "longitude", + y_coord: str = "latitude", + spherical: bool = False, + time_dim: str | None = "time", + skipna: bool = True, + nan_threshold: float = 1.0, + n_threads: int | None = None, + ) -> xr.DataArray | xr.Dataset: + """Backward-compatible alias for :meth:`conservative_2d`.""" + return self.conservative_2d( + ds_target_grid, + x_coord=x_coord, + y_coord=y_coord, + manifold="cea" if spherical else "planar", + time_dim=time_dim, + skipna=skipna, + nan_threshold=nan_threshold, + n_threads=n_threads, + ) + def conservative( self, ds_target_grid: xr.Dataset, diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index 6979f84..7253cca 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -80,7 +80,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]: grid.south, grid.north + grid.resolution_lat, grid.resolution_lat ) - if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0: + if np.remainder((grid.east - grid.west), grid.resolution_lon) > 0: lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon) else: lon_coords = np.arange( @@ -114,6 +114,22 @@ def create_regridding_dataset( ) +def infer_1d_edges(centers: np.ndarray) -> np.ndarray: + """Return cell edges from 1D centers: midpoints between consecutive + centers, with symmetric reflection for the two outer bounds. + + Requires at least two centers. + """ + c = np.asarray(centers, dtype=float) + if c.size < 2: + msg = "need at least two centers to infer cell edges" + raise ValueError(msg) + mids = 0.5 * (c[:-1] + c[1:]) + left = 2 * c[0] - mids[0] + right = 2 * c[-1] - mids[-1] + return np.concatenate([[left], mids, [right]]) + + def to_intervalindex(coords: np.ndarray) -> pd.IntervalIndex: """Convert a 1-d coordinate array to a pandas IntervalIndex. Take the midpoints between the coordinates as the interval boundaries. @@ -126,20 +142,9 @@ def to_intervalindex(coords: np.ndarray) -> pd.IntervalIndex: coordinates. """ if len(coords) > 1: - midpoints = (coords[:-1] + coords[1:]) / 2 - - # Extrapolate outer bounds beyond the first and last coordinates - left_bound = 2 * coords[0] - midpoints[0] - right_bound = 2 * coords[-1] - midpoints[-1] - - breaks = np.concatenate([[left_bound], midpoints, [right_bound]]) - intervals = pd.IntervalIndex.from_breaks(breaks) - - else: - # If the target grid has a single point, set search interval to span all space - intervals = pd.IntervalIndex.from_breaks([-np.inf, np.inf]) - - return intervals + return pd.IntervalIndex.from_breaks(infer_1d_edges(coords)) + # If the target grid has a single point, set search interval to span all space + return pd.IntervalIndex.from_breaks([-np.inf, np.inf]) def overlap(a: pd.IntervalIndex, b: pd.IntervalIndex) -> np.ndarray: @@ -436,8 +441,8 @@ def update_coord( def update_coord( obj: xr.DataArray | xr.Dataset, coord: Hashable, coord_vals: np.ndarray ) -> xr.DataArray | xr.Dataset: - """Update the values of a coordinate, ensuring indexes stay in sync.""" - attrs = obj.coords[coord].attrs - obj = obj.assign_coords({coord: coord_vals}) - obj.coords[coord].attrs = attrs - return obj + """Update the values of a coordinate, ensuring indexes stay in sync. + Preserves the coord's existing dims and attrs (so multi-dim coords work).""" + original = obj.coords[coord] + new_coord = xr.DataArray(coord_vals, dims=original.dims, attrs=original.attrs) + return obj.assign_coords({coord: new_coord}) diff --git a/tests/test_conservative_2d.py b/tests/test_conservative_2d.py new file mode 100644 index 0000000..fd0603b --- /dev/null +++ b/tests/test_conservative_2d.py @@ -0,0 +1,796 @@ +"""Tests for conservative_2d.""" + +import numpy as np +import pytest +import xarray as xr + +import xarray_regrid # noqa: F401 (registers the accessor) +from xarray_regrid import ConservativeRegridder, polygons_from_coords + +shapely = pytest.importorskip("shapely") + + +def _rect_da(ny=60, nx=120, nt=2, seed=0): + x = np.linspace(-180, 180, nx, endpoint=False) + 180 / nx + y = np.linspace(-90, 90, ny, endpoint=False) + 90 / ny + rng = np.random.default_rng(seed) + return xr.DataArray( + rng.normal(size=(nt, ny, nx)).astype(np.float64), + dims=("time", "y", "x"), + coords={"time": np.arange(nt), "y": y, "x": x}, + name="var", + ) + + +def _rect_target(ny=24, nx=47): + x = np.linspace(-180, 180, nx, endpoint=False) + 180 / nx + y = np.linspace(-90, 90, ny, endpoint=False) + 90 / ny + return xr.Dataset(coords={"y": y, "x": x}) + + +def test_polygon_matches_factored_planar(): + """On a rectilinear grid with no spherical correction, the polygon path + should reproduce the axis-factored path to machine precision.""" + da = _rect_da() + target = _rect_target() + ref = da.regrid.conservative(target, latitude_coord=None) + got = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + got = got.transpose(*ref.dims) + np.testing.assert_allclose(got.values, ref.values, atol=1e-12) + + +def test_polygon_dask_time_chunks(): + da = _rect_da(nt=4).chunk({"time": 2}) + target = _rect_target() + got = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + assert got.chunks is not None + got = got.compute() + ref = _rect_da(nt=4).regrid.conservative(target, latitude_coord=None) + np.testing.assert_allclose(got.transpose(*ref.dims).values, ref.values, atol=1e-12) + + +def test_polygon_rechunks_spatial(): + """Spatially-chunked input should be accepted (rechunked internally).""" + da = _rect_da().chunk({"time": 1, "y": 30, "x": 40}) + target = _rect_target() + out = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + out.compute() + + +def test_polygon_nan_threshold(): + """Stricter nan_threshold produces more NaN output cells when partial + overlaps exist.""" + da = _rect_da() + da.values[:, 21:29, :] = np.nan + target = _rect_target(ny=23, nx=47) + out1 = da.regrid.conservative_2d( + target, x_coord="x", y_coord="y", skipna=True, nan_threshold=1.0 + ) + out0 = da.regrid.conservative_2d( + target, x_coord="x", y_coord="y", skipna=True, nan_threshold=0.0 + ) + assert int(np.isnan(out0.values).sum()) > int(np.isnan(out1.values).sum()) + + +def test_polygon_curvilinear_target(): + """Curvilinear target (2D lat/lon corners) returns finite values.""" + da = _rect_da() + ny_t, nx_t = 20, 30 + xi, yi = np.meshgrid( + np.linspace(-120, 120, nx_t), + np.linspace(-60, 60, ny_t), + indexing="xy", + ) + th = np.deg2rad(30) + x2d = xi * np.cos(th) - yi * np.sin(th) + y2d = xi * np.sin(th) + yi * np.cos(th) + target = xr.Dataset(coords={"x": (("ny", "nx"), x2d), "y": (("ny", "nx"), y2d)}) + out = da.regrid.conservative_2d(target, x_coord="x", y_coord="y") + assert out.shape == (2, 20, 30) + assert np.isfinite(out.values).mean() > 0.9 + + +def test_antimeridian_rectilinear_constant(): + da = xr.DataArray( + np.full((2, 4), 2.5), + dims=("latitude", "longitude"), + coords={ + "latitude": np.array([-2.5, 2.5]), + "longitude": np.array([167.5, 172.5, -177.5, -172.5]), + }, + ) + target = xr.Dataset( + coords={ + "latitude": np.array([-2.5, 2.5]), + "longitude": np.array([170.0, -170.0]), + } + ) + + out = da.regrid.conservative_2d(target, x_coord="longitude", y_coord="latitude") + np.testing.assert_allclose(out.values, 2.5, atol=1e-12) + + +def test_cross_convention_longitude_alignment(): + """Source on [0, 360] with target on [-180, 180] (and vice versa) must + align — a uniform shift can't reconcile the two conventions, so per-value + wrap of source longitudes is required. Regression: previously yielded + NaN on half the target cells because banker's rounding on the exact-180° + mean diff produced a zero offset.""" + src_vals = np.array([1.0, 2.0, 3.0, 4.0]) + tgt_vals_neg = np.array([-135.0, -45.0, 45.0, 135.0]) + src_vals_neg_x = np.array([45.0, 135.0, 225.0, 315.0]) + da = xr.DataArray( + np.broadcast_to(src_vals, (2, 4)).copy(), + dims=("latitude", "longitude"), + coords={"latitude": [-30.0, 30.0], "longitude": src_vals_neg_x}, + ) + target = xr.Dataset(coords={"latitude": [-30.0, 30.0], "longitude": tgt_vals_neg}) + expected = da.regrid.conservative(target).transpose("latitude", "longitude") + out_planar = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude" + ).transpose("latitude", "longitude") + out_spherical = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", spherical=True + ).transpose("latitude", "longitude") + np.testing.assert_allclose(out_planar.values, expected.values, atol=1e-12) + np.testing.assert_allclose(out_spherical.values, expected.values, atol=1e-12) + + # Reverse: source on [-180, 180], target on [0, 360]. + da_rev = xr.DataArray( + np.broadcast_to(src_vals, (2, 4)).copy(), + dims=("latitude", "longitude"), + coords={"latitude": [-30.0, 30.0], "longitude": tgt_vals_neg}, + ) + target_rev = xr.Dataset( + coords={"latitude": [-30.0, 30.0], "longitude": src_vals_neg_x} + ) + expected_rev = da_rev.regrid.conservative(target_rev).transpose( + "latitude", "longitude" + ) + out_rev = da_rev.regrid.conservative_2d( + target_rev, x_coord="longitude", y_coord="latitude" + ).transpose("latitude", "longitude") + np.testing.assert_allclose(out_rev.values, expected_rev.values, atol=1e-12) + + +def test_polygon_nan_threshold_invalid(): + da = _rect_da() + with pytest.raises(ValueError): + da.regrid.conservative_2d( + _rect_target(), x_coord="x", y_coord="y", nan_threshold=1.5 + ) + + +def test_polygon_dataset_input(): + """A Dataset input with multiple variables should regrid all of them.""" + da = _rect_da() + ds = xr.Dataset({"a": da, "b": da * 2.0}) + target = _rect_target() + out = ds.regrid.conservative_2d(target, x_coord="x", y_coord="y") + assert set(out.data_vars) == {"a", "b"} + np.testing.assert_allclose( + out["b"].transpose(*out["a"].dims).values, + out["a"].transpose(*out["a"].dims).values * 2.0, + atol=1e-12, + ) + + +# --- ConservativeRegridder (reusable) ------------------------------------------ + + +def test_regridder_reusable_matches_oneshot(): + """Reusing a single ConservativeRegridder on multiple fields matches the + one-shot `conservative_2d_regrid` call.""" + da1 = _rect_da(seed=1) + da2 = _rect_da(seed=2) + target = _rect_target() + regridder = ConservativeRegridder(da1, target, x_coord="x", y_coord="y") + ref1 = da1.regrid.conservative_2d(target, x_coord="x", y_coord="y") + ref2 = da2.regrid.conservative_2d(target, x_coord="x", y_coord="y") + out1 = regridder.regrid(da1) + out2 = regridder(da2) # __call__ alias + np.testing.assert_allclose(out1.values, ref1.values, atol=1e-12) + np.testing.assert_allclose(out2.values, ref2.values, atol=1e-12) + + +def test_regridder_weight_cache(): + """Forward weight matrix is built lazily and then reused across calls.""" + da = _rect_da() + target = _rect_target() + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + assert "forward_weights" not in regridder.__dict__ + regridder.regrid(da) + w1 = regridder.forward_weights + regridder.regrid(da) + assert regridder.forward_weights is w1 # same object, not rebuilt + + +def test_regridder_transpose_roundtrip_rectilinear_aligned(): + """If target cells are aligned unions of source cells, forward followed by + backward reproduces a constant field exactly.""" + # 120 source cells along each axis, target is a 4x coarsening (exact union + # of source cells). A constant source field survives the roundtrip to + # itself because every source cell is fully covered. + ns = 120 + nt = 30 # exactly ns / 4 + x_s = np.linspace(-180, 180, ns, endpoint=False) + 180 / ns + y_s = np.linspace(-90, 90, ns, endpoint=False) + 90 / ns + x_t = np.linspace(-180, 180, nt, endpoint=False) + 180 / nt + y_t = np.linspace(-90, 90, nt, endpoint=False) + 90 / nt + da = xr.DataArray( + np.full((ns, ns), 3.5, dtype=np.float64), + dims=("y", "x"), + coords={"y": y_s, "x": x_s}, + ) + target = xr.Dataset(coords={"y": y_t, "x": x_t}) + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + coarse = regridder.regrid(da) + back = regridder.T.regrid(coarse) + # Every value should match the original constant. + np.testing.assert_allclose(back.values, 3.5, atol=1e-12) + + +def test_regridder_T_preserves_weights(): # noqa: N802 + """regridder.T.T should share the raw area matrix with the original.""" + da = _rect_da() + target = _rect_target() + r = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + rr = r.T.T + # Same shape, same coords, same data. + assert r.areas.shape == rr.areas.shape + if hasattr(r.areas, "data"): + np.testing.assert_array_equal(r.areas.data, rr.areas.data) + + +def test_regridder_shape_mismatch_raises(): + """Applying the regridder to data whose spatial shape differs from the + source it was built for should raise.""" + da = _rect_da() + target = _rect_target() + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + smaller = _rect_da(ny=30, nx=60) + with pytest.raises(ValueError, match="spatial shape"): + regridder.regrid(smaller) + + +def test_cea_mode_matches_factored(): + """Polygon path with ``manifold='cea'`` should match the axis-factored + sin-weighted conservative path to within a tight tolerance on lat/lon + grids (both are analytically equivalent for cylindrical equal-area).""" + lon_s = np.linspace(-180, 180, 180, endpoint=False) + 1.0 + lat_s = np.linspace(-90, 90, 90, endpoint=False) + 1.0 + lon_t = np.linspace(-180, 180, 60, endpoint=False) + 3.0 + lat_t = np.linspace(-90, 90, 30, endpoint=False) + 3.0 + vals = np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.sin(np.deg2rad(lon_s))[None, :] + da = xr.DataArray( + vals, + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + factored = da.regrid.conservative(target, latitude_coord="latitude") + polygon = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", manifold="cea" + ) + # Both methods should agree to the grid's own quadrature accuracy. Near the + # poles the factored path's median-dlat approximation introduces a small + # discrepancy; 1e-3 absolute is well below the planar raw error floor + # (~1e-2 at this resolution, tested separately). + np.testing.assert_allclose( + polygon.transpose(*factored.dims).values, + factored.values, + atol=1e-3, + ) + + +def test_cea_conserves_integral(): + """Mass conservation check on the sphere. For cos^2(lat), true integral is + 8*pi/3; the regridder on a 2-to-6-degree grid keeps the + spherical-area-weighted sum within the grid quadrature floor with + manifold='cea', and misses it by ~17x more with manifold='planar'.""" + lon_s = np.linspace(-180, 180, 180, endpoint=False) + 1.0 + lat_s = np.linspace(-90, 90, 90, endpoint=False) + 1.0 + lon_t = np.linspace(-180, 180, 60, endpoint=False) + 3.0 + lat_t = np.linspace(-90, 90, 30, endpoint=False) + 3.0 + da = xr.DataArray( + np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.ones(lon_s.size)[None, :], + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + out_sph = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", manifold="cea" + ) + out_raw = da.regrid.conservative_2d( + target, x_coord="longitude", y_coord="latitude", manifold="planar" + ) + + # True target spherical cell areas + dlon_arr = np.full(lon_t.size, np.deg2rad(np.mean(np.diff(lon_t)))) + lat_r = np.deg2rad(lat_t) + dlat_r = np.gradient(lat_r) + dlat_bands = np.sin(lat_r + dlat_r / 2) - np.sin(lat_r - dlat_r / 2) + a_tgt = dlat_bands[:, None] * dlon_arr[None, :] + + true_val = 8 * np.pi / 3 + sph_vals = out_sph.transpose("latitude", "longitude").values + raw_vals = out_raw.transpose("latitude", "longitude").values + err_sph = abs(float((sph_vals * a_tgt).sum()) - true_val) + err_raw = abs(float((raw_vals * a_tgt).sum()) - true_val) + # Spherical should be at least 10x more accurate than raw planar here. + assert err_sph < 0.1 * err_raw, f"err_sph={err_sph:.2e} err_raw={err_raw:.2e}" + + +# --- from_polygons (unstructured mesh) ---------------------------------------- + + +def _box_polygons(): + rng = np.random.default_rng(1) + n = 50 + cx = rng.uniform(-170, 170, n) + cy = rng.uniform(-80, 80, n) + return shapely.box(cx - 5, cy - 5, cx + 5, cy + 5) + + +def test_polygons_from_coords_periodic(): + polys = polygons_from_coords( + np.array([167.5, 172.5, -177.5, -172.5]), + np.array([-2.5, 2.5]), + periodic=True, + ) + bounds = shapely.bounds(polys) + widths = bounds[:, 2] - bounds[:, 0] + assert np.all(widths < 10.1) + + +def test_from_polygons_basic(): + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 30, endpoint=False) + 6, + np.linspace(-90, 90, 15, endpoint=False) + 6, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + da = xr.DataArray( + np.arange(src_polys.size, dtype=np.float64), + dims=("face",), + ) + out = rgr.regrid(da) + assert out.dims == ("cell",) + assert out.sizes["cell"] == tgt_polys.size + + +def test_from_polygons_attaches_target_aux_coords(): + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + region_id = np.arange(tgt_polys.size) + 100 + target_coords = xr.Dataset( + coords={ + "cell": np.arange(tgt_polys.size), + "region_id": ("cell", region_id), + } + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, + tgt_polys, + source_dim="face", + target_dim="cell", + target_coords=target_coords, + ) + da = xr.DataArray(np.arange(src_polys.size, dtype=np.float64), dims=("face",)) + out = rgr.regrid(da) + assert "region_id" in out.coords + np.testing.assert_array_equal(out["region_id"].values, region_id) + + +def test_from_polygons_periodic_antimeridian(): + src_polys = np.array( + [shapely.Polygon([(175, -5), (-175, -5), (-175, 5), (175, 5)])], + dtype=object, + ) + tgt_polys = np.array( + [ + shapely.box(160, -5, 170, 5), + shapely.box(175, -5, 185, 5), + shapely.box(-170, -5, -160, 5), + ], + dtype=object, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, + tgt_polys, + source_dim="src", + target_dim="tgt", + periodic=True, + ) + out = rgr.regrid(xr.DataArray([7.0], dims=("src",))) + + assert np.isnan(out.values[0]) + assert out.values[1] == pytest.approx(7.0) + assert np.isnan(out.values[2]) + + +def test_from_polygons_mass_conservation(): + """Sum of intersected mass should match the direct A·s calculation to + machine precision for any source field.""" + + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 36, endpoint=False) + 5, + np.linspace(-90, 90, 18, endpoint=False) + 5, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + rng = np.random.default_rng(3) + s = rng.normal(size=src_polys.size) + da = xr.DataArray(s, dims=("face",)) + out = rgr.regrid(da).values + # Direct mass = sum_i s_i * source_coverage_i. Matches output if we + # multiply output by target-covered area. + tgt_covered = rgr.target_areas + valid = tgt_covered > 0 + direct = float((s * rgr.source_coverage_areas).sum()) + via_regrid = float((out[valid] * tgt_covered[valid]).sum()) + rel = abs(direct - via_regrid) / max(abs(direct), 1e-12) + assert rel < 1e-12, f"rel err {rel:.2e}" + + +def test_from_polygons_transpose_roundtrip(): + """Roundtrip mesh ↔ mesh of a constant field returns the constant.""" + src_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="src", target_dim="tgt" + ) + da = xr.DataArray(np.full(src_polys.size, 3.5), dims=("src",)) + out = rgr.regrid(da) + back = rgr.T.regrid(out) + # Inner source cells (fully covered by target cells they map to) must be 3.5. + # Edge cells may be NaN if target domain doesn't cover them. + finite = np.isfinite(back.values) + np.testing.assert_allclose(back.values[finite], 3.5, atol=1e-12) + + +def test_from_polygons_nan_propagation(): + """NaN source cells propagate through skipna=True correctly.""" + src_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="src", target_dim="tgt" + ) + vals = np.full(src_polys.size, 1.0) + vals[:5] = np.nan + da = xr.DataArray(vals, dims=("src",)) + out_keep = rgr.regrid(da, nan_threshold=1.0) + out_strict = rgr.regrid(da, nan_threshold=0.0) + # Strict should have at least as many NaNs. + strict_nans = int(np.isnan(out_strict.values).sum()) + keep_nans = int(np.isnan(out_keep.values).sum()) + assert strict_nans >= keep_nans + + +def test_from_polygons_target_outside_source_is_nan(): + """Target cells entirely outside the source domain must be NaN regardless + of skipna (regression test for a bug where the "skip mask matmul when no + NaNs" optimization returned zeros for uncovered cells).""" + + src = np.array([shapely.box(0, 0, 1, 1)], dtype=object) + tgt = polygons_from_coords( + np.linspace(100, 110, 4, endpoint=False) + 1.25, + np.linspace(100, 110, 4, endpoint=False) + 1.25, + ) + rgr = ConservativeRegridder.from_polygons(src, tgt, source_dim="src") + # Source has no NaNs → used to wrongly return 0 here. + da = xr.DataArray(np.array([5.0]), dims=("src",)) + out = rgr.regrid(da, skipna=True) + assert np.isnan(out.values).all() + out_no = rgr.regrid(da, skipna=False) + assert np.isnan(out_no.values).all() + + +def test_from_polygons_hole_is_nan(): + """A target cell fully inside a source-polygon hole should be NaN, not 0.""" + + ring_with_hole = shapely.Polygon( + [(0, 0), (10, 0), (10, 10), (0, 10)], + [[(3, 3), (7, 3), (7, 7), (3, 7)]], + ) + src = np.array([ring_with_hole], dtype=object) + tgt = polygons_from_coords( + np.linspace(0, 10, 5, endpoint=False) + 1, + np.linspace(0, 10, 5, endpoint=False) + 1, + ) + rgr = ConservativeRegridder.from_polygons(src, tgt, source_dim="src") + out = rgr.regrid(xr.DataArray([7.0], dims=("src",))) + # The cell centered at (5,5), edges [4,6]x[4,6], lies entirely in the hole. + assert int(np.isnan(out.values).sum()) >= 1 + + +def test_from_polygons_input_validation(): + # 2D polygons array rejected + p = shapely.box(0, 0, 1, 1) + arr_2d = np.array([[p, p], [p, p]], dtype=object) + with pytest.raises(ValueError, match="1D"): + ConservativeRegridder.from_polygons(arr_2d, np.array([p])) + + +# --- netCDF save / load ------------------------------------------------------- + + +def test_regrid_preserves_input_dtype(): + """Float32 in → float32 out; float64 in → float64 out. Sparse promotes + to float64 internally, so we rely on an explicit cast at the end of + ``_apply_core``.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + assert rgr.regrid(da.astype(np.float32)).dtype == np.float32 + assert rgr.regrid(da.astype(np.float64)).dtype == np.float64 + # Integer inputs promote (float32 can't hold int32 without precision loss). + assert np.issubdtype(rgr.regrid((da * 10).astype(np.int32)).dtype, np.floating) + + +def test_to_netcdf_roundtrip_structured(tmp_path): + """Save, reload, regrid → identical output to the original regridder.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + out_before = rgr.regrid(da).values + + path = tmp_path / "regridder.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + + np.testing.assert_array_equal(rgr2.regrid(da).values, out_before) + assert rgr2.x_coord == "x" + assert rgr2.y_coord == "y" + assert rgr2.manifold == "planar" + assert rgr2._src_dims == rgr._src_dims + assert rgr2._dst_dims == rgr._dst_dims + + +def test_to_netcdf_preserves_manifold(tmp_path): + lat_s = np.linspace(-90, 90, 30, endpoint=False) + 3 + lon_s = np.linspace(-180, 180, 60, endpoint=False) + 3 + lat_t = np.linspace(-90, 90, 15, endpoint=False) + 6 + lon_t = np.linspace(-180, 180, 30, endpoint=False) + 6 + da = xr.DataArray( + np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.ones(lon_s.size)[None, :], + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + rgr = ConservativeRegridder( + da, target, x_coord="longitude", y_coord="latitude", manifold="cea" + ) + before = rgr.regrid(da).values + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + assert rgr2.manifold == "cea" + np.testing.assert_allclose(rgr2.regrid(da).values, before, atol=1e-15) + + +def test_to_netcdf_transpose_works_after_reload(tmp_path): + """A reloaded regridder can still take .T for backward regridding.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + + fwd = rgr.regrid(da) + back_original = rgr.T.regrid(fwd).values + back_reloaded = rgr2.T.regrid(fwd).values + np.testing.assert_array_equal(back_original, back_reloaded) + + +def test_to_netcdf_unstructured_roundtrip(tmp_path): + """from_polygons regridder roundtrips too (single spatial dim each side).""" + rng = np.random.default_rng(1) + cx = rng.uniform(-170, 170, 30) + cy = rng.uniform(-80, 80, 30) + src_polys = shapely.box(cx - 5, cy - 5, cx + 5, cy + 5) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + da = xr.DataArray(rng.normal(size=30), dims=("face",)) + out_before = rgr.regrid(da).values + + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + np.testing.assert_array_equal(rgr2.regrid(da).values, out_before) + + +def test_to_netcdf_metadata_fields(tmp_path): + """Metadata captures grid ranges, version, created timestamp.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + + with xr.open_dataset(path) as ds: + attrs = dict(ds.attrs) + + assert attrs["x_coord"] == "x" + assert attrs["y_coord"] == "y" + assert str(attrs["manifold"]) == "planar" + assert tuple(int(size) for size in attrs["src_shape"]) == rgr._src_shape + assert tuple(int(size) for size in attrs["dst_shape"]) == rgr._dst_shape + # Grid ranges captured when the coord is present in source/target. + assert "source_x_range" in attrs + assert "target_x_range" in attrs + assert attrs["source_x_range"][0] <= attrs["source_x_range"][1] + assert attrs["created"] + assert int(attrs["schema_version"]) == 2 + + +def test_from_netcdf_rejects_unknown_schema(tmp_path): + """Loading a file written with a future schema version raises cleanly.""" + h5py = pytest.importorskip("h5py") + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + + # Bump the on-disk schema_version so the loader should reject it. + with h5py.File(path, "a") as f: + f.attrs["schema_version"] = 999 + + with pytest.raises(ValueError, match="schema version"): + ConservativeRegridder.from_netcdf(path) + + +def test_regridder_transpose_curvilinear(): + """Transpose works when the target is a curvilinear grid with different + dim names from the source.""" + da = _rect_da(ny=40, nx=80) + ny_t, nx_t = 20, 30 + xi, yi = np.meshgrid( + np.linspace(-120, 120, nx_t), + np.linspace(-60, 60, ny_t), + indexing="xy", + ) + th = np.deg2rad(15) + x2 = xi * np.cos(th) - yi * np.sin(th) + y2 = xi * np.sin(th) + yi * np.cos(th) + target = xr.Dataset(coords={"x": (("ny", "nx"), x2), "y": (("ny", "nx"), y2)}) + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + fwd = regridder.regrid(da) + assert fwd.dims == ("time", "ny", "nx") + # Going backward we should land back on (time, y, x) + back = regridder.T.regrid(fwd) + assert "y" in back.dims and "x" in back.dims + assert back.sizes["y"] == da.sizes["y"] + assert back.sizes["x"] == da.sizes["x"] + + +# --- s2 manifold (optional, requires spherely) -------------------------------- + +spherely = pytest.importorskip("spherely", reason="s2 manifold tests require spherely") + + +def _latlon_da(ny=24, nx=36, lat_max=87.5, lon_max=175, fill=None, seed=0): + lat = np.linspace(-lat_max, lat_max, ny) + lon = np.linspace(-lon_max, lon_max, nx) + if fill is None: + vals = np.random.default_rng(seed).normal(size=(ny, nx)) + else: + vals = np.full((ny, nx), fill) + return xr.DataArray( + vals, + dims=("latitude", "longitude"), + coords={"latitude": lat, "longitude": lon}, + ) + + +def _latlon_target(ny=9, nx=18, lat_max=80, lon_max=170): + return xr.Dataset( + coords={ + "latitude": np.linspace(-lat_max, lat_max, ny), + "longitude": np.linspace(-lon_max, lon_max, nx), + } + ) + + +def _s2_regridder(da, target): + return ConservativeRegridder( + da, target, x_coord="longitude", y_coord="latitude", manifold="s2" + ) + + +def test_s2_manifold_conserves_mass(): + """On s2 the raw area matrix rows sum to the target-cell steradians, so + `out · a_dst` equals `A · s` to machine precision for any source field.""" + da = _latlon_da(ny=36, nx=48) + target = _latlon_target() + rgr = _s2_regridder(da, target) + out = rgr.regrid(da).values + areas = rgr._areas + src_covered = np.ravel(areas.sum(axis=0).todense()) + dst_covered = np.ravel(areas.sum(axis=1).todense()) + + direct_mass = float((da.values.ravel() * src_covered).sum()) + valid = np.isfinite(out).ravel() + out_mass = float((out.ravel()[valid] * dst_covered[valid]).sum()) + rel = abs(direct_mass - out_mass) / max(abs(direct_mass), 1e-12) + assert rel < 1e-12, f"rel err {rel:.2e}" + + +def test_s2_preserves_constant_field(): + """Roundtrip of a constant field through the s2 path reproduces the + constant to machine precision.""" + da = _latlon_da(fill=7.3) + rgr = _s2_regridder(da, _latlon_target()) + out = rgr.regrid(da).values + finite = np.isfinite(out) + np.testing.assert_allclose(out[finite], 7.3, atol=1e-12) + + +def test_s2_netcdf_roundtrip(tmp_path): + """A saved s2 regridder reloads back to an s2 regridder and produces + identical output.""" + da = _latlon_da(seed=1) + rgr = _s2_regridder(da, _latlon_target()) + before = rgr.regrid(da).values + path = tmp_path / "s2.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + assert rgr2.manifold == "s2" + np.testing.assert_array_equal(rgr2.regrid(da).values, before) + + +def test_s2_handles_pole_touching_cells(): + """A global grid whose outer rows touch ±90° exercises the `oriented=True` + branch of `spherely.create_polygon` — without an explicit orientation, s2 + would silently pick the complementary (hemisphere-sized) interpretation + of a near-pole cell.""" + lat = np.linspace(-89.5, 89.5, 24) # outer edges land exactly on ±90° + lon = np.linspace(-175, 175, 36) + vals = (np.cos(np.deg2rad(lat)) ** 2)[:, None] * np.ones(lon.size)[None, :] + da = xr.DataArray( + vals, + dims=("latitude", "longitude"), + coords={"latitude": lat, "longitude": lon}, + ) + rgr = _s2_regridder(da, _latlon_target(ny=12, nx=24, lat_max=85)) + out = rgr.regrid(da).values + # Output values should stay in [0, 1] (the source range), not hemispheres. + finite = np.isfinite(out) + assert np.all(out[finite] >= -1e-12) + assert np.all(out[finite] <= 1.0 + 1e-12) + + +def test_invalid_manifold_raises(): + da = _rect_da() + target = _rect_target() + with pytest.raises(ValueError, match="manifold must be one of"): + ConservativeRegridder(da, target, x_coord="x", y_coord="y", manifold="bogus") diff --git a/tests/test_conservative_polygon.py b/tests/test_conservative_polygon.py new file mode 100644 index 0000000..a5e5e5c --- /dev/null +++ b/tests/test_conservative_polygon.py @@ -0,0 +1,576 @@ +"""Tests for the polygon-intersection conservative regridder.""" +import numpy as np +import pytest +import xarray as xr + +import xarray_regrid # noqa: F401 (registers the accessor) +from xarray_regrid import ConservativeRegridder, RegridderMetadata, polygons_from_coords + +shapely = pytest.importorskip("shapely") + + +def _rect_da(ny=60, nx=120, nt=2, seed=0): + x = np.linspace(-180, 180, nx, endpoint=False) + 180 / nx + y = np.linspace(-90, 90, ny, endpoint=False) + 90 / ny + rng = np.random.default_rng(seed) + return xr.DataArray( + rng.normal(size=(nt, ny, nx)).astype(np.float64), + dims=("time", "y", "x"), + coords={"time": np.arange(nt), "y": y, "x": x}, + name="var", + ) + + +def _rect_target(ny=24, nx=47): + x = np.linspace(-180, 180, nx, endpoint=False) + 180 / nx + y = np.linspace(-90, 90, ny, endpoint=False) + 90 / ny + return xr.Dataset(coords={"y": y, "x": x}) + + +def test_polygon_matches_factored_planar(): + """On a rectilinear grid with no spherical correction, the polygon path + should reproduce the axis-factored path to machine precision.""" + da = _rect_da() + target = _rect_target() + ref = da.regrid.conservative(target, latitude_coord=None) + got = da.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + got = got.transpose(*ref.dims) + np.testing.assert_allclose(got.values, ref.values, atol=1e-12) + + +def test_polygon_dask_time_chunks(): + da = _rect_da(nt=4).chunk({"time": 2}) + target = _rect_target() + got = da.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + assert got.chunks is not None + got = got.compute() + ref = _rect_da(nt=4).regrid.conservative(target, latitude_coord=None) + np.testing.assert_allclose(got.transpose(*ref.dims).values, ref.values, atol=1e-12) + + +def test_polygon_rechunks_spatial(): + """Spatially-chunked input should be accepted (rechunked internally).""" + da = _rect_da().chunk({"time": 1, "y": 30, "x": 40}) + target = _rect_target() + out = da.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + out.compute() + + +def test_polygon_nan_threshold(): + """Stricter nan_threshold produces more NaN output cells when partial + overlaps exist.""" + da = _rect_da() + da.values[:, 21:29, :] = np.nan + target = _rect_target(ny=23, nx=47) + out1 = da.regrid.conservative_polygon( + target, x_coord="x", y_coord="y", skipna=True, nan_threshold=1.0 + ) + out0 = da.regrid.conservative_polygon( + target, x_coord="x", y_coord="y", skipna=True, nan_threshold=0.0 + ) + assert int(np.isnan(out0.values).sum()) > int(np.isnan(out1.values).sum()) + + +def test_polygon_curvilinear_target(): + """Curvilinear target (2D lat/lon corners) returns finite values.""" + da = _rect_da() + ny_t, nx_t = 20, 30 + xi, yi = np.meshgrid( + np.linspace(-120, 120, nx_t), + np.linspace(-60, 60, ny_t), + indexing="xy", + ) + th = np.deg2rad(30) + x2d = xi * np.cos(th) - yi * np.sin(th) + y2d = xi * np.sin(th) + yi * np.cos(th) + target = xr.Dataset( + coords={"x": (("ny", "nx"), x2d), "y": (("ny", "nx"), y2d)} + ) + out = da.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + assert out.shape == (2, 20, 30) + assert np.isfinite(out.values).mean() > 0.9 + + +def test_polygon_nan_threshold_invalid(): + da = _rect_da() + with pytest.raises(ValueError): + da.regrid.conservative_polygon( + _rect_target(), x_coord="x", y_coord="y", nan_threshold=1.5 + ) + + +def test_polygon_dataset_input(): + """A Dataset input with multiple variables should regrid all of them.""" + da = _rect_da() + ds = xr.Dataset({"a": da, "b": da * 2.0}) + target = _rect_target() + out = ds.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + assert set(out.data_vars) == {"a", "b"} + np.testing.assert_allclose( + out["b"].transpose(*out["a"].dims).values, + out["a"].transpose(*out["a"].dims).values * 2.0, + atol=1e-12, + ) + + +# --- ConservativeRegridder (reusable) ------------------------------------------ + + +def test_regridder_reusable_matches_oneshot(): + """Reusing a single ConservativeRegridder on multiple fields matches the + one-shot `polygon_conservative_regrid` call.""" + da1 = _rect_da(seed=1) + da2 = _rect_da(seed=2) + target = _rect_target() + regridder = ConservativeRegridder(da1, target, x_coord="x", y_coord="y") + ref1 = da1.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + ref2 = da2.regrid.conservative_polygon(target, x_coord="x", y_coord="y") + out1 = regridder.regrid(da1) + out2 = regridder(da2) # __call__ alias + np.testing.assert_allclose(out1.values, ref1.values, atol=1e-12) + np.testing.assert_allclose(out2.values, ref2.values, atol=1e-12) + + +def test_regridder_weight_cache(): + """Forward weight matrix is built lazily and then reused across calls.""" + da = _rect_da() + target = _rect_target() + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + assert regridder._fwd_weights is None + regridder.regrid(da) + w1 = regridder._fwd_weights + assert w1 is not None + regridder.regrid(da) + assert regridder._fwd_weights is w1 # same object, not rebuilt + + +def test_regridder_transpose_roundtrip_rectilinear_aligned(): + """If target cells are aligned unions of source cells, forward followed by + backward reproduces a constant field exactly.""" + # 120 source cells along each axis, target is a 4x coarsening (exact union + # of source cells). A constant source field survives the roundtrip to + # itself because every source cell is fully covered. + ns = 120 + nt = 30 # exactly ns / 4 + x_s = np.linspace(-180, 180, ns, endpoint=False) + 180 / ns + y_s = np.linspace(-90, 90, ns, endpoint=False) + 90 / ns + x_t = np.linspace(-180, 180, nt, endpoint=False) + 180 / nt + y_t = np.linspace(-90, 90, nt, endpoint=False) + 90 / nt + da = xr.DataArray( + np.full((ns, ns), 3.5, dtype=np.float64), + dims=("y", "x"), + coords={"y": y_s, "x": x_s}, + ) + target = xr.Dataset(coords={"y": y_t, "x": x_t}) + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + coarse = regridder.regrid(da) + back = regridder.T.regrid(coarse) + # Every value should match the original constant. + np.testing.assert_allclose(back.values, 3.5, atol=1e-12) + + +def test_regridder_T_preserves_weights(): + """regridder.T.T should share the raw area matrix with the original.""" + da = _rect_da() + target = _rect_target() + r = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + rr = r.T.T + # Same shape, same coords, same data. + assert r._areas.shape == rr._areas.shape + if hasattr(r._areas, "data"): + np.testing.assert_array_equal(r._areas.data, rr._areas.data) + + +def test_regridder_shape_mismatch_raises(): + """Applying the regridder to data whose spatial shape differs from the + source it was built for should raise.""" + da = _rect_da() + target = _rect_target() + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + smaller = _rect_da(ny=30, nx=60) + with pytest.raises(ValueError, match="spatial shape"): + regridder.regrid(smaller) + + +def test_spherical_mode_matches_factored(): + """Polygon path with ``spherical=True`` should match the axis-factored + sin-weighted conservative path to within a tight tolerance on lat/lon + grids (both are analytically equivalent for cylindrical equal-area).""" + lon_s = np.linspace(-180, 180, 180, endpoint=False) + 1.0 + lat_s = np.linspace(-90, 90, 90, endpoint=False) + 1.0 + lon_t = np.linspace(-180, 180, 60, endpoint=False) + 3.0 + lat_t = np.linspace(-90, 90, 30, endpoint=False) + 3.0 + vals = np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.sin(np.deg2rad(lon_s))[None, :] + da = xr.DataArray( + vals, + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + factored = da.regrid.conservative(target, latitude_coord="latitude") + polygon = da.regrid.conservative_polygon( + target, x_coord="longitude", y_coord="latitude", spherical=True + ) + # Both methods should agree to the grid's own quadrature accuracy. Near the + # poles the factored path's median-dlat approximation introduces a small + # discrepancy; 1e-3 absolute is well below the planar raw error floor + # (~1e-2 at this resolution, tested separately). + np.testing.assert_allclose( + polygon.transpose(*factored.dims).values, + factored.values, + atol=1e-3, + ) + + +def test_spherical_conserves_integral(): + """Mass conservation check on the sphere. For cos²(lat), true integral is + 8π/3; the regridder on a 2°→6° grid should keep the spherical-area-weighted + sum within the grid quadrature floor when spherical=True, and miss it by + ~17× more when spherical=False.""" + lon_s = np.linspace(-180, 180, 180, endpoint=False) + 1.0 + lat_s = np.linspace(-90, 90, 90, endpoint=False) + 1.0 + lon_t = np.linspace(-180, 180, 60, endpoint=False) + 3.0 + lat_t = np.linspace(-90, 90, 30, endpoint=False) + 3.0 + da = xr.DataArray( + np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.ones(lon_s.size)[None, :], + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + out_sph = da.regrid.conservative_polygon( + target, x_coord="longitude", y_coord="latitude", spherical=True + ) + out_raw = da.regrid.conservative_polygon( + target, x_coord="longitude", y_coord="latitude", spherical=False + ) + + # True target spherical cell areas + dlon = np.deg2rad(np.gradient(lat_t.astype(float) * 0 + np.mean(np.diff(lon_t)))) + dlon_arr = np.full(lon_t.size, np.deg2rad(np.mean(np.diff(lon_t)))) + lat_r = np.deg2rad(lat_t) + dlat_r = np.gradient(lat_r) + a_tgt = (np.sin(lat_r + dlat_r / 2) - np.sin(lat_r - dlat_r / 2))[:, None] * dlon_arr[None, :] + + true_val = 8 * np.pi / 3 + err_sph = abs(float((out_sph.transpose("latitude", "longitude").values * a_tgt).sum()) - true_val) + err_raw = abs(float((out_raw.transpose("latitude", "longitude").values * a_tgt).sum()) - true_val) + # Spherical should be at least 10× more accurate than raw planar on this grid. + assert err_sph < 0.1 * err_raw, f"err_sph={err_sph:.2e} err_raw={err_raw:.2e}" + + +# --- from_polygons (unstructured mesh) ---------------------------------------- + + +def _box_polygons(): + import shapely + rng = np.random.default_rng(1) + n = 50 + cx = rng.uniform(-170, 170, n) + cy = rng.uniform(-80, 80, n) + return shapely.box(cx - 5, cy - 5, cx + 5, cy + 5) + + +def test_from_polygons_basic(): + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 30, endpoint=False) + 6, + np.linspace(-90, 90, 15, endpoint=False) + 6, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + da = xr.DataArray( + np.arange(src_polys.size, dtype=np.float64), + dims=("face",), + ) + out = rgr.regrid(da) + assert out.dims == ("cell",) + assert out.sizes["cell"] == tgt_polys.size + + +def test_from_polygons_mass_conservation(): + """Sum of intersected mass should match the direct A·s calculation to + machine precision for any source field.""" + import shapely + + src_polys = _box_polygons() + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 36, endpoint=False) + 5, + np.linspace(-90, 90, 18, endpoint=False) + 5, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + rng = np.random.default_rng(3) + s = rng.normal(size=src_polys.size) + da = xr.DataArray(s, dims=("face",)) + out = rgr.regrid(da).values + # Direct mass = Σ_i s_i × (Σ_j A_ij). Matches output if we multiply output + # by target-covered area = Σ_i A_ij. + A = rgr._areas # (n_tgt, n_src) + tgt_covered = A.sum(axis=1).todense() + valid = tgt_covered > 0 + direct = float((s * A.sum(axis=0).todense()).sum()) + via_regrid = float((out[valid] * tgt_covered[valid]).sum()) + rel = abs(direct - via_regrid) / max(abs(direct), 1e-12) + assert rel < 1e-12, f"rel err {rel:.2e}" + + +def test_from_polygons_transpose_roundtrip(): + """Roundtrip mesh ↔ mesh of a constant field returns the constant.""" + import shapely + src_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="src", target_dim="tgt" + ) + da = xr.DataArray(np.full(src_polys.size, 3.5), dims=("src",)) + out = rgr.regrid(da) + back = rgr.T.regrid(out) + # Inner source cells (fully covered by target cells they map to) must be 3.5. + # Edge cells may be NaN if target domain doesn't cover them. + finite = np.isfinite(back.values) + np.testing.assert_allclose(back.values[finite], 3.5, atol=1e-12) + + +def test_from_polygons_nan_propagation(): + """NaN source cells propagate through skipna=True correctly.""" + import shapely + src_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 12, endpoint=False) + 15, + np.linspace(-90, 90, 6, endpoint=False) + 15, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="src", target_dim="tgt" + ) + vals = np.full(src_polys.size, 1.0) + vals[:5] = np.nan + da = xr.DataArray(vals, dims=("src",)) + out_keep = rgr.regrid(da, nan_threshold=1.0) + out_strict = rgr.regrid(da, nan_threshold=0.0) + # Strict should have at least as many NaNs. + assert int(np.isnan(out_strict.values).sum()) >= int(np.isnan(out_keep.values).sum()) + + +def test_from_polygons_target_outside_source_is_nan(): + """Target cells entirely outside the source domain must be NaN regardless + of skipna (regression test for a bug where the "skip mask matmul when no + NaNs" optimization returned zeros for uncovered cells).""" + import shapely + + src = np.array([shapely.box(0, 0, 1, 1)], dtype=object) + tgt = polygons_from_coords( + np.linspace(100, 110, 4, endpoint=False) + 1.25, + np.linspace(100, 110, 4, endpoint=False) + 1.25, + ) + rgr = ConservativeRegridder.from_polygons(src, tgt, source_dim="src") + # Source has no NaNs → used to wrongly return 0 here. + da = xr.DataArray(np.array([5.0]), dims=("src",)) + out = rgr.regrid(da, skipna=True) + assert np.isnan(out.values).all() + out_no = rgr.regrid(da, skipna=False) + assert np.isnan(out_no.values).all() + + +def test_from_polygons_hole_is_nan(): + """A target cell fully inside a source-polygon hole should be NaN, not 0.""" + import shapely + + ring_with_hole = shapely.Polygon( + [(0, 0), (10, 0), (10, 10), (0, 10)], + [[(3, 3), (7, 3), (7, 7), (3, 7)]], + ) + src = np.array([ring_with_hole], dtype=object) + tgt = polygons_from_coords( + np.linspace(0, 10, 5, endpoint=False) + 1, + np.linspace(0, 10, 5, endpoint=False) + 1, + ) + rgr = ConservativeRegridder.from_polygons(src, tgt, source_dim="src") + out = rgr.regrid(xr.DataArray([7.0], dims=("src",))) + # The cell centered at (5,5), edges [4,6]x[4,6], lies entirely in the hole. + assert int(np.isnan(out.values).sum()) >= 1 + + +def test_from_polygons_input_validation(): + # 2D polygons array rejected + import shapely + p = shapely.box(0, 0, 1, 1) + arr_2d = np.array([[p, p], [p, p]], dtype=object) + with pytest.raises(ValueError, match="1D"): + ConservativeRegridder.from_polygons(arr_2d, np.array([p])) + + +# --- netCDF save / load ------------------------------------------------------- + + +def test_regrid_preserves_input_dtype(): + """Float32 in → float32 out; float64 in → float64 out. Sparse promotes + to float64 internally, so we rely on an explicit cast at the end of + ``_apply_core``.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + assert rgr.regrid(da.astype(np.float32)).dtype == np.float32 + assert rgr.regrid(da.astype(np.float64)).dtype == np.float64 + # Integer inputs promote (float32 can't hold int32 without precision loss). + assert np.issubdtype(rgr.regrid((da * 10).astype(np.int32)).dtype, np.floating) + + +def test_to_netcdf_roundtrip_structured(tmp_path): + """Save, reload, regrid → identical output to the original regridder.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + out_before = rgr.regrid(da).values + + path = tmp_path / "regridder.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + + np.testing.assert_array_equal(rgr2.regrid(da).values, out_before) + assert rgr2.x_coord == "x" + assert rgr2.y_coord == "y" + assert rgr2.spherical is False + assert rgr2._src_dims == rgr._src_dims + assert rgr2._dst_dims == rgr._dst_dims + + +def test_to_netcdf_preserves_spherical_flag(tmp_path): + lat_s = np.linspace(-90, 90, 30, endpoint=False) + 3 + lon_s = np.linspace(-180, 180, 60, endpoint=False) + 3 + lat_t = np.linspace(-90, 90, 15, endpoint=False) + 6 + lon_t = np.linspace(-180, 180, 30, endpoint=False) + 6 + da = xr.DataArray( + np.cos(np.deg2rad(lat_s))[:, None] ** 2 * np.ones(lon_s.size)[None, :], + dims=("latitude", "longitude"), + coords={"latitude": lat_s, "longitude": lon_s}, + ) + target = xr.Dataset(coords={"latitude": lat_t, "longitude": lon_t}) + + rgr = ConservativeRegridder( + da, target, x_coord="longitude", y_coord="latitude", spherical=True + ) + before = rgr.regrid(da).values + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + assert rgr2.spherical is True + np.testing.assert_allclose(rgr2.regrid(da).values, before, atol=1e-15) + + +def test_to_netcdf_transpose_works_after_reload(tmp_path): + """A reloaded regridder can still take .T for backward regridding.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + + fwd = rgr.regrid(da) + back_original = rgr.T.regrid(fwd).values + back_reloaded = rgr2.T.regrid(fwd).values + np.testing.assert_array_equal(back_original, back_reloaded) + + +def test_to_netcdf_unstructured_roundtrip(tmp_path): + """from_polygons regridder roundtrips too (single spatial dim each side).""" + import shapely + rng = np.random.default_rng(1) + cx = rng.uniform(-170, 170, 30) + cy = rng.uniform(-80, 80, 30) + src_polys = shapely.box(cx - 5, cy - 5, cx + 5, cy + 5) + tgt_polys = polygons_from_coords( + np.linspace(-180, 180, 24, endpoint=False) + 7.5, + np.linspace(-90, 90, 12, endpoint=False) + 7.5, + ) + rgr = ConservativeRegridder.from_polygons( + src_polys, tgt_polys, source_dim="face", target_dim="cell" + ) + da = xr.DataArray(rng.normal(size=30), dims=("face",)) + out_before = rgr.regrid(da).values + + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + rgr2 = ConservativeRegridder.from_netcdf(path) + np.testing.assert_array_equal(rgr2.regrid(da).values, out_before) + + +def test_to_netcdf_metadata_fields(tmp_path): + """Metadata captures grid ranges, version, created timestamp.""" + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + + with xr.open_dataset(path) as ds: + attrs = dict(ds.attrs) + meta = RegridderMetadata.from_attrs(attrs) + + assert meta.x_coord == "x" + assert meta.y_coord == "y" + assert meta.spherical is False + assert meta.src_shape == rgr._src_shape + assert meta.dst_shape == rgr._dst_shape + # Grid ranges captured when the coord is present in source/target. + assert meta.source_x_range is not None + assert meta.target_x_range is not None + assert meta.source_x_range[0] <= meta.source_x_range[1] + assert meta.created # non-empty ISO timestamp + assert meta.schema_version == 2 + + +def test_from_netcdf_rejects_unknown_schema(tmp_path): + """Loading a file written with a future schema version raises cleanly.""" + import h5py + da = _rect_da() + target = _rect_target() + rgr = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + path = tmp_path / "r.nc" + rgr.to_netcdf(path) + + # Bump the on-disk schema_version so the loader should reject it. + with h5py.File(path, "a") as f: + f.attrs["schema_version"] = 999 + + with pytest.raises(ValueError, match="schema version"): + ConservativeRegridder.from_netcdf(path) + + +def test_regridder_transpose_curvilinear(): + """Transpose works when the target is a curvilinear grid with different + dim names from the source.""" + da = _rect_da(ny=40, nx=80) + ny_t, nx_t = 20, 30 + xi, yi = np.meshgrid( + np.linspace(-120, 120, nx_t), + np.linspace(-60, 60, ny_t), + indexing="xy", + ) + th = np.deg2rad(15) + x2 = xi * np.cos(th) - yi * np.sin(th) + y2 = xi * np.sin(th) + yi * np.cos(th) + target = xr.Dataset( + coords={"x": (("ny", "nx"), x2), "y": (("ny", "nx"), y2)} + ) + regridder = ConservativeRegridder(da, target, x_coord="x", y_coord="y") + fwd = regridder.regrid(da) + assert fwd.dims == ("time", "ny", "nx") + # Going backward we should land back on (time, y, x) + back = regridder.T.regrid(fwd) + assert "y" in back.dims and "x" in back.dims + assert back.sizes["y"] == da.sizes["y"] + assert back.sizes["x"] == da.sizes["x"]