Skip to content

Commit

Permalink
Fix PEP8
Browse files Browse the repository at this point in the history
  • Loading branch information
jaime-cespedes-sisniega committed Aug 14, 2023
1 parent f137843 commit 3598aa2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 25 deletions.
42 changes: 19 additions & 23 deletions frouros/detectors/data_drift/batch/distance_based/mmd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""MMD (Maximum Mean Discrepancy) module."""

import itertools
import math
from typing import Callable, Generator, Optional, List, Union

import numpy as np # type: ignore
import tqdm # type: ignore

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import MultivariateData
Expand Down Expand Up @@ -137,7 +135,7 @@ def _fit(
# Add dimension only for the kernel calculation (if dim == 1)
if X.ndim == 1:
X = np.expand_dims(X, axis=1) # noqa: N806
x_num_samples = len(self.X_ref) # type: ignore # noqa: N806
x_num_samples = len(self.X_ref) # type: ignore

chunk_size_x = (
x_num_samples
Expand All @@ -147,7 +145,7 @@ def _fit(

x_chunks = self._get_chunks( # noqa: N806
data=X,
chunk_size=chunk_size_x, # type: ignore
chunk_size=chunk_size_x,
)
x_chunks_combinations = itertools.product(x_chunks, repeat=2) # noqa: N806

Expand All @@ -157,11 +155,11 @@ def _fit(
kernel=self.kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples # type: ignore
- x_num_samples
)

self._expected_k_xx = k_xx_sum / ( # type: ignore
x_num_samples * (x_num_samples - 1) # type: ignore
x_num_samples * (x_num_samples - 1)
)

@staticmethod
Expand Down Expand Up @@ -201,33 +199,31 @@ def _mmd( # pylint: disable=too-many-locals
if "expected_k_xx" in kwargs:
x_chunks_copy = MMD._get_chunks( # noqa: N806
data=X,
chunk_size=chunk_size_x, # type: ignore
chunk_size=chunk_size_x,
)
expected_k_xx = kwargs["expected_k_xx"]
else:
# Compute expected_k_xx
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
x_chunks, x_chunks_copy = itertools.tee( # type: ignore
MMD._get_chunks(
data=X,
chunk_size=chunk_size_x, # type: ignore
chunk_size=chunk_size_x,
),
2,
)
x_chunks_combinations = itertools.product( # noqa: N806
x_chunks_combinations = itertools.product( # type: ignore
x_chunks,
repeat=2,
)
k_xx_sum = (
MMD._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples # type: ignore
)
expected_k_xx = k_xx_sum / ( # type: ignore
x_num_samples * (x_num_samples - 1) # type: ignore
MMD._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples
)
expected_k_xx = k_xx_sum / (x_num_samples * (x_num_samples - 1))

y_num_samples = len(Y) # noqa: N806
chunk_size_y = (
Expand All @@ -238,7 +234,7 @@ def _mmd( # pylint: disable=too-many-locals
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
MMD._get_chunks(
data=Y,
chunk_size=chunk_size_y, # type: ignore
chunk_size=chunk_size_y,
),
2,
)
Expand All @@ -257,15 +253,15 @@ def _mmd( # pylint: disable=too-many-locals
kernel=kernel,
)
# Remove diagonal (j!=i case)
- y_num_samples # type: ignore
- y_num_samples
)
k_xy_sum = MMD._compute_kernel(
chunk_combinations=xy_chunks_combinations, # type: ignore
kernel=kernel,
)
mmd = (
+ expected_k_xx
+expected_k_xx
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
- 2 * k_xy_sum / (x_num_samples * y_num_samples)
)
return mmd
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,11 @@ def test_mmd_batch_precomputed_expected_k_xx(
precomputed_distance = detector.compare(X=X_test)[0].distance

# Computes mmd from scratch
scratch_distance = MMD._mmd(
scratch_distance = MMD._mmd( # pylint: disable=protected-access
X=X_ref,
Y=X_test,
kernel=kernel,
chunk_size=chunk_size,
)

assert np.isclose(precomputed_distance, scratch_distance)

0 comments on commit 3598aa2

Please sign in to comment.