mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor(uop) does not need explicit device (#15361)
This commit is contained in:
@@ -238,10 +238,10 @@ class Tensor(OpMixin):
|
||||
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
|
||||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
return Tensor(param)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
||||
return Tensor(fret.gettuple(0), device=self.device)
|
||||
return Tensor(fret.gettuple(0))
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
@@ -324,7 +324,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
||||
"""
|
||||
return Tensor(self.uop.detach(), device=self.device, requires_grad=False)
|
||||
return Tensor(self.uop.detach(), requires_grad=False)
|
||||
|
||||
def _buffer(self) -> Buffer:
|
||||
from tinygrad.engine.realize import capturing
|
||||
@@ -803,7 +803,7 @@ class Tensor(OpMixin):
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), dtype=dtype)
|
||||
|
||||
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user