From 416f838a1a7230d84dd6d2bb2b023cf360e939b2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 11 Jul 2024 10:30:12 -0400 Subject: [PATCH] hotfix tqdm respects total=0 if set (#5380) if you insist total=0, it should use 0 instead of inferring from iterable. matched tqdm --- tinygrad/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4a9fa15019..5eeaa24023 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -292,9 +292,9 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m # *** tqdm class tqdm: - def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=0, rate:int=100): + def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100): self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate - self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, total or getattr(iterable,"__len__",lambda:0)() + self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total self.update(0) def __iter__(self): for item in self.iter: