mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
tolist to return multidimensional list (#4192)
* lol does this work * some more changes * a tiny note * rename a variable * add test for data const and add TODO comment * make type correct make type correct
This commit is contained in:
@@ -19,7 +19,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
out = (bt*2).expand(10,10).sum(1)
|
||||
sched = create_schedule([out.lazydata], None)
|
||||
run_schedule(sched)
|
||||
outd = out.data().tolist()
|
||||
outd = out.tolist()
|
||||
assert all(x == 20.0 for x in outd)
|
||||
|
||||
def test_recursive_add(self):
|
||||
|
||||
@@ -244,12 +244,9 @@ class TestTinygrad(unittest.TestCase):
|
||||
with self.assertRaises(IndexError): t2.size(2)
|
||||
|
||||
def test_tolist(self):
|
||||
assert Tensor([1,2,3]).tolist() == [1,2,3]
|
||||
assert Tensor([1.5,2,3]).tolist() == [1.5,2,3]
|
||||
|
||||
# TODO: match torch here
|
||||
# NotImplementedError: multi-dimensional sub-views are not implemented
|
||||
#assert Tensor([[1,2,3], [4,5,6]]).tolist() == [[1,2,3], [4,5,6]]
|
||||
# NOTE: float16 Tensor.tolist() requires python 3.12
|
||||
for arr in [[1,2,3], [1.5,2,3], [[1,2,3], [4,5,6]], 3]:
|
||||
assert Tensor(arr).tolist() == torch.tensor(arr).tolist() == arr
|
||||
|
||||
def test_element_size(self):
|
||||
for _, dtype in dtypes.fields().items():
|
||||
|
||||
@@ -31,6 +31,14 @@ class TestTensorData(unittest.TestCase):
|
||||
assert dat[0, 0] == 1
|
||||
assert dat[1, 1] == 4
|
||||
|
||||
def test_data_const(self):
|
||||
a = Tensor(3, dtype=dtypes.int32)
|
||||
dat = a.data()
|
||||
assert dat.format == "i"
|
||||
assert dat.itemsize == 4
|
||||
assert dat.tolist() == 3
|
||||
assert dat.shape == ()
|
||||
|
||||
def test_data_float32(self):
|
||||
a = Tensor([[1,2.5],[3,4]], dtype=dtypes.float32)
|
||||
dat = a.data()
|
||||
|
||||
@@ -191,13 +191,15 @@ class Tensor:
|
||||
def data(self) -> memoryview:
|
||||
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,))
|
||||
return self._data().cast(self.dtype.fmt, self.shape)
|
||||
def item(self) -> ConstType:
|
||||
"""Returns the value of this tensor as a standard Python number."""
|
||||
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self._data().cast(self.dtype.fmt)[0]
|
||||
def tolist(self) -> List[ConstType]: return list(self.data())
|
||||
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
||||
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
||||
def tolist(self) -> Union[Sequence[ConstType], ConstType]: return self.data().tolist()
|
||||
def numpy(self) -> np.ndarray:
|
||||
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
||||
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
||||
|
||||
Reference in New Issue
Block a user