1
1
import pytest
2
+ import functools
2
3
import torchdrift
3
4
import torch
4
5
@@ -12,32 +13,39 @@ def test_detector():
12
13
13
14
14
15
def _test_detector_class (cls ):
15
- torch .manual_seed (1234 )
16
- d = cls ()
17
- d2 = cls (return_p_value = True )
18
- x = torch .randn (5 , 5 )
19
- y = torch .randn (5 , 5 ) + 1.0
20
- d .fit (x )
21
- d2 .fit (x )
22
- assert d (x ).item () < d (y ).item ()
23
- assert d .compute_p_value (x ) > 0.80
24
- assert d .compute_p_value (y ) < 0.05
25
- torch .manual_seed (1234 )
26
- p1 = d .compute_p_value (y )
27
- torch .manual_seed (1234 )
28
- p2 = d2 (y )
29
- assert p1 == p2
16
+ devices = ['cpu' ] + (['cuda' ] if torch .cuda .is_available () else [])
17
+ for device in devices :
18
+ torch .manual_seed (1234 )
19
+ d = cls ()
20
+ d2 = cls (return_p_value = True )
21
+ x = torch .randn (5 , 5 , device = device )
22
+ y = torch .randn (5 , 5 , device = device ) + 1.0
23
+ d .fit (x )
24
+ d2 .fit (x )
25
+ assert d (x ).item () < d (y ).item ()
26
+ assert d .compute_p_value (x ) > 0.80
27
+ assert d .compute_p_value (y ) < 0.05
28
+ torch .manual_seed (1234 )
29
+ p1 = d .compute_p_value (y )
30
+ torch .manual_seed (1234 )
31
+ p2 = d2 (y )
32
+ assert p1 == p2
33
+ assert p1 .device == x .device
30
34
31
35
32
36
def _test_detector_class_fit_bootstrap (cls ):
33
- torch .manual_seed (1234 )
34
- d = cls ()
35
- x = torch .randn (100 , 5 )
36
- y = torch .randn (50 , 5 ) + 1.0
37
- z = torch .randn (50 , 5 )
38
- d .fit (x , n_test = 50 )
39
- assert d .compute_p_value (x [:50 ]) > 0.80
40
- assert d .compute_p_value (y ) < 0.05
37
+ devices = ['cpu' ] + (['cuda' ] if torch .cuda .is_available () else [])
38
+ for device in devices :
39
+ torch .manual_seed (1234 )
40
+ d = cls ()
41
+ x = torch .randn (100 , 5 , device = device )
42
+ y = torch .randn (50 , 5 , device = device ) + 1.0
43
+ z = torch .randn (50 , 5 , device = device )
44
+ d .fit (x , n_test = 50 )
45
+ assert d .compute_p_value (x [:50 ]) > 0.80
46
+ assert d .compute_p_value (y ) < 0.05
47
+ p = d .compute_p_value (y )
48
+ assert p .device == x .device
41
49
42
50
43
51
def test_ksdetector ():
@@ -84,6 +92,45 @@ def partial_wasserstein(return_p_value=False):
84
92
d , p , c = torchdrift .detectors .wasserstein (x , y , return_coupling = True )
85
93
d , c = torchdrift .detectors .wasserstein (x , y , return_coupling = True , n_perm = None )
86
94
95
+ def test_partial_mmd_detector ():
96
+ _test_detector_class_fit_bootstrap (torchdrift .detectors .PartialKernelMMDDriftDetector )
97
+ pmmd = functools .partial (
98
+ torchdrift .detectors .PartialKernelMMDDriftDetector ,
99
+ fraction_to_match = 0.5 ,
100
+ n_perm = 100 ,
101
+ )
102
+ _test_detector_class_fit_bootstrap (pmmd )
103
+ pmmd_approx = functools .partial (
104
+ torchdrift .detectors .PartialKernelMMDDriftDetector ,
105
+ method = torchdrift .detectors .PartialKernelMMDDriftDetector .METHOD_APPROX ,
106
+ fraction_to_match = 0.5 ,
107
+ n_perm = 100 ,
108
+ )
109
+ _test_detector_class_fit_bootstrap (pmmd_approx )
110
+ pmmd_approx = functools .partial (
111
+ torchdrift .detectors .PartialKernelMMDDriftDetector ,
112
+ method = torchdrift .detectors .PartialKernelMMDDriftDetector .METHOD_APPROX ,
113
+ n_perm = 100 ,
114
+ fraction_to_match = 1.0 , # corner case
115
+ )
116
+ _test_detector_class_fit_bootstrap (pmmd_approx )
117
+ pmmd_qp = functools .partial (
118
+ torchdrift .detectors .PartialKernelMMDDriftDetector ,
119
+ method = torchdrift .detectors .PartialKernelMMDDriftDetector .METHOD_QP ,
120
+ n_perm = 100 ,
121
+ fraction_to_match = 0.5 ,
122
+ )
123
+ _test_detector_class_fit_bootstrap (pmmd_qp )
124
+
125
+ # Check that we can also just get the distance...
126
+ dd = torchdrift .detectors .PartialKernelMMDDriftDetector (
127
+ fraction_to_match = 0.5
128
+ )
129
+ x = torch .randn (5 , 5 )
130
+ y = torch .randn (5 , 5 ) + 1.0
131
+ dd .fit (x )
132
+ dd (y )
133
+
87
134
88
135
if __name__ == "__main__" :
89
136
pytest .main ([__file__ ])
0 commit comments