From e356807696975e282d54b2e523656ea55e65416c Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 22 Jun 2024 14:45:06 -0400 Subject: [PATCH] tinytqdm.set_description and tinytrange (#5101) --- examples/beautiful_cartpole.py | 2 +- examples/beautiful_mnist.py | 3 +-- examples/beautiful_mnist_multigpu.py | 3 +-- examples/gpt2.py | 3 +-- examples/mixtral.py | 3 +-- examples/mnist_gan.py | 3 +-- examples/stable_diffusion.py | 3 +-- examples/train_efficientnet.py | 3 +-- extra/datasets/fake_imagenet_from_mnist.py | 2 +- extra/training.py | 3 +-- test/external/fuzz_shapetracker_math.py | 3 +-- test/unit/test_tqdm.py | 30 ++++++++++++++++++++-- tinygrad/helpers.py | 4 +++ 13 files changed, 43 insertions(+), 22 deletions(-) diff --git a/examples/beautiful_cartpole.py b/examples/beautiful_cartpole.py index abb6a08677..5ff6d3e7ee 100644 --- a/examples/beautiful_cartpole.py +++ b/examples/beautiful_cartpole.py @@ -2,7 +2,7 @@ from typing import Tuple import time from tinygrad import Tensor, TinyJit, nn import gymnasium as gym -from tqdm import trange +from tinygrad.helpers import trange import numpy as np # TODO: remove numpy import ENVIRONMENT_NAME = 'CartPole-v1' diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 62fb7aa9fc..94a8143271 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -1,9 +1,8 @@ # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 from typing import List, Callable from tinygrad import Tensor, TinyJit, nn, GlobalCounters -from tinygrad.helpers import getenv, colored +from tinygrad.helpers import getenv, colored, trange from tinygrad.nn.datasets import mnist -from tqdm import trange class Model: def __init__(self): diff --git a/examples/beautiful_mnist_multigpu.py b/examples/beautiful_mnist_multigpu.py index 6c5cb70dcc..e649a79577 100644 --- a/examples/beautiful_mnist_multigpu.py +++ b/examples/beautiful_mnist_multigpu.py @@ -1,9 +1,8 @@ # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 from typing import List, Callable from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device -from tinygrad.helpers import getenv, colored +from tinygrad.helpers import getenv, colored, trange from extra.datasets import fetch_mnist -from tqdm import trange GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))] diff --git a/examples/gpt2.py b/examples/gpt2.py index 8226d5f048..1e073f8b93 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 from typing import Optional, Union import argparse -from tqdm import trange import numpy as np import tiktoken from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable -from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored +from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, trange from tinygrad.nn import Embedding, Linear, LayerNorm from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict diff --git a/examples/mixtral.py b/examples/mixtral.py index e1a46b2880..f2627f5dff 100644 --- a/examples/mixtral.py +++ b/examples/mixtral.py @@ -1,7 +1,6 @@ import functools, argparse, pathlib -from tqdm import tqdm from tinygrad import Tensor, nn, Device, GlobalCounters, Variable -from tinygrad.helpers import Timing, Profiling, CI +from tinygrad.helpers import Timing, Profiling, CI, tqdm from tinygrad.nn.state import torch_load, get_state_dict from extra.models.llama import FeedForward, Transformer diff --git a/examples/mnist_gan.py b/examples/mnist_gan.py index 46adde04e4..75f39a42ae 100644 --- a/examples/mnist_gan.py +++ b/examples/mnist_gan.py @@ -1,11 +1,10 @@ from pathlib import Path import numpy as np -from tqdm import trange import torch from torchvision.utils import make_grid, save_image from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv +from tinygrad.helpers import trange from tinygrad.nn import optim from extra.datasets import fetch_mnist diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 15ae87913e..5b1df9a14f 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -8,9 +8,8 @@ from collections import namedtuple from PIL import Image import numpy as np -from tqdm import tqdm from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit -from tinygrad.helpers import Timing, Context, getenv, fetch, colored +from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index e0bd2240b0..521c981182 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -2,10 +2,9 @@ import traceback import time from multiprocessing import Process, Queue import numpy as np -from tqdm import trange from tinygrad.nn.state import get_parameters from tinygrad.nn import optim -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, trange from tinygrad.tensor import Tensor from extra.datasets import fetch_cifar from extra.models.efficientnet import EfficientNet diff --git a/extra/datasets/fake_imagenet_from_mnist.py b/extra/datasets/fake_imagenet_from_mnist.py index 983ac38475..978bb86d15 100755 --- a/extra/datasets/fake_imagenet_from_mnist.py +++ b/extra/datasets/fake_imagenet_from_mnist.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import pathlib, json -from tqdm import trange +from tinygrad.helpers import trange from extra.datasets import fetch_mnist from PIL import Image import numpy as np diff --git a/extra/training.py b/extra/training.py index 9145830840..d134e68bf2 100644 --- a/extra/training.py +++ b/extra/training.py @@ -1,7 +1,6 @@ import numpy as np -from tqdm import trange from tinygrad.tensor import Tensor -from tinygrad.helpers import CI +from tinygrad.helpers import CI, trange from tinygrad.engine.jit import TinyJit diff --git a/test/external/fuzz_shapetracker_math.py b/test/external/fuzz_shapetracker_math.py index 35665a872b..663259a743 100644 --- a/test/external/fuzz_shapetracker_math.py +++ b/test/external/fuzz_shapetracker_math.py @@ -1,7 +1,6 @@ import random from typing import Tuple -from tqdm import trange -from tinygrad.helpers import getenv, DEBUG, colored +from tinygrad.helpers import getenv, DEBUG, colored, trange from tinygrad.shape.shapetracker import ShapeTracker from test.external.fuzz_shapetracker import shapetracker_ops from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad diff --git a/test/unit/test_tqdm.py b/test/unit/test_tqdm.py index 93e199fdb0..2e3a8efcdb 100644 --- a/test/unit/test_tqdm.py +++ b/test/unit/test_tqdm.py @@ -1,9 +1,9 @@ import time, random, unittest -from tqdm import tqdm from unittest.mock import patch from io import StringIO -from tinygrad.helpers import tqdm as tinytqdm from collections import namedtuple +from tqdm import tqdm +from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange class TestProgressBar(unittest.TestCase): def _compare_bars(self, bar1, bar2, cmp_prog=False): @@ -31,6 +31,7 @@ class TestProgressBar(unittest.TestCase): diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars) self.assertTrue(not cmp_prog or diff <= 1) + @patch('sys.stderr', new_callable=StringIO) @patch('shutil.get_terminal_size') def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr): @@ -56,6 +57,31 @@ class TestProgressBar(unittest.TestCase): tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test") self._compare_bars(tinytqdm_output, tqdm_output) + @patch('sys.stderr', new_callable=StringIO) + @patch('shutil.get_terminal_size') + def test_trange_output_iter(self, mock_terminal_size, mock_stderr): + for _ in range(5): + total, ncols = random.randint(5, 30), random.randint(80, 240) + mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols) + mock_stderr.truncate(0) + + # compare bars at each iteration (only when tinytqdm bar has been updated) + for n in (bar := tinytrange(total, desc="Test: ")): + time.sleep(0.01) + if bar.i % bar.skip != 0: continue + tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip() + iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0 + elapsed = n/iters_per_sec if n>0 else 0 + tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test") + self._compare_bars(tiny_output, tqdm_output) + + # compare final bars + tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip() + iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0 + elapsed = total/iters_per_sec if n>0 else 0 + tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test") + self._compare_bars(tiny_output, tqdm_output) + @patch('sys.stderr', new_callable=StringIO) @patch('shutil.get_terminal_size') def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 7ca874b512..01d56e9786 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -260,6 +260,7 @@ class tqdm: yield item self.update(1) finally: self.update(close=True) + def set_description(self, desc:str): self.desc = desc def update(self, n:int=0, close:bool=False): self.n, self.i = self.n+n, self.i+1 if (self.i % self.skip != 0 and not close) or self.dis: return @@ -276,3 +277,6 @@ class tqdm: sz = max(term-5-len(suf)-len(self.desc), 1) bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}' print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr) + +class trange(tqdm): + def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs) \ No newline at end of file