Use correct dtype in Tensor when data is an ndarray (#1785)

* use correct dtype in Tensor when data is an ndarray

* attempt 2

* add assert to be consistent

* Add test case for ndarray

* Add test case for list

* remove whitespace
This commit is contained in:
badcc
2023-09-06 07:35:32 -07:00
committed by GitHub
parent 130cd55942
commit ee9ac20752
2 changed files with 14 additions and 1 deletions

View File

@@ -224,5 +224,17 @@ class TestTinygrad(unittest.TestCase):
Tensor([]).realize()
Tensor([]).numpy()
def test_tensor_ndarray_dtype(self):
arr = np.array([1]) # where dtype is implicitly int64
assert Tensor(arr).dtype == dtypes.int64
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
def test_tensor_list_dtype(self):
arr = [1]
assert Tensor(arr).dtype == Tensor.default_type
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
if __name__ == '__main__':
unittest.main()

View File

@@ -59,7 +59,8 @@ class Tensor:
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
elif isinstance(data, np.ndarray):
data = LazyBuffer.fromCPU(data)
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
else: raise RuntimeError(f"can't create Tensor from {data}")
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)