diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index cfe055d..fec7db3 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -69,7 +69,7 @@ def get_map(self, x0, x1): x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) - represents the source minibatch + represents the target minibatch Returns -------