From 44d46149e7f12cf4f712332f70fad0a6f78a0299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Sep 2024 23:46:02 +0200 Subject: [PATCH 01/18] merge --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index cc18cc91b..277af7847 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,8 @@ - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) +- restructure `ot.unbalanced` module (PR #658) +- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From a93c60cbe7f317c77accd5dbb6850ec8cfad584c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 21 Apr 2025 22:14:33 +0200 Subject: [PATCH 02/18] first commit --- ot/__init__.py | 3 +- ot/solvers.py | 546 ++++++++++++++++++++++++++++++++++++++++++- ot/utils.py | 171 ++++++++++++++ test/test_solvers.py | 136 +++++++++++ 4 files changed, 854 insertions(+), 2 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 5e21d6a76..dbddbd03f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -68,7 +68,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov, solve_sample, bary_sample from .lowrank import lowrank_sinkhorn # utils functions @@ -116,6 +116,7 @@ "solve", "solve_gromov", "solve_sample", + "bary_sample", "smooth", "stochastic", "unbalanced", diff --git a/ot/solvers.py b/ot/solvers.py index a5bbf0e94..257e0d701 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -4,10 +4,11 @@ """ # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License -from .utils import OTResult, dist +from .utils import OTResult, BaryResult, dist from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced @@ -33,6 +34,7 @@ from .optim import cg import warnings +import numpy as np lst_method_lazy = [ @@ -1936,3 +1938,545 @@ def solve_sample( log=log, ) return res + + +def _bary_sample_bcd( + X_s, + X_init, + a_s, + b_init, + w_s, + metric, + inner_solver, + max_iter_bary, + tol_bary, + verbose, + log, + nx, +): + """Compute the barycenter using BCD. + + Parameters + ---------- + X_s : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + X_init : array-like, shape (n_samples_b, dim), + Initialization of the barycenter samples. + a_s : list of array-like, shape (dim_k,) + List of samples weights in each source distribution + b_init : array-like, shape (n_samples_b,) + Initialization of the barycenter weights. + w_s : list of array-like, shape (N,) + Samples barycentric weights + metric : str + Metric to use for the cost matrix, by default "sqeuclidean" + inner_solver : callable + Function to solve the inner OT problem + max_iter_bary : int + Maximum number of iterations for the barycenter + tol_bary : float + Tolerance for the barycenter convergence + verbose : bool + Print information in the solver + log : bool + Log the loss during the iterations + nx: backend + Backend to use for the computation. Must match<< + Returns + ------- + TBD + """ + + X = X_init + b = b_init + inv_b = 1.0 / b + + prev_loss = np.inf + n_samples = len(X_s) + + if log: + log_ = {"loss": []} + else: + log_ = None + # Compute the barycenter using BCD + for it in range(max_iter_bary): + # Solve the inner OT problem for each source distribution + list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)] + + # Update the barycenter samples + if metric in ["sqeuclidean", "euclidean"]: + X_new = ( + sum([w_s[k] * list_res[k].plan.T @ X_s[k] for k in range(n_samples)]) + * inv_b[:, None] + ) + else: + raise NotImplementedError('Not implemented metric="{}"'.format(metric)) + + # compute loss + new_loss = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + + if verbose: + if it % 1 == 0: + print(f"BCD iteration {it}: loss = {new_loss:.4f}") + + if log: + log_["loss"].append(new_loss) + # Check convergence + if abs(new_loss - prev_loss) / abs(prev_loss) < tol_bary: + print(f"BCD converged in {it} iterations") + break + + X = X_new + prev_loss = new_loss + + # compute value_linear + value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)]) + # update BaryResult + bary_res = BaryResult( + X=X_new, + b=b, + value=new_loss, + value_linear=value_linear, + log=log_, + list_res=list_res, + backend=nx, + ) + return bary_res + + +def bary_sample( + X_s, + n, + a_s=None, + w_s=None, + X_init=None, + b_init=None, + learn_X=True, + learn_b=False, + metric="sqeuclidean", + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + lazy=False, + batch_size=None, + method=None, + n_threads=1, + max_iter_bary=1000, + max_iter=None, + rank=100, + scaling=0.95, + tol_bary=1e-5, + tol=None, + random_state=0, + verbose=False, +): + r"""Solve the discrete OT barycenter problem over source distributions using Block-Coordinate Descent. + + The function solves the following general OT barycenter problem + + .. math:: + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}, \mathbf{b} \in \Sigma_n} \min_{\{ \mathbf{T}^{(k)} \}_k \in \R_+^{n_i \times n}} \quad \sum_k w_k \sum_{i,j} T^{(k)}_{i,j}M^{(k)}_{i,j} + \lambda_r R(\mathbf{T}^{(k)}) + + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) + + where the cost matrices :math:`\mathbf{M}^{(k)}` for each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{b}^{(k)})` + is computed from the samples in the source and barycenter domains such that + :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + n : int + number of samples in the barycenter domain + a_s : list of array-like, shape (dim_k,), optional + List of samples weights in each source distribution (default is uniform) + w_s : list of array-like, shape (N,), optional + Samples barycentric weights (default is uniform) + X_init : array-like, shape (n_samples_b, dim), optional + Initialization of the barycenter samples (default is gaussian random sampling). + Shape must match with required n. + b_init : array-like, shape (n_samples_b,), optional + Initialization of the barycenter weights (default is uniform). + Shape must match with required n. + learn_X : bool, optional + Learn the barycenter samples (default is True) + learn_b : bool, optional + Learn the barycenter weights (default is False) + metric : str, optional + Metric to use for the cost matrix, by default "sqeuclidean" + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + c : array-like, shape (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + batch_size : int, optional + Batch size for lazy solver, by default None (default values in each + solvers) + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter_bary : int, optional + Maximum number of iteration for the BCD solver, by default 1000. + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + rank : int, optional + Rank of the OT matrix for lazy solers (method='factored'), by default 100 + scaling : float, optional + Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 + tol_bary : float, optional + Tolerance for solution precision of barycenter problem, by default None (default value 1e-5) + tol : float, optional + Tolerance for solution precision of inner OT solver, by default None (default values in each solvers) + random_state : int, optional + Random seed for the initialization of the barycenter samples, by default 0. + Only used if `X_init` is None. + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b) + + # for uniform weights + res = ot.solve_sample(xa, xb) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_sample(xa, xb, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + + # lazy solver of memory complexity O(n) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) + # lazy OT plan + lazy_plan = res.lazy_plan + + # Use envelope theorem differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `envelope` mode computes the gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + + We also have a very efficient solver with compiled CPU/CUDA code using + geomloss/PyKeOps that can be used with the following code: + + .. code-block:: python + + # automatic solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') + + # force O(n) memory efficient solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') + + # force pre-computed cost matrix + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') + + # use multiscale solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') + + # One can play with speed (small scaling factor) and precision (scaling close to 1) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + \text{with} \ M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + \text{with} \ M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', + unbalanced=1.0, unbalanced_type='L2') + + + - **Factored OT [2]** (when ``method='factored'``): + + This method solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated + to the samples in the source and target domains, and :math:`W_2` is the + Wasserstein distance. This problem is solved using exact OT solvers for + `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides + two transport plans that can be used to recover a low rank OT plan between + the two distributions. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='factored', rank=10) + + # recover the lazy low rank plan + factored_solution_lazy = res.lazy_plan + + # recover the full low rank plan + factored_solution = factored_solution_lazy[:] + + - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): + + This method computes the Gaussian Bures-Wasserstein distance between two + Gaussian distributions estimated from the empirical distributions + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + The covariances and means are estimated from the data. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='gaussian') + + # recover the squared Gaussian Bures-Wasserstein distance + BW_dist = res.value + + - **Wasserstein 1d [1]** (when ``method='1D'``): + + This method computes the Wasserstein distance between two 1d distributions + estimated from the empirical distributions. For multivariate data the + distances are computed independently for each dimension. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='1D') + + # recover the squared Wasserstein distances + W_dists = res.value + + + .. _references-bary-sample: + References + ---------- + + """ + if learn_b: + raise NotImplementedError("Barycenter weights learning not implemented yet") + + if method is not None and method.lower() in lst_method_lazy: + raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") + + n_samples = len(X_s) + + if ( + not lazy + ): # default non lazy solver calls ot.solve_sample within _bary_sample_bcd + # Detect backend + nx = get_backend(*X_s, X_init, b_init, w_s) + + # check sample weights + if a_s is None: + a_s = [ + nx.ones((X_s[k].shape[0],), type_as=X_s[k]) / X_s[k].shape[0] + for k in range(n_samples) + ] + + # check samples barycentric weights + if w_s is None: + w_s = nx.ones(n_samples, type_as=X_s[0]) / n_samples + + # check X_init + if X_init is None: + if (not learn_X) and learn_b: + raise ValueError( + "X_init must be provided if learn_X=False and learn_b=True" + ) + else: + rng = np.random.RandomState(random_state) + mean_ = nx.concatenate( + [nx.mean(X_s[k], axis=0) for k in range(n_samples)], + axis=0, + ) + mean_ = nx.mean(mean_, axis=0) + std_ = nx.concatenate( + [nx.std(X_s[k], axis=0) for k in range(n_samples)], + axis=0, + ) + std_ = nx.mean(std_, axis=0) + X_init = rng.normal( + loc=mean_, + scale=std_, + size=(n, X_s[0].shape[1]), + ) + X_init = nx.from_numpy(X_init, type_as=X_s[0]) + else: + if (X_init.shape[0] != n) or (X_init.shape[1] != X_s[0].shape[1]): + raise ValueError("X_init must have shape (n, dim)") + + # check b_init + if b_init is None: + b_init = nx.ones((n,), type_as=X_s[0]) / n + + def inner_solver(X_a, X, a, b): + return solve_sample( + X_a=X_a, + X_b=X, + a=a, + b=b, + metric=metric, + reg=reg, + c=c, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + method=method, + n_threads=n_threads, + max_iter=max_iter, + tol=tol, + verbose=False, + ) + + res = _bary_sample_bcd( + X_s, + X_init, + a_s, + b_init, + w_s, + metric, + inner_solver, + max_iter_bary, + tol_bary, + verbose, + True, # log set to True by default + nx, + ) + + return res + + else: + raise (NotImplementedError("Barycenter solver with lazy=True not implemented")) diff --git a/ot/utils.py b/ot/utils.py index 1f24fa33f..6f8a5682f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1310,6 +1310,177 @@ def citation(self): """ +class BaryResult: + """Base class for OT barycenter results. + + Parameters + ---------- + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + + Attributes + ---------- + + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + backend : Backend + Backend used to compute the results. + """ + + def __init__( + self, + X=None, + C=None, + b=None, + value=None, + value_linear=None, + value_quad=None, + log=None, + list_res=None, + backend=None, + ): + self._X = X + self._C = C + self._b = b + self._value = value + self._value_linear = value_linear + self._value_quad = value_quad + self._log = log + self._list_res = list_res + self._backend = backend if backend is not None else NumpyBackend() + + def __repr__(self): + s = "BaryResult(" + if self._value is not None: + s += "value={},".format(self._value) + if self._value_linear is not None: + s += "value_linear={},".format(self._value_linear) + if self._X is not None: + s += "X={}(shape={}),".format(self._X.__class__.__name__, self._X.shape) + if self._C is not None: + s += "C={}(shape={}),".format(self._C.__class__.__name__, self._C.shape) + if self._b is not None: + s += "b={}(shape={}),".format(self._b.__class__.__name__, self._b.shape) + if s[-1] != "(": + s = s[:-1] + ")" + else: + s = s + ")" + return s + + # Barycerters -------------------------------- + + @property + def X(self): + """Barycenter features.""" + return self._X + + @property + def C(self): + """Barycenter structure for Gromov Wasserstein solutions.""" + return self._C + + @property + def b(self): + """Barycenter weights.""" + return self._b + + # Loss values -------------------------------- + + @property + def value(self): + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" + return self._value + + @property + def value_linear(self): + """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" + return self._value_linear + + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + return self._value_quad + + # List of OTResult objects ------------------------- + + @property + def list_res(self): + """List of results for the individual OT matching.""" + return self._list_res + + @property + def status(self): + """Optimization status of the solver.""" + return self._status + + @property + def log(self): + """Dictionary containing potential information about the solver.""" + return self._log + + # Miscellaneous -------------------------------- + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ + + class LazyTensor(object): """A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. diff --git a/test/test_solvers.py b/test/test_solvers.py index a0c1d7c43..c691b9cc0 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -1,6 +1,7 @@ """Tests for ot solvers""" # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License @@ -703,3 +704,138 @@ def test_solve_sample_NotImplemented(nx, method_params): with pytest.raises(NotImplementedError): ot.solve_sample(xb, yb, ab, bb, **method_params) + + +def assert_allclose_bary_sol(sol1, sol2): + lst_attr = ["X", "b", "value", "value_linear", "log"] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: + try: + var1 = getattr(sol1, attr) + var2 = getattr(sol2, attr) + if isinstance(var1, dict): # only contains lists + for key in var1.keys(): + np.allclose( + np.array(var1[key]), + np.array(var2[key]), + equal_nan=True, + ) + else: + np.allclose( + nx1.to_numpy(getattr(sol1, attr)), + nx2.to_numpy(getattr(sol2, attr)), + equal_nan=True, + ) + except NotImplementedError: + pass + elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: + return True + else: + return False + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type", + itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type), +) +def test_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState(0) + + K = 3 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_s = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_s = [ot.utils.unif(X.shape[0]) for X in X_s] + b = ot.utils.unif(n) + + w_s = ot.utils.unif(K) + + try: + if reg_type == "tuple": + + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + + # solve default None weights + sol0 = ot.bary_sample( + X_s, + n, + w_s=None, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + max_iter_bary=4, + tol_bary=1e-3, + verbose=True, + ) + + # solve provided uniform weights + sol = ot.bary_sample( + X_s, + n, + a_s=a_s, + b_init=b, + w_s=w_s, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + max_iter_bary=4, + tol_bary=1e-3, + verbose=True, + ) + + assert_allclose_bary_sol(sol0, sol) + + # solve in backend + X_sb = nx.from_numpy(*X_s) + a_sb = nx.from_numpy(*a_s) + w_sb, bb = nx.from_numpy(w_s, b) + + if isinstance(reg_type, tuple): + + def f(G): + return nx.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + + solb = ot.bary_sample( + X_sb, + n, + a_s=a_sb, + b_init=bb, + w_s=w_sb, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + max_iter_bary=4, + tol_bary=1e-3, + verbose=True, + ) + assert_allclose_bary_sol(sol, solb) + + except NotImplementedError: + pytest.skip("Not implemented") From 9e25e8080506512acf14b3c56623d228d751a05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 23 Apr 2025 00:48:18 +0200 Subject: [PATCH 03/18] handle masses in unbalanced cases --- ot/solvers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ot/solvers.py b/ot/solvers.py index 257e0d701..0c944e709 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1948,6 +1948,7 @@ def _bary_sample_bcd( w_s, metric, inner_solver, + update_masses, max_iter_bary, tol_bary, verbose, @@ -1972,6 +1973,8 @@ def _bary_sample_bcd( Metric to use for the cost matrix, by default "sqeuclidean" inner_solver : callable Function to solve the inner OT problem + update_masses : bool + Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. max_iter_bary : int Maximum number of iterations for the barycenter tol_bary : float @@ -2003,6 +2006,10 @@ def _bary_sample_bcd( # Solve the inner OT problem for each source distribution list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)] + # Update the estimated barycenter weights in unbalanced cases + if update_masses: + b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) + inv_b = 1.0 / b # Update the barycenter samples if metric in ["sqeuclidean", "euclidean"]: X_new = ( @@ -2461,6 +2468,8 @@ def inner_solver(X_a, X, a, b): verbose=False, ) + # compute the barycenter using BCD + update_masses = unbalanced is not None res = _bary_sample_bcd( X_s, X_init, @@ -2469,6 +2478,7 @@ def inner_solver(X_a, X, a, b): w_s, metric, inner_solver, + update_masses, max_iter_bary, tol_bary, verbose, From 46c46385eb27c5624e5ca97aaef232a5c429e2ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 24 Apr 2025 18:20:05 +0200 Subject: [PATCH 04/18] update free support --- ot/solvers.py | 102 ++++++++++++++++++++++++++++++++++++++----- ot/utils.py | 6 +++ test/test_solvers.py | 21 ++++++--- test/test_utils.py | 27 +++++++++++- 4 files changed, 136 insertions(+), 20 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 0c944e709..3e1c35ab6 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1949,6 +1949,9 @@ def _bary_sample_bcd( metric, inner_solver, update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, max_iter_bary, tol_bary, verbose, @@ -1975,6 +1978,12 @@ def _bary_sample_bcd( Function to solve the inner OT problem update_masses : bool Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. + warmstart_plan : bool + Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + warmstart_potentials : bool + Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + stopping_criterion : str + Stopping criterion for the BCD algorithm. Can be "loss" or "bary". max_iter_bary : int Maximum number of iterations for the barycenter tol_bary : float @@ -1994,22 +2003,41 @@ def _bary_sample_bcd( b = b_init inv_b = 1.0 / b - prev_loss = np.inf + prev_criterion = np.inf n_samples = len(X_s) if log: - log_ = {"loss": []} + log_ = {"stopping_criterion": []} else: log_ = None + # Compute the barycenter using BCD for it in range(max_iter_bary): # Solve the inner OT problem for each source distribution - list_res = [inner_solver(X_s[k], X, a_s[k], b) for k in range(n_samples)] + if it == 0: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + ] + elif warmstart_plan: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, list_res[k].plan, None) + for k in range(n_samples) + ] + elif warmstart_potentials: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, list_res[k].potentials) + for k in range(n_samples) + ] + else: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + ] # Update the estimated barycenter weights in unbalanced cases if update_masses: b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) inv_b = 1.0 / b + # Update the barycenter samples if metric in ["sqeuclidean", "euclidean"]: X_new = ( @@ -2019,30 +2047,40 @@ def _bary_sample_bcd( else: raise NotImplementedError('Not implemented metric="{}"'.format(metric)) - # compute loss - new_loss = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + # compute criterion + if stopping_criterion == "loss": + new_criterion = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + else: # stopping_criterion = "bary" + new_criterion = nx.norm(X_new - X, ord=2) if verbose: if it % 1 == 0: - print(f"BCD iteration {it}: loss = {new_loss:.4f}") + print( + f"BCD iteration {it}: criterion {stopping_criterion} = {new_criterion:.4f}" + ) if log: - log_["loss"].append(new_loss) + log_["stopping_criterion"].append(new_criterion) # Check convergence - if abs(new_loss - prev_loss) / abs(prev_loss) < tol_bary: + if abs(new_criterion - prev_criterion) / abs(prev_criterion) < tol_bary: print(f"BCD converged in {it} iterations") break X = X_new - prev_loss = new_loss + prev_criterion = new_criterion + + # compute loss values - # compute value_linear value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)]) + if stopping_criterion == "loss": + value = new_criterion + else: + value = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) # update BaryResult bary_res = BaryResult( X=X_new, b=b, - value=new_loss, + value=value, value_linear=value_linear, log=log_, list_res=list_res, @@ -2070,6 +2108,8 @@ def bary_sample( batch_size=None, method=None, n_threads=1, + warmstart=False, + stopping_criterion="loss", max_iter_bary=1000, max_iter=None, rank=100, @@ -2154,6 +2194,11 @@ def bary_sample( large scale solver. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 + warmstart : bool, optional + Use the previous OT or potentials as initialization for the next inner solver iteration, by default False. + stopping_criterion : str, optional + Stopping criterion for the outer loop of the BCD solver, by default 'loss'. + Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm. max_iter_bary : int, optional Maximum number of iteration for the BCD solver, by default 1000. max_iter : int, optional @@ -2398,6 +2443,13 @@ def bary_sample( if method is not None and method.lower() in lst_method_lazy: raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") + if stopping_criterion not in ["loss", "bary"]: + raise ValueError( + "stopping_criterion must be either 'loss' or 'bary', got {}".format( + stopping_criterion + ) + ) + n_samples = len(X_s) if ( @@ -2449,7 +2501,28 @@ def bary_sample( if b_init is None: b_init = nx.ones((n,), type_as=X_s[0]) / n - def inner_solver(X_a, X, a, b): + if warmstart: + if reg is None: # exact OT + warmstart_plan = True + warmstart_potentials = False + else: # regularized OT + # unbalanced AND regularized OT + if ( + not isinstance(reg_type, tuple) + and reg_type.lower() in ["kl"] + and unbalanced_type.lower() == "kl" + ): + warmstart_plan = False + warmstart_potentials = True + + else: + warmstart_plan = True + warmstart_potentials = False + else: + warmstart_plan = False + warmstart_potentials = False + + def inner_solver(X_a, X, a, b, plan_init, potentials_init): return solve_sample( X_a=X_a, X_b=X, @@ -2465,6 +2538,8 @@ def inner_solver(X_a, X, a, b): n_threads=n_threads, max_iter=max_iter, tol=tol, + plan_init=plan_init, + potentials_init=potentials_init, verbose=False, ) @@ -2479,6 +2554,9 @@ def inner_solver(X_a, X, a, b): metric, inner_solver, update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, max_iter_bary, tol_bary, verbose, diff --git a/ot/utils.py b/ot/utils.py index 6f8a5682f..8b045984b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1334,6 +1334,8 @@ class BaryResult: Dictionary containing potential information about the solver. list_res: list of OTResult List of results for the individual OT matching. + status : int or str + Status of the solver. Attributes ---------- @@ -1357,6 +1359,8 @@ class BaryResult: Dictionary containing potential information about the solver. list_res: list of OTResult List of results for the individual OT matching. + status : int or str + Status of the solver. backend : Backend Backend used to compute the results. """ @@ -1371,6 +1375,7 @@ def __init__( value_quad=None, log=None, list_res=None, + status=None, backend=None, ): self._X = X @@ -1381,6 +1386,7 @@ def __init__( self._value_quad = value_quad self._log = log self._list_res = list_res + self._status = status self._backend = backend if backend is not None else NumpyBackend() def __repr__(self): diff --git a/test/test_solvers.py b/test/test_solvers.py index c691b9cc0..75c58dd99 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -741,12 +741,16 @@ def assert_allclose_bary_sol(sol1, sol2): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") @pytest.mark.parametrize( - "reg,reg_type,unbalanced,unbalanced_type", - itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type), + "reg,reg_type,unbalanced,unbalanced_type,warmstart", + itertools.product( + lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] + ), ) -def test_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type): +def test_bary_sample_free_support( + nx, reg, reg_type, unbalanced, unbalanced_type, warmstart +): # test bary_sample when is_Lazy = False - rng = np.random.RandomState(0) + rng = np.random.RandomState() K = 3 # number of distributions ns = rng.randint(10, 20, K) # number of samples within each distribution @@ -781,7 +785,8 @@ def df(G): reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, - max_iter_bary=4, + warmstart=warmstart, + max_iter_bary=3, tol_bary=1e-3, verbose=True, ) @@ -798,7 +803,8 @@ def df(G): reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, - max_iter_bary=4, + warmstart=warmstart, + max_iter_bary=3, tol_bary=1e-3, verbose=True, ) @@ -831,7 +837,8 @@ def df(G): reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, - max_iter_bary=4, + warmstart=warmstart, + max_iter_bary=3, tol_bary=1e-3, verbose=True, ) diff --git a/test/test_utils.py b/test/test_utils.py index 938fd6058..1ecd1b51f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -456,7 +456,7 @@ def test_OTResult(): # test print print(res) - # tets get citation + # test get citation print(res.citation) lst_attributes = [ @@ -486,6 +486,31 @@ def test_OTResult(): getattr(res, at) +def test_BaryResult(): + res = ot.utils.BaryResult() + + # test print + print(res) + + # test get citation + print(res.citation) + + lst_attributes = [ + "X", + "C", + "b", + "value", + "value_linear", + "value_quad", + "list_res", + "status", + "log", + ] + for at in lst_attributes: + print(at) + assert getattr(res, at) is None + + def test_get_coordinate_circle(): rng = np.random.RandomState(42) u = rng.rand(1, 100) From 671788d40508042f5833695f0cb96d09fcdbdbbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Fri, 25 Apr 2025 01:12:41 +0200 Subject: [PATCH 05/18] trying to fix tests --- ot/solvers.py | 2 +- ot/unbalanced/_lbfgs.py | 8 ++++---- test/test_solvers.py | 33 +++++++++++++++++++++++++++------ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 3e1c35ab6..daad962a8 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -2032,7 +2032,7 @@ def _bary_sample_bcd( list_res = [ inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) ] - + print("inv_b:", inv_b) # Update the estimated barycenter weights in unbalanced cases if update_masses: b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c4de87474..eb995efb5 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -46,9 +46,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) @@ -218,9 +218,9 @@ def lbfgsb_unbalanced( Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) diff --git a/test/test_solvers.py b/test/test_solvers.py index 75c58dd99..6ede9b3f6 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -743,7 +743,12 @@ def assert_allclose_bary_sol(sol1, sol2): @pytest.mark.parametrize( "reg,reg_type,unbalanced,unbalanced_type,warmstart", itertools.product( - lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] + lst_reg, + ["tuple"], + lst_unbalanced, + lst_unbalanced_type, + [True, False], + # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] ), ) def test_bary_sample_free_support( @@ -774,7 +779,7 @@ def df(G): return 2 * G reg_type = (f, df) - + # print('test reg_type:', reg_type[0](None), reg_type[1](None)) # solve default None weights sol0 = ot.bary_sample( X_s, @@ -790,8 +795,10 @@ def df(G): tol_bary=1e-3, verbose=True, ) + print("------ [done] sol0 - no backend") # solve provided uniform weights + sol = ot.bary_sample( X_s, n, @@ -808,6 +815,7 @@ def df(G): tol_bary=1e-3, verbose=True, ) + print("------ [done] sol - no backend") assert_allclose_bary_sol(sol0, sol) @@ -816,14 +824,25 @@ def df(G): a_sb = nx.from_numpy(*a_s) w_sb, bb = nx.from_numpy(w_s, b) - if isinstance(reg_type, tuple): + if reg_type == "tuple": - def f(G): - return nx.sum(G**2) + def fb(G): + return nx.sum( + G**2 + ) # otherwise we keep previously defined (f, df) as required by inner solver - def df(G): + def dfb(G): return 2 * G + """ + if ( + unbalanced_type.lower() in ["kl", "l2", "tv"]) and ( + unbalanced is not None) and ( + reg is not None + ): + reg_type = (f, df) + else: + """ reg_type = (f, df) solb = ot.bary_sample( @@ -842,6 +861,8 @@ def df(G): tol_bary=1e-3, verbose=True, ) + print("------ [done] sol - with backend") + assert_allclose_bary_sol(sol, solb) except NotImplementedError: From df1da8d78f2241313be19e93b5c9e43ecca5ad22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 12 Oct 2025 20:42:03 +0200 Subject: [PATCH 06/18] small updates --- RELEASES.md | 6 ++ ot/solvers.py | 151 +++++++++++++++++++++++++------------------------- 2 files changed, 80 insertions(+), 77 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index ccb9b97d2..b602b8b08 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +### New features +- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730) + +### Closed issues + + ## 0.9.6.post1 *September 2025* diff --git a/ot/solvers.py b/ot/solvers.py index 541218e84..26005fbcf 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1988,11 +1988,11 @@ def solve_sample( def _bary_sample_bcd( - X_s, - X_init, - a_s, + X_a_list, + X_b_init, + a_list, b_init, - w_s, + w, metric, inner_solver, update_masses, @@ -2009,15 +2009,15 @@ def _bary_sample_bcd( Parameters ---------- - X_s : list of array-like, shape (n_samples_k, dim) + X_a_list : list of array-like, shape (n_samples_k, dim) List of samples in each source distribution - X_init : array-like, shape (n_samples_b, dim), + X_b_init : array-like, shape (n_samples_b, dim), Initialization of the barycenter samples. - a_s : list of array-like, shape (dim_k,) + a_list : list of array-like, shape (dim_k,) List of samples weights in each source distribution b_init : array-like, shape (n_samples_b,) Initialization of the barycenter weights. - w_s : list of array-like, shape (N,) + w : list of array-like, shape (N,) Samples barycentric weights metric : str Metric to use for the cost matrix, by default "sqeuclidean" @@ -2046,12 +2046,12 @@ def _bary_sample_bcd( TBD """ - X = X_init + X_b = X_b_init b = b_init inv_b = 1.0 / b prev_criterion = np.inf - n_samples = len(X_s) + n_samples = len(X_a_list) if log: log_ = {"stopping_criterion": []} @@ -2063,32 +2063,36 @@ def _bary_sample_bcd( # Solve the inner OT problem for each source distribution if it == 0: list_res = [ - inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) + for k in range(n_samples) ] elif warmstart_plan: list_res = [ - inner_solver(X_s[k], X, a_s[k], b, list_res[k].plan, None) + inner_solver(X_a_list[k], X_b, a_list[k], b, list_res[k].plan, None) for k in range(n_samples) ] elif warmstart_potentials: list_res = [ - inner_solver(X_s[k], X, a_s[k], b, None, list_res[k].potentials) + inner_solver( + X_a_list[k], X_b, a_list[k], b, None, list_res[k].potentials + ) for k in range(n_samples) ] else: list_res = [ - inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) + for k in range(n_samples) ] print("inv_b:", inv_b) # Update the estimated barycenter weights in unbalanced cases if update_masses: - b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) + b = sum([w[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) inv_b = 1.0 / b # Update the barycenter samples if metric in ["sqeuclidean", "euclidean"]: - X_new = ( - sum([w_s[k] * list_res[k].plan.T @ X_s[k] for k in range(n_samples)]) + X_b_new = ( + sum([w[k] * list_res[k].plan.T @ X_a_list[k] for k in range(n_samples)]) * inv_b[:, None] ) else: @@ -2096,9 +2100,9 @@ def _bary_sample_bcd( # compute criterion if stopping_criterion == "loss": - new_criterion = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + new_criterion = sum([w[k] * list_res[k].value for k in range(n_samples)]) else: # stopping_criterion = "bary" - new_criterion = nx.norm(X_new - X, ord=2) + new_criterion = nx.norm(X_b_new - X_b, ord=2) if verbose: if it % 1 == 0: @@ -2113,19 +2117,19 @@ def _bary_sample_bcd( print(f"BCD converged in {it} iterations") break - X = X_new + X_b = X_b_new prev_criterion = new_criterion # compute loss values - value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)]) + value_linear = sum([w[k] * list_res[k].value_linear for k in range(n_samples)]) if stopping_criterion == "loss": value = new_criterion else: - value = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + value = sum([w[k] * list_res[k].value for k in range(n_samples)]) # update BaryResult bary_res = BaryResult( - X=X_new, + X=X_b, b=b, value=value, value_linear=value_linear, @@ -2136,14 +2140,13 @@ def _bary_sample_bcd( return bary_res -def bary_sample( - X_s, +def bary_free_support( + X_a_list, n, - a_s=None, - w_s=None, - X_init=None, + a_list=None, + w=None, + X_b_init=None, b_init=None, - learn_X=True, learn_b=False, metric="sqeuclidean", reg=None, @@ -2166,7 +2169,7 @@ def bary_sample( random_state=0, verbose=False, ): - r"""Solve the discrete OT barycenter problem over source distributions using Block-Coordinate Descent. + r"""Solve the discrete OT barycenter problem over source distributions optimizing the barycenter support using Block-Coordinate Descent. The function solves the following general OT barycenter problem @@ -2188,22 +2191,20 @@ def bary_sample( Parameters ---------- - X_s : list of array-like, shape (n_samples_k, dim) + X_a_list : list of array-like, shape (n_samples_k, dim) List of samples in each source distribution n : int number of samples in the barycenter domain - a_s : list of array-like, shape (dim_k,), optional + a_list : list of array-like, shape (dim_k,), optional List of samples weights in each source distribution (default is uniform) - w_s : list of array-like, shape (N,), optional + w : list of array-like, shape (N,), optional Samples barycentric weights (default is uniform) - X_init : array-like, shape (n_samples_b, dim), optional + X_b_init : array-like, shape (n_samples_b, dim), optional Initialization of the barycenter samples (default is gaussian random sampling). Shape must match with required n. b_init : array-like, shape (n_samples_b,), optional Initialization of the barycenter weights (default is uniform). Shape must match with required n. - learn_X : bool, optional - Learn the barycenter samples (default is True) learn_b : bool, optional Learn the barycenter weights (default is False) metric : str, optional @@ -2497,56 +2498,52 @@ def bary_sample( ) ) - n_samples = len(X_s) + n_samples = len(X_a_list) if ( not lazy ): # default non lazy solver calls ot.solve_sample within _bary_sample_bcd # Detect backend - nx = get_backend(*X_s, X_init, b_init, w_s) + nx = get_backend(*X_a_list, X_b_init, b_init, w) # check sample weights - if a_s is None: - a_s = [ - nx.ones((X_s[k].shape[0],), type_as=X_s[k]) / X_s[k].shape[0] + if a_list is None: + a_list = [ + nx.ones((X_a_list[k].shape[0],), type_as=X_a_list[k]) + / X_a_list[k].shape[0] for k in range(n_samples) ] # check samples barycentric weights - if w_s is None: - w_s = nx.ones(n_samples, type_as=X_s[0]) / n_samples - - # check X_init - if X_init is None: - if (not learn_X) and learn_b: - raise ValueError( - "X_init must be provided if learn_X=False and learn_b=True" - ) - else: - rng = np.random.RandomState(random_state) - mean_ = nx.concatenate( - [nx.mean(X_s[k], axis=0) for k in range(n_samples)], - axis=0, - ) - mean_ = nx.mean(mean_, axis=0) - std_ = nx.concatenate( - [nx.std(X_s[k], axis=0) for k in range(n_samples)], - axis=0, - ) - std_ = nx.mean(std_, axis=0) - X_init = rng.normal( - loc=mean_, - scale=std_, - size=(n, X_s[0].shape[1]), - ) - X_init = nx.from_numpy(X_init, type_as=X_s[0]) + if w is None: + w = nx.ones(n_samples, type_as=X_a_list[0]) / n_samples + + # check X_b_init + if X_b_init is None: + rng = np.random.RandomState(random_state) + mean_ = nx.concatenate( + [nx.mean(X_a_list[k], axis=0) for k in range(n_samples)], + axis=0, + ) + mean_ = nx.mean(mean_, axis=0) + std_ = nx.concatenate( + [nx.std(X_a_list[k], axis=0) for k in range(n_samples)], + axis=0, + ) + std_ = nx.mean(std_, axis=0) + X_b_init = rng.normal( + loc=mean_, + scale=std_, + size=(n, X_a_list[0].shape[1]), + ) + X_b_init = nx.from_numpy(X_b_init, type_as=X_a_list[0]) else: - if (X_init.shape[0] != n) or (X_init.shape[1] != X_s[0].shape[1]): - raise ValueError("X_init must have shape (n, dim)") + if (X_b_init.shape[0] != n) or (X_b_init.shape[1] != X_a_list[0].shape[1]): + raise ValueError("X_b_init must have shape (n, dim)") # check b_init if b_init is None: - b_init = nx.ones((n,), type_as=X_s[0]) / n + b_init = nx.ones((n,), type_as=X_a_list[0]) / n if warmstart: if reg is None: # exact OT @@ -2569,10 +2566,10 @@ def bary_sample( warmstart_plan = False warmstart_potentials = False - def inner_solver(X_a, X, a, b, plan_init, potentials_init): + def inner_solver(X_a, X_b, a, b, plan_init, potentials_init): return solve_sample( X_a=X_a, - X_b=X, + X_b=X_b, a=a, b=b, metric=metric, @@ -2593,11 +2590,11 @@ def inner_solver(X_a, X, a, b, plan_init, potentials_init): # compute the barycenter using BCD update_masses = unbalanced is not None res = _bary_sample_bcd( - X_s, - X_init, - a_s, + X_a_list, + X_b_init, + a_list, b_init, - w_s, + w, metric, inner_solver, update_masses, From 75d6c1144ac64458428a9932e0e0af7f7cb5708e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 12 Oct 2025 21:24:17 +0200 Subject: [PATCH 07/18] fix fun name --- ot/__init__.py | 4 ++-- test/test_solvers.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 81fe0999b..c6b4c667e 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -69,7 +69,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample, bary_sample +from .solvers import solve, solve_gromov, solve_sample, bary_free_support from .lowrank import lowrank_sinkhorn from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch @@ -120,7 +120,7 @@ "solve", "solve_gromov", "solve_sample", - "bary_sample", + "bary_free_support", "smooth", "stochastic", "unbalanced", diff --git a/test/test_solvers.py b/test/test_solvers.py index 9f8dee53a..dce5c22fb 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -756,9 +756,7 @@ def assert_allclose_bary_sol(sol1, sol2): # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] ), ) -def test_bary_sample_free_support( - nx, reg, reg_type, unbalanced, unbalanced_type, warmstart -): +def test_bary_free_support(nx, reg, reg_type, unbalanced, unbalanced_type, warmstart): # test bary_sample when is_Lazy = False rng = np.random.RandomState() From c71e5442165f9965461dae4e85993ef1d722f220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 12 Oct 2025 21:33:06 +0200 Subject: [PATCH 08/18] update tests --- test/test_solvers.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index dce5c22fb..5b27d5c32 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -764,13 +764,13 @@ def test_bary_free_support(nx, reg, reg_type, unbalanced, unbalanced_type, warms ns = rng.randint(10, 20, K) # number of samples within each distribution n = 5 # number of samples in the barycenter - X_s = [rng.randn(ns_i, 2) for ns_i in ns] + X_list = [rng.randn(ns_i, 2) for ns_i in ns] # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) - a_s = [ot.utils.unif(X.shape[0]) for X in X_s] + a_list = [ot.utils.unif(X.shape[0]) for X in X_list] b = ot.utils.unif(n) - w_s = ot.utils.unif(K) + w = ot.utils.unif(K) try: if reg_type == "tuple": @@ -784,10 +784,10 @@ def df(G): reg_type = (f, df) # print('test reg_type:', reg_type[0](None), reg_type[1](None)) # solve default None weights - sol0 = ot.bary_sample( - X_s, + sol0 = ot.bary_free_support( + X_list, n, - w_s=None, + w=None, metric="sqeuclidean", reg=reg, reg_type=reg_type, @@ -802,12 +802,12 @@ def df(G): # solve provided uniform weights - sol = ot.bary_sample( - X_s, + sol = ot.bary_free_support( + X_list, n, - a_s=a_s, + a_list=a_list, b_init=b, - w_s=w_s, + w=w, metric="sqeuclidean", reg=reg, reg_type=reg_type, @@ -823,9 +823,9 @@ def df(G): assert_allclose_bary_sol(sol0, sol) # solve in backend - X_sb = nx.from_numpy(*X_s) - a_sb = nx.from_numpy(*a_s) - w_sb, bb = nx.from_numpy(w_s, b) + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb, bb = nx.from_numpy(w, b) if reg_type == "tuple": @@ -848,12 +848,12 @@ def dfb(G): """ reg_type = (f, df) - solb = ot.bary_sample( - X_sb, + solb = ot.bary_free_support( + X_listb, n, - a_s=a_sb, + a_listb=a_listb, b_init=bb, - w_s=w_sb, + w=wb, metric="sqeuclidean", reg=reg, reg_type=reg_type, From ed9c992e1966589a27b69829950e895c2fab7a2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 12 Oct 2025 21:44:55 +0200 Subject: [PATCH 09/18] update tests --- test/test_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index 5b27d5c32..3ad235370 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -851,7 +851,7 @@ def dfb(G): solb = ot.bary_free_support( X_listb, n, - a_listb=a_listb, + a_list=a_listb, b_init=bb, w=wb, metric="sqeuclidean", From 7132f6286a242ca62afd25b92d185558004fa29f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 15 Oct 2025 14:00:32 +0200 Subject: [PATCH 10/18] fix tests --- ot/__init__.py | 4 ++-- ot/solvers.py | 7 +------ test/test_solvers.py | 28 ++++++++-------------------- 3 files changed, 11 insertions(+), 28 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index c6b4c667e..ad799d69d 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -69,7 +69,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample, bary_free_support +from .solvers import solve, solve_gromov, solve_sample, solve_bary_sample from .lowrank import lowrank_sinkhorn from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch @@ -120,7 +120,7 @@ "solve", "solve_gromov", "solve_sample", - "bary_free_support", + "solve_bary_sample", "smooth", "stochastic", "unbalanced", diff --git a/ot/solvers.py b/ot/solvers.py index 26005fbcf..f1daaa861 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -2140,14 +2140,13 @@ def _bary_sample_bcd( return bary_res -def bary_free_support( +def solve_bary_sample( X_a_list, n, a_list=None, w=None, X_b_init=None, b_init=None, - learn_b=False, metric="sqeuclidean", reg=None, c=None, @@ -2205,8 +2204,6 @@ def bary_free_support( b_init : array-like, shape (n_samples_b,), optional Initialization of the barycenter weights (default is uniform). Shape must match with required n. - learn_b : bool, optional - Learn the barycenter weights (default is False) metric : str, optional Metric to use for the cost matrix, by default "sqeuclidean" reg : float, optional @@ -2485,8 +2482,6 @@ def bary_free_support( ---------- """ - if learn_b: - raise NotImplementedError("Barycenter weights learning not implemented yet") if method is not None and method.lower() in lst_method_lazy: raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") diff --git a/test/test_solvers.py b/test/test_solvers.py index 3ad235370..b7e157a8a 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -749,14 +749,14 @@ def assert_allclose_bary_sol(sol1, sol2): "reg,reg_type,unbalanced,unbalanced_type,warmstart", itertools.product( lst_reg, - ["tuple"], + ["tuple"], # lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False], - # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] + # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, warmstart ), ) -def test_bary_free_support(nx, reg, reg_type, unbalanced, unbalanced_type, warmstart): +def test_solve_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type, warmstart): # test bary_sample when is_Lazy = False rng = np.random.RandomState() @@ -784,7 +784,7 @@ def df(G): reg_type = (f, df) # print('test reg_type:', reg_type[0](None), reg_type[1](None)) # solve default None weights - sol0 = ot.bary_free_support( + sol0 = ot.solve_bary_sample( X_list, n, w=None, @@ -802,7 +802,7 @@ def df(G): # solve provided uniform weights - sol = ot.bary_free_support( + sol = ot.solve_bary_sample( X_list, n, a_list=a_list, @@ -827,28 +827,16 @@ def df(G): a_listb = nx.from_numpy(*a_list) wb, bb = nx.from_numpy(w, b) - if reg_type == "tuple": + if isinstance(reg_type, tuple): def fb(G): return nx.sum( G**2 ) # otherwise we keep previously defined (f, df) as required by inner solver - def dfb(G): - return 2 * G - - """ - if ( - unbalanced_type.lower() in ["kl", "l2", "tv"]) and ( - unbalanced is not None) and ( - reg is not None - ): - reg_type = (f, df) - else: - """ - reg_type = (f, df) + reg_type = (fb, df) - solb = ot.bary_free_support( + solb = ot.solve_bary_sample( X_listb, n, a_list=a_listb, From 4a0b5ec2a4226bd08beb443bb9ac4dcc87bf1467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Fri, 6 Mar 2026 15:55:04 +0100 Subject: [PATCH 11/18] fix tests --- ot/solvers.py | 33 ++++++++++++++++++-------- test/test_solvers.py | 56 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index f1daaa861..8452d82ab 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -56,6 +56,18 @@ ] +lst_bary_method_lazy = [ + "lowrank", + "nystroem", + "factored", + "geomloss", + "geomloss_auto", + "geomloss_tensorized", + "geomloss_online", + "geomloss_multiscale", +] + + def solve( M, a=None, @@ -1443,7 +1455,7 @@ def solve_sample( plan_init : array-like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None rank : int, optional - Rank of the OT matrix for lazy solers (method='factored') or (method='nystroem'), by default 100 + Rank of the OT matrix for lazy solvers (method='factored') or (method='nystroem'), by default 100 scaling : float, optional Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional @@ -2048,20 +2060,19 @@ def _bary_sample_bcd( X_b = X_b_init b = b_init - inv_b = 1.0 / b + inv_b = nx.nan_to_num(1.0 / b, nan=1.0, posinf=1.0, neginf=1.0) prev_criterion = np.inf n_samples = len(X_a_list) + log_ = None if log: log_ = {"stopping_criterion": []} - else: - log_ = None # Compute the barycenter using BCD for it in range(max_iter_bary): # Solve the inner OT problem for each source distribution - if it == 0: + if it == 0: # no pre-defined warmstart used at iteration 0. list_res = [ inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) for k in range(n_samples) @@ -2083,11 +2094,11 @@ def _bary_sample_bcd( inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) for k in range(n_samples) ] - print("inv_b:", inv_b) + # Update the estimated barycenter weights in unbalanced cases if update_masses: b = sum([w[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) - inv_b = 1.0 / b + inv_b = nx.nan_to_num(1.0 / b, nan=1.0, posinf=1.0, neginf=1.0) # Update the barycenter samples if metric in ["sqeuclidean", "euclidean"]: @@ -2102,7 +2113,7 @@ def _bary_sample_bcd( if stopping_criterion == "loss": new_criterion = sum([w[k] * list_res[k].value for k in range(n_samples)]) else: # stopping_criterion = "bary" - new_criterion = nx.norm(X_b_new - X_b, ord=2) + new_criterion = nx.sum((X_b_new - X_b) ** 2) if verbose: if it % 1 == 0: @@ -2483,8 +2494,10 @@ def solve_bary_sample( """ - if method is not None and method.lower() in lst_method_lazy: - raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") + if method is not None and method.lower() in lst_bary_method_lazy: + raise NotImplementedError( + f"method {method} operating on lazy tensors is not implemented yet" + ) if stopping_criterion not in ["loss", "bary"]: raise ValueError( diff --git a/test/test_solvers.py b/test/test_solvers.py index b7e157a8a..9f2c84e7d 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -13,6 +13,7 @@ import ot from ot.bregman import geomloss from ot.backend import torch +from ot.solvers import lst_bary_method_lazy lst_reg = [None, 1] @@ -63,6 +64,12 @@ }, # fail lazy for unbalanced and regularized ] +lst_parameters_solve_bary_sample_NotImplemented = [ + {"method": method} for method in lst_bary_method_lazy +] + [ + {"lazy": True}, # fail lazy +] + # set readable ids for each param lst_method_params_solve_sample = [ pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample @@ -71,6 +78,10 @@ pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented ] +lst_parameters_solve_bary_sample_NotImplemented = [ + pytest.param(param, id=str(param)) + for param in lst_parameters_solve_bary_sample_NotImplemented +] def assert_allclose_sol(sol1, sol2): @@ -749,7 +760,7 @@ def assert_allclose_bary_sol(sol1, sol2): "reg,reg_type,unbalanced,unbalanced_type,warmstart", itertools.product( lst_reg, - ["tuple"], # lst_reg_type, + lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False], @@ -760,7 +771,7 @@ def test_solve_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type, warms # test bary_sample when is_Lazy = False rng = np.random.RandomState() - K = 3 # number of distributions + K = 2 # number of distributions ns = rng.randint(10, 20, K) # number of samples within each distribution n = 5 # number of samples in the barycenter @@ -772,6 +783,8 @@ def test_solve_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type, warms w = ot.utils.unif(K) + stopping_criterion = "loss" if rng.choice([True, False]) else "bary" + try: if reg_type == "tuple": @@ -794,8 +807,9 @@ def df(G): unbalanced=unbalanced, unbalanced_type=unbalanced_type, warmstart=warmstart, - max_iter_bary=3, + max_iter_bary=2, tol_bary=1e-3, + stopping_criterion=stopping_criterion, verbose=True, ) print("------ [done] sol0 - no backend") @@ -814,8 +828,9 @@ def df(G): unbalanced=unbalanced, unbalanced_type=unbalanced_type, warmstart=warmstart, - max_iter_bary=3, + max_iter_bary=2, tol_bary=1e-3, + stopping_criterion=stopping_criterion, verbose=True, ) print("------ [done] sol - no backend") @@ -848,8 +863,9 @@ def fb(G): unbalanced=unbalanced, unbalanced_type=unbalanced_type, warmstart=warmstart, - max_iter_bary=3, + max_iter_bary=2, tol_bary=1e-3, + stopping_criterion=stopping_criterion, verbose=True, ) print("------ [done] sol - with backend") @@ -858,3 +874,33 @@ def fb(G): except NotImplementedError: pytest.skip("Not implemented") + + +@pytest.mark.parametrize( + "method_params", lst_parameters_solve_bary_sample_NotImplemented +) +def test_solve_bary_sample_NotImplemented(nx, method_params): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState() + + K = 2 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_list = [ot.utils.unif(X.shape[0]) for X in X_list] + b = ot.utils.unif(n) + + w = ot.utils.unif(K) + + # solve in backend + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb, bb = nx.from_numpy(w, b) + + with pytest.raises(NotImplementedError): + ot.solve_bary_sample( + X_listb, n, a_list=a_listb, b_init=bb, w=wb, **method_params + ) From 3d9f987284a93833833b2546ade5fbbb25182c8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 9 Mar 2026 12:03:47 +0100 Subject: [PATCH 12/18] update docstring for solve_bary_sample --- RELEASES.md | 2 + ot/solvers.py | 240 ++++++++++--------------------------------- test/test_solvers.py | 4 +- 3 files changed, 60 insertions(+), 186 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..5dbd04627 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730) + #### Closed issues diff --git a/ot/solvers.py b/ot/solvers.py index 8b52952d6..7dd70053d 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -10,7 +10,6 @@ from .utils import OTResult, BaryResult, dist from .lp import emd2, emd2_lazy, wasserstein_1d -from .utils import OTResult, dist from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import ( @@ -57,18 +56,6 @@ ] -lst_bary_method_lazy = [ - "lowrank", - "nystroem", - "factored", - "geomloss", - "geomloss_auto", - "geomloss_tensorized", - "geomloss_online", - "geomloss_multiscale", -] - - def solve( M, a=None, @@ -2202,33 +2189,27 @@ def solve_bary_sample( unbalanced=None, unbalanced_type="KL", lazy=False, - batch_size=None, method=None, - n_threads=1, warmstart=False, stopping_criterion="loss", max_iter_bary=1000, - max_iter=None, - rank=100, - scaling=0.95, tol_bary=1e-5, - tol=None, random_state=0, verbose=False, + **kwargs, ): r"""Solve the discrete OT barycenter problem over source distributions optimizing the barycenter support using Block-Coordinate Descent. The function solves the following general OT barycenter problem .. math:: - \min_{\mathbf{X} \in \mathbb{R}^{n \times d}, \mathbf{b} \in \Sigma_n} \min_{\{ \mathbf{T}^{(k)} \}_k \in \R_+^{n_i \times n}} \quad \sum_k w_k \sum_{i,j} T^{(k)}_{i,j}M^{(k)}_{i,j} + \lambda_r R(\mathbf{T}^{(k)}) + + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}} \min_{\{ \mathbf{T}^{(k)} \}_k \in \R_+^{n_i \times n}} \quad \sum_k w_k \sum_{i,j} T^{(k)}_{i,j}M^{(k)}_{i,j} + \lambda_r R(\mathbf{T}^{(k)}) + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) - where the cost matrices :math:`\mathbf{M}^{(k)}` for each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{b}^{(k)})` - is computed from the samples in the source and barycenter domains such that - :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where - :math:`d` is a metric (by default the squared Euclidean distance). + where the cost matrices :math:`\mathbf{M}^{(k)}` from each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{b}^{(k)})` + to the barycenter domain are computed as :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). The barycenter probability weights are fixed to :math:`\mathbf{b}`. The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal @@ -2239,19 +2220,17 @@ def solve_bary_sample( Parameters ---------- X_a_list : list of array-like, shape (n_samples_k, dim) - List of samples in each source distribution + List of N samples in each source distribution n : int number of samples in the barycenter domain a_list : list of array-like, shape (dim_k,), optional List of samples weights in each source distribution (default is uniform) w : list of array-like, shape (N,), optional Samples barycentric weights (default is uniform) - X_b_init : array-like, shape (n_samples_b, dim), optional - Initialization of the barycenter samples (default is gaussian random sampling). - Shape must match with required n. + X_b_init : array-like, shape (n, dim), optional + Initialization of the barycenter samples (default is gaussian random sampling) b_init : array-like, shape (n_samples_b,), optional - Initialization of the barycenter weights (default is uniform). - Shape must match with required n. + Initialization of the barycenter weights (default is uniform) metric : str, optional Metric to use for the cost matrix, by default "sqeuclidean" reg : float, optional @@ -2278,70 +2257,56 @@ def solve_bary_sample( lazy : bool, optional Return :any:`OTResultlazy` object to reduce memory cost when True, by default False - batch_size : int, optional - Batch size for lazy solver, by default None (default values in each - solvers) method : str, optional Method for solving the problem, this can be used to select the solver for unbalanced problems (see :any:`ot.solve`), or to select a specific large scale solver. - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 warmstart : bool, optional Use the previous OT or potentials as initialization for the next inner solver iteration, by default False. stopping_criterion : str, optional Stopping criterion for the outer loop of the BCD solver, by default 'loss'. Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm. max_iter_bary : int, optional - Maximum number of iteration for the BCD solver, by default 1000. - max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) - rank : int, optional - Rank of the OT matrix for lazy solers (method='factored'), by default 100 - scaling : float, optional - Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 + Maximum number of iteration for the outer loop of the BCD solver, by default 1000. tol_bary : float, optional - Tolerance for solution precision of barycenter problem, by default None (default value 1e-5) - tol : float, optional - Tolerance for solution precision of inner OT solver, by default None (default values in each solvers) + Tolerance for solution precision of the barycenter problem, by default 1e-5. random_state : int, optional Random seed for the initialization of the barycenter samples, by default 0. Only used if `X_init` is None. verbose : bool, optional Print information in the solver, by default False - + kwargs : optional + Additional parameters for the inner solver (see :any:`ot.solve`) Returns ------- res : BaryResult() Result of the optimization problem. The information can be obtained as follows: - OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials + - res.X : Barycenter samples + - res.b : Barycenter weights - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan - - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) + - res.list_res: List of OTResult for each inner OT problem (one per source distribution) + - res.log: log of the optimization process (if log=True) - See :any:`OTResult` for more information. + See :any:`BaryResult` for more information. Notes ----- - The following methods are available for solving the OT problems: + The following methods are available for solving barycenter problems with respect to these inner OT problems: - **Classical exact OT problem [1]** (default parameters) : .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \forall k, \quad \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} - \mathbf{T}^T \mathbf{1} = \mathbf{b} + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} - \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) @@ -2349,189 +2314,98 @@ def solve_bary_sample( .. code-block:: python - res = ot.solve_sample(xa, xb, a, b) + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w) - # for uniform weights - res = ot.solve_sample(xa, xb) + # for uniform sample weights and barycentric weights, + res = ot.solve_bary_sample([x1, x2], n) - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + \min_{\mathbf{T}^{(k)} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k)}) - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} - \mathbf{T}^T \mathbf{1} = \mathbf{b} + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} - \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` regularization (``reg_type="KL"``) - res = ot.solve_sample(xa, xb, a, b, reg=1.0) - # or for original Sinkhorn paper formulation [2] - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') - - # lazy solver of memory complexity O(n) - res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) - # lazy OT plan - lazy_plan = res.lazy_plan - - # Use envelope theorem differentiation for memory saving - res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') - res.value.backward() # only the value is differentiable - Note that by default the Sinkhorn solver uses automatic differentiation to - compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `envelope` mode computes the gradients only - for the value and the other outputs are detached. This is useful for - memory saving when only the gradient of value is needed. - We also have a very efficient solver with compiled CPU/CUDA code using - geomloss/PyKeOps that can be used with the following code: + can be solved with the following code: .. code-block:: python - # automatic solver - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') - - # force O(n) memory efficient solver - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') - - # force pre-computed cost matrix - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0) - # use multiscale solver - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') + # or for original Sinkhorn paper formulation [2] + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='entropy') - # One can play with speed (small scaling factor) and precision (scaling close to 1) - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k))}) - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} - \mathbf{T}^T \mathbf{1} = \mathbf{b} + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} - \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) can be solved with the following code: .. code-block:: python - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2') - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - \text{with} \ M_{i,j} = d(x_i,y_j) + \min_{\mathbf{T}^{(k)}\geq 0} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_u U(\mathbf{T}^{(k)}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)^T}\mathbf{1},\mathbf{b}) can be solved with the following code: .. code-block:: python # default is ``"KL"`` - res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0) + # quadratic unbalanced OT - res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='L2') # TV = partial OT - res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='TV') - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + \min_{\mathbf{T}^{(k)} \geq 0} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_r R(\mathbf{T}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)^T}\mathbf{1},\mathbf{b}) - \text{with} \ M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` for both - res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0) # quadratic unbalanced OT with KL regularization - res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0, unbalanced_type='L2') # both quadratic - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', - unbalanced=1.0, unbalanced_type='L2') - - - - **Factored OT [2]** (when ``method='factored'``): - - This method solve the following OT problem [40]_ - - .. math:: - \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) - - where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated - to the samples in the source and target domains, and :math:`W_2` is the - Wasserstein distance. This problem is solved using exact OT solvers for - `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides - two transport plans that can be used to recover a low rank OT plan between - the two distributions. - - .. code-block:: python - - res = ot.solve_sample(xa, xb, method='factored', rank=10) - - # recover the lazy low rank plan - factored_solution_lazy = res.lazy_plan + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') - # recover the full low rank plan - factored_solution = factored_solution_lazy[:] - - - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): - - This method computes the Gaussian Bures-Wasserstein distance between two - Gaussian distributions estimated from the empirical distributions - - .. math:: - \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} - - where : - - .. math:: - \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) - - The covariances and means are estimated from the data. - - .. code-block:: python - - res = ot.solve_sample(xa, xb, method='gaussian') - - # recover the squared Gaussian Bures-Wasserstein distance - BW_dist = res.value - - - **Wasserstein 1d [1]** (when ``method='1D'``): - - This method computes the Wasserstein distance between two 1d distributions - estimated from the empirical distributions. For multivariate data the - distances are computed independently for each dimension. - - .. code-block:: python - - res = ot.solve_sample(xa, xb, method='1D') - - # recover the squared Wasserstein distances - W_dists = res.value - - - .. _references-bary-sample: + .. _references-solve_bary_sample: References ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + """ - if method is not None and method.lower() in lst_bary_method_lazy: + if method is not None and method.lower() in lst_method_lazy: raise NotImplementedError( f"method {method} operating on lazy tensors is not implemented yet" ) @@ -2624,12 +2498,10 @@ def inner_solver(X_a, X_b, a, b, plan_init, potentials_init): unbalanced=unbalanced, unbalanced_type=unbalanced_type, method=method, - n_threads=n_threads, - max_iter=max_iter, - tol=tol, plan_init=plan_init, potentials_init=potentials_init, verbose=False, + **kwargs, ) # compute the barycenter using BCD diff --git a/test/test_solvers.py b/test/test_solvers.py index 64322680a..020d062a2 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -13,7 +13,7 @@ import ot from ot.bregman import geomloss from ot.backend import torch -from ot.solvers import lst_bary_method_lazy +from ot.solvers import lst_method_lazy lst_reg = [None, 1] @@ -64,7 +64,7 @@ ] lst_parameters_solve_bary_sample_NotImplemented = [ - {"method": method} for method in lst_bary_method_lazy + {"method": method} for method in lst_method_lazy ] + [ {"lazy": True}, # fail lazy ] From 7b7cfc0741bddd4a9a6ec8cba11ac0ac9a25ba81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 9 Mar 2026 16:04:02 +0100 Subject: [PATCH 13/18] update plot quickstart guide --- examples/plot_quickstart_guide.py | 39 +++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 6704b860e..519f226da 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -638,6 +638,45 @@ def df(G): plot_plan(P_fgw, "Fused GW plan", axis=False) pl.show() + +# +# Solving free support barycenter problems +# ------------------------------- +# Solve Optimal transport barycenter problem with free support between several input distributions. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The :func:`ot.solve_bary_sample` function can be used to solve the Optimal Transport barycenter problem +# between multiple sets of samples while optimizing the support of the barycenter and letting fixed their probability weights. +# The function takes as its first argument the list of samples in each input distribution, +# and as second argument the number of samples to learn in the barycenter. By default, the probability weights in each distribution and the barycentric weights are uniform but they can be customized by the user. +# The function returns an :class:`ot.utils.OTBaryResult` object that contains in part the barycenter samples and the OT plans between the barycenter and each input distribution. +# In the following, we illustrate the use of this function with the same 2D data as above considered as input distributions and compute their barycenter while using exact OT. +# Notice that most of the arguments of the :func:`ot.solve_bary_sample` function are similar to those of the :func:`ot.solve_sample` function and that the same regularization and unbalanced parameters can be used to solve regularized and unbalanced barycenter problems. + +# Solve the OT barycenter problem (exact OT with regularization) +sol = ot.solve_bary_sample([x1, x2], n=35) + +# get the barycenter support +X = sol.X + +# get the OT plans between the barycenter and each input distribution +list_P = [sol.list_res[i].plan for i in range(2)] + +# get the barycenterOT loss +loss = sol.value + +print(f"Barycenter OT loss = {loss:1.3f}") + +# sphinx_gallery_start_ignore +from ot.plot import plot2D_samples_mat + +pl.figure(1, (8, 4)) +plot2D_samples_mat(x1, X, list_P[0]) +plot2D_samples_mat(x2, X, list_P[1]) +pl.axis("off") +pl.title("Barycenter samples and OT plans, loss={:.3f}".format(loss)) +pl.show() + # sphinx_gallery_end_ignore # %% # From 406d026e1cf7f2c53b6a06398b73d378621f9887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 9 Mar 2026 18:28:03 +0100 Subject: [PATCH 14/18] add ex --- .../plot_solve_barycenter_variants.py | 122 ++++++++++++++++++ examples/plot_quickstart_guide.py | 2 - 2 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 examples/barycenters/plot_solve_barycenter_variants.py diff --git a/examples/barycenters/plot_solve_barycenter_variants.py b/examples/barycenters/plot_solve_barycenter_variants.py new file mode 100644 index 000000000..da9c6aa6a --- /dev/null +++ b/examples/barycenters/plot_solve_barycenter_variants.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +""" +====================================== +Optimal Transport Barycenter solvers comparison +====================================== + +This example illustrates solutions returned for different variants of exact, +regularized and unbalanced OT barycenter problems with free support using our wrapper `ot.solve_bary_sample`. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 3 + +# %% + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +# %% parameters + +n = 50 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std +b = gauss(n, m=25, s=5) + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +# %% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution 1") +pl.plot(x, b, "r", label="Source distribution 2") +pl.legend() + + +# %% +# Set up parameters for barycenter solvers and solve +# --------------------------------------- + +lst_regs = [ + "No Reg.", + "Entropic", +] # support e.g ["No Reg.", "Entropic", "L2", "Group Lasso + L2"] +lst_unbalanced = [ + "Balanced", + "Unbalanced KL", +] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] + +lst_solvers = [ # name, param for ot.solve function + # balanced OT + ("Exact OT", dict()), + ("Entropic Reg. OT", dict(reg=0.005)), + # unbalanced OT KL + ("Unbalanced KL No Reg.", dict(unbalanced=0.005)), + ( + "Unbalanced KL with KL Reg.", + dict(reg=0.0005, unbalanced=0.005, unbalanced_type="kl", reg_type="kl"), + ), +] + +lst_res = [] +for name, param in lst_solvers: + res = ot.solve_bary_sample(X_a_list=[x, x], n=50, a_list=[a, b], **param) + lst_res.append(res) + + +############################################################################## +# Plot distributions and plans +# ---------- + +pl.figure(3, figsize=(9, 9)) + +for i, bname in enumerate(lst_unbalanced): + for j, rname in enumerate(lst_regs): + pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1) + + bary_bins = np.histogram(lst_res[i * len(lst_regs) + j], bins=x)[0] + if i == 0 and j == 0: # add labels + pl.plot(x, a, "b", label="Source distribution 1") + pl.plot(x, b, "r", label="Source distribution 2") + pl.plot(x, bary_bins, "g", label="Barycenter") + else: + pl.plot(x, a, "b") + pl.plot(x, b, "r") + pl.plot(x, bary_bins, "g") + + for i, local_res in enumerate(lst_res[i * len(lst_regs) + j].list_res): + plan = local_res.plan + m2 = plan.sum(0) + m1 = plan.sum(1) + if i == 0: + m1, m2 = m1 / a.max(), m2 * n + else: + m1, m2 = m1 / b.max(), m2 * n + pl.imshow(plan, cmap="Greys") + pl.plot(x, m2 * 10, "g") + pl.plot(m1 * 10, x, "b" if i == 0 else "r") + + pl.tick_params( + left=False, right=False, labelleft=False, labelbottom=False, bottom=False + ) + if i == 0: + pl.title(rname) + if j == 0: + pl.ylabel(bname, fontsize=14) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 519f226da..08acf3d46 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -667,8 +667,6 @@ def df(G): print(f"Barycenter OT loss = {loss:1.3f}") -# sphinx_gallery_start_ignore -from ot.plot import plot2D_samples_mat pl.figure(1, (8, 4)) plot2D_samples_mat(x1, X, list_P[0]) From 0172a8979a676897ff013b4ca300ee76c2655788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Mar 2026 17:55:16 +0100 Subject: [PATCH 15/18] fix ex --- .../plot_solve_barycenter_variants.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/barycenters/plot_solve_barycenter_variants.py b/examples/barycenters/plot_solve_barycenter_variants.py index da9c6aa6a..4f1debac1 100644 --- a/examples/barycenters/plot_solve_barycenter_variants.py +++ b/examples/barycenters/plot_solve_barycenter_variants.py @@ -31,7 +31,7 @@ n = 50 # nb bins # bin positions -x = np.arange(n, dtype=np.float64) +x = np.arange(n, dtype=np.float64)[:, None] # Gaussian distributions a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std @@ -45,8 +45,8 @@ # %% plot the distributions pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, "b", label="Source distribution 1") -pl.plot(x, b, "r", label="Source distribution 2") +pl.plot(x[:, 0], a, "b", label="Source distribution 1") +pl.plot(x[:, 0], b, "r", label="Source distribution 2") pl.legend() @@ -93,13 +93,13 @@ bary_bins = np.histogram(lst_res[i * len(lst_regs) + j], bins=x)[0] if i == 0 and j == 0: # add labels - pl.plot(x, a, "b", label="Source distribution 1") - pl.plot(x, b, "r", label="Source distribution 2") - pl.plot(x, bary_bins, "g", label="Barycenter") + pl.plot(x[:, 0], a, "b", label="Source distribution 1") + pl.plot(x[:, 0], b, "r", label="Source distribution 2") + pl.plot(x[:, 0], bary_bins, "g", label="Barycenter") else: - pl.plot(x, a, "b") - pl.plot(x, b, "r") - pl.plot(x, bary_bins, "g") + pl.plot(x[:, 0], a, "b") + pl.plot(x[:, 0], b, "r") + pl.plot(x[:, 0], bary_bins, "g") for i, local_res in enumerate(lst_res[i * len(lst_regs) + j].list_res): plan = local_res.plan @@ -110,7 +110,7 @@ else: m1, m2 = m1 / b.max(), m2 * n pl.imshow(plan, cmap="Greys") - pl.plot(x, m2 * 10, "g") + pl.plot(x[:, 0], m2 * 10, "g") pl.plot(m1 * 10, x, "b" if i == 0 else "r") pl.tick_params( From d4931940712b9be1fac21431c878dd63db2478b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 11 Mar 2026 12:41:16 +0100 Subject: [PATCH 16/18] fix docs --- .../plot_solve_barycenter_variants.py | 112 ++++++++++-------- examples/plot_quickstart_guide.py | 28 +++-- ot/solvers.py | 16 +-- 3 files changed, 89 insertions(+), 67 deletions(-) diff --git a/examples/barycenters/plot_solve_barycenter_variants.py b/examples/barycenters/plot_solve_barycenter_variants.py index 4f1debac1..b1797ea7c 100644 --- a/examples/barycenters/plot_solve_barycenter_variants.py +++ b/examples/barycenters/plot_solve_barycenter_variants.py @@ -11,43 +11,45 @@ # Author: Cédric Vincent-Cuaz # # License: MIT License -# sphinx_gallery_thumbnail_number = 3 +# sphinx_gallery_thumbnail_number = 2 # %% import numpy as np import matplotlib.pylab as pl import ot -import ot.plot -from ot.datasets import make_1D_gauss as gauss +from ot.plot import plot2D_samples_mat -############################################################################## -# Generate data -# ------------- - - -# %% parameters +# %% +# 2D data example +# --------------- +# +# We first generate two sets of samples in 2D that 25 and 50 +# samples respectively located on circles. The weights of the samples are +# uniform. -n = 50 # nb bins +# Problem size +n1 = 25 +n2 = 50 -# bin positions -x = np.arange(n, dtype=np.float64)[:, None] +# Generate random data +np.random.seed(0) -# Gaussian distributions -a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std -b = gauss(n, m=25, s=5) +x1 = np.random.randn(n1, 2) +x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 +x2 = np.random.randn(n2, 2) +x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 -############################################################################## -# Plot distributions and loss matrix -# ---------------------------------- +style = {"markeredgecolor": "k"} -# %% plot the distributions +pl.figure(1, (4, 4)) +pl.plot(x1[:, 0], x1[:, 1], "ob", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", **style) +pl.title("Source distributions") +pl.show() -pl.figure(1, figsize=(6.4, 3)) -pl.plot(x[:, 0], a, "b", label="Source distribution 1") -pl.plot(x[:, 0], b, "r", label="Source distribution 2") -pl.legend() +# sphinx_gallery_end_ignore # %% @@ -66,56 +68,62 @@ lst_solvers = [ # name, param for ot.solve function # balanced OT ("Exact OT", dict()), - ("Entropic Reg. OT", dict(reg=0.005)), + ("Entropic Reg. OT", dict(reg=0.1)), # unbalanced OT KL - ("Unbalanced KL No Reg.", dict(unbalanced=0.005)), + ("Unbalanced KL No Reg.", dict(unbalanced=0.05)), ( "Unbalanced KL with KL Reg.", - dict(reg=0.0005, unbalanced=0.005, unbalanced_type="kl", reg_type="kl"), + dict(reg=0.1, unbalanced=0.05, unbalanced_type="kl", reg_type="kl"), ), ] lst_res = [] for name, param in lst_solvers: - res = ot.solve_bary_sample(X_a_list=[x, x], n=50, a_list=[a, b], **param) + print(f"-- name = {name} / param = {param}") + res = ot.solve_bary_sample(X_a_list=[x1, x2], n=35, **param) lst_res.append(res) - + list_P = [res.list_res[k].plan for k in range(2)] + print("X:", res.X) + print("loss:", res.value) + print("loss:", res.log) + print( + "marginals OT 1:", + res.list_res[0].plan.sum(axis=1), + res.list_res[0].plan.sum(axis=0), + ) + print( + "marginals OT 2:", + res.list_res[1].plan.sum(axis=1), + res.list_res[1].plan.sum(axis=0), + ) ############################################################################## # Plot distributions and plans # ---------- -pl.figure(3, figsize=(9, 9)) +pl.figure(2, figsize=(16, 16)) for i, bname in enumerate(lst_unbalanced): for j, rname in enumerate(lst_regs): pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1) - bary_bins = np.histogram(lst_res[i * len(lst_regs) + j], bins=x)[0] + X = lst_res[i * len(lst_regs) + j].X + list_P = [lst_res[i * len(lst_regs) + j].list_res[k].plan for k in range(2)] + loss = lst_res[i * len(lst_regs) + j].value + + plot2D_samples_mat(x1, X, list_P[0]) + plot2D_samples_mat(x2, X, list_P[1]) + if i == 0 and j == 0: # add labels - pl.plot(x[:, 0], a, "b", label="Source distribution 1") - pl.plot(x[:, 0], b, "r", label="Source distribution 2") - pl.plot(x[:, 0], bary_bins, "g", label="Barycenter") + pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style) + pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style) + pl.legend(loc="best") else: - pl.plot(x[:, 0], a, "b") - pl.plot(x[:, 0], b, "r") - pl.plot(x[:, 0], bary_bins, "g") - - for i, local_res in enumerate(lst_res[i * len(lst_regs) + j].list_res): - plan = local_res.plan - m2 = plan.sum(0) - m1 = plan.sum(1) - if i == 0: - m1, m2 = m1 / a.max(), m2 * n - else: - m1, m2 = m1 / b.max(), m2 * n - pl.imshow(plan, cmap="Greys") - pl.plot(x[:, 0], m2 * 10, "g") - pl.plot(m1 * 10, x, "b" if i == 0 else "r") - - pl.tick_params( - left=False, right=False, labelleft=False, labelbottom=False, bottom=False - ) + pl.plot(x1[:, 0], x1[:, 1], "ob", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", **style) + pl.plot(X[:, 0], X[:, 1], "og", **style) + if i == 0: pl.title(rname) if j == 0: diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 08acf3d46..6d8409752 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -638,22 +638,24 @@ def df(G): plot_plan(P_fgw, "Fused GW plan", axis=False) pl.show() - +# %% # -# Solving free support barycenter problems +# Solving barycenter problems # ------------------------------- # Solve Optimal transport barycenter problem with free support between several input distributions. -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~ # # The :func:`ot.solve_bary_sample` function can be used to solve the Optimal Transport barycenter problem # between multiple sets of samples while optimizing the support of the barycenter and letting fixed their probability weights. # The function takes as its first argument the list of samples in each input distribution, # and as second argument the number of samples to learn in the barycenter. By default, the probability weights in each distribution and the barycentric weights are uniform but they can be customized by the user. +# # The function returns an :class:`ot.utils.OTBaryResult` object that contains in part the barycenter samples and the OT plans between the barycenter and each input distribution. +# # In the following, we illustrate the use of this function with the same 2D data as above considered as input distributions and compute their barycenter while using exact OT. # Notice that most of the arguments of the :func:`ot.solve_bary_sample` function are similar to those of the :func:`ot.solve_sample` function and that the same regularization and unbalanced parameters can be used to solve regularized and unbalanced barycenter problems. -# Solve the OT barycenter problem (exact OT with regularization) +# Solve the OT barycenter problem (exact OT without any regularization) sol = ot.solve_bary_sample([x1, x2], n=35) # get the barycenter support @@ -668,11 +670,23 @@ def df(G): print(f"Barycenter OT loss = {loss:1.3f}") -pl.figure(1, (8, 4)) +pl.figure(1, (8, 8)) plot2D_samples_mat(x1, X, list_P[0]) plot2D_samples_mat(x2, X, list_P[1]) -pl.axis("off") -pl.title("Barycenter samples and OT plans, loss={:.3f}".format(loss)) + +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style) +pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style) + +pl.title( + "Barycenter samples and OT plans \n total loss= %s = 0.5 * %s + 0.5 * %s" + % ( + np.round(loss, 3), + np.round(sol.list_res[0].value, 3), + np.round(sol.list_res[1].value, 3), + ) +) +pl.legend(loc="best") pl.show() # sphinx_gallery_end_ignore diff --git a/ot/solvers.py b/ot/solvers.py index 7dd70053d..7a87013b2 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -2203,11 +2203,11 @@ def solve_bary_sample( The function solves the following general OT barycenter problem .. math:: - \min_{\mathbf{X} \in \mathbb{R}^{n \times d}} \min_{\{ \mathbf{T}^{(k)} \}_k \in \R_+^{n_i \times n}} \quad \sum_k w_k \sum_{i,j} T^{(k)}_{i,j}M^{(k)}_{i,j} + \lambda_r R(\mathbf{T}^{(k)}) + + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}} \min_{\{ \mathbf{T}^{(k)} \}_k \in \mathbb{R}_+^{n_i \times n}} \quad \sum_k w_k \{ \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_r R(\mathbf{T}^{(k)}) + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + - \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) \} - where the cost matrices :math:`\mathbf{M}^{(k)}` from each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{b}^{(k)})` + where the cost matrices :math:`\mathbf{M}^{(k)}` from each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{a}^{(k)})` to the barycenter domain are computed as :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where :math:`d` is a metric (by default the squared Euclidean distance). The barycenter probability weights are fixed to :math:`\mathbf{b}`. @@ -2223,13 +2223,13 @@ def solve_bary_sample( List of N samples in each source distribution n : int number of samples in the barycenter domain - a_list : list of array-like, shape (dim_k,), optional + a_list : list of array-like, shape (n_samples_k,), optional List of samples weights in each source distribution (default is uniform) w : list of array-like, shape (N,), optional Samples barycentric weights (default is uniform) X_b_init : array-like, shape (n, dim), optional Initialization of the barycenter samples (default is gaussian random sampling) - b_init : array-like, shape (n_samples_b,), optional + b_init : array-like, shape (n,), optional Initialization of the barycenter weights (default is uniform) metric : str, optional Metric to use for the cost matrix, by default "sqeuclidean" @@ -2238,8 +2238,8 @@ def solve_bary_sample( OT) c : array-like, shape (dim_a, dim_b), optional (default=None) Reference measure for the regularization. - If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a}^{(k)} \mathbf{b}^T`. + If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{|a^{(k)}|} 1_{|b|}^T`. reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" unbalanced : float or indexable object of length 1 or 2 @@ -2322,7 +2322,7 @@ def solve_bary_sample( - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: - \min_{\mathbf{T}^{(k)} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k)}) + \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k)}) s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} From 95e5a1c9826e6a084029aaf9091a04a67c237afc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 11 Mar 2026 12:54:50 +0100 Subject: [PATCH 17/18] fix docs --- examples/plot_quickstart_guide.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 6d8409752..d5cfac8e6 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -638,6 +638,8 @@ def df(G): plot_plan(P_fgw, "Fused GW plan", axis=False) pl.show() +# sphinx_gallery_end_ignore + # %% # # Solving barycenter problems @@ -669,7 +671,7 @@ def df(G): print(f"Barycenter OT loss = {loss:1.3f}") - +# sphinx_gallery_start_ignore pl.figure(1, (8, 8)) plot2D_samples_mat(x1, X, list_P[0]) plot2D_samples_mat(x2, X, list_P[1]) From 05b47fc15a87c7ad31f936d50d3adefe675fa9b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 11 Mar 2026 15:59:39 +0100 Subject: [PATCH 18/18] fix sphinx --- .../barycenters/plot_solve_barycenter_variants.py | 2 -- ot/solvers.py | 14 +++++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/barycenters/plot_solve_barycenter_variants.py b/examples/barycenters/plot_solve_barycenter_variants.py index b1797ea7c..9403fbfce 100644 --- a/examples/barycenters/plot_solve_barycenter_variants.py +++ b/examples/barycenters/plot_solve_barycenter_variants.py @@ -49,8 +49,6 @@ pl.title("Source distributions") pl.show() -# sphinx_gallery_end_ignore - # %% # Set up parameters for barycenter solvers and solve diff --git a/ot/solvers.py b/ot/solvers.py index 7a87013b2..02d3d4beb 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -2079,7 +2079,19 @@ def _bary_sample_bcd( Backend to use for the computation. Must match<< Returns ------- - TBD + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.X : Barycenter samples + - res.b : Barycenter weights + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.list_res: List of OTResult for each inner OT problem (one per source distribution) + - res.log: log of the optimization process (if log=True) + + See :any:`BaryResult` for more information. + """ X_b = X_b_init