Skip to content

Commit ff85ff0

Browse files
committed
replace asserts by check and runtime error
1 parent 2e813f6 commit ff85ff0

File tree

11 files changed

+53
-44
lines changed

11 files changed

+53
-44
lines changed

notebooks/deployment_monitoring_example.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,6 @@
397397
"source": [
398398
"So in this notebook we saw how to use model hooks with the drift detector to automatically set of the alarm when something bad happens. Just remember that if you set the p-value to $x\\%$ you expect to get a false alarm every $100\\%/x\\%$ batches to not spam your emergency contact."
399399
]
400-
},
401-
{
402-
"cell_type": "code",
403-
"execution_count": null,
404-
"id": "careful-exposure",
405-
"metadata": {},
406-
"outputs": [],
407-
"source": []
408400
}
409401
],
410402
"metadata": {

notebooks/drift_detection_overview.ipynb

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,6 @@
282282
"\n",
283283
"To make this operational, we can get out our toolbox of classifiers, e.g. Neural Networks and Nearest-Neighbor ones, see [D. Lopez-Paz, M. Oquab: Revisiting classifier two-sample tests, ICLR 2017](https://arxiv.org/abs/1610.06545). Note that this approach can be data-intensive: To execute, we need to split the samples $x^{ref}_i$ and $x_i$ into train and test samples. When using neural networks, we also need to train the classifier, adding computational requirements. When we have enough data and time, we may hope that such a classification-based approach may be highly effective.\n"
284284
]
285-
},
286-
{
287-
"cell_type": "code",
288-
"execution_count": null,
289-
"id": "loaded-quebec",
290-
"metadata": {},
291-
"outputs": [],
292-
"source": []
293285
}
294286
],
295287
"metadata": {

test/test_corruption_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def test_gaussian_blur():
3939
a1 = torchdrift.data.functional.gaussian_blur(a, severity=5)
4040
a2 = scipy.ndimage.gaussian_filter(a, [0, 0, 6, 6])
4141
assert ((a1 - a2)[:, :, 32:-32, 32:-32]).max().abs() < 1e-2
42-
42+
with pytest.raises(RuntimeError):
43+
a3 = torchdrift.data.functional.gaussian_blur(a, severity=6)
44+
4345

4446
if __name__ == "__main__":
4547
pytest.main([__file__])

test/test_functions.py

Whitespace-only changes.

torchdrift/data/functional/corruption_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch import Tensor
2727
import torch
2828
import math
29+
import torchdrift.utils
2930

3031
__all__ = []
3132

@@ -49,7 +50,9 @@ def _export(fn):
4950

5051

5152
def interpolate_severity(img, cifar, imagenet, severity):
52-
assert severity >= 1 and severity <= 5
53+
torchdrift.utils.check(
54+
severity >= 1 and severity <= 5, "severity needs to be between 1 and 5"
55+
)
5356
length = (img.size(-1) * img.size(-2)) ** 0.5
5457
alpha = max(min((length - 32) / (224 - 32), 1), 0)
5558
res = (1 - alpha) * cifar[severity - 1] + alpha * imagenet[severity - 1]

torchdrift/detectors/detector.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torchdrift.utils
23

34

45
class Detector(torch.nn.Module):
@@ -35,7 +36,9 @@ def compute_p_value(self, inputs: torch.Tensor) -> torch.Tensor:
3536
"""Performs a statistical test for drift and returns the p-value.
3637
3738
This method calls `predict_shift_from_features` under the hood, so you only need to override that when subclassing."""
38-
assert self.base_outputs is not None, "Please call fit before compute_p_value"
39+
torchdrift.utils.check(
40+
self.base_outputs is not None, "Please call fit before compute_p_value"
41+
)
3942
_, p_value = self.predict_shift_from_features(
4043
self.base_outputs, inputs, compute_score=False, compute_p_value=True
4144
)
@@ -47,7 +50,9 @@ def forward(
4750
"""Performs a statistical test for drift and returns the score or, if `return_p_value` has been set in the constructor, the p-value.
4851
4952
This method calls `predict_shift_from_features` under the hood, so you only need to override that when subclassing."""
50-
assert self.base_outputs is not None, "Please call fit before predict_shift"
53+
torchdrift.utils.check(
54+
self.base_outputs is not None, "Please call fit before predict_shift"
55+
)
5156
ood_score, p_value = self.predict_shift_from_features(
5257
self.base_outputs,
5358
inputs,

torchdrift/detectors/ks.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy
55

66
from . import Detector
7+
import torchdrift.utils
78

89
try:
910
import numba
@@ -60,7 +61,7 @@ def ks_two_sample_multi_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6061
"""
6162
n_x, n_features = x.shape
6263
n_y, n_features_y = y.shape
63-
assert n_features == n_features_y
64+
torchdrift.utils.check(n_features == n_features_y, "feature dimension mismatch")
6465

6566
joint_sorted = torch.argsort(torch.cat([x, y], dim=0), dim=0)
6667
sign = (joint_sorted < n_x).to(dtype=torch.float) * (1 / (n_x) + 1 / (n_y)) - (
@@ -92,9 +93,9 @@ def predict_shift_from_features(
9293
compute_p_value: bool,
9394
individual_samples: bool = False,
9495
):
95-
assert (
96-
not individual_samples
97-
), "Individual samples not supported by MMD detector"
96+
torchdrift.utils.check(
97+
not individual_samples, "Individual samples not supported by MMD detector"
98+
)
9899
ood_score = ks_two_sample_multi_dim(outputs, self.base_outputs)
99100
# Like failing loudly suggests to return the minimum p-value under the
100101
# label Bonferroni correction, this would correspond to the maximum score

torchdrift/detectors/mmd.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,26 @@
33
import torch
44

55
from . import Detector
6+
import torchdrift.utils
67

78

89
class Kernel:
910
"""Base class for kernels
1011
11-
Unless otherwise noted, all kernels implementing lengthscale detection
12-
use the median of pairwise distances as the lengthscale."""
12+
Unless otherwise noted, all kernels implementing lengthscale detection
13+
use the median of pairwise distances as the lengthscale."""
14+
1315
pass
1416

1517

1618
class GaussianKernel(Kernel):
1719
r"""Unnormalized gaussian kernel
1820
19-
.. math::
20-
k(|x-y|) = \exp(-|x-y|^2/(2\ell^2))
21+
.. math::
22+
k(|x-y|) = \exp(-|x-y|^2/(2\ell^2))
23+
24+
where :math:`\ell` is the `lengthscale` (autodetected or given)."""
2125

22-
where :math:`\ell` is the `lengthscale` (autodetected or given).
23-
"""
2426
def __init__(self, lengthscale=None):
2527
super().__init__()
2628
self.lengthscale = lengthscale
@@ -37,11 +39,10 @@ def __call__(self, dists):
3739
class ExpKernel(Kernel):
3840
r"""Unnormalized exponential kernel
3941
40-
.. math::
41-
k(|x-y|) = \exp(-|x-y|/\ell)
42+
.. math::
43+
k(|x-y|) = \exp(-|x-y|/\ell)
4244
43-
where :math:`\ell` is the `lengthscale` (autodetected or given).
44-
"""
45+
where :math:`\ell` is the `lengthscale` (autodetected or given)."""
4546

4647
def __init__(self, lengthscale=None):
4748
super().__init__()
@@ -58,11 +59,10 @@ def __call__(self, dists):
5859
class RationalQuadraticKernel(Kernel):
5960
r"""Unnormalized rational quadratic kernel
6061
61-
.. math::
62-
k(|x-y|) = (1+|x-y|^2/(2 \alpha \ell^2))^{-\alpha}
62+
.. math::
63+
k(|x-y|) = (1+|x-y|^2/(2 \alpha \ell^2))^{-\alpha}
6364
64-
where :math:`\ell` is the `lengthscale` (autodetected or given).
65-
"""
65+
where :math:`\ell` is the `lengthscale` (autodetected or given)."""
6666

6767
def __init__(self, lengthscale=None, alpha=1.0):
6868
super().__init__()
@@ -95,7 +95,7 @@ def kernel_mmd(x, y, n_perm=1000, kernel=GaussianKernel()):
9595

9696
n, d = x.shape
9797
m, d2 = y.shape
98-
assert d == d2
98+
torchdrift.utils.check(d == d2, "feature dimension mismatch")
9999
xy = torch.cat([x.detach(), y.detach()], dim=0)
100100
dists = torch.cdist(xy, xy, p=2.0)
101101
# we are a bit sloppy here as we just keep the diagonal and everything twice
@@ -161,9 +161,9 @@ def predict_shift_from_features(
161161
compute_p_value: bool,
162162
individual_samples: bool = False,
163163
):
164-
assert (
165-
not individual_samples
166-
), "Individual samples not supported by MMD detector"
164+
torchdrift.utils.check(
165+
not individual_samples, "Individual samples not supported by MMD detector"
166+
)
167167
if not compute_p_value:
168168
ood_score = kernel_mmd(
169169
outputs, base_outputs, n_perm=None, kernel=self.kernel

torchdrift/reducers/pca.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from . import Reducer
3+
import torchdrift.utils
34

45

56
class PCAReducer(Reducer):
@@ -21,7 +22,10 @@ def extra_repr(self) -> str:
2122

2223
def fit(self, x: torch.Tensor) -> torch.Tensor:
2324
batch, feat = x.shape
24-
assert min(batch, feat) >= self.n_components
25+
torchdrift.utils.check(
26+
min(batch, feat) >= self.n_components,
27+
"need number of samples and size of feature to be at least the number of components",
28+
)
2529
self.mean = x.mean(0, keepdim=True)
2630
x = x - self.mean
2731
u, s, v = x.svd()

torchdrift/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
from .experiments import DriftDetectionExperiment
22
from .fit import fit
3+
4+
5+
def check(check, message):
6+
"""tests `check` and raises `RuntimeError` with `message` if false"""
7+
if not check:
8+
raise RuntimeError(message)

0 commit comments

Comments
 (0)