|
4 | 4 | import torch.utils.data
|
5 | 5 |
|
6 | 6 |
|
| 7 | +class DummyIterableDataset(torch.utils.data.IterableDataset): |
| 8 | + def __init__(self, *args) -> None: |
| 9 | + super().__init__() |
| 10 | + self.args = args |
| 11 | + |
| 12 | + def __iter__(self): |
| 13 | + return iter(*self.args) |
| 14 | + |
| 15 | + |
7 | 16 | class TensorDataModule:
|
8 |
| - def __init__(self, *args): |
9 |
| - self.ds = torch.utils.data.TensorDataset(*args) |
| 17 | + def __init__(self, *args, ds_type="map"): |
| 18 | + self.ds_type = ds_type |
| 19 | + |
| 20 | + if ds_type == "map": |
| 21 | + self.ds = torch.utils.data.TensorDataset(*args) |
| 22 | + else: |
| 23 | + self.ds = DummyIterableDataset(*args) |
10 | 24 |
|
11 | 25 | def default_dataloader(self, batch_size=None, num_samples=None, shuffle=True):
|
12 |
| - dataset = self.ds |
13 | 26 | if batch_size is None:
|
14 | 27 | batch_size = self.val_batch_size
|
15 | 28 | replacement = num_samples is not None
|
16 |
| - if shuffle: |
| 29 | + if shuffle and self.ds_type == "map": |
17 | 30 | sampler = torch.utils.data.RandomSampler(
|
18 |
| - dataset, replacement=replacement, num_samples=num_samples |
| 31 | + self.ds, replacement=replacement, num_samples=num_samples |
19 | 32 | )
|
20 | 33 | else:
|
21 | 34 | sampler = None
|
22 | 35 | return torch.utils.data.DataLoader(
|
23 |
| - dataset, batch_size=batch_size, sampler=sampler |
| 36 | + self.ds, batch_size=batch_size, sampler=sampler |
24 | 37 | )
|
25 | 38 |
|
26 | 39 |
|
27 |
| -def test_fit(): |
28 |
| - dm_ref = TensorDataModule(torch.randn(500, 5)) |
| 40 | +@pytest.mark.parametrize("ds_type", ("map", "iterable")) |
| 41 | +def test_fit(ds_type): |
| 42 | + dm_ref = TensorDataModule(torch.randn(500, 5), ds_type=ds_type) |
29 | 43 | d = torchdrift.detectors.KernelMMDDriftDetector()
|
30 | 44 | torchdrift.utils.fit(
|
31 | 45 | dm_ref.default_dataloader(batch_size=10),
|
|
0 commit comments