fix tqdm unit_scale and support hours in time (#5227)

* fix tqdm unit_scale and support hours in time

previously it only supports MM:SS.
more chars to unitscales, strip trailing "." and " " in formatting, and more tests

* simpler
This commit is contained in:
chenyu
2024-06-29 14:48:51 -04:00
committed by GitHub
parent f374fb77af
commit b2ea610df8
2 changed files with 30 additions and 5 deletions

View File

@@ -4,6 +4,7 @@ from io import StringIO
from collections import namedtuple
from tqdm import tqdm
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
import numpy as np
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2, cmp_prog=False):
@@ -13,14 +14,14 @@ class TestProgressBar(unittest.TestCase):
self.assertEqual(len(bar1), len(bar2))
self.assertEqual(prefix1, prefix2)
def parse_timer(timer): return sum([int(timer.split(":")[0])*60, int(timer.split(":")[1])])
def parse_timer(timer): return sum(int(x) * y for x, y in zip(timer.split(':')[::-1], (1, 60, 3600)))
if "?" not in suffix1 and "?" not in suffix2:
# allow for few sec diff in timers (removes flakiness)
timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
self.assertTrue(abs(timer1 - timer2) <= 5)
self.assertTrue(abs(rm1 - rm2) <= 5)
np.testing.assert_allclose(timer1, timer2, atol=5, rtol=1e-2)
np.testing.assert_allclose(rm1, rm2, atol=5, rtol=1e-2)
# get suffix without timers
suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
@@ -57,6 +58,30 @@ class TestProgressBar(unittest.TestCase):
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_unit_scale(self, mock_terminal_size, mock_stderr):
for unit_scale in [True, False]:
# NOTE: numpy comparison raises TypeError if exponent > 22
for exponent in range(1, 22, 3):
low, high = 10 ** exponent, 10 ** (exponent+1)
for _ in range(3):
total, ncols = random.randint(low, high), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale):
time.sleep(0.01)
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
# print(f"tiny: {tinytqdm_output}")
# print(f"tqdm: {tqdm_output}")
self._compare_bars(tinytqdm_output, tqdm_output)
if n > 3: break
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_set_description(self, mock_terminal_size, mock_stderr):

View File

@@ -294,9 +294,9 @@ class tqdm:
if (self.i % self.skip != 0 and not close) or self.dis: return
prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
def fmt(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([t//3600,t%3600//60,t%60]) if i or x) if (t:=int(t)) != -1 else '?'
def scl(x): return x/1000**int(math.log(x,1000))
def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(' kMGTP'[int(math.log(x,1000))]) if x else '0.00')
def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(math.log(x,1000))].strip(' ')) if x else '0.00'
if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
it_text = (f"{fn(self.n/dur)}" if self.unit_scale else f"{self.n/dur:5.2f}") if self.n else "?"