Skip to content

Conversation

JPZ4-5
Copy link
Contributor

@JPZ4-5 JPZ4-5 commented Oct 7, 2025

Here, the gaussian_kernel calculates $K(x, y) = \exp(-\gamma | x - y |^2)$. By replacing my_cdist with torch.cdist, we can significantly improve performance without affecting the original logic.

In my environment tests (RTX 2080 Super, with AdvancedProfiler in pl.trainer), the total execution time within the gaussian_kernel function was reduced by half.

Here is a partial comparison from perf.log:

new:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      150    0.467    0.003   58.552    0.390 CACM.py:105(training_step)
      150   21.678    0.145   55.006    0.367 regularization.py:99(conditional_reg)
    87897    0.257    0.000   33.031    0.000 regularization.py:33(mmd)
    87897    3.731    0.000   32.772    0.000 utils.py:27(mmd_compute)
   263691    1.720    0.000   25.306    0.000 utils.py:23(gaussian_kernel)
   154/24    0.005    0.000    9.601    0.400 strategy.py:380(training_step)
   263691    6.774    0.000    6.774    0.000 {method 'clamp_min_' of 'torch._C.TensorBase' objects}
   263691    1.006    0.000    6.295    0.000 functional.py:1431(cdist)
   263691    5.010    0.000    5.010    0.000 {built-in method torch.cdist}

old:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      150    0.855    0.006   95.147    0.634 CACM.py:105(training_step)
      150   24.117    0.161   90.420    0.603 regularization.py:99(conditional_reg)
    88725    0.276    0.000   66.707    0.001 regularization.py:33(mmd)
    88725    4.468    0.000   66.429    0.001 utils.py:32(mmd_compute)
   266175    1.639    0.000   57.745    0.000 utils.py:23(gaussian_kernel)
   266175    2.162    0.000   41.267    0.000 utils.py:16(my_cdist)
   154/51    0.005    0.000   27.495    0.539 strategy.py:380(training_step)
   266175   10.389    0.000   10.389    0.000 {built-in method torch.addmm}

As for the calculation error, I tested with the following code:

def my_cdist(x1, x2):
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
    return res.clamp_min_(1e-30)


def gaussian_kernel(x, y, gamma):
    D = my_cdist(x, y)
    K = torch.zeros_like(D)

    K.add_(torch.exp(D.mul(-gamma)))

    return K


def my(x, y, gamma):
    return torch.exp(torch.cdist(x, y, p=2.0).pow(2).clamp_min_(1e-30).mul(-gamma))


def mmd_compute(x, y, kernel_type, gamma):
    if kernel_type == "gaussian":
        Kxx = gaussian_kernel(x, x, gamma).mean()
        Kyy = gaussian_kernel(y, y, gamma).mean()
        Kxy = gaussian_kernel(x, y, gamma).mean()
        try:
            assert torch.allclose(gaussian_kernel(x, y, gamma), my(x, y, gamma), atol=1e-7, rtol=0.0), "Kxy is different from 1"
        except AssertionError as e:
            print(e)
            torch.set_printoptions(precision=10)
            print(gaussian_kernel(x, y, gamma))
            print(my(x, y, gamma))

        return Kxx + Kyy - 2 * Kxy

In terms of numerical stability, there were only 2 assertion errors across over 87k calculations, and the calculation error remains < 2e-7, which should be acceptable for FP32 calculations.

Looking forward to any suggestions for further performance improvements!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant