Skip to content

Commit 6b83f5c

Browse files
committed
fix randomness in test
1 parent b8ff1cd commit 6b83f5c

File tree

5 files changed

+48
-18
lines changed

5 files changed

+48
-18
lines changed

test/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
21
import os
3-
os.environ['NUMBA_DISABLE_JIT'] = '1'
2+
3+
os.environ["NUMBA_DISABLE_JIT"] = "1"

test/test_corruption_functions.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,40 @@
55

66
# these tests could be made more specific, eventually
77

8+
89
def test_gaussian_noise():
910
a = torch.rand(4, 3, 32, 32)
1011
b = torch.rand(1, 3, 200, 200)
1112
a1 = torchdrift.data.functional.gaussian_noise(a)
1213
b1 = torchdrift.data.functional.gaussian_noise(b)
1314
assert (a1 == a1.clamp(min=0, max=1)).all()
14-
assert (a-a1).std() < (b-b1).std()
15+
assert (a - a1).std() < (b - b1).std()
16+
1517

1618
def test_shot_noise():
1719
a = torch.rand(4, 3, 100, 100)
1820
a1 = torchdrift.data.functional.shot_noise(a, severity=4)
1921
a2 = torchdrift.data.functional.gaussian_noise(a, severity=4)
20-
assert (a1-a).abs().max() < (a2-a).abs().max()
21-
22+
assert (a1 - a).abs().max() < (a2 - a).abs().max()
23+
24+
2225
def test_impulse_noise():
23-
a = torch.rand(4, 3, 100, 100)/2+0.25
26+
a = torch.rand(4, 3, 100, 100) / 2 + 0.25
2427
a1 = torchdrift.data.functional.impulse_noise(a, severity=4)
2528
assert a1.min().abs() + (1 - a1.max()).abs() < 1e-6
2629

30+
2731
def test_speckle_noise():
28-
a = torch.rand(4, 3, 100, 100)/2+0.25
32+
a = torch.rand(4, 3, 100, 100) / 2 + 0.25
2933
a1 = torchdrift.data.functional.speckle_noise(a, severity=4)
3034
assert a1.min().abs() + (1 - a1.max()).abs() < 1e-6
3135

36+
3237
def test_gaussian_blur():
3338
a = torch.rand(1, 3, 300, 300)
3439
a1 = torchdrift.data.functional.gaussian_blur(a, severity=5)
3540
a2 = scipy.ndimage.gaussian_filter(a, [0, 0, 6, 6])
36-
assert ((a1-a2)[:,:, 32:-32, 32:-32]).max().abs() < 1e-2
41+
assert ((a1 - a2)[:, :, 32:-32, 32:-32]).max().abs() < 1e-2
3742

3843

3944
if __name__ == "__main__":

test/test_detectors.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import torch
44
import sklearn.decomposition
55

6+
67
def test_detector():
78
x = torch.randn(5, 5)
89
d = torchdrift.detectors.Detector()
910
d.fit(x)
1011
with pytest.raises(NotImplementedError):
1112
d(x)
1213

14+
1315
def _test_detector_class(cls):
1416
torch.manual_seed(1234)
1517
d = cls()
@@ -18,10 +20,15 @@ def _test_detector_class(cls):
1820
y = torch.randn(5, 5) + 1.0
1921
d.fit(x)
2022
d2.fit(x)
21-
assert (d(x).item() < d(y).item())
23+
assert d(x).item() < d(y).item()
2224
assert d.compute_p_value(x) > 0.80
2325
assert d.compute_p_value(y) < 0.05
24-
assert d.compute_p_value(y) == d2(y)
26+
torch.manual_seed(1234)
27+
p1 = d.compute_p_value(y)
28+
torch.manual_seed(1234)
29+
p2 = d2(y)
30+
assert p1 == p2
31+
2532

2633
def test_ksdetector():
2734
_test_detector_class(torchdrift.detectors.KSDriftDetector)
@@ -33,17 +40,21 @@ def _test_mmd_kernel(kernel):
3340
x = torch.randn(5, 5)
3441
y = torch.randn(5, 5) + 1.0
3542
d.fit(x)
36-
assert (d(x).item() < d(y).item())
43+
assert d(x).item() < d(y).item()
3744
assert d.compute_p_value(x) > 0.80
3845
assert d.compute_p_value(y) < 0.05
3946

47+
4048
def test_mmddetector():
4149
_test_detector_class(torchdrift.detectors.KernelMMDDriftDetector)
4250
_test_mmd_kernel(torchdrift.detectors.mmd.GaussianKernel(lengthscale=1.0))
4351
_test_mmd_kernel(torchdrift.detectors.mmd.ExpKernel())
4452
_test_mmd_kernel(torchdrift.detectors.mmd.ExpKernel(lengthscale=1.0))
4553
_test_mmd_kernel(torchdrift.detectors.mmd.RationalQuadraticKernel())
46-
_test_mmd_kernel(torchdrift.detectors.mmd.RationalQuadraticKernel(lengthscale=1.0, alpha=2.0))
54+
_test_mmd_kernel(
55+
torchdrift.detectors.mmd.RationalQuadraticKernel(lengthscale=1.0, alpha=2.0)
56+
)
57+
4758

4859
if __name__ == "__main__":
4960
pytest.main([__file__])

test/test_reducers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import torch
44
import sklearn.decomposition
55

6+
67
def test_pca():
78
pca = torchdrift.reducers.PCAReducer(n_components=2)
8-
assert 'n_components' in str(pca)
9+
assert "n_components" in str(pca)
910
a = torch.randn(100, 50, dtype=torch.double)
1011
red = pca.fit(a)
1112
pca_ref = sklearn.decomposition.PCA(n_components=2)
@@ -16,6 +17,7 @@ def test_pca():
1617
red2 = pca(b)
1718
red2_ref = torch.from_numpy(pca_ref.transform(b))
1819

20+
1921
def test_reducer():
2022
x = torch.randn(5, 5)
2123
r = torchdrift.reducers.Reducer()
@@ -24,5 +26,6 @@ def test_reducer():
2426
with pytest.raises(NotImplementedError):
2527
r(x)
2628

29+
2730
if __name__ == "__main__":
2831
pytest.main([__file__])

test/test_utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.utils.data
55

6+
67
class TensorDataModule:
78
def __init__(self, *args):
89
self.ds = torch.utils.data.TensorDataset(*args)
@@ -13,19 +14,27 @@ def default_dataloader(self, batch_size=None, num_samples=None, shuffle=True):
1314
batch_size = self.val_batch_size
1415
replacement = num_samples is not None
1516
if shuffle:
16-
sampler = torch.utils.data.RandomSampler(dataset, replacement=replacement, num_samples=num_samples)
17+
sampler = torch.utils.data.RandomSampler(
18+
dataset, replacement=replacement, num_samples=num_samples
19+
)
1720
else:
1821
sampler = None
19-
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler)
22+
return torch.utils.data.DataLoader(
23+
dataset, batch_size=batch_size, sampler=sampler
24+
)
25+
2026

2127
def test_fit():
2228
dm_ref = TensorDataModule(torch.randn(500, 5))
2329
d = torchdrift.detectors.KernelMMDDriftDetector()
2430
torchdrift.utils.fit(
25-
dm_ref.default_dataloader(batch_size=10), torch.nn.Identity(),
31+
dm_ref.default_dataloader(batch_size=10),
32+
torch.nn.Identity(),
2633
[torch.nn.Identity(), d],
2734
num_batches=3,
28-
device='cpu')
35+
device="cpu",
36+
)
37+
2938

3039
def test_experiment():
3140
torch.manual_seed(1234)
@@ -34,7 +43,8 @@ def test_experiment():
3443
dm_x = TensorDataModule(torch.randn(500, 5))
3544
dm_y = TensorDataModule(torch.randn(500, 5) + 1)
3645
experiment = torchdrift.utils.DriftDetectionExperiment(
37-
d, torch.nn.Linear(5, 5),
46+
d,
47+
torch.nn.Linear(5, 5),
3848
)
3949
experiment.post_training(torch.utils.data.DataLoader(dm_ref.ds, batch_size=100))
4050
experiment.evaluate(dm_x, dm_y)
@@ -44,5 +54,6 @@ def test_experiment():
4454
experiment.post_training(torch.utils.data.DataLoader(dm_ref.ds, batch_size=100))
4555
experiment.evaluate(dm_x, dm_y)
4656

57+
4758
if __name__ == "__main__":
4859
pytest.main([__file__])

0 commit comments

Comments
 (0)