Files
tinygrad/test/unit/test_tqdm.py
chenyu 33211f356b fix desc in tqdm (#5107)
per doc `https://tqdm.github.io/docs/tqdm/`, user does not need to put `: ` in desc, and `: ` is automatically removed after desc if the latter is empty.

updated test cases and added a test for set_description
2024-06-22 19:00:38 -04:00

188 lines
8.4 KiB
Python

import time, random, unittest
from unittest.mock import patch
from io import StringIO
from collections import namedtuple
from tqdm import tqdm
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2, cmp_prog=False):
prefix1, prog1, suffix1 = bar1.split("|")
prefix2, prog2, suffix2 = bar2.split("|")
self.assertEqual(len(bar1), len(bar2))
self.assertEqual(prefix1, prefix2)
def parse_timer(timer): return sum([int(timer.split(":")[0])*60, int(timer.split(":")[1])])
if "?" not in suffix1 and "?" not in suffix2:
# allow for few sec diff in timers (removes flakiness)
timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
self.assertTrue(abs(timer1 - timer2) <= 5)
self.assertTrue(abs(rm1 - rm2) <= 5)
# get suffix without timers
suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
suffix2 = suffix2.split("[")[0] + suffix2.split(",")[1]
self.assertEqual(suffix1, suffix2)
else:
self.assertEqual(suffix1, suffix2)
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
self.assertTrue(not cmp_prog or diff <= 1)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytqdm(range(total), desc="Test")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
# compare final bars
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_set_description(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
expected_prefix = "Test"
# compare bars at each iteration (only when tinytqdm bar has been updated)
for i,n in enumerate(bar := tinytqdm(range(total), desc="Test")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
expected_prefix = desc = f"Test {i}" if i % 2 == 0 else ""
bar.set_description(desc)
self._compare_bars(tinytqdm_output, tqdm_output)
# compare final bars
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix=expected_prefix)
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
for _ in range(5):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytrange(total, desc="Test")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
# compare final bars
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(10000, 100000), random.randint(80, 120)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
bar = tinytqdm(total=total, desc="Test")
n = 0
while n < total:
time.sleep(0.01)
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
n += incr
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom_0_total(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(10000, 100000), random.randint(80, 120)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
bar = tinytqdm(total=0, desc="Test")
n = 0
while n < total:
time.sleep(0.01)
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
n += incr
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=0, elapsed=elapsed, ncols=ncols, prefix="Test")
self.assertEqual(tinytqdm_output, tqdm_output)
def test_tqdm_perf(self):
st = time.perf_counter()
for _ in tqdm(range(100)): time.sleep(0.01)
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(100)): time.sleep(0.01)
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 2.0 * tqdm_time
def test_tqdm_perf_high_iter(self):
st = time.perf_counter()
for _ in tqdm(range(10^7)): pass
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(10^7)): pass
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 5 * tqdm_time
if __name__ == '__main__':
unittest.main()