Skip to content

Commit a8d9171

Browse files
committed
fit reducers
1 parent c39cd65 commit a8d9171

File tree

11 files changed

+137
-91
lines changed

11 files changed

+137
-91
lines changed

notebooks/drift_detection_on_images.ipynb

Lines changed: 61 additions & 52 deletions
Large diffs are not rendered by default.

torchdrift/detectors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .detector import DriftDetector
1+
from .detector import Detector
22
from .mmd import kernel_mmd, KernelMMDDriftDetector
33
from .ks import ks_two_sample_multi_dim, KSDriftDetector, ks_p_value

torchdrift/detectors/detector.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
from typing import Optional, Callable
2-
import tqdm
3-
41
import torch
52

6-
class DriftDetector(torch.nn.Module):
3+
class Detector(torch.nn.Module):
74
"""Detector class.
85
96
The detector is is a `nn.Module` subclass that, after fitting, performs a drift test when called and returns a score or p-value.
@@ -16,28 +13,10 @@ def __init__(self, *, return_p_value: bool=False):
1613
self.register_buffer('base_outputs', None)
1714
self.return_p_value = return_p_value
1815

19-
def fit(
20-
self,
21-
ref_ds: torch.utils.data.Dataset,
22-
feature_extractor: torch.nn.Module,
23-
batch_size: int = 32,
24-
num_batches: Optional[int] = None,
25-
):
26-
"""Train drift detector on reference distribution.
27-
"""
28-
29-
feature_extractor.eval() # careful about test time dropout
30-
device = next(feature_extractor.parameters()).device
31-
all_outputs = []
32-
dl = torch.utils.data.DataLoader(ref_ds, batch_size=batch_size, shuffle=True)
33-
nb = len(dl)
34-
if num_batches is not None:
35-
nb = min(nb, num_batches)
36-
for i, (b, _) in tqdm.tqdm(zip(range(nb), dl), total=nb):
37-
with torch.no_grad():
38-
all_outputs.append(feature_extractor(b.to(device)))
39-
all_outputs = torch.cat(all_outputs, dim=0)
40-
self.base_outputs = all_outputs
16+
def fit(self, x: torch.Tensor):
17+
"""Record a sample as the reference distribution"""
18+
self.base_outputs = x.detach()
19+
return x
4120

4221
def predict_shift_from_features(self, base_outputs: torch.Tensor, outputs: torch.Tensor, compute_score: bool, compute_p_value: bool, individual_samples: bool = False) -> torch.Tensor:
4322
"""stub to be overridden by subclasses"""
@@ -65,3 +44,4 @@ def forward(
6544
if self.return_p_value:
6645
return p_value
6746
return ood_score
47+

torchdrift/detectors/ks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import numpy
55

6-
from . import DriftDetector
6+
from . import Detector
77

88
try:
99
import numba
@@ -64,7 +64,7 @@ def ks_two_sample_multi_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6464
ks_scores = sign.cumsum(0).abs().max(0).values
6565
return ks_scores
6666

67-
class KSDriftDetector(DriftDetector):
67+
class KSDriftDetector(Detector):
6868
"""Drift detector based on (multiple) Kolmogorov-Smirnov tests.
6969
7070
This detector uses the Kolmogorov-Smirnov test on the marginals of the features

torchdrift/detectors/mmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from . import DriftDetector
5+
from . import Detector
66

77
def kernel_mmd(x, y, n_perm=1000):
88
"""Implements the kernel MMD two-sample test.
@@ -59,7 +59,7 @@ def kernel_mmd(x, y, n_perm=1000):
5959
return mmd, p_val
6060

6161

62-
class KernelMMDDriftDetector(DriftDetector):
62+
class KernelMMDDriftDetector(Detector):
6363
"""Drift detector based on the kernel Maximum Mean Discrepancy (MMD) test.
6464
6565
This is modelled after the MMD drift detection in

torchdrift/reducers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .reducer import Reducer
12
from .pca import PCAReducer

torchdrift/reducers/pca.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
2+
from . import Reducer
23

3-
class PCAReducer(torch.nn.Module):
4+
class PCAReducer(Reducer):
45
"""Reduce dimensions using PCA.
56
67
This nn.Modue subclass reduces the dimensions of the inputs
@@ -13,14 +14,20 @@ def __init__(self, n_components:int = 2):
1314
super().__init__()
1415
self.n_components = n_components
1516

16-
def extra_repr(self):
17+
def extra_repr(self) -> str:
1718
return f'n_components={self.n_components}'
1819

19-
def forward(self, x: torch.Tensor):
20+
def fit(self, x: torch.Tensor) -> torch.Tensor:
2021
batch, feat = x.shape
2122
assert min(batch, feat) >= self.n_components
22-
x = x - x.mean(1, keepdim=True)
23+
self.mean = x.mean(0, keepdim=True)
24+
x = x - self.mean
2325
u, s, v = x.svd()
24-
comp = v[:, :self.n_components]
25-
reduced = x @ comp
26+
self.comp = v[:, :self.n_components]
27+
reduced = x @ self.comp
28+
return reduced
29+
30+
def forward(self, x: torch.Tensor) -> torch.Tensor:
31+
x = x - self.mean
32+
reduced = x @ self.comp
2633
return reduced

torchdrift/reducers/reducer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
class Reducer(torch.nn.Module):
4+
"""Base class for reducers"""
5+
6+
def fit(self, x: torch.Tensor) -> torch.Tensor:
7+
raise NotImplementedError("Override fit in subclass")
8+
9+
def forward(self, x: torch.Tensor) -> torch.Tensor:
10+
raise NotImplementedError("Override forward in subclass")

torchdrift/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .experiments import DriftDetectionExperiment
2+
from .fit import fit

torchdrift/utils/experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import tqdm
3-
3+
from .fit import fit
44

55
class DriftDetectionExperiment:
66
"""An experimental setup to explore the ROC of drift detection setups
@@ -23,7 +23,7 @@ def __init__(self, drift_detector, feature_extractor, ood_ratio=1.0, sample_size
2323
# def extra_loss(self, ...): add components to loss from training the detector
2424
def post_training(self, train_dataloader):
2525
"Called after training the main model, fits the drift detector."
26-
self.drift_detector.fit(train_dataloader, self.feature_extractor)
26+
fit(train_dataloader, self.feature_extractor, self.drift_detector)
2727

2828
def evaluate(self, ind_datamodule, ood_datamodule, num_runs=50):
2929
"""runs the experiment (`num_runs` inputs)

0 commit comments

Comments
 (0)