mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix tqdm unit_scale and support hours in time (#5227)
* fix tqdm unit_scale and support hours in time previously it only supports MM:SS. more chars to unitscales, strip trailing "." and " " in formatting, and more tests * simpler
This commit is contained in:
@@ -4,6 +4,7 @@ from io import StringIO
|
||||
from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
|
||||
import numpy as np
|
||||
|
||||
class TestProgressBar(unittest.TestCase):
|
||||
def _compare_bars(self, bar1, bar2, cmp_prog=False):
|
||||
@@ -13,14 +14,14 @@ class TestProgressBar(unittest.TestCase):
|
||||
self.assertEqual(len(bar1), len(bar2))
|
||||
self.assertEqual(prefix1, prefix2)
|
||||
|
||||
def parse_timer(timer): return sum([int(timer.split(":")[0])*60, int(timer.split(":")[1])])
|
||||
def parse_timer(timer): return sum(int(x) * y for x, y in zip(timer.split(':')[::-1], (1, 60, 3600)))
|
||||
|
||||
if "?" not in suffix1 and "?" not in suffix2:
|
||||
# allow for few sec diff in timers (removes flakiness)
|
||||
timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
|
||||
timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
|
||||
self.assertTrue(abs(timer1 - timer2) <= 5)
|
||||
self.assertTrue(abs(rm1 - rm2) <= 5)
|
||||
np.testing.assert_allclose(timer1, timer2, atol=5, rtol=1e-2)
|
||||
np.testing.assert_allclose(rm1, rm2, atol=5, rtol=1e-2)
|
||||
|
||||
# get suffix without timers
|
||||
suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
|
||||
@@ -57,6 +58,30 @@ class TestProgressBar(unittest.TestCase):
|
||||
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_unit_scale(self, mock_terminal_size, mock_stderr):
|
||||
for unit_scale in [True, False]:
|
||||
# NOTE: numpy comparison raises TypeError if exponent > 22
|
||||
for exponent in range(1, 22, 3):
|
||||
low, high = 10 ** exponent, 10 ** (exponent+1)
|
||||
for _ in range(3):
|
||||
total, ncols = random.randint(low, high), random.randint(80, 240)
|
||||
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
|
||||
mock_stderr.truncate(0)
|
||||
|
||||
# compare bars at each iteration (only when tinytqdm bar has been updated)
|
||||
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale):
|
||||
time.sleep(0.01)
|
||||
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
|
||||
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
|
||||
elapsed = n/iters_per_sec if n>0 else 0
|
||||
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
|
||||
# print(f"tiny: {tinytqdm_output}")
|
||||
# print(f"tqdm: {tqdm_output}")
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
if n > 3: break
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_set_description(self, mock_terminal_size, mock_stderr):
|
||||
|
||||
@@ -294,9 +294,9 @@ class tqdm:
|
||||
if (self.i % self.skip != 0 and not close) or self.dis: return
|
||||
prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
||||
if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
|
||||
def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
|
||||
def fmt(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([t//3600,t%3600//60,t%60]) if i or x) if (t:=int(t)) != -1 else '?'
|
||||
def scl(x): return x/1000**int(math.log(x,1000))
|
||||
def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(' kMGTP'[int(math.log(x,1000))]) if x else '0.00')
|
||||
def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(math.log(x,1000))].strip(' ')) if x else '0.00'
|
||||
if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
|
||||
else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
|
||||
it_text = (f"{fn(self.n/dur)}" if self.unit_scale else f"{self.n/dur:5.2f}") if self.n else "?"
|
||||
|
||||
Reference in New Issue
Block a user