Skip to content

Commit febc8d1

Browse files
committed
partial MMD distance
1 parent fc66d73 commit febc8d1

File tree

7 files changed

+378
-29
lines changed

7 files changed

+378
-29
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ jobs:
6363
pip install flake8
6464
pip install scipy
6565
pip install sklearn
66+
pip install qpsolvers
67+
pip install numba
6668
6769
- name: Install typing for old Python
6870
run: pip install typing

test/test_detectors.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import functools
23
import torchdrift
34
import torch
45

@@ -12,32 +13,39 @@ def test_detector():
1213

1314

1415
def _test_detector_class(cls):
15-
torch.manual_seed(1234)
16-
d = cls()
17-
d2 = cls(return_p_value=True)
18-
x = torch.randn(5, 5)
19-
y = torch.randn(5, 5) + 1.0
20-
d.fit(x)
21-
d2.fit(x)
22-
assert d(x).item() < d(y).item()
23-
assert d.compute_p_value(x) > 0.80
24-
assert d.compute_p_value(y) < 0.05
25-
torch.manual_seed(1234)
26-
p1 = d.compute_p_value(y)
27-
torch.manual_seed(1234)
28-
p2 = d2(y)
29-
assert p1 == p2
16+
devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else [])
17+
for device in devices:
18+
torch.manual_seed(1234)
19+
d = cls()
20+
d2 = cls(return_p_value=True)
21+
x = torch.randn(5, 5, device=device)
22+
y = torch.randn(5, 5, device=device) + 1.0
23+
d.fit(x)
24+
d2.fit(x)
25+
assert d(x).item() < d(y).item()
26+
assert d.compute_p_value(x) > 0.80
27+
assert d.compute_p_value(y) < 0.05
28+
torch.manual_seed(1234)
29+
p1 = d.compute_p_value(y)
30+
torch.manual_seed(1234)
31+
p2 = d2(y)
32+
assert p1 == p2
33+
assert p1.device == x.device
3034

3135

3236
def _test_detector_class_fit_bootstrap(cls):
33-
torch.manual_seed(1234)
34-
d = cls()
35-
x = torch.randn(100, 5)
36-
y = torch.randn(50, 5) + 1.0
37-
z = torch.randn(50, 5)
38-
d.fit(x, n_test=50)
39-
assert d.compute_p_value(x[:50]) > 0.80
40-
assert d.compute_p_value(y) < 0.05
37+
devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else [])
38+
for device in devices:
39+
torch.manual_seed(1234)
40+
d = cls()
41+
x = torch.randn(100, 5, device=device)
42+
y = torch.randn(50, 5, device=device) + 1.0
43+
z = torch.randn(50, 5, device=device)
44+
d.fit(x, n_test=50)
45+
assert d.compute_p_value(x[:50]) > 0.80
46+
assert d.compute_p_value(y) < 0.05
47+
p = d.compute_p_value(y)
48+
assert p.device == x.device
4149

4250

4351
def test_ksdetector():
@@ -84,6 +92,45 @@ def partial_wasserstein(return_p_value=False):
8492
d, p, c = torchdrift.detectors.wasserstein(x, y, return_coupling=True)
8593
d, c = torchdrift.detectors.wasserstein(x, y, return_coupling=True, n_perm=None)
8694

95+
def test_partial_mmd_detector():
96+
_test_detector_class_fit_bootstrap(torchdrift.detectors.PartialKernelMMDDriftDetector)
97+
pmmd = functools.partial(
98+
torchdrift.detectors.PartialKernelMMDDriftDetector,
99+
fraction_to_match=0.5,
100+
n_perm=100,
101+
)
102+
_test_detector_class_fit_bootstrap(pmmd)
103+
pmmd_approx = functools.partial(
104+
torchdrift.detectors.PartialKernelMMDDriftDetector,
105+
method=torchdrift.detectors.PartialKernelMMDDriftDetector.METHOD_APPROX,
106+
fraction_to_match=0.5,
107+
n_perm=100,
108+
)
109+
_test_detector_class_fit_bootstrap(pmmd_approx)
110+
pmmd_approx = functools.partial(
111+
torchdrift.detectors.PartialKernelMMDDriftDetector,
112+
method=torchdrift.detectors.PartialKernelMMDDriftDetector.METHOD_APPROX,
113+
n_perm=100,
114+
fraction_to_match=1.0, # corner case
115+
)
116+
_test_detector_class_fit_bootstrap(pmmd_approx)
117+
pmmd_qp = functools.partial(
118+
torchdrift.detectors.PartialKernelMMDDriftDetector,
119+
method=torchdrift.detectors.PartialKernelMMDDriftDetector.METHOD_QP,
120+
n_perm=100,
121+
fraction_to_match=0.5,
122+
)
123+
_test_detector_class_fit_bootstrap(pmmd_qp)
124+
125+
# Check that we can also just get the distance...
126+
dd = torchdrift.detectors.PartialKernelMMDDriftDetector(
127+
fraction_to_match=0.5
128+
)
129+
x = torch.randn(5, 5)
130+
y = torch.randn(5, 5) + 1.0
131+
dd.fit(x)
132+
dd(y)
133+
87134

88135
if __name__ == "__main__":
89136
pytest.main([__file__])

torchdrift/detectors/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,10 @@
22
from .mmd import kernel_mmd, KernelMMDDriftDetector
33
from .ks import ks_two_sample_multi_dim, KSDriftDetector, ks_p_value
44
from .wasserstein import wasserstein, WassersteinDriftDetector
5+
from .partial_mmd import (
6+
partial_kernel_mmd_twostage,
7+
partial_kernel_mmd_approx,
8+
partial_kernel_mmd_qp,
9+
PartialKernelMMDDriftDetector
10+
)
511
from . import mmd

torchdrift/detectors/ks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def predict_shift_from_features(
108108
ny, _ = outputs.shape
109109
# multiply by n_features for Bonferroni correction.
110110
p_value = min(1.0, ks_p_value(nx, ny, ood_score.item()) * n_features)
111+
p_value = torch.as_tensor(p_value, device=ood_score.device)
111112
else:
112113
p_value = None
113114
return ood_score, p_value

0 commit comments

Comments
 (0)