diff --git a/test/test_ops.py b/test/test_ops.py index 5fbd55a860..d590df5c92 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -327,6 +327,12 @@ class TestOps(unittest.TestCase): def test_rsqrt(self): helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) + def test_xor(self): + tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int) + ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32) + helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True) + helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True) + helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True) def test_sin(self): helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index ee0766fbe9..2309d6affb 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -90,6 +90,10 @@ class Less(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y) +class Xor(Function): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: + return x.e(BinaryOps.XOR, y) + class Add(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ccfd454e53..ec0e3775bd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -11,7 +11,7 @@ from dataclasses import dataclass # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars # NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 +class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); XOR = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 46497d9c9b..a6f089f869 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -35,8 +35,8 @@ class CStyleLanguage(NamedTuple): BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", - BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", - TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" + BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", + TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" } # returns a str expression of the casted xs with the given type @@ -154,7 +154,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu elif uop == UOps.ALU: assert dtype is not None # remove parens if ALU types are the same. TODO: can do more here - if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}: + if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}: val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype) else: val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 99801ece50..27ea46223c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -20,6 +20,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.CMPLT: lambda builder,x,y: builder.icmp_unsigned("<", x, y) if is_bool(x.type) else builder.icmp_signed("<", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), BinaryOps.MAX: lambda builder,x,y: builder.select(builder.icmp_unsigned(">", x, y) if is_bool(x.type) else builder.icmp_signed(">", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y), BinaryOps.MOD: lambda builder,x,y: builder.urem(x,y) if is_bool(x.type) else builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y), + BinaryOps.XOR: lambda builder,x,y: builder.xor(x,y), TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS), TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z ), diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index e4c0c56c0d..63d1d7ba2c 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -30,7 +30,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x Tensor: return x.dot(self) if reverse else self.dot(x) + def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse)) def maximum(self, x:Union[Tensor, float]) -> Tensor: return (selfx).detach().where(self, (self+x)/2)) def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x)) @@ -754,6 +755,7 @@ class Tensor: def __pow__(self, x) -> Tensor: return self.pow(x) def __truediv__(self, x) -> Tensor: return self.div(x) def __matmul__(self, x) -> Tensor: return self.matmul(x) + def __xor__(self, x) -> Tensor: return self.xor(x) def __radd__(self, x) -> Tensor: return self.add(x, True) def __rsub__(self, x) -> Tensor: return self.sub(x, True) @@ -761,6 +763,7 @@ class Tensor: def __rpow__(self, x) -> Tensor: return self.pow(x, True) def __rtruediv__(self, x) -> Tensor: return self.div(x, True) def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True) + def __rxor__(self, x) -> Tensor: return self.xor(x, True) def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) @@ -768,6 +771,7 @@ class Tensor: def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x)) def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) + def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x)) def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)) def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))