Skip to content

Commit ef16e90

Browse files
committed
add detector.compute_p_values
1 parent 7d2484b commit ef16e90

File tree

5 files changed

+141
-147
lines changed

5 files changed

+141
-147
lines changed

notebooks/drift_detection_on_images.ipynb

Lines changed: 67 additions & 90 deletions
Large diffs are not rendered by default.

torchdrift/detectors/detector.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import torch
55

66
class DriftDetector(torch.nn.Module):
7-
def __init__(self):
7+
def __init__(self, *, return_p_value=False):
88
super().__init__()
99
self.register_buffer('base_outputs', None)
10+
self.return_p_value = return_p_value
1011

1112
def fit(
1213
self,
@@ -28,13 +29,20 @@ def fit(
2829
all_outputs = torch.cat(all_outputs, dim=0)
2930
self.base_outputs = all_outputs
3031

31-
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, individual_samples: bool = False):
32+
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False):
3233
raise NotImplementedError("Override predict_shift_from_features in detectors")
3334

35+
def compute_p_value(self, inputs: torch.Tensor):
36+
assert self.base_outputs is not None, "Please call fit before compute_p_value"
37+
_, p_value = self.predict_shift_from_features(self.base_outputs, inputs, compute_score=False, compute_p_value=True)
38+
return p_value
39+
3440
def forward(
3541
self, inputs: torch.Tensor,
3642
individual_samples: bool = False
3743
):
3844
assert self.base_outputs is not None, "Please call fit before predict_shift"
39-
ood_score = self.predict_shift_from_features(self.base_outputs, inputs, individual_samples=individual_samples)
45+
ood_score, p_value = self.predict_shift_from_features(self.base_outputs, inputs, compute_score=not self.return_p_value, compute_p_value=self.return_p_value, individual_samples=individual_samples)
46+
if self.return_p_value:
47+
return p_value
4048
return ood_score

torchdrift/detectors/isolation_forest.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

torchdrift/detectors/ks.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,46 @@
11
from typing import Optional
22

33
import torch
4+
import numpy
45

56
from . import DriftDetector
67

7-
# TODO: ks p-value computation...
8+
try:
9+
import numba
10+
njit = numba.jit(nopython=True, fastmath=True)
11+
except ImportError:
12+
njit = lambda x: x
13+
14+
15+
# Numerically stable p-Value computation, see
16+
# T. Viehmann: Numerically more stable computation of the p-values for the
17+
# two-sample Kolmogorov-Smirnov test
18+
# https://arxiv.org/abs/2102.xxxxx
19+
20+
@njit
21+
def ks_p_value(n, m, d):
22+
size = int(2*m*d+2)
23+
lastrow, row = numpy.zeros((2, size), dtype=numpy.float64)
24+
last_start_j = 0
25+
for i in range(n + 1):
26+
start_j = max(int(m * (i/n + d)) + 1-size, 0)
27+
lastrow, row = row, lastrow
28+
val = 0.0
29+
for jj in range(size):
30+
j = jj + start_j
31+
dist = i/n - j/m
32+
if dist > d or dist < -d:
33+
val = 1.0
34+
elif i == 0 or j == 0:
35+
val = 0.0
36+
elif jj + start_j - last_start_j >= size:
37+
val = (i + val * j) / (i + j)
38+
else:
39+
val = (lastrow[jj + start_j - last_start_j] * i + val * j) / (i + j)
40+
row[jj] = val
41+
jjmax = min(size, m + 1 - start_j)
42+
last_start_j = start_j
43+
return row[m - start_j]
844

945
def ks_two_sample_multi_dim(x, y):
1046
"""
@@ -20,12 +56,20 @@ def ks_two_sample_multi_dim(x, y):
2056

2157
# Like failing loudly suggests to return the minimum p-Value under the
2258
# label Bonferroni correction, this would correspond to the maximum score
59+
# see the p-value computation below...
2360
return ks_scores.max()
2461

2562
class KSDriftDetector(DriftDetector):
26-
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, individual_samples: bool = False):
63+
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False):
2764
assert (
2865
not individual_samples
29-
), "Individual samples not supported by KS detector"
66+
), "Individual samples not supported by MMD detector"
3067
ood_score = ks_two_sample_multi_dim(outputs, self.base_outputs)
31-
return ood_score
68+
if compute_p_value:
69+
nx, n_features = base_outputs.shape
70+
ny, _ = outputs.shape
71+
# multiply by n_features for Bonferroni correction.
72+
p_value = min(1.0, ks_p_value(nx, ny, ood_score.item()) * n_features)
73+
else:
74+
p_value = None
75+
return ood_score, p_value

torchdrift/detectors/mmd.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,21 @@ def kernel_mmd(x, y, n_perm=1000):
5252

5353

5454
class KernelMMDDriftDetector(DriftDetector):
55-
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, individual_samples: bool = False):
55+
def __init__(self, *, return_p_value=False, n_perm: int = 1000):
56+
super().__init__(return_p_value=return_p_value)
57+
self.n_perm = n_perm
58+
59+
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False):
5660
assert (
5761
not individual_samples
5862
), "Individual samples not supported by MMD detector"
59-
ood_score = kernel_mmd(
60-
outputs, self.base_outputs, n_perm=None
61-
) # we have higher == more abnormal
62-
return ood_score
63+
if not compute_p_value:
64+
ood_score = kernel_mmd(
65+
outputs, base_outputs, n_perm=None
66+
)
67+
p_value = None
68+
else:
69+
ood_score, p_value = kernel_mmd(
70+
outputs, base_outputs, n_perm=self.n_perm
71+
)
72+
return ood_score, p_value

0 commit comments

Comments
 (0)