diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..5dbd04627 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730) + #### Closed issues diff --git a/examples/barycenters/plot_solve_barycenter_variants.py b/examples/barycenters/plot_solve_barycenter_variants.py new file mode 100644 index 000000000..9403fbfce --- /dev/null +++ b/examples/barycenters/plot_solve_barycenter_variants.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +""" +====================================== +Optimal Transport Barycenter solvers comparison +====================================== + +This example illustrates solutions returned for different variants of exact, +regularized and unbalanced OT barycenter problems with free support using our wrapper `ot.solve_bary_sample`. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 2 + +# %% + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.plot import plot2D_samples_mat + +# %% +# 2D data example +# --------------- +# +# We first generate two sets of samples in 2D that 25 and 50 +# samples respectively located on circles. The weights of the samples are +# uniform. + +# Problem size +n1 = 25 +n2 = 50 + +# Generate random data +np.random.seed(0) + +x1 = np.random.randn(n1, 2) +x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 + +x2 = np.random.randn(n2, 2) +x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 + +style = {"markeredgecolor": "k"} + +pl.figure(1, (4, 4)) +pl.plot(x1[:, 0], x1[:, 1], "ob", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", **style) +pl.title("Source distributions") +pl.show() + + +# %% +# Set up parameters for barycenter solvers and solve +# --------------------------------------- + +lst_regs = [ + "No Reg.", + "Entropic", +] # support e.g ["No Reg.", "Entropic", "L2", "Group Lasso + L2"] +lst_unbalanced = [ + "Balanced", + "Unbalanced KL", +] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] + +lst_solvers = [ # name, param for ot.solve function + # balanced OT + ("Exact OT", dict()), + ("Entropic Reg. OT", dict(reg=0.1)), + # unbalanced OT KL + ("Unbalanced KL No Reg.", dict(unbalanced=0.05)), + ( + "Unbalanced KL with KL Reg.", + dict(reg=0.1, unbalanced=0.05, unbalanced_type="kl", reg_type="kl"), + ), +] + +lst_res = [] +for name, param in lst_solvers: + print(f"-- name = {name} / param = {param}") + res = ot.solve_bary_sample(X_a_list=[x1, x2], n=35, **param) + lst_res.append(res) + list_P = [res.list_res[k].plan for k in range(2)] + print("X:", res.X) + print("loss:", res.value) + print("loss:", res.log) + print( + "marginals OT 1:", + res.list_res[0].plan.sum(axis=1), + res.list_res[0].plan.sum(axis=0), + ) + print( + "marginals OT 2:", + res.list_res[1].plan.sum(axis=1), + res.list_res[1].plan.sum(axis=0), + ) + +############################################################################## +# Plot distributions and plans +# ---------- + +pl.figure(2, figsize=(16, 16)) + +for i, bname in enumerate(lst_unbalanced): + for j, rname in enumerate(lst_regs): + pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1) + + X = lst_res[i * len(lst_regs) + j].X + list_P = [lst_res[i * len(lst_regs) + j].list_res[k].plan for k in range(2)] + loss = lst_res[i * len(lst_regs) + j].value + + plot2D_samples_mat(x1, X, list_P[0]) + plot2D_samples_mat(x2, X, list_P[1]) + + if i == 0 and j == 0: # add labels + pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style) + pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style) + pl.legend(loc="best") + else: + pl.plot(x1[:, 0], x1[:, 1], "ob", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", **style) + pl.plot(X[:, 0], X[:, 1], "og", **style) + + if i == 0: + pl.title(rname) + if j == 0: + pl.ylabel(bname, fontsize=14) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 6704b860e..d5cfac8e6 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -638,6 +638,59 @@ def df(G): plot_plan(P_fgw, "Fused GW plan", axis=False) pl.show() +# sphinx_gallery_end_ignore + +# %% +# +# Solving barycenter problems +# ------------------------------- +# Solve Optimal transport barycenter problem with free support between several input distributions. +# ~~~~~~~~~~~~~~~~~~~~ +# +# The :func:`ot.solve_bary_sample` function can be used to solve the Optimal Transport barycenter problem +# between multiple sets of samples while optimizing the support of the barycenter and letting fixed their probability weights. +# The function takes as its first argument the list of samples in each input distribution, +# and as second argument the number of samples to learn in the barycenter. By default, the probability weights in each distribution and the barycentric weights are uniform but they can be customized by the user. +# +# The function returns an :class:`ot.utils.OTBaryResult` object that contains in part the barycenter samples and the OT plans between the barycenter and each input distribution. +# +# In the following, we illustrate the use of this function with the same 2D data as above considered as input distributions and compute their barycenter while using exact OT. +# Notice that most of the arguments of the :func:`ot.solve_bary_sample` function are similar to those of the :func:`ot.solve_sample` function and that the same regularization and unbalanced parameters can be used to solve regularized and unbalanced barycenter problems. + +# Solve the OT barycenter problem (exact OT without any regularization) +sol = ot.solve_bary_sample([x1, x2], n=35) + +# get the barycenter support +X = sol.X + +# get the OT plans between the barycenter and each input distribution +list_P = [sol.list_res[i].plan for i in range(2)] + +# get the barycenterOT loss +loss = sol.value + +print(f"Barycenter OT loss = {loss:1.3f}") + +# sphinx_gallery_start_ignore +pl.figure(1, (8, 8)) +plot2D_samples_mat(x1, X, list_P[0]) +plot2D_samples_mat(x2, X, list_P[1]) + +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style) +pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style) + +pl.title( + "Barycenter samples and OT plans \n total loss= %s = 0.5 * %s + 0.5 * %s" + % ( + np.round(loss, 3), + np.round(sol.list_res[0].value, 3), + np.round(sol.list_res[1].value, 3), + ) +) +pl.legend(loc="best") +pl.show() + # sphinx_gallery_end_ignore # %% # diff --git a/ot/__init__.py b/ot/__init__.py index 75f17fed6..944c60e8c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -78,7 +78,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov, solve_sample, solve_bary_sample from .lowrank import lowrank_sinkhorn from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch @@ -135,6 +135,7 @@ "solve", "solve_gromov", "solve_sample", + "solve_bary_sample", "smooth", "stochastic", "unbalanced", diff --git a/ot/solvers.py b/ot/solvers.py index 25f3bd32f..02d3d4beb 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -4,10 +4,11 @@ """ # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License -from .utils import OTResult, dist +from .utils import OTResult, BaryResult, dist from .lp import emd2, emd2_lazy, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced @@ -38,6 +39,7 @@ from .optim import cg import warnings +import numpy as np lst_method_lazy = [ @@ -1444,7 +1446,7 @@ def solve_sample( plan_init : array-like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None rank : int, optional - Rank of the OT matrix for lazy solers (method='factored') or (method='nystroem'), by default 100 + Rank of the OT matrix for lazy solvers (method='factored') or (method='nystroem'), by default 100 scaling : float, optional Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional @@ -2019,3 +2021,523 @@ def solve_sample( log=log, ) return res + + +def _bary_sample_bcd( + X_a_list, + X_b_init, + a_list, + b_init, + w, + metric, + inner_solver, + update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, + max_iter_bary, + tol_bary, + verbose, + log, + nx, +): + """Compute the barycenter using BCD. + + Parameters + ---------- + X_a_list : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + X_b_init : array-like, shape (n_samples_b, dim), + Initialization of the barycenter samples. + a_list : list of array-like, shape (dim_k,) + List of samples weights in each source distribution + b_init : array-like, shape (n_samples_b,) + Initialization of the barycenter weights. + w : list of array-like, shape (N,) + Samples barycentric weights + metric : str + Metric to use for the cost matrix, by default "sqeuclidean" + inner_solver : callable + Function to solve the inner OT problem + update_masses : bool + Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. + warmstart_plan : bool + Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + warmstart_potentials : bool + Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + stopping_criterion : str + Stopping criterion for the BCD algorithm. Can be "loss" or "bary". + max_iter_bary : int + Maximum number of iterations for the barycenter + tol_bary : float + Tolerance for the barycenter convergence + verbose : bool + Print information in the solver + log : bool + Log the loss during the iterations + nx: backend + Backend to use for the computation. Must match<< + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.X : Barycenter samples + - res.b : Barycenter weights + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.list_res: List of OTResult for each inner OT problem (one per source distribution) + - res.log: log of the optimization process (if log=True) + + See :any:`BaryResult` for more information. + + """ + + X_b = X_b_init + b = b_init + inv_b = nx.nan_to_num(1.0 / b, nan=1.0, posinf=1.0, neginf=1.0) + + prev_criterion = np.inf + n_samples = len(X_a_list) + + log_ = None + if log: + log_ = {"stopping_criterion": []} + + # Compute the barycenter using BCD + for it in range(max_iter_bary): + # Solve the inner OT problem for each source distribution + if it == 0: # no pre-defined warmstart used at iteration 0. + list_res = [ + inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) + for k in range(n_samples) + ] + elif warmstart_plan: + list_res = [ + inner_solver(X_a_list[k], X_b, a_list[k], b, list_res[k].plan, None) + for k in range(n_samples) + ] + elif warmstart_potentials: + list_res = [ + inner_solver( + X_a_list[k], X_b, a_list[k], b, None, list_res[k].potentials + ) + for k in range(n_samples) + ] + else: + list_res = [ + inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) + for k in range(n_samples) + ] + + # Update the estimated barycenter weights in unbalanced cases + if update_masses: + b = sum([w[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) + inv_b = nx.nan_to_num(1.0 / b, nan=1.0, posinf=1.0, neginf=1.0) + + # Update the barycenter samples + if metric in ["sqeuclidean", "euclidean"]: + X_b_new = ( + sum([w[k] * list_res[k].plan.T @ X_a_list[k] for k in range(n_samples)]) + * inv_b[:, None] + ) + else: + raise NotImplementedError('Not implemented metric="{}"'.format(metric)) + + # compute criterion + if stopping_criterion == "loss": + new_criterion = sum([w[k] * list_res[k].value for k in range(n_samples)]) + else: # stopping_criterion = "bary" + new_criterion = nx.sum((X_b_new - X_b) ** 2) + + if verbose: + if it % 1 == 0: + print( + f"BCD iteration {it}: criterion {stopping_criterion} = {new_criterion:.4f}" + ) + + if log: + log_["stopping_criterion"].append(new_criterion) + # Check convergence + if abs(new_criterion - prev_criterion) / abs(prev_criterion) < tol_bary: + print(f"BCD converged in {it} iterations") + break + + X_b = X_b_new + prev_criterion = new_criterion + + # compute loss values + + value_linear = sum([w[k] * list_res[k].value_linear for k in range(n_samples)]) + if stopping_criterion == "loss": + value = new_criterion + else: + value = sum([w[k] * list_res[k].value for k in range(n_samples)]) + # update BaryResult + bary_res = BaryResult( + X=X_b, + b=b, + value=value, + value_linear=value_linear, + log=log_, + list_res=list_res, + backend=nx, + ) + return bary_res + + +def solve_bary_sample( + X_a_list, + n, + a_list=None, + w=None, + X_b_init=None, + b_init=None, + metric="sqeuclidean", + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + lazy=False, + method=None, + warmstart=False, + stopping_criterion="loss", + max_iter_bary=1000, + tol_bary=1e-5, + random_state=0, + verbose=False, + **kwargs, +): + r"""Solve the discrete OT barycenter problem over source distributions optimizing the barycenter support using Block-Coordinate Descent. + + The function solves the following general OT barycenter problem + + .. math:: + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}} \min_{\{ \mathbf{T}^{(k)} \}_k \in \mathbb{R}_+^{n_i \times n}} \quad \sum_k w_k \{ \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_r R(\mathbf{T}^{(k)}) + + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) \} + + where the cost matrices :math:`\mathbf{M}^{(k)}` from each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{a}^{(k)})` + to the barycenter domain are computed as :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). The barycenter probability weights are fixed to :math:`\mathbf{b}`. + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_a_list : list of array-like, shape (n_samples_k, dim) + List of N samples in each source distribution + n : int + number of samples in the barycenter domain + a_list : list of array-like, shape (n_samples_k,), optional + List of samples weights in each source distribution (default is uniform) + w : list of array-like, shape (N,), optional + Samples barycentric weights (default is uniform) + X_b_init : array-like, shape (n, dim), optional + Initialization of the barycenter samples (default is gaussian random sampling) + b_init : array-like, shape (n,), optional + Initialization of the barycenter weights (default is uniform) + metric : str, optional + Metric to use for the cost matrix, by default "sqeuclidean" + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + c : array-like, shape (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a}^{(k)} \mathbf{b}^T`. + If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{|a^{(k)}|} 1_{|b|}^T`. + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + warmstart : bool, optional + Use the previous OT or potentials as initialization for the next inner solver iteration, by default False. + stopping_criterion : str, optional + Stopping criterion for the outer loop of the BCD solver, by default 'loss'. + Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm. + max_iter_bary : int, optional + Maximum number of iteration for the outer loop of the BCD solver, by default 1000. + tol_bary : float, optional + Tolerance for solution precision of the barycenter problem, by default 1e-5. + random_state : int, optional + Random seed for the initialization of the barycenter samples, by default 0. + Only used if `X_init` is None. + verbose : bool, optional + Print information in the solver, by default False + kwargs : optional + Additional parameters for the inner solver (see :any:`ot.solve`) + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.X : Barycenter samples + - res.b : Barycenter weights + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.list_res: List of OTResult for each inner OT problem (one per source distribution) + - res.log: log of the optimization process (if log=True) + + See :any:`BaryResult` for more information. + + Notes + ----- + + The following methods are available for solving barycenter problems with respect to these inner OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \forall k, \quad \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} + + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} + + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w) + + # for uniform sample weights and barycentric weights, + res = ot.solve_bary_sample([x1, x2], n) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k)}) + + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} + + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} + + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) + + + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0) + + # or for original Sinkhorn paper formulation [2] + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='entropy') + + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k))}) + + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} + + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} + + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}^{(k)}\geq 0} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_u U(\mathbf{T}^{(k)}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)^T}\mathbf{1},\mathbf{b}) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0) + + # quadratic unbalanced OT + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='L2') + # TV = partial OT + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}^{(k)} \geq 0} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_r R(\mathbf{T}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)^T}\mathbf{1},\mathbf{b}) + + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0, unbalanced_type='L2') + # both quadratic + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') + + .. _references-solve_bary_sample: + References + ---------- + + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + if method is not None and method.lower() in lst_method_lazy: + raise NotImplementedError( + f"method {method} operating on lazy tensors is not implemented yet" + ) + + if stopping_criterion not in ["loss", "bary"]: + raise ValueError( + "stopping_criterion must be either 'loss' or 'bary', got {}".format( + stopping_criterion + ) + ) + + n_samples = len(X_a_list) + + if ( + not lazy + ): # default non lazy solver calls ot.solve_sample within _bary_sample_bcd + # Detect backend + nx = get_backend(*X_a_list, X_b_init, b_init, w) + + # check sample weights + if a_list is None: + a_list = [ + nx.ones((X_a_list[k].shape[0],), type_as=X_a_list[k]) + / X_a_list[k].shape[0] + for k in range(n_samples) + ] + + # check samples barycentric weights + if w is None: + w = nx.ones(n_samples, type_as=X_a_list[0]) / n_samples + + # check X_b_init + if X_b_init is None: + rng = np.random.RandomState(random_state) + mean_ = nx.concatenate( + [nx.mean(X_a_list[k], axis=0) for k in range(n_samples)], + axis=0, + ) + mean_ = nx.mean(mean_, axis=0) + std_ = nx.concatenate( + [nx.std(X_a_list[k], axis=0) for k in range(n_samples)], + axis=0, + ) + std_ = nx.mean(std_, axis=0) + X_b_init = rng.normal( + loc=mean_, + scale=std_, + size=(n, X_a_list[0].shape[1]), + ) + X_b_init = nx.from_numpy(X_b_init, type_as=X_a_list[0]) + else: + if (X_b_init.shape[0] != n) or (X_b_init.shape[1] != X_a_list[0].shape[1]): + raise ValueError("X_b_init must have shape (n, dim)") + + # check b_init + if b_init is None: + b_init = nx.ones((n,), type_as=X_a_list[0]) / n + + if warmstart: + if reg is None: # exact OT + warmstart_plan = True + warmstart_potentials = False + else: # regularized OT + # unbalanced AND regularized OT + if ( + not isinstance(reg_type, tuple) + and reg_type.lower() in ["kl"] + and unbalanced_type.lower() == "kl" + ): + warmstart_plan = False + warmstart_potentials = True + + else: + warmstart_plan = True + warmstart_potentials = False + else: + warmstart_plan = False + warmstart_potentials = False + + def inner_solver(X_a, X_b, a, b, plan_init, potentials_init): + return solve_sample( + X_a=X_a, + X_b=X_b, + a=a, + b=b, + metric=metric, + reg=reg, + c=c, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + method=method, + plan_init=plan_init, + potentials_init=potentials_init, + verbose=False, + **kwargs, + ) + + # compute the barycenter using BCD + update_masses = unbalanced is not None + res = _bary_sample_bcd( + X_a_list, + X_b_init, + a_list, + b_init, + w, + metric, + inner_solver, + update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, + max_iter_bary, + tol_bary, + verbose, + True, # log set to True by default + nx, + ) + + return res + + else: + raise (NotImplementedError("Barycenter solver with lazy=True not implemented")) diff --git a/ot/utils.py b/ot/utils.py index 64bf1ace9..a62f0a347 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1328,6 +1328,183 @@ def citation(self): """ +class BaryResult: + """Base class for OT barycenter results. + + Parameters + ---------- + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + status : int or str + Status of the solver. + + Attributes + ---------- + + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + status : int or str + Status of the solver. + backend : Backend + Backend used to compute the results. + """ + + def __init__( + self, + X=None, + C=None, + b=None, + value=None, + value_linear=None, + value_quad=None, + log=None, + list_res=None, + status=None, + backend=None, + ): + self._X = X + self._C = C + self._b = b + self._value = value + self._value_linear = value_linear + self._value_quad = value_quad + self._log = log + self._list_res = list_res + self._status = status + self._backend = backend if backend is not None else NumpyBackend() + + def __repr__(self): + s = "BaryResult(" + if self._value is not None: + s += "value={},".format(self._value) + if self._value_linear is not None: + s += "value_linear={},".format(self._value_linear) + if self._X is not None: + s += "X={}(shape={}),".format(self._X.__class__.__name__, self._X.shape) + if self._C is not None: + s += "C={}(shape={}),".format(self._C.__class__.__name__, self._C.shape) + if self._b is not None: + s += "b={}(shape={}),".format(self._b.__class__.__name__, self._b.shape) + if s[-1] != "(": + s = s[:-1] + ")" + else: + s = s + ")" + return s + + # Barycerters -------------------------------- + + @property + def X(self): + """Barycenter features.""" + return self._X + + @property + def C(self): + """Barycenter structure for Gromov Wasserstein solutions.""" + return self._C + + @property + def b(self): + """Barycenter weights.""" + return self._b + + # Loss values -------------------------------- + + @property + def value(self): + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" + return self._value + + @property + def value_linear(self): + """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" + return self._value_linear + + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + return self._value_quad + + # List of OTResult objects ------------------------- + + @property + def list_res(self): + """List of results for the individual OT matching.""" + return self._list_res + + @property + def status(self): + """Optimization status of the solver.""" + return self._status + + @property + def log(self): + """Dictionary containing potential information about the solver.""" + return self._log + + # Miscellaneous -------------------------------- + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ + + class LazyTensor(object): """A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. diff --git a/test/test_solvers.py b/test/test_solvers.py index 802aca631..020d062a2 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -1,6 +1,7 @@ """Tests for ot solvers""" # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License @@ -12,6 +13,7 @@ import ot from ot.bregman import geomloss from ot.backend import torch +from ot.solvers import lst_method_lazy lst_reg = [None, 1] @@ -61,6 +63,12 @@ }, # fail lazy for unbalanced and regularized ] +lst_parameters_solve_bary_sample_NotImplemented = [ + {"method": method} for method in lst_method_lazy +] + [ + {"lazy": True}, # fail lazy +] + # set readable ids for each param lst_method_params_solve_sample = [ pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample @@ -69,6 +77,10 @@ pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented ] +lst_parameters_solve_bary_sample_NotImplemented = [ + pytest.param(param, id=str(param)) + for param in lst_parameters_solve_bary_sample_NotImplemented +] def assert_allclose_sol(sol1, sol2): @@ -800,3 +812,187 @@ def test_solve_sample_NotImplemented(nx, method_params): with pytest.raises(NotImplementedError): ot.solve_sample(xb, yb, ab, bb, **method_params) + + +def assert_allclose_bary_sol(sol1, sol2): + lst_attr = ["X", "b", "value", "value_linear", "log"] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: + try: + var1 = getattr(sol1, attr) + var2 = getattr(sol2, attr) + if isinstance(var1, dict): # only contains lists + for key in var1.keys(): + np.allclose( + np.array(var1[key]), + np.array(var2[key]), + equal_nan=True, + ) + else: + np.allclose( + nx1.to_numpy(getattr(sol1, attr)), + nx2.to_numpy(getattr(sol2, attr)), + equal_nan=True, + ) + except NotImplementedError: + pass + elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: + return True + else: + return False + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type,warmstart", + itertools.product( + lst_reg, + lst_reg_type, + lst_unbalanced, + lst_unbalanced_type, + [True, False], + # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, warmstart + ), +) +def test_solve_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type, warmstart): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState() + + K = 2 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_list = [ot.utils.unif(X.shape[0]) for X in X_list] + b = ot.utils.unif(n) + + w = ot.utils.unif(K) + + stopping_criterion = "loss" if rng.choice([True, False]) else "bary" + + try: + if reg_type == "tuple": + + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + # print('test reg_type:', reg_type[0](None), reg_type[1](None)) + # solve default None weights + sol0 = ot.solve_bary_sample( + X_list, + n, + w=None, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=2, + tol_bary=1e-3, + stopping_criterion=stopping_criterion, + verbose=True, + ) + print("------ [done] sol0 - no backend") + + # solve provided uniform weights + + sol = ot.solve_bary_sample( + X_list, + n, + a_list=a_list, + b_init=b, + w=w, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=2, + tol_bary=1e-3, + stopping_criterion=stopping_criterion, + verbose=True, + ) + print("------ [done] sol - no backend") + + assert_allclose_bary_sol(sol0, sol) + + # solve in backend + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb, bb = nx.from_numpy(w, b) + + if isinstance(reg_type, tuple): + + def fb(G): + return nx.sum( + G**2 + ) # otherwise we keep previously defined (f, df) as required by inner solver + + reg_type = (fb, df) + + solb = ot.solve_bary_sample( + X_listb, + n, + a_list=a_listb, + b_init=bb, + w=wb, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=2, + tol_bary=1e-3, + stopping_criterion=stopping_criterion, + verbose=True, + ) + print("------ [done] sol - with backend") + + assert_allclose_bary_sol(sol, solb) + + except NotImplementedError: + pytest.skip("Not implemented") + + +@pytest.mark.parametrize( + "method_params", lst_parameters_solve_bary_sample_NotImplemented +) +def test_solve_bary_sample_NotImplemented(nx, method_params): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState() + + K = 2 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_list = [ot.utils.unif(X.shape[0]) for X in X_list] + b = ot.utils.unif(n) + + w = ot.utils.unif(K) + + # solve in backend + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb, bb = nx.from_numpy(w, b) + + with pytest.raises(NotImplementedError): + ot.solve_bary_sample( + X_listb, n, a_list=a_listb, b_init=bb, w=wb, **method_params + ) diff --git a/test/test_utils.py b/test/test_utils.py index 8c5e65b93..9c350e0dd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -479,7 +479,7 @@ def test_OTResult(): # test print print(res) - # tets get citation + # test get citation print(res.citation) lst_attributes = [ @@ -509,6 +509,31 @@ def test_OTResult(): getattr(res, at) +def test_BaryResult(): + res = ot.utils.BaryResult() + + # test print + print(res) + + # test get citation + print(res.citation) + + lst_attributes = [ + "X", + "C", + "b", + "value", + "value_linear", + "value_quad", + "list_res", + "status", + "log", + ] + for at in lst_attributes: + print(at) + assert getattr(res, at) is None + + def test_get_coordinate_circle(): rng = np.random.RandomState(42) u = rng.rand(1, 100)