diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 35c1bad572..3320fc089f 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -11,7 +11,7 @@ inverse_type_map = {v:k for k,v in type_map.items()} torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, - BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), + BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cd359c23b6..d5100a1fcd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -483,9 +483,9 @@ class Tensor: def tan(self): return self.sin() / self.cos() @staticmethod - def _tri(r:int, c:int, k:int=0) -> Tensor: return Tensor.arange(r).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k).where(self, Tensor.zeros_like(self)) - def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1).where(Tensor.zeros_like(self), self) + def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c) + def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self)) + def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self) # ***** math functions (unary) *****