Skip to content

Commit b8ff1cd

Browse files
committed
formatting, true divide, and kernels
1 parent 4a0397d commit b8ff1cd

File tree

9 files changed

+192
-79
lines changed

9 files changed

+192
-79
lines changed

test/test_detectors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,24 @@ def _test_detector_class(cls):
2626
def test_ksdetector():
2727
_test_detector_class(torchdrift.detectors.KSDriftDetector)
2828

29+
30+
def _test_mmd_kernel(kernel):
31+
torch.manual_seed(1234)
32+
d = torchdrift.detectors.KernelMMDDriftDetector(kernel=kernel)
33+
x = torch.randn(5, 5)
34+
y = torch.randn(5, 5) + 1.0
35+
d.fit(x)
36+
assert (d(x).item() < d(y).item())
37+
assert d.compute_p_value(x) > 0.80
38+
assert d.compute_p_value(y) < 0.05
39+
2940
def test_mmddetector():
3041
_test_detector_class(torchdrift.detectors.KernelMMDDriftDetector)
42+
_test_mmd_kernel(torchdrift.detectors.mmd.GaussianKernel(lengthscale=1.0))
43+
_test_mmd_kernel(torchdrift.detectors.mmd.ExpKernel())
44+
_test_mmd_kernel(torchdrift.detectors.mmd.ExpKernel(lengthscale=1.0))
45+
_test_mmd_kernel(torchdrift.detectors.mmd.RationalQuadraticKernel())
46+
_test_mmd_kernel(torchdrift.detectors.mmd.RationalQuadraticKernel(lengthscale=1.0, alpha=2.0))
3147

3248
if __name__ == "__main__":
3349
pytest.main([__file__])

torchdrift/detectors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .detector import Detector
22
from .mmd import kernel_mmd, KernelMMDDriftDetector
33
from .ks import ks_two_sample_multi_dim, KSDriftDetector, ks_p_value
4+
from . import mmd

torchdrift/detectors/detector.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,60 @@
11
import torch
22

3+
34
class Detector(torch.nn.Module):
45
"""Detector class.
56
6-
The detector is is a `nn.Module` subclass that, after fitting, performs a drift test when called and returns a score or p-value.
7+
The detector is is a `nn.Module` subclass that, after fitting, performs a drift test when called and returns a score or p-value.
8+
9+
Constructor Args:
10+
return_p_value (bool): If set, forward returns a p-value (estimate) instead of the raw test score.
11+
"""
712

8-
Constructor Args:
9-
return_p_value (bool): If set, forward returns a p-value (estimate) instead of the raw test score.
10-
"""
11-
def __init__(self, *, return_p_value: bool=False):
13+
def __init__(self, *, return_p_value: bool = False):
1214
super().__init__()
13-
self.register_buffer('base_outputs', None)
15+
self.register_buffer("base_outputs", None)
1416
self.return_p_value = return_p_value
1517

1618
def fit(self, x: torch.Tensor):
1719
"""Record a sample as the reference distribution"""
1820
self.base_outputs = x.detach()
1921
return x
2022

21-
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False) -> torch.Tensor:
23+
def predict_shift_from_features(
24+
self,
25+
base_outputs: torch.Tensor,
26+
outputs: torch.Tensor,
27+
compute_score: bool,
28+
compute_p_value: bool,
29+
individual_samples: bool = False,
30+
) -> torch.Tensor:
2231
"""stub to be overridden by subclasses"""
2332
raise NotImplementedError("Override predict_shift_from_features in detectors")
2433

2534
def compute_p_value(self, inputs: torch.Tensor) -> torch.Tensor:
2635
"""Performs a statistical test for drift and returns the p-value.
2736
28-
This method calls `predict_shift_from_features` under the hood, so you only need to override that when subclassing.
29-
"""
37+
This method calls `predict_shift_from_features` under the hood, so you only need to override that when subclassing."""
3038
assert self.base_outputs is not None, "Please call fit before compute_p_value"
31-
_, p_value = self.predict_shift_from_features(self.base_outputs, inputs, compute_score=False, compute_p_value=True)
39+
_, p_value = self.predict_shift_from_features(
40+
self.base_outputs, inputs, compute_score=False, compute_p_value=True
41+
)
3242
return p_value
3343

3444
def forward(
35-
self, inputs: torch.Tensor,
36-
individual_samples: bool = False
45+
self, inputs: torch.Tensor, individual_samples: bool = False
3746
) -> torch.Tensor:
3847
"""Performs a statistical test for drift and returns the score or, if `return_p_value` has been set in the constructor, the p-value.
3948
40-
This method calls `predict_shift_from_features` under the hood, so you only need to override that when subclassing.
41-
"""
49+
This method calls `predict_shift_from_features` under the hood, so you only need to override that when subclassing."""
4250
assert self.base_outputs is not None, "Please call fit before predict_shift"
43-
ood_score, p_value = self.predict_shift_from_features(self.base_outputs, inputs, compute_score=not self.return_p_value, compute_p_value=self.return_p_value, individual_samples=individual_samples)
51+
ood_score, p_value = self.predict_shift_from_features(
52+
self.base_outputs,
53+
inputs,
54+
compute_score=not self.return_p_value,
55+
compute_p_value=self.return_p_value,
56+
individual_samples=individual_samples,
57+
)
4458
if self.return_p_value:
4559
return p_value
4660
return ood_score
47-

torchdrift/detectors/ks.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
try:
99
import numba
10+
1011
njit = numba.jit(nopython=True, fastmath=True)
11-
except ImportError: # pragma: no cover
12+
except ImportError: # pragma: no cover
1213
njit = lambda x: x
1314

1415

@@ -17,22 +18,23 @@
1718
# two-sample Kolmogorov-Smirnov test
1819
# https://arxiv.org/abs/2102.08037
1920

21+
2022
@njit
21-
def ks_p_value(n : int, m : int, d : float) -> float:
23+
def ks_p_value(n: int, m: int, d: float) -> float:
2224
"""Computes the p-value for the two-sided two-sample KS test from the D-statistic.
2325
24-
This uses the stable recursion from T. Viehmann: Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test.
26+
This uses the stable recursion from T. Viehmann: Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test.
2527
"""
26-
size = int(2*m*d+2)
28+
size = int(2 * m * d + 2)
2729
lastrow, row = numpy.zeros((2, size), dtype=numpy.float64)
2830
last_start_j = 0
2931
for i in range(n + 1):
30-
start_j = max(int(m * (i/n + d)) + 1-size, 0)
32+
start_j = max(int(m * (i / n + d)) + 1 - size, 0)
3133
lastrow, row = row, lastrow
3234
val = 0.0
3335
for jj in range(size):
3436
j = jj + start_j
35-
dist = i/n - j/m
37+
dist = i / n - j / m
3638
if dist > d or dist < -d:
3739
val = 1.0
3840
elif i == 0 or j == 0:
@@ -46,6 +48,7 @@ def ks_p_value(n : int, m : int, d : float) -> float:
4648
last_start_j = start_j
4749
return row[m - start_j]
4850

51+
4952
def ks_two_sample_multi_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5053
"""Computes the two-sample two-sided Kolmorogov-Smirnov statistic.
5154
@@ -58,26 +61,37 @@ def ks_two_sample_multi_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5861
n_x, n_features = x.shape
5962
n_y, n_features_y = y.shape
6063
assert n_features == n_features_y
61-
64+
6265
joint_sorted = torch.argsort(torch.cat([x, y], dim=0), dim=0)
63-
sign = (joint_sorted < n_x).to(dtype=torch.float) * (1 /(n_x) + 1/(n_y)) - (1/(n_y))
66+
sign = (joint_sorted < n_x).to(dtype=torch.float) * (1 / (n_x) + 1 / (n_y)) - (
67+
1 / (n_y)
68+
)
6469
ks_scores = sign.cumsum(0).abs().max(0).values
6570
return ks_scores
6671

72+
6773
class KSDriftDetector(Detector):
6874
"""Drift detector based on (multiple) Kolmogorov-Smirnov tests.
6975
70-
This detector uses the Kolmogorov-Smirnov test on the marginals of the features
71-
for each feature.
76+
This detector uses the Kolmogorov-Smirnov test on the marginals of the features
77+
for each feature.
7278
73-
For scores, it returns the maximum score. p-values are computed with the
74-
Bonferroni correction of multiplying the p-value of the maximum score by
75-
the number of features/tests.
79+
For scores, it returns the maximum score. p-values are computed with the
80+
Bonferroni correction of multiplying the p-value of the maximum score by
81+
the number of features/tests.
7682
77-
This is modelled after the KS drift detection in
78-
S. Rabanser et al: *Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift* (NeurIPS), 2019.
83+
This is modelled after the KS drift detection in
84+
S. Rabanser et al: *Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift* (NeurIPS), 2019.
7985
"""
80-
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False):
86+
87+
def predict_shift_from_features(
88+
self,
89+
base_outputs: torch.Tensor,
90+
outputs: torch.Tensor,
91+
compute_score: bool,
92+
compute_p_value: bool,
93+
individual_samples: bool = False,
94+
):
8195
assert (
8296
not individual_samples
8397
), "Individual samples not supported by MMD detector"

torchdrift/detectors/mmd.py

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,69 @@
44

55
from . import Detector
66

7-
def kernel_mmd(x, y, n_perm=1000):
7+
8+
class Kernel:
9+
pass
10+
11+
12+
class GaussianKernel(Kernel):
13+
"""Unnormalized gaussian kernel"""
14+
15+
def __init__(self, lengthscale=None):
16+
super().__init__()
17+
self.lengthscale = lengthscale
18+
19+
def __call__(self, dists):
20+
# note that lengthscale should be squared in the RBF to match the Gretton et al heuristic
21+
if self.lengthscale is not None:
22+
lengthscale = self.lengthscale
23+
else:
24+
lengthscale = dists[:100, :100].median()
25+
return torch.exp((-1 / lengthscale ** 2) * dists ** 2)
26+
27+
28+
class ExpKernel(Kernel):
29+
"""Unnormalized exponential kernel"""
30+
31+
def __init__(self, lengthscale=None):
32+
super().__init__()
33+
self.lengthscale = lengthscale
34+
35+
def __call__(self, dists):
36+
if self.lengthscale is not None:
37+
lengthscale = self.lengthscale
38+
else:
39+
lengthscale = dists[:100, :100].median()
40+
return torch.exp((-1 / lengthscale) * dists)
41+
42+
43+
class RationalQuadraticKernel(Kernel):
44+
"""Unnormalized rational quadratic kernel
45+
46+
k(|x-y|) = (1+|x-y|^2/(2 alpha lengthscale**2))^(-alpha)"""
47+
48+
def __init__(self, lengthscale=None, alpha=1.0):
49+
super().__init__()
50+
self.alpha = alpha
51+
self.lengthscale = lengthscale
52+
53+
def __call__(self, dists):
54+
if self.lengthscale is not None:
55+
lengthscale = self.lengthscale
56+
else:
57+
lengthscale = dists[:100, :100].median()
58+
return torch.pow(
59+
1 + (1 / (2 * self.alpha * lengthscale ** 2)) * dists ** 2, -self.alpha
60+
)
61+
62+
63+
def kernel_mmd(x, y, n_perm=1000, kernel=GaussianKernel()):
864
"""Implements the kernel MMD two-sample test.
965
1066
It is modelled after the kernel MMD paper and code:
1167
A. Gretton et al.: A kernel two-sample test, JMLR 13 (2012)
1268
http://www.gatsby.ucl.ac.uk/~gretton/mmd/mmd.htm
13-
69+
1470
The arguments `x` and `y` should be two-dimensional tensors.
1571
The first is the batch dimension (which may differ), the second
1672
the features (which must be the same on both `x` and `y`).
@@ -24,8 +80,7 @@ def kernel_mmd(x, y, n_perm=1000):
2480
xy = torch.cat([x.detach(), y.detach()], dim=0)
2581
dists = torch.cdist(xy, xy, p=2.0)
2682
# we are a bit sloppy here as we just keep the diagonal and everything twice
27-
# note that sigma should be squared in the RBF to match the Gretton et al heuristic
28-
k = torch.exp((-1 / dists[:100, :100].median() ** 2) * dists ** 2)
83+
k = kernel(dists)
2984
k_x = k[:n, :n]
3085
k_y = k[n:, n:]
3186
k_xy = k[:n, n:]
@@ -55,34 +110,46 @@ def kernel_mmd(x, y, n_perm=1000):
55110
mmd_0s.append(mmd_0)
56111
count = count + (mmd_0 > mmd)
57112
# pyplot.hist(torch.stack(mmd_0s, dim=0).tolist(), bins=50)
58-
p_val = count / n_perm
113+
# true_divide: torch 1.6 compat replace with "/" after October 2021
114+
p_val = torch.true_divide(count, n_perm)
115+
59116
return mmd, p_val
60117

61118

62119
class KernelMMDDriftDetector(Detector):
63120
"""Drift detector based on the kernel Maximum Mean Discrepancy (MMD) test.
64121
65-
This is modelled after the MMD drift detection in
66-
S. Rabanser et al: *Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift* (NeurIPS), 2019.
122+
This is modelled after the MMD drift detection in
123+
S. Rabanser et al: *Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift* (NeurIPS), 2019.
67124
68-
Note that our heuristic choice of the kernel bandwith is more closely aligned with that of the original MMD paper and code than S. Rabanser's.
125+
Note that our heuristic choice of the kernel bandwith is more closely aligned with that of the original MMD paper and code than S. Rabanser's.
69126
"""
70-
71-
def __init__(self, *, return_p_value=False, n_perm: int = 1000):
127+
128+
def __init__(
129+
self, *, return_p_value=False, n_perm: int = 1000, kernel=GaussianKernel()
130+
):
72131
super().__init__(return_p_value=return_p_value)
73132
self.n_perm = n_perm
133+
self.kernel = kernel
74134

75-
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False):
135+
def predict_shift_from_features(
136+
self,
137+
base_outputs: torch.Tensor,
138+
outputs: torch.Tensor,
139+
compute_score: bool,
140+
compute_p_value: bool,
141+
individual_samples: bool = False,
142+
):
76143
assert (
77144
not individual_samples
78145
), "Individual samples not supported by MMD detector"
79146
if not compute_p_value:
80147
ood_score = kernel_mmd(
81-
outputs, base_outputs, n_perm=None
148+
outputs, base_outputs, n_perm=None, kernel=self.kernel
82149
)
83150
p_value = None
84151
else:
85152
ood_score, p_value = kernel_mmd(
86-
outputs, base_outputs, n_perm=self.n_perm
153+
outputs, base_outputs, n_perm=self.n_perm, kernel=self.kernel
87154
)
88155
return ood_score, p_value

torchdrift/reducers/pca.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
import torch
22
from . import Reducer
33

4+
45
class PCAReducer(Reducer):
56
"""Reduce dimensions using PCA.
67
7-
This nn.Modue subclass reduces the dimensions of the inputs
8-
specified by `n_components`.
8+
This nn.Modue subclass reduces the dimensions of the inputs
9+
specified by `n_components`.
910
10-
The input is a 2-dimensional `Tensor` of size `batch` x `features`,
11-
the output is a `Tensor` of size `batch` x `n_components`.
11+
The input is a 2-dimensional `Tensor` of size `batch` x `features`,
12+
the output is a `Tensor` of size `batch` x `n_components`.
1213
"""
13-
def __init__(self, n_components:int = 2):
14+
15+
def __init__(self, n_components: int = 2):
1416
super().__init__()
1517
self.n_components = n_components
1618

1719
def extra_repr(self) -> str:
18-
return f'n_components={self.n_components}'
20+
return f"n_components={self.n_components}"
1921

2022
def fit(self, x: torch.Tensor) -> torch.Tensor:
2123
batch, feat = x.shape
2224
assert min(batch, feat) >= self.n_components
2325
self.mean = x.mean(0, keepdim=True)
2426
x = x - self.mean
2527
u, s, v = x.svd()
26-
self.comp = v[:, :self.n_components]
28+
self.comp = v[:, : self.n_components]
2729
reduced = x @ self.comp
2830
return reduced
2931

0 commit comments

Comments
 (0)