mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
maintain input tensor dtype (#969)
This commit is contained in:
@@ -11,7 +11,7 @@ inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
|
||||
|
||||
@@ -483,9 +483,9 @@ class Tensor:
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0) -> Tensor: return Tensor.arange(r).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1).where(Tensor.zeros_like(self), self)
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user