Skip to content

Commit 3d590e8

Browse files
authored
Merge pull request #16 from alessiamarcolini/iterable-dataset
Allow iterable dataset for detector fit
2 parents d828e41 + 03a40d3 commit 3d590e8

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

test/test_utils.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,48 @@
44
import torch.utils.data
55

66

7+
class DummyIterableDataset(torch.utils.data.IterableDataset):
8+
def __init__(self, *args) -> None:
9+
super().__init__()
10+
self.args = args
11+
12+
def __iter__(self):
13+
return iter(*self.args)
14+
15+
716
class TensorDataModule:
8-
def __init__(self, *args):
9-
self.ds = torch.utils.data.TensorDataset(*args)
17+
def __init__(self, *args, ds_type="map"):
18+
self.ds_type = ds_type
19+
20+
if ds_type == "map":
21+
self.ds = torch.utils.data.TensorDataset(*args)
22+
else:
23+
self.ds = DummyIterableDataset(*args)
1024

1125
def default_dataloader(self, batch_size=None, num_samples=None, shuffle=True):
12-
dataset = self.ds
1326
if batch_size is None:
1427
batch_size = self.val_batch_size
1528
replacement = num_samples is not None
16-
if shuffle:
29+
if shuffle and self.ds_type == "map":
1730
sampler = torch.utils.data.RandomSampler(
18-
dataset, replacement=replacement, num_samples=num_samples
31+
self.ds, replacement=replacement, num_samples=num_samples
1932
)
2033
else:
2134
sampler = None
2235
return torch.utils.data.DataLoader(
23-
dataset, batch_size=batch_size, sampler=sampler
36+
self.ds, batch_size=batch_size, sampler=sampler
2437
)
2538

26-
27-
def test_fit():
28-
dm_ref = TensorDataModule(torch.randn(500, 5))
39+
@pytest.mark.parametrize("num_batches", (3, None))
40+
@pytest.mark.parametrize("ds_type", ("map", "iterable"))
41+
def test_fit(ds_type, num_batches):
42+
dm_ref = TensorDataModule(torch.randn(500, 5), ds_type=ds_type)
2943
d = torchdrift.detectors.KernelMMDDriftDetector()
3044
torchdrift.utils.fit(
3145
dm_ref.default_dataloader(batch_size=10),
3246
torch.nn.Identity(),
3347
[torch.nn.Identity(), d],
34-
num_batches=3,
48+
num_batches=num_batches,
3549
device="cpu",
3650
)
3751

torchdrift/utils/fit.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,24 @@ def fit(
3333

3434
all_outputs = []
3535
# dl = torch.utils.data.DataLoader(ref_ds, batch_size=batch_size, shuffle=True)
36-
nb = len(dl)
37-
if num_batches is not None:
38-
nb = min(nb, num_batches)
39-
for i, b in tqdm.tqdm(zip(range(nb), dl), total=nb):
36+
37+
if hasattr(dl.dataset, "__len__"):
38+
nb = len(dl)
39+
if num_batches is not None:
40+
nb = min(nb, num_batches)
41+
total = nb
42+
else:
43+
total = None
44+
45+
for i, b in enumerate(tqdm.tqdm(dl, total=total)):
46+
if num_batches is not None and i >= num_batches:
47+
break
48+
4049
if not isinstance(b, torch.Tensor):
4150
b = b[0]
4251
with torch.no_grad():
4352
all_outputs.append(feature_extractor(b.to(device)))
53+
4454
all_outputs = torch.cat(all_outputs, dim=0)
4555

4656
for m in reducers_detectors:

0 commit comments

Comments
 (0)