Skip to content

Latest commit

 

History

History
54 lines (39 loc) · 1.63 KB

File metadata and controls

54 lines (39 loc) · 1.63 KB

DROT-PyBind11

The PyBind11 version for Douglas-Rachford Splitting for Optimal Transport

Installation

# create conda environment
$ conda create -n fastot python=3.9
$ conda activate fastot

# install conda packages
$ conda install -c conda-forge pot matplotlib
$ conda install scipy numpy

# cmake, replace the cuda location with yours
$ cmake -DCMAKE_CUDA_COMPILER=/usr/local/cuda-11.7/bin/nvcc
$ make

Examples

# generate data
dim = 2
n_sample = 1000

samples0 = multivariate_normal.rvs(np.zeros((dim,)), np.eye(dim), size=n_sample).reshape(-1, dim)
samples1 = multivariate_normal.rvs(np.ones((dim,)), np.eye(dim), size=n_sample).reshape(-1, dim)

C = cdist(samples0, samples1)

# run drot
p = np.ones((C.shape[0],)) * 1 / C.shape[0]
q = np.ones((C.shape[1],)) * 1 / C.shape[1]
stepsize = 2. / sum(C.shape)
maxiters = 100000
eps = 1e-2

result = fast_ot.drot(C, p, q, C.shape[0], C.shape[1], stepsize, maxiters, eps, False, True)
print(result.fval)

Check the complete example in example.ipynb

Comments

The motivation of this repo: see whether this new method is a good replacement for ot.emd2. It turns out that it isn't (at least for this version). Even for small samples, ot.emd2 is still faster (and more accurate, definitely) than the current code.

With eps=1e-5 and n=10000, the current code would require 100 seconds to converge, while ot.emd2 takes less than 10 seconds.

Acknowledgement