diff --git a/test/test_tensor.py b/test/test_tensor.py index 4a1fbfce26..8113e573f7 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -214,6 +214,13 @@ class TestTinygrad(unittest.TestCase): self.assertEqual(Tensor.empty(1,10,20).shape, (1,10,20)) self.assertEqual(Tensor.empty((10,20,40)).shape, (10,20,40)) + with self.assertRaises(ValueError): + Tensor.zeros((2, 2), 2, 2) + with self.assertRaises(ValueError): + Tensor.zeros((2, 2), (2, 2)) + with self.assertRaises(ValueError): + Tensor.randn((128, 128), 0.0, 0.01) + def test_numel(self): assert Tensor.randn(10, 10).numel() == 100 assert Tensor.randn(1,2,5).numel() == 10 diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d19236de56..6cff505237 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -17,7 +17,11 @@ OSX = platform.system() == "Darwin" CI = os.getenv("CI", "") != "" def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order -def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x +def argfix(*x): + if x and x[0].__class__ in (tuple, list): + if len(x) != 1: raise ValueError(f"bad arg {x}") + return tuple(x[0]) + return x def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def all_same(items:List[T]): return all(x == items[0] for x in items) def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t) @@ -207,9 +211,7 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]): _pack_, _fields_ = 1, fields return CStruct def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1] -def flat_mv(mv:memoryview): - if len(mv) == 0: return mv - return mv.cast("B", shape=(mv.nbytes,)) +def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,)) # *** Helpers for CUDA-like APIs.