diff --git a/test/test_tensor.py b/test/test_tensor.py index 8eddce3fe5..b35a92ed54 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -224,5 +224,17 @@ class TestTinygrad(unittest.TestCase): Tensor([]).realize() Tensor([]).numpy() + def test_tensor_ndarray_dtype(self): + arr = np.array([1]) # where dtype is implicitly int64 + assert Tensor(arr).dtype == dtypes.int64 + assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype + assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else + + def test_tensor_list_dtype(self): + arr = [1] + assert Tensor(arr).dtype == Tensor.default_type + assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 + assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5b4405f242..c2713ec6b6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -59,7 +59,8 @@ class Tensor: assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np)) elif isinstance(data, np.ndarray): - data = LazyBuffer.fromCPU(data) + assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" + data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) else: raise RuntimeError(f"can't create Tensor from {data}") self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)