Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
44d4614
merge
cedricvincentcuaz Sep 10, 2024
63477c2
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Sep 10, 2024
a94c6ac
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Nov 6, 2024
27944a5
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Nov 6, 2024
0392961
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Nov 16, 2024
60d1295
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Apr 2, 2025
a93c60c
first commit
cedricvincentcuaz Apr 21, 2025
9e25e80
handle masses in unbalanced cases
cedricvincentcuaz Apr 22, 2025
46c4638
update free support
cedricvincentcuaz Apr 24, 2025
671788d
trying to fix tests
cedricvincentcuaz Apr 24, 2025
a5b0f70
Merge branch 'master' into solvers
rflamary May 23, 2025
9cf60fd
Merge branch 'master' into solvers
rflamary Jun 3, 2025
48a63ea
merge
cedricvincentcuaz Oct 11, 2025
df1da8d
small updates
cedricvincentcuaz Oct 12, 2025
75d6c11
fix fun name
cedricvincentcuaz Oct 12, 2025
c71e544
update tests
cedricvincentcuaz Oct 12, 2025
ed9c992
update tests
cedricvincentcuaz Oct 12, 2025
7132f62
fix tests
cedricvincentcuaz Oct 15, 2025
8e55777
Merge branch 'master' into solvers
rflamary Oct 21, 2025
4a0b5ec
fix tests
cedricvincentcuaz Mar 6, 2026
b8219ac
Merge branch 'solvers' of https://github.com/cedricvincentcuaz/POT in…
cedricvincentcuaz Mar 6, 2026
6f900c6
Merge branch 'master' into solvers
cedricvincentcuaz Mar 6, 2026
3d9f987
update docstring for solve_bary_sample
cedricvincentcuaz Mar 9, 2026
7b7cfc0
update plot quickstart guide
cedricvincentcuaz Mar 9, 2026
406d026
add ex
cedricvincentcuaz Mar 9, 2026
0172a89
fix ex
cedricvincentcuaz Mar 10, 2026
d493194
fix docs
cedricvincentcuaz Mar 11, 2026
95e5a1c
fix docs
cedricvincentcuaz Mar 11, 2026
05b47fc
fix sphinx
cedricvincentcuaz Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730)


#### Closed issues

Expand Down
128 changes: 128 additions & 0 deletions examples/barycenters/plot_solve_barycenter_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
"""
======================================
Optimal Transport Barycenter solvers comparison
======================================

This example illustrates solutions returned for different variants of exact,
regularized and unbalanced OT barycenter problems with free support using our wrapper `ot.solve_bary_sample`.
"""

# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2

# %%

import numpy as np
import matplotlib.pylab as pl
import ot
from ot.plot import plot2D_samples_mat

# %%
# 2D data example
# ---------------
#
# We first generate two sets of samples in 2D that 25 and 50
# samples respectively located on circles. The weights of the samples are
# uniform.

# Problem size
n1 = 25
n2 = 50

# Generate random data
np.random.seed(0)

x1 = np.random.randn(n1, 2)
x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2

x2 = np.random.randn(n2, 2)
x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4

style = {"markeredgecolor": "k"}

pl.figure(1, (4, 4))
pl.plot(x1[:, 0], x1[:, 1], "ob", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", **style)
pl.title("Source distributions")
pl.show()


# %%
# Set up parameters for barycenter solvers and solve
# ---------------------------------------

lst_regs = [
"No Reg.",
"Entropic",
] # support e.g ["No Reg.", "Entropic", "L2", "Group Lasso + L2"]
lst_unbalanced = [
"Balanced",
"Unbalanced KL",
] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"]

lst_solvers = [ # name, param for ot.solve function
# balanced OT
("Exact OT", dict()),
("Entropic Reg. OT", dict(reg=0.1)),
# unbalanced OT KL
("Unbalanced KL No Reg.", dict(unbalanced=0.05)),
(
"Unbalanced KL with KL Reg.",
dict(reg=0.1, unbalanced=0.05, unbalanced_type="kl", reg_type="kl"),
),
]

lst_res = []
for name, param in lst_solvers:
print(f"-- name = {name} / param = {param}")
res = ot.solve_bary_sample(X_a_list=[x1, x2], n=35, **param)
lst_res.append(res)
list_P = [res.list_res[k].plan for k in range(2)]
print("X:", res.X)
print("loss:", res.value)
print("loss:", res.log)
print(
"marginals OT 1:",
res.list_res[0].plan.sum(axis=1),
res.list_res[0].plan.sum(axis=0),
)
print(
"marginals OT 2:",
res.list_res[1].plan.sum(axis=1),
res.list_res[1].plan.sum(axis=0),
)

##############################################################################
# Plot distributions and plans
# ----------

pl.figure(2, figsize=(16, 16))

for i, bname in enumerate(lst_unbalanced):
for j, rname in enumerate(lst_regs):
pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)

X = lst_res[i * len(lst_regs) + j].X
list_P = [lst_res[i * len(lst_regs) + j].list_res[k].plan for k in range(2)]
loss = lst_res[i * len(lst_regs) + j].value

plot2D_samples_mat(x1, X, list_P[0])
plot2D_samples_mat(x2, X, list_P[1])

if i == 0 and j == 0: # add labels
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style)
pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style)
pl.legend(loc="best")
else:
pl.plot(x1[:, 0], x1[:, 1], "ob", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", **style)
pl.plot(X[:, 0], X[:, 1], "og", **style)

if i == 0:
pl.title(rname)
if j == 0:
pl.ylabel(bname, fontsize=14)
53 changes: 53 additions & 0 deletions examples/plot_quickstart_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,59 @@ def df(G):
plot_plan(P_fgw, "Fused GW plan", axis=False)
pl.show()

# sphinx_gallery_end_ignore

# %%
#
# Solving barycenter problems
# -------------------------------
# Solve Optimal transport barycenter problem with free support between several input distributions.
# ~~~~~~~~~~~~~~~~~~~~
#
# The :func:`ot.solve_bary_sample` function can be used to solve the Optimal Transport barycenter problem
# between multiple sets of samples while optimizing the support of the barycenter and letting fixed their probability weights.
# The function takes as its first argument the list of samples in each input distribution,
# and as second argument the number of samples to learn in the barycenter. By default, the probability weights in each distribution and the barycentric weights are uniform but they can be customized by the user.
#
# The function returns an :class:`ot.utils.OTBaryResult` object that contains in part the barycenter samples and the OT plans between the barycenter and each input distribution.
#
# In the following, we illustrate the use of this function with the same 2D data as above considered as input distributions and compute their barycenter while using exact OT.
# Notice that most of the arguments of the :func:`ot.solve_bary_sample` function are similar to those of the :func:`ot.solve_sample` function and that the same regularization and unbalanced parameters can be used to solve regularized and unbalanced barycenter problems.

# Solve the OT barycenter problem (exact OT without any regularization)
sol = ot.solve_bary_sample([x1, x2], n=35)

# get the barycenter support
X = sol.X

# get the OT plans between the barycenter and each input distribution
list_P = [sol.list_res[i].plan for i in range(2)]

# get the barycenterOT loss
loss = sol.value

print(f"Barycenter OT loss = {loss:1.3f}")

# sphinx_gallery_start_ignore
pl.figure(1, (8, 8))
plot2D_samples_mat(x1, X, list_P[0])
plot2D_samples_mat(x2, X, list_P[1])

pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style)
pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style)

pl.title(
"Barycenter samples and OT plans \n total loss= %s = 0.5 * %s + 0.5 * %s"
% (
np.round(loss, 3),
np.round(sol.list_res[0].value, 3),
np.round(sol.list_res[1].value, 3),
)
)
pl.legend(loc="best")
pl.show()

# sphinx_gallery_end_ignore
# %%
#
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov, solve_sample
from .solvers import solve, solve_gromov, solve_sample, solve_bary_sample
from .lowrank import lowrank_sinkhorn

from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch
Expand Down Expand Up @@ -135,6 +135,7 @@
"solve",
"solve_gromov",
"solve_sample",
"solve_bary_sample",
"smooth",
"stochastic",
"unbalanced",
Expand Down
Loading
Loading