Skip to content

Commit 03a40d3

Browse files
Explicitly check num_batches is not None
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
1 parent 0fee002 commit 03a40d3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchdrift/utils/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def fit(
4343
total = None
4444

4545
for i, b in enumerate(tqdm.tqdm(dl, total=total)):
46-
if num_batches and i >= num_batches:
46+
if num_batches is not None and i >= num_batches:
4747
break
4848

4949
if not isinstance(b, torch.Tensor):

0 commit comments

Comments
 (0)