mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
tinytqdm.set_description and tinytrange (#5101)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Tuple
|
|||||||
import time
|
import time
|
||||||
from tinygrad import Tensor, TinyJit, nn
|
from tinygrad import Tensor, TinyJit, nn
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from tqdm import trange
|
from tinygrad.helpers import trange
|
||||||
import numpy as np # TODO: remove numpy import
|
import numpy as np # TODO: remove numpy import
|
||||||
|
|
||||||
ENVIRONMENT_NAME = 'CartPole-v1'
|
ENVIRONMENT_NAME = 'CartPole-v1'
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||||
from typing import List, Callable
|
from typing import List, Callable
|
||||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||||
from tinygrad.helpers import getenv, colored
|
from tinygrad.helpers import getenv, colored, trange
|
||||||
from tinygrad.nn.datasets import mnist
|
from tinygrad.nn.datasets import mnist
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||||
from typing import List, Callable
|
from typing import List, Callable
|
||||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
||||||
from tinygrad.helpers import getenv, colored
|
from tinygrad.helpers import getenv, colored, trange
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
|
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import argparse
|
import argparse
|
||||||
from tqdm import trange
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
|
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
|
||||||
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored
|
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, trange
|
||||||
from tinygrad.nn import Embedding, Linear, LayerNorm
|
from tinygrad.nn import Embedding, Linear, LayerNorm
|
||||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import functools, argparse, pathlib
|
import functools, argparse, pathlib
|
||||||
from tqdm import tqdm
|
|
||||||
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
|
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
|
||||||
from tinygrad.helpers import Timing, Profiling, CI
|
from tinygrad.helpers import Timing, Profiling, CI, tqdm
|
||||||
from tinygrad.nn.state import torch_load, get_state_dict
|
from tinygrad.nn.state import torch_load, get_state_dict
|
||||||
from extra.models.llama import FeedForward, Transformer
|
from extra.models.llama import FeedForward, Transformer
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import trange
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.utils import make_grid, save_image
|
from torchvision.utils import make_grid, save_image
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import trange
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ from collections import namedtuple
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
|
||||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored
|
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
||||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ import traceback
|
|||||||
import time
|
import time
|
||||||
from multiprocessing import Process, Queue
|
from multiprocessing import Process, Queue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import trange
|
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv, trange
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from extra.datasets import fetch_cifar
|
from extra.datasets import fetch_cifar
|
||||||
from extra.models.efficientnet import EfficientNet
|
from extra.models.efficientnet import EfficientNet
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import pathlib, json
|
import pathlib, json
|
||||||
from tqdm import trange
|
from tinygrad.helpers import trange
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import trange
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import CI
|
from tinygrad.helpers import CI, trange
|
||||||
from tinygrad.engine.jit import TinyJit
|
from tinygrad.engine.jit import TinyJit
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
3
test/external/fuzz_shapetracker_math.py
vendored
3
test/external/fuzz_shapetracker_math.py
vendored
@@ -1,7 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from tqdm import trange
|
from tinygrad.helpers import getenv, DEBUG, colored, trange
|
||||||
from tinygrad.helpers import getenv, DEBUG, colored
|
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from test.external.fuzz_shapetracker import shapetracker_ops
|
from test.external.fuzz_shapetracker import shapetracker_ops
|
||||||
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad
|
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import time, random, unittest
|
import time, random, unittest
|
||||||
from tqdm import tqdm
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from tinygrad.helpers import tqdm as tinytqdm
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from tqdm import tqdm
|
||||||
|
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
|
||||||
|
|
||||||
class TestProgressBar(unittest.TestCase):
|
class TestProgressBar(unittest.TestCase):
|
||||||
def _compare_bars(self, bar1, bar2, cmp_prog=False):
|
def _compare_bars(self, bar1, bar2, cmp_prog=False):
|
||||||
@@ -31,6 +31,7 @@ class TestProgressBar(unittest.TestCase):
|
|||||||
|
|
||||||
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
|
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
|
||||||
self.assertTrue(not cmp_prog or diff <= 1)
|
self.assertTrue(not cmp_prog or diff <= 1)
|
||||||
|
|
||||||
@patch('sys.stderr', new_callable=StringIO)
|
@patch('sys.stderr', new_callable=StringIO)
|
||||||
@patch('shutil.get_terminal_size')
|
@patch('shutil.get_terminal_size')
|
||||||
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
|
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
|
||||||
@@ -56,6 +57,31 @@ class TestProgressBar(unittest.TestCase):
|
|||||||
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||||
|
|
||||||
|
@patch('sys.stderr', new_callable=StringIO)
|
||||||
|
@patch('shutil.get_terminal_size')
|
||||||
|
def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
|
||||||
|
for _ in range(5):
|
||||||
|
total, ncols = random.randint(5, 30), random.randint(80, 240)
|
||||||
|
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
|
||||||
|
mock_stderr.truncate(0)
|
||||||
|
|
||||||
|
# compare bars at each iteration (only when tinytqdm bar has been updated)
|
||||||
|
for n in (bar := tinytrange(total, desc="Test: ")):
|
||||||
|
time.sleep(0.01)
|
||||||
|
if bar.i % bar.skip != 0: continue
|
||||||
|
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
|
||||||
|
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
|
||||||
|
elapsed = n/iters_per_sec if n>0 else 0
|
||||||
|
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||||
|
self._compare_bars(tiny_output, tqdm_output)
|
||||||
|
|
||||||
|
# compare final bars
|
||||||
|
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
|
||||||
|
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
|
||||||
|
elapsed = total/iters_per_sec if n>0 else 0
|
||||||
|
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||||
|
self._compare_bars(tiny_output, tqdm_output)
|
||||||
|
|
||||||
@patch('sys.stderr', new_callable=StringIO)
|
@patch('sys.stderr', new_callable=StringIO)
|
||||||
@patch('shutil.get_terminal_size')
|
@patch('shutil.get_terminal_size')
|
||||||
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
|
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ class tqdm:
|
|||||||
yield item
|
yield item
|
||||||
self.update(1)
|
self.update(1)
|
||||||
finally: self.update(close=True)
|
finally: self.update(close=True)
|
||||||
|
def set_description(self, desc:str): self.desc = desc
|
||||||
def update(self, n:int=0, close:bool=False):
|
def update(self, n:int=0, close:bool=False):
|
||||||
self.n, self.i = self.n+n, self.i+1
|
self.n, self.i = self.n+n, self.i+1
|
||||||
if (self.i % self.skip != 0 and not close) or self.dis: return
|
if (self.i % self.skip != 0 and not close) or self.dis: return
|
||||||
@@ -276,3 +277,6 @@ class tqdm:
|
|||||||
sz = max(term-5-len(suf)-len(self.desc), 1)
|
sz = max(term-5-len(suf)-len(self.desc), 1)
|
||||||
bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
|
bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
|
||||||
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
|
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
|
||||||
|
|
||||||
|
class trange(tqdm):
|
||||||
|
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
||||||
Reference in New Issue
Block a user