tinytqdm.set_description and tinytrange (#5101)

This commit is contained in:
chenyu
2024-06-22 14:45:06 -04:00
committed by GitHub
parent 8080298739
commit e356807696
13 changed files with 43 additions and 22 deletions

View File

@@ -2,7 +2,7 @@ from typing import Tuple
import time import time
from tinygrad import Tensor, TinyJit, nn from tinygrad import Tensor, TinyJit, nn
import gymnasium as gym import gymnasium as gym
from tqdm import trange from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import import numpy as np # TODO: remove numpy import
ENVIRONMENT_NAME = 'CartPole-v1' ENVIRONMENT_NAME = 'CartPole-v1'

View File

@@ -1,9 +1,8 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad.helpers import getenv, colored from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
from tqdm import trange
class Model: class Model:
def __init__(self): def __init__(self):

View File

@@ -1,9 +1,8 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad.helpers import getenv, colored from tinygrad.helpers import getenv, colored, trange
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist
from tqdm import trange
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))] GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]

View File

@@ -1,11 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Optional, Union from typing import Optional, Union
import argparse import argparse
from tqdm import trange
import numpy as np import numpy as np
import tiktoken import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, trange
from tinygrad.nn import Embedding, Linear, LayerNorm from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict

View File

@@ -1,7 +1,6 @@
import functools, argparse, pathlib import functools, argparse, pathlib
from tqdm import tqdm
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
from tinygrad.helpers import Timing, Profiling, CI from tinygrad.helpers import Timing, Profiling, CI, tqdm
from tinygrad.nn.state import torch_load, get_state_dict from tinygrad.nn.state import torch_load, get_state_dict
from extra.models.llama import FeedForward, Transformer from extra.models.llama import FeedForward, Transformer

View File

@@ -1,11 +1,10 @@
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from tqdm import trange
import torch import torch
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv from tinygrad.helpers import trange
from tinygrad.nn import optim from tinygrad.nn import optim
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist

View File

@@ -8,9 +8,8 @@ from collections import namedtuple
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from tqdm import tqdm
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict

View File

@@ -2,10 +2,9 @@ import traceback
import time import time
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
import numpy as np import numpy as np
from tqdm import trange
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.helpers import getenv from tinygrad.helpers import getenv, trange
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from extra.datasets import fetch_cifar from extra.datasets import fetch_cifar
from extra.models.efficientnet import EfficientNet from extra.models.efficientnet import EfficientNet

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import pathlib, json import pathlib, json
from tqdm import trange from tinygrad.helpers import trange
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist
from PIL import Image from PIL import Image
import numpy as np import numpy as np

View File

@@ -1,7 +1,6 @@
import numpy as np import numpy as np
from tqdm import trange
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import CI from tinygrad.helpers import CI, trange
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit

View File

@@ -1,7 +1,6 @@
import random import random
from typing import Tuple from typing import Tuple
from tqdm import trange from tinygrad.helpers import getenv, DEBUG, colored, trange
from tinygrad.helpers import getenv, DEBUG, colored
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from test.external.fuzz_shapetracker import shapetracker_ops from test.external.fuzz_shapetracker import shapetracker_ops
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad

View File

@@ -1,9 +1,9 @@
import time, random, unittest import time, random, unittest
from tqdm import tqdm
from unittest.mock import patch from unittest.mock import patch
from io import StringIO from io import StringIO
from tinygrad.helpers import tqdm as tinytqdm
from collections import namedtuple from collections import namedtuple
from tqdm import tqdm
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
class TestProgressBar(unittest.TestCase): class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2, cmp_prog=False): def _compare_bars(self, bar1, bar2, cmp_prog=False):
@@ -31,6 +31,7 @@ class TestProgressBar(unittest.TestCase):
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars) diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
self.assertTrue(not cmp_prog or diff <= 1) self.assertTrue(not cmp_prog or diff <= 1)
@patch('sys.stderr', new_callable=StringIO) @patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size') @patch('shutil.get_terminal_size')
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr): def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
@@ -56,6 +57,31 @@ class TestProgressBar(unittest.TestCase):
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test") tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output) self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
for _ in range(5):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytrange(total, desc="Test: ")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
# compare final bars
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO) @patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size') @patch('shutil.get_terminal_size')
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr): def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):

View File

@@ -260,6 +260,7 @@ class tqdm:
yield item yield item
self.update(1) self.update(1)
finally: self.update(close=True) finally: self.update(close=True)
def set_description(self, desc:str): self.desc = desc
def update(self, n:int=0, close:bool=False): def update(self, n:int=0, close:bool=False):
self.n, self.i = self.n+n, self.i+1 self.n, self.i = self.n+n, self.i+1
if (self.i % self.skip != 0 and not close) or self.dis: return if (self.i % self.skip != 0 and not close) or self.dis: return
@@ -276,3 +277,6 @@ class tqdm:
sz = max(term-5-len(suf)-len(self.desc), 1) sz = max(term-5-len(suf)-len(self.desc), 1)
bar = f'\r{self.desc}{round(100*prog):3}%|{""*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}' bar = f'\r{self.desc}{round(100*prog):3}%|{""*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr) print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
class trange(tqdm):
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)