Use correct dtype in Tensor when data is an ndarray (#1785)

* use correct dtype in Tensor when data is an ndarray

* attempt 2

* add assert to be consistent

* Add test case for ndarray

* Add test case for list

* remove whitespace
This commit is contained in:
badcc
2023-09-06 07:35:32 -07:00
committed by GitHub
parent 130cd55942
commit ee9ac20752
2 changed files with 14 additions and 1 deletions

View File

@@ -224,5 +224,17 @@ class TestTinygrad(unittest.TestCase):
Tensor([]).realize()
Tensor([]).numpy()
def test_tensor_ndarray_dtype(self):
arr = np.array([1]) # where dtype is implicitly int64
assert Tensor(arr).dtype == dtypes.int64
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
def test_tensor_list_dtype(self):
arr = [1]
assert Tensor(arr).dtype == Tensor.default_type
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
if __name__ == '__main__':
unittest.main()