From e64111ad08d82cf873795fc51587c6e7cf2f03cd Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 21 Jan 2026 11:26:15 -0500 Subject: [PATCH] update all_same [pr] (#14270) add type annotation and unit test --- test/unit/test_helpers.py | 8 +++++++- tinygrad/helpers.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index ff615ee21f..c75e8d07f0 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,6 @@ import ctypes, gzip, unittest, timeit, pickle from tinygrad import Variable -from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count +from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count, all_same from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits from tinygrad.helpers import ceildiv from tinygrad.tensor import Tensor, get_shape @@ -84,6 +84,12 @@ class TestContextVars(unittest.TestCase): ... assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value." +class TestAllSame(unittest.TestCase): + def test_empty(self): self.assertTrue(all_same([])) + def test_single(self): self.assertTrue(all_same([1])) + def test_same(self): self.assertTrue(all_same([1, 1, 1])) + def test_different(self): self.assertFalse(all_same([1, 2, 1])) + class TestMergeDicts(unittest.TestCase): def test_merge_dicts(self): a = {"a": 1, "b": 2} diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e75e5ae5b1..d4b0e7d053 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -25,7 +25,7 @@ def argfix(*x): return x # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) -def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items) +def all_same(items:Sequence): return all(x == items[0] for x in items) # works for empty input def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t) def colored(st, color:str|None, background=False): # replace the termcolor library colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white']