update all_same [pr] (#14270)

add type annotation and unit test
This commit is contained in:
chenyu
2026-01-21 11:26:15 -05:00
committed by GitHub
parent 9ad3c865ac
commit e64111ad08
2 changed files with 8 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
import ctypes, gzip, unittest, timeit, pickle
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count, all_same
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
from tinygrad.helpers import ceildiv
from tinygrad.tensor import Tensor, get_shape
@@ -84,6 +84,12 @@ class TestContextVars(unittest.TestCase):
...
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
class TestAllSame(unittest.TestCase):
def test_empty(self): self.assertTrue(all_same([]))
def test_single(self): self.assertTrue(all_same([1]))
def test_same(self): self.assertTrue(all_same([1, 1, 1]))
def test_different(self): self.assertFalse(all_same([1, 2, 1]))
class TestMergeDicts(unittest.TestCase):
def test_merge_dicts(self):
a = {"a": 1, "b": 2}

View File

@@ -25,7 +25,7 @@ def argfix(*x):
return x
# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__))
def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items)
def all_same(items:Sequence): return all(x == items[0] for x in items) # works for empty input
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color:str|None, background=False): # replace the termcolor library
colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white']