diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index afc0460b63..dec69cff30 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -19,7 +19,7 @@ class TestFusionOp(unittest.TestCase): out = (bt*2).expand(10,10).sum(1) sched = create_schedule([out.lazydata], None) run_schedule(sched) - outd = out.data().tolist() + outd = out.tolist() assert all(x == 20.0 for x in outd) def test_recursive_add(self): diff --git a/test/test_tensor.py b/test/test_tensor.py index 84aa1757cb..057cf8c35f 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -244,12 +244,9 @@ class TestTinygrad(unittest.TestCase): with self.assertRaises(IndexError): t2.size(2) def test_tolist(self): - assert Tensor([1,2,3]).tolist() == [1,2,3] - assert Tensor([1.5,2,3]).tolist() == [1.5,2,3] - - # TODO: match torch here - # NotImplementedError: multi-dimensional sub-views are not implemented - #assert Tensor([[1,2,3], [4,5,6]]).tolist() == [[1,2,3], [4,5,6]] + # NOTE: float16 Tensor.tolist() requires python 3.12 + for arr in [[1,2,3], [1.5,2,3], [[1,2,3], [4,5,6]], 3]: + assert Tensor(arr).tolist() == torch.tensor(arr).tolist() == arr def test_element_size(self): for _, dtype in dtypes.fields().items(): diff --git a/test/test_tensor_data.py b/test/test_tensor_data.py index c7ef664330..4ab945deb0 100644 --- a/test/test_tensor_data.py +++ b/test/test_tensor_data.py @@ -31,6 +31,14 @@ class TestTensorData(unittest.TestCase): assert dat[0, 0] == 1 assert dat[1, 1] == 4 + def test_data_const(self): + a = Tensor(3, dtype=dtypes.int32) + dat = a.data() + assert dat.format == "i" + assert dat.itemsize == 4 + assert dat.tolist() == 3 + assert dat.shape == () + def test_data_float32(self): a = Tensor([[1,2.5],[3,4]], dtype=dtypes.float32) dat = a.data() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5f05f4be09..c826a1ea36 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -191,13 +191,15 @@ class Tensor: def data(self) -> memoryview: assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" - return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,)) + return self._data().cast(self.dtype.fmt, self.shape) def item(self) -> ConstType: """Returns the value of this tensor as a standard Python number.""" assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" assert self.numel() == 1, "must have one element for item" return self._data().cast(self.dtype.fmt)[0] - def tolist(self) -> List[ConstType]: return list(self.data()) + # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int] + # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803 + def tolist(self) -> Union[Sequence[ConstType], ConstType]: return self.data().tolist() def numpy(self) -> np.ndarray: if self.dtype == dtypes.bfloat16: return self.float().numpy() assert self.dtype.np is not None, f"no np dtype for {self.dtype}"