diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4a9fa15019..5eeaa24023 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -292,9 +292,9 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m # *** tqdm class tqdm: - def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=0, rate:int=100): + def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100): self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate - self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, total or getattr(iterable,"__len__",lambda:0)() + self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total self.update(0) def __iter__(self): for item in self.iter: