Skip to content

Commit

Permalink
Merge pull request #269 from IFCA/feature-precompute-mmd-ref
Browse files Browse the repository at this point in the history
Add precompute kernel ref matrix values for MMD
  • Loading branch information
jaime-cespedes-sisniega authored Aug 14, 2023
2 parents 5f9c5ac + 3598aa2 commit 4f71e12
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 45 deletions.
116 changes: 71 additions & 45 deletions frouros/detectors/data_drift/batch/distance_based/mmd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""MMD (Maximum Mean Discrepancy) module."""

import itertools
import math
from typing import Callable, Generator, Optional, List, Union

import numpy as np # type: ignore
import tqdm # type: ignore

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import MultivariateData
Expand Down Expand Up @@ -64,6 +62,7 @@ def __init__( # noqa: D107
)
self.kernel = kernel
self.chunk_size = chunk_size
self._expected_k_xx = None

@property
def chunk_size(self) -> Optional[int]:
Expand Down Expand Up @@ -122,11 +121,47 @@ def _distance_measure(
Y=X,
kernel=self.kernel,
chunk_size=self.chunk_size,
expected_k_xx=self._expected_k_xx,
**kwargs,
)
distance_test = DistanceResult(distance=mmd)
return distance_test

def _fit(
self,
X: np.ndarray, # noqa: N803
) -> None:
super()._fit(X=X)
# Add dimension only for the kernel calculation (if dim == 1)
if X.ndim == 1:
X = np.expand_dims(X, axis=1) # noqa: N806
x_num_samples = len(self.X_ref) # type: ignore

chunk_size_x = (
x_num_samples
if self.chunk_size is None
else self.chunk_size # type: ignore
)

x_chunks = self._get_chunks( # noqa: N806
data=X,
chunk_size=chunk_size_x,
)
x_chunks_combinations = itertools.product(x_chunks, repeat=2) # noqa: N806

k_xx_sum = (
self._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=self.kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples
)

self._expected_k_xx = k_xx_sum / ( # type: ignore
x_num_samples * (x_num_samples - 1)
)

@staticmethod
def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float:
k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum()
Expand Down Expand Up @@ -159,13 +194,37 @@ def _mmd( # pylint: disable=too-many-locals
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
else x_num_samples
)
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
MMD._get_chunks(

# If expected_k_xx is provided, we don't need to compute it again
if "expected_k_xx" in kwargs:
x_chunks_copy = MMD._get_chunks( # noqa: N806
data=X,
chunk_size=chunk_size_x, # type: ignore
),
2,
)
chunk_size=chunk_size_x,
)
expected_k_xx = kwargs["expected_k_xx"]
else:
# Compute expected_k_xx
x_chunks, x_chunks_copy = itertools.tee( # type: ignore
MMD._get_chunks(
data=X,
chunk_size=chunk_size_x,
),
2,
)
x_chunks_combinations = itertools.product( # type: ignore
x_chunks,
repeat=2,
)
k_xx_sum = (
MMD._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples
)
expected_k_xx = k_xx_sum / (x_num_samples * (x_num_samples - 1))

y_num_samples = len(Y) # noqa: N806
chunk_size_y = (
kwargs["chunk_size"]
Expand All @@ -175,14 +234,10 @@ def _mmd( # pylint: disable=too-many-locals
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
MMD._get_chunks(
data=Y,
chunk_size=chunk_size_y, # type: ignore
chunk_size=chunk_size_y,
),
2,
)
x_chunks_combinations = itertools.product( # noqa: N806
x_chunks,
repeat=2,
)
y_chunks_combinations = itertools.product( # noqa: N806
y_chunks,
repeat=2,
Expand All @@ -192,50 +247,21 @@ def _mmd( # pylint: disable=too-many-locals
y_chunks_copy,
)

if kwargs.get("verbose", False):
num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore
num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore
num_chunks_x_combinations = num_chunks_x**2
num_chunks_y_combinations = num_chunks_y**2
num_chunks_xy = (
math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore
)
x_chunks_combinations = tqdm.tqdm(
x_chunks_combinations,
total=num_chunks_x_combinations,
)
y_chunks_combinations = tqdm.tqdm(
y_chunks_combinations,
total=num_chunks_y_combinations,
)
xy_chunks_combinations = tqdm.tqdm(
xy_chunks_combinations,
total=num_chunks_xy,
)

k_xx_sum = (
MMD._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples # type: ignore
)
k_yy_sum = (
MMD._compute_kernel(
chunk_combinations=y_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- y_num_samples # type: ignore
- y_num_samples
)
k_xy_sum = MMD._compute_kernel(
chunk_combinations=xy_chunks_combinations, # type: ignore
kernel=kernel,
)
mmd = (
+k_xx_sum / (x_num_samples * (x_num_samples - 1))
+expected_k_xx
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
- 2 * k_xy_sum / (x_num_samples * y_num_samples)
)
return mmd
1 change: 1 addition & 0 deletions frouros/tests/unit/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Detectors test init."""
1 change: 1 addition & 0 deletions frouros/tests/unit/detectors/data_drift/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Data drift detectors test init."""
1 change: 1 addition & 0 deletions frouros/tests/unit/detectors/data_drift/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Batch data drift detectors test init."""
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Distance based batch data drift detectors test init."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Test MMD."""

from functools import partial
from typing import Optional, Tuple

import numpy as np # type: ignore
import pytest # type: ignore

from frouros.detectors.data_drift import MMD
from frouros.utils.kernels import rbf_kernel


@pytest.mark.parametrize(
"distribution_p, distribution_q, expected_distance",
[
((0, 1, 100), (0, 1, 100), 0.00052755), # (mean, std, size)
((0, 1, 100), (0, 1, 10), -0.03200193),
((0, 1, 10), (0, 1, 100), 0.07154671),
((2, 1, 100), (0, 1, 100), 0.43377622),
((2, 1, 100), (0, 1, 10), 0.23051378),
((2, 1, 10), (0, 1, 100), 0.62530767),
],
)
def test_mmd_batch_univariate(
distribution_p: Tuple[float, float, int],
distribution_q: Tuple[float, float, int],
expected_distance: float,
) -> None:
"""Test MMD batch with univariate data.
:param distribution_p: mean, std and size of samples from distribution p
:type distribution_p: Tuple[float, float, int]
:param distribution_q: mean, std and size of samples from distribution q
:type distribution_q: Tuple[float, float, int]
:param expected_distance: expected distance value
:type expected_distance: float
"""
np.random.seed(seed=31)
X_ref = np.random.normal(*distribution_p) # noqa: N806
X_test = np.random.normal(*distribution_q) # noqa: N806

detector = MMD(
kernel=partial(rbf_kernel, sigma=0.5),
)
_ = detector.fit(X=X_ref)

result = detector.compare(X=X_test)[0]

assert np.isclose(result.distance, expected_distance)


@pytest.mark.parametrize(
"distribution_p, distribution_q, chunk_size",
[
((0, 1, 100), (0, 1, 100), None), # (mean, std, size)
((0, 1, 100), (0, 1, 100), 2),
((0, 1, 100), (0, 1, 100), 10),
((0, 1, 100), (0, 1, 10), None),
((0, 1, 100), (0, 1, 10), 2),
((0, 1, 100), (0, 1, 10), 10),
((0, 1, 10), (0, 1, 100), None),
((0, 1, 10), (0, 1, 100), 2),
((0, 1, 10), (0, 1, 100), 10),
],
)
def test_mmd_batch_precomputed_expected_k_xx(
distribution_p: Tuple[float, float, int],
distribution_q: Tuple[float, float, int],
chunk_size: Optional[int],
) -> None:
"""Test MMD batch with precomputed expected k_xx.
:param distribution_p: mean, std and size of samples from distribution p
:type distribution_p: Tuple[float, float, int]
:param distribution_q: mean, std and size of samples from distribution q
:type distribution_q: Tuple[float, float, int]
:param chunk_size: chunk size
:type chunk_size: Optional[int]
"""
np.random.seed(seed=31)
X_ref = np.random.normal(*distribution_p) # noqa: N806
X_test = np.random.normal(*distribution_q) # noqa: N806

kernel = partial(rbf_kernel, sigma=0.5)

detector = MMD(
kernel=kernel,
chunk_size=chunk_size,
)
_ = detector.fit(X=X_ref)

# Computes mmd using precomputed expected k_xx
precomputed_distance = detector.compare(X=X_test)[0].distance

# Computes mmd from scratch
scratch_distance = MMD._mmd( # pylint: disable=protected-access
X=X_ref,
Y=X_test,
kernel=kernel,
chunk_size=chunk_size,
)

assert np.isclose(precomputed_distance, scratch_distance)

0 comments on commit 4f71e12

Please sign in to comment.