From 2dae657415e9c2d227663668e61598d3ab2d33cc Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:57:57 +0200 Subject: [PATCH] improve readability (#4809) --- tinygrad/tensor.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b5e0e128ec..8a59bc1efd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -104,13 +104,14 @@ class Tensor: # None (the default) will be updated to True if it's put in an optimizer self.requires_grad: Optional[bool] = requires_grad - # internal variables used for autograd graph construction + # internal variable used for autograd graph construction self._ctx: Optional[Function] = None + + # create a LazyBuffer from the different types of inputs if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data) elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8)) - elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, list): if dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool @@ -120,16 +121,25 @@ class Tensor: elif isinstance(data, np.ndarray): if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) + elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) + + # by this point, it has to be a LazyBuffer + if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): + raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") # data is a LazyBuffer, but it might be on the wrong device - if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") if isinstance(device, tuple): - # TODO: what if it's a MultiLazyBuffer on other devices? - self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data + # if device is a tuple, we should have/construct a MultiLazyBuffer + if isinstance(data, MultiLazyBuffer): + assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" + self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data + else: + self.lazydata = MultiLazyBuffer.from_sharded(data, device, None) else: self.lazydata = data if data.device == device else data.copy_to_device(device) - def __repr__(self): return f"" + def __repr__(self): + return f"" # Python has a non moving GC, so this should be okay def __hash__(self): return id(self) @@ -196,6 +206,7 @@ class Tensor: if not self.lazydata.is_realized(): return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) return self + def detach(self) -> Tensor: """ Returns a new tensor with the same data as this tensor, but detached from the autograd graph.