From 1abb6297f6c9893263fc312db9786837b8426547 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Mar 2026 03:34:30 -0400 Subject: [PATCH] more Tensor(UOp) cleanups (#15364) * more Tensor(UOp) cleanups * function too --- tinygrad/function.py | 4 ++-- tinygrad/tensor.py | 29 +++++++++++++---------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/tinygrad/function.py b/tinygrad/function.py index 71c2f13ae0..495bcd8748 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -78,9 +78,9 @@ class _function(Generic[ReturnType]): fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile, precompile_backward=self.precompile_backward) if isinstance(ret, tuple): - return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret)))) + return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret)))) else: - return cast(ReturnType, Tensor(fret.gettuple(0), device=fret.device)) + return cast(ReturnType, Tensor(fret.gettuple(0))) # overload signatures support both @function and @function(precompile=True) syntax @overload diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7ea368540a..c5f191e3e0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -249,7 +249,7 @@ class Tensor(OpMixin): This API is alpha and may change. """ - return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] + return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] def callify(self, *lst:Tensor) -> Tensor: big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) @@ -409,10 +409,8 @@ class Tensor(OpMixin): """ Moves the tensor to the given device. """ - device = canonicalize_device(device) - if device == self.device: return self - if not isinstance(device, str): return self.shard(device) - ret = Tensor(self.uop, device, requires_grad=self.requires_grad) + if (device:=canonicalize_device(device)) == self.device: return self + ret = Tensor(self.uop.copy_to_device(device), requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.to(device) return ret @@ -436,8 +434,8 @@ class Tensor(OpMixin): if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor") if len(devices) == 1: return self.to(devices[0]) devices = cast(tuple[str, ...], canonicalize_device(devices)) - mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) - return Tensor(mlb, device=devices, requires_grad=self.requires_grad) + uop = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) + return Tensor(uop, requires_grad=self.requires_grad) def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: """ @@ -532,7 +530,7 @@ class Tensor(OpMixin): if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}") # TODO: add test for multidevice tensor device = canonicalize_device(device) - return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape) + return Tensor(UOp.new_buffer(device, size, dtype), **kwargs).shrink(((0,prod(shape)),)).reshape(shape) def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor: """ @@ -541,7 +539,7 @@ class Tensor(OpMixin): """ dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device if isinstance(device, tuple) and (axis := self.uop.axis) is not None: - return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device) + return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis)) return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs) @staticmethod @@ -802,8 +800,8 @@ class Tensor(OpMixin): dtype = kwargs.pop("dtype", self.dtype) if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor") if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device) - stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])) - return Tensor(UOp.multi(stacked, axis=self.uop.axis), dtype=dtype) + stacked = UOp.mstack(*[fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]) + return Tensor(stacked.multi(self.uop.axis)) def full_like(self, fill_value:PyConst, **kwargs) -> Tensor: """ @@ -1065,14 +1063,13 @@ class Tensor(OpMixin): if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) target_uops = [x.uop for x in targets] grads = compute_gradient(self.uop, gradient.uop, set(target_uops)) - ret = [] + ret:list[Tensor] = [] for x in target_uops: if (y:=grads.get(x)) is None: if materialize_grads: y = x.const_like(0) else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}") - ret.append(y) - # create returned Tensors - return [Tensor(u, device=t.device) for t,u in zip(targets, ret)] + ret.append(Tensor(y)) + return ret def backward(self, gradient:Tensor|None=None) -> Tensor: """ @@ -3197,7 +3194,7 @@ class Tensor(OpMixin): assert frame_pos.op is Ops.BIND, "frame_pos must be a bound Variable" srcs = (out:=Tensor.empty(*shape, device=self.device, dtype=self.dtype), self.contiguous(), state.contiguous(), *ref_frames) fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec") - return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)), device=self.device) + return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos))) # ***** functional nn ops *****