From 50f669e43b0e2f2215e172f2ecc5437571bb65fb Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Mon, 28 Aug 2023 04:18:03 +0200 Subject: [PATCH] [ready] perf: simpler Tensor init (#1679) Co-authored-by: Roelof van Dijk --- tinygrad/tensor.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c9e617b4e4..413c836d6f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -38,7 +38,6 @@ class Tensor: training: ClassVar[bool] = False no_grad: ClassVar[bool] = False default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = Device.canonicalize(device) @@ -51,25 +50,18 @@ class Tensor: # internal variables used for autograd graph construction self._ctx: Optional[Function] = None - if isinstance(data, LazyBuffer): - assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) - return - - if isinstance(data, (int, float)): + if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" + elif isinstance(data, (int, float)): self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) return - - if data.__class__ is list: + elif data.__class__ is list: assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - data = np.array(data, dtype=(dtype or Tensor.default_type).np) - - if isinstance(data, np.ndarray): + data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np)) + elif isinstance(data, np.ndarray): data = LazyBuffer.fromCPU(data) - self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) - return + else: raise RuntimeError(f"can't create Tensor from {data}") - raise RuntimeError(f"can't create Tensor from {data}") + self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) def __repr__(self): return f""