mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Use correct dtype in Tensor when data is an ndarray (#1785)
* use correct dtype in Tensor when data is an ndarray * attempt 2 * add assert to be consistent * Add test case for ndarray * Add test case for list * remove whitespace
This commit is contained in:
@@ -224,5 +224,17 @@ class TestTinygrad(unittest.TestCase):
|
||||
Tensor([]).realize()
|
||||
Tensor([]).numpy()
|
||||
|
||||
def test_tensor_ndarray_dtype(self):
|
||||
arr = np.array([1]) # where dtype is implicitly int64
|
||||
assert Tensor(arr).dtype == dtypes.int64
|
||||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
|
||||
|
||||
def test_tensor_list_dtype(self):
|
||||
arr = [1]
|
||||
assert Tensor(arr).dtype == Tensor.default_type
|
||||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user