Skip to content

Commit d20f1c0

Browse files
committed
update docstrings
1 parent a8d9171 commit d20f1c0

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

torchdrift/reducers/reducer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
import torch
22

33
class Reducer(torch.nn.Module):
4-
"""Base class for reducers"""
4+
"""Base class for reducers.
5+
6+
This is a `torch.nn.Module` with an additional `fit` method.
7+
The usual forward is for testing after fitting."""
58

69
def fit(self, x: torch.Tensor) -> torch.Tensor:
10+
"""Fits the reducer to reference data `x` and returns the reduced
11+
data.
12+
13+
Override this in your reducer implementation.
14+
"""
715
raise NotImplementedError("Override fit in subclass")
816

917
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
"""Reduces the input `x` (in testing) and returns the reduced data.
19+
20+
Override this in your reducer implementation.
21+
"""
1022
raise NotImplementedError("Override forward in subclass")

torchdrift/utils/fit.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@ def fit(
1010
dl: torch.utils.data.DataLoader,
1111
feature_extractor: torch.nn.Module,
1212
reducers_detectors: Union[Reducer, Detector, List[Union[Reducer, Detector]]],
13+
*,
1314
num_batches: Optional[int] = None,
1415
device: Union[torch.device, str, None] = None
1516
):
1617
"""Train drift detector on reference distribution.
18+
19+
The dataloader `dl` should provide the reference distribution. Optionally you can limit the number of batches sampled from the dataloader with `num_batches`.
20+
21+
The `feature extractor` can be any module be anything that does not need to be fit.
22+
23+
The reducers and detectors should be passed (in the order they should be applied, one takes the output from the previous) as a list. A single detector or reducer can also be passed.
24+
25+
If you provide a `device`, data is moved there before running through the
26+
feature extractor, otherwise the functions try to infer the device from the `feature_extractor`.
1727
"""
1828
if not isinstance(reducers_detectors, typing.Iterable):
1929
reducers_detectors = [reducers_detectors]

0 commit comments

Comments
 (0)