Skip to content

Commit 0fee002

Browse files
author
alessiamarcolini
committed
Allow num_batches to be None
1 parent 2248181 commit 0fee002

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

test/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ def default_dataloader(self, batch_size=None, num_samples=None, shuffle=True):
3636
self.ds, batch_size=batch_size, sampler=sampler
3737
)
3838

39-
39+
@pytest.mark.parametrize("num_batches", (3, None))
4040
@pytest.mark.parametrize("ds_type", ("map", "iterable"))
41-
def test_fit(ds_type):
41+
def test_fit(ds_type, num_batches):
4242
dm_ref = TensorDataModule(torch.randn(500, 5), ds_type=ds_type)
4343
d = torchdrift.detectors.KernelMMDDriftDetector()
4444
torchdrift.utils.fit(
4545
dm_ref.default_dataloader(batch_size=10),
4646
torch.nn.Identity(),
4747
[torch.nn.Identity(), d],
48-
num_batches=3,
48+
num_batches=num_batches,
4949
device="cpu",
5050
)
5151

torchdrift/utils/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def fit(
4343
total = None
4444

4545
for i, b in enumerate(tqdm.tqdm(dl, total=total)):
46-
if i >= num_batches:
46+
if num_batches and i >= num_batches:
4747
break
48-
48+
4949
if not isinstance(b, torch.Tensor):
5050
b = b[0]
5151
with torch.no_grad():

0 commit comments

Comments
 (0)