mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
raise RuntimeError in merge_dicts instead of assert [pr] (#11965)
This commit is contained in:
@@ -1221,7 +1221,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_einsum_shape_check(self):
|
||||
a = Tensor.zeros(3,8,10,5)
|
||||
b = Tensor.zeros(11,5,13,16,8)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.einsum('pqrs,tuqvr->pstuv',a,b)
|
||||
|
||||
def test_einsum_arity_check1(self):
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestMergeDicts(unittest.TestCase):
|
||||
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
|
||||
assert merge_dicts([a, c]) == a
|
||||
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(RuntimeError):
|
||||
merge_dicts([a, d])
|
||||
|
||||
class TestStripParens(unittest.TestCase):
|
||||
|
||||
@@ -56,7 +56,7 @@ def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
|
||||
def is_numpy_ndarray(x) -> bool: return str(type(x)) == "<class 'numpy.ndarray'>"
|
||||
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
|
||||
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
if len(kvs) != len(set(kv[0] for kv in kvs)): raise RuntimeError(f"{kvs} contains different values for the same key")
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
|
||||
ret:tuple[list[T], list[T]] = ([], [])
|
||||
|
||||
Reference in New Issue
Block a user