mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more Tensor(UOp) cleanups (#15364)
* more Tensor(UOp) cleanups * function too
This commit is contained in:
@@ -78,9 +78,9 @@ class _function(Generic[ReturnType]):
|
||||
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
|
||||
precompile_backward=self.precompile_backward)
|
||||
if isinstance(ret, tuple):
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))
|
||||
else:
|
||||
return cast(ReturnType, Tensor(fret.gettuple(0), device=fret.device))
|
||||
return cast(ReturnType, Tensor(fret.gettuple(0)))
|
||||
|
||||
# overload signatures support both @function and @function(precompile=True) syntax
|
||||
@overload
|
||||
|
||||
@@ -249,7 +249,7 @@ class Tensor(OpMixin):
|
||||
|
||||
This API is alpha and may change.
|
||||
"""
|
||||
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
|
||||
def callify(self, *lst:Tensor) -> Tensor:
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
@@ -409,10 +409,8 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
Moves the tensor to the given device.
|
||||
"""
|
||||
device = canonicalize_device(device)
|
||||
if device == self.device: return self
|
||||
if not isinstance(device, str): return self.shard(device)
|
||||
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
|
||||
if (device:=canonicalize_device(device)) == self.device: return self
|
||||
ret = Tensor(self.uop.copy_to_device(device), requires_grad=self.requires_grad)
|
||||
if self.grad is not None: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
@@ -436,8 +434,8 @@ class Tensor(OpMixin):
|
||||
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
|
||||
if len(devices) == 1: return self.to(devices[0])
|
||||
devices = cast(tuple[str, ...], canonicalize_device(devices))
|
||||
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
uop = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
return Tensor(uop, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
||||
"""
|
||||
@@ -532,7 +530,7 @@ class Tensor(OpMixin):
|
||||
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
||||
# TODO: add test for multidevice tensor
|
||||
device = canonicalize_device(device)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
|
||||
|
||||
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -541,7 +539,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
|
||||
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
|
||||
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device)
|
||||
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis))
|
||||
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@@ -802,8 +800,8 @@ class Tensor(OpMixin):
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), dtype=dtype)
|
||||
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])
|
||||
return Tensor(stacked.multi(self.uop.axis))
|
||||
|
||||
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -1065,14 +1063,13 @@ class Tensor(OpMixin):
|
||||
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
target_uops = [x.uop for x in targets]
|
||||
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
|
||||
ret = []
|
||||
ret:list[Tensor] = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None:
|
||||
if materialize_grads: y = x.const_like(0)
|
||||
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}")
|
||||
ret.append(y)
|
||||
# create returned Tensors
|
||||
return [Tensor(u, device=t.device) for t,u in zip(targets, ret)]
|
||||
ret.append(Tensor(y))
|
||||
return ret
|
||||
|
||||
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
||||
"""
|
||||
@@ -3197,7 +3194,7 @@ class Tensor(OpMixin):
|
||||
assert frame_pos.op is Ops.BIND, "frame_pos must be a bound Variable"
|
||||
srcs = (out:=Tensor.empty(*shape, device=self.device, dtype=self.dtype), self.contiguous(), state.contiguous(), *ref_frames)
|
||||
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
|
||||
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)), device=self.device)
|
||||
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user