From 4d5c4d256dab4617aad45726b13e86c5f57be278 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 1 Jan 2026 11:37:26 -0500 Subject: [PATCH] update tqdm for edge case (#13956) 1.00kit/s and not 1000it/s for value 999.5 --- test/unit/test_tqdm.py | 20 ++++++++++++++++++++ tinygrad/helpers.py | 4 +++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/unit/test_tqdm.py b/test/unit/test_tqdm.py index 2ba3f2fe4a..bf89d49e6e 100644 --- a/test/unit/test_tqdm.py +++ b/test/unit/test_tqdm.py @@ -128,6 +128,26 @@ class TestProgressBar(unittest.TestCase): self._compare_bars(tinytqdm_output, tqdm_output) if n > 5: break + @patch('sys.stderr', new_callable=StringIO) + @patch('shutil.get_terminal_size') + def test_si_boundary(self, mock_terminal_size, mock_stderr): + """Test SI formatting at boundaries (e.g., 999.5 -> 1.00k, not 1000)""" + ncols = 80 + mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols) + + # Test rates at the boundary: 999 stays as "999", 999.5+ becomes "1.00k" + for rate in [999, 999.4, 999.5, 1000, 1001]: + mock_stderr.truncate(0) + mock_stderr.seek(0) + elapsed = 1.0 / rate + # Need 3 perf_counter calls: init st, init update, final update + with patch('time.perf_counter', side_effect=[0, 0, elapsed]): + bar = tinytqdm(desc="Test", total=1, unit_scale=True, rate=10**9) + bar.update(1, close=True) + tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip() + tqdm_output = tqdm.format_meter(n=1, total=1, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=True) + self._compare_bars(tinytqdm_output, tqdm_output) + @unittest.skip("this is flaky") @patch('sys.stderr', new_callable=StringIO) @patch('shutil.get_terminal_size') diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e3abbde9e4..a730237764 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -508,7 +508,9 @@ class tqdm(Generic[T]): if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1) def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x) def SI(x): - return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00' + if not x: return '0.00' + v = f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.') + return (f"{x/1000**(int(g)+1):.3f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)+1]) if v == "1000" else v+' kMGTPEZY'[int(g)].strip() prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}' est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else '' it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"