tolist to return multidimensional list (#4192)

* lol does this work

* some more changes

* a tiny note

* rename a variable

* add test for data const and add TODO comment

* make type correct

make type correct
This commit is contained in:
geohotstan
2024-04-18 11:43:10 +08:00
committed by GitHub
parent 3644077a42
commit 269a58d5fa
4 changed files with 16 additions and 9 deletions

View File

@@ -19,7 +19,7 @@ class TestFusionOp(unittest.TestCase):
out = (bt*2).expand(10,10).sum(1)
sched = create_schedule([out.lazydata], None)
run_schedule(sched)
outd = out.data().tolist()
outd = out.tolist()
assert all(x == 20.0 for x in outd)
def test_recursive_add(self):

View File

@@ -244,12 +244,9 @@ class TestTinygrad(unittest.TestCase):
with self.assertRaises(IndexError): t2.size(2)
def test_tolist(self):
assert Tensor([1,2,3]).tolist() == [1,2,3]
assert Tensor([1.5,2,3]).tolist() == [1.5,2,3]
# TODO: match torch here
# NotImplementedError: multi-dimensional sub-views are not implemented
#assert Tensor([[1,2,3], [4,5,6]]).tolist() == [[1,2,3], [4,5,6]]
# NOTE: float16 Tensor.tolist() requires python 3.12
for arr in [[1,2,3], [1.5,2,3], [[1,2,3], [4,5,6]], 3]:
assert Tensor(arr).tolist() == torch.tensor(arr).tolist() == arr
def test_element_size(self):
for _, dtype in dtypes.fields().items():

View File

@@ -31,6 +31,14 @@ class TestTensorData(unittest.TestCase):
assert dat[0, 0] == 1
assert dat[1, 1] == 4
def test_data_const(self):
a = Tensor(3, dtype=dtypes.int32)
dat = a.data()
assert dat.format == "i"
assert dat.itemsize == 4
assert dat.tolist() == 3
assert dat.shape == ()
def test_data_float32(self):
a = Tensor([[1,2.5],[3,4]], dtype=dtypes.float32)
dat = a.data()

View File

@@ -191,13 +191,15 @@ class Tensor:
def data(self) -> memoryview:
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,))
return self._data().cast(self.dtype.fmt, self.shape)
def item(self) -> ConstType:
"""Returns the value of this tensor as a standard Python number."""
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert self.numel() == 1, "must have one element for item"
return self._data().cast(self.dtype.fmt)[0]
def tolist(self) -> List[ConstType]: return list(self.data())
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
def tolist(self) -> Union[Sequence[ConstType], ConstType]: return self.data().tolist()
def numpy(self) -> np.ndarray:
if self.dtype == dtypes.bfloat16: return self.float().numpy()
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"