raise RuntimeError in merge_dicts instead of assert [pr] (#11965)

This commit is contained in:
chenyu
2025-09-02 17:18:44 -04:00
committed by GitHub
parent f750c15965
commit 69dd1817d0
3 changed files with 3 additions and 3 deletions

View File

@@ -1221,7 +1221,7 @@ class TestOps(unittest.TestCase):
def test_einsum_shape_check(self):
a = Tensor.zeros(3,8,10,5)
b = Tensor.zeros(11,5,13,16,8)
with self.assertRaises(AssertionError):
with self.assertRaises(RuntimeError):
Tensor.einsum('pqrs,tuqvr->pstuv',a,b)
def test_einsum_arity_check1(self):

View File

@@ -93,7 +93,7 @@ class TestMergeDicts(unittest.TestCase):
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
assert merge_dicts([a, c]) == a
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
with self.assertRaises(AssertionError):
with self.assertRaises(RuntimeError):
merge_dicts([a, d])
class TestStripParens(unittest.TestCase):

View File

@@ -56,7 +56,7 @@ def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
def is_numpy_ndarray(x) -> bool: return str(type(x)) == "<class 'numpy.ndarray'>"
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
if len(kvs) != len(set(kv[0] for kv in kvs)): raise RuntimeError(f"{kvs} contains different values for the same key")
return {k:v for d in ds for k,v in d.items()}
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
ret:tuple[list[T], list[T]] = ([], [])