mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Use correct dtype in Tensor when data is an ndarray (#1785)
* use correct dtype in Tensor when data is an ndarray * attempt 2 * add assert to be consistent * Add test case for ndarray * Add test case for list * remove whitespace
This commit is contained in:
@@ -224,5 +224,17 @@ class TestTinygrad(unittest.TestCase):
|
||||
Tensor([]).realize()
|
||||
Tensor([]).numpy()
|
||||
|
||||
def test_tensor_ndarray_dtype(self):
|
||||
arr = np.array([1]) # where dtype is implicitly int64
|
||||
assert Tensor(arr).dtype == dtypes.int64
|
||||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
|
||||
|
||||
def test_tensor_list_dtype(self):
|
||||
arr = [1]
|
||||
assert Tensor(arr).dtype == Tensor.default_type
|
||||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -59,7 +59,8 @@ class Tensor:
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
|
||||
elif isinstance(data, np.ndarray):
|
||||
data = LazyBuffer.fromCPU(data)
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
else: raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
|
||||
Reference in New Issue
Block a user