Files
tinygrad/test/test_tqdm.py
SnakeOnex b1db2d0094 tqdm replacement (#4846)
* tqdm replacement almost

* formatting

* formatting

* imports

* line len

* fix

* removed set description :(

* removed set description :(

* fix

* fix

* green check?

* rewrote as class, fixed several bugs

* types spacing

* removed imports

* fix

* iterable

* typing

* mypy disagreement

* imports

* more e2e tests vs tqdm

* removed seed setting

* robustness against time.sleep() flakiness

* flaky fix

* automatic bar closing when count==total

* cleanup

* clang error with tqdm

* tqdm back

* use os lib, print to stderr (fixes the clang bug, where the bar was leaking into the generated c program

* back to shutil

* unit_scale + unit_scale test

* custom unit to tests

* pretty

* clean

* removed flaky test

* less test iters

* empty line

* remove disable
2024-06-09 23:46:03 +02:00

109 lines
4.5 KiB
Python

import time, random, unittest
from tqdm import tqdm
from unittest.mock import patch
from io import StringIO
from tinygrad.helpers import tinytqdm
from collections import namedtuple
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2, cmp_prog=False):
prefix1, prog1, suffix1 = bar1.split("|")
prefix2, prog2, suffix2 = bar2.split("|")
self.assertEqual(len(bar1), len(bar2))
self.assertEqual(prefix1, prefix2)
def parse_timer(timer): return sum([int(timer.split(":")[0])*60, int(timer.split(":")[1])])
if "?" not in suffix1 and "?" not in suffix2:
# allow for few sec diff in timers (removes flakiness)
timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
self.assertTrue(abs(timer1 - timer2) <= 5)
self.assertTrue(abs(rm1 - rm2) <= 5)
# get suffix without timers
suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
suffix2 = suffix2.split("[")[0] + suffix2.split(",")[1]
self.assertEqual(suffix1, suffix2)
else:
self.assertEqual(suffix1, suffix2)
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
self.assertTrue(not cmp_prog or diff <= 1)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_iter_e2e(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytqdm(range(total), desc="Test: ")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
# compare final bars
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom_e2e(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(10000, 100000), random.randint(80, 120)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
bar = tinytqdm(total=total, desc="Test: ")
n = 0
while n < total:
time.sleep(0.01)
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
n += incr
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
def test_tqdm_perf(self):
st = time.perf_counter()
for _ in tqdm(range(100)): time.sleep(0.01)
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(100)): time.sleep(0.01)
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 2.0 * tqdm_time
def test_tqdm_perf_high_iter(self):
st = time.perf_counter()
for _ in tqdm(range(10^7)): pass
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(10^7)): pass
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 5 * tqdm_time
if __name__ == '__main__':
unittest.main()