mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
binaryops xor (#2627)
* feat: initial xor * feat: numpy xor * feat: llvm xor * feat: quick test for xor * feat: slightly working xor in torch * feat: xor in tensor * feat: slightly better test
This commit is contained in:
@@ -327,6 +327,12 @@ class TestOps(unittest.TestCase):
|
||||
def test_rsqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0)
|
||||
helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0)
|
||||
def test_xor(self):
|
||||
tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int)
|
||||
ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32)
|
||||
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
def test_sin(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)
|
||||
|
||||
@@ -90,6 +90,10 @@ class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPLT, y)
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.XOR, y)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.ADD, y)
|
||||
|
||||
@@ -11,7 +11,7 @@ from dataclasses import dataclass
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); XOR = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
|
||||
@@ -35,8 +35,8 @@ class CStyleLanguage(NamedTuple):
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})"
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})"
|
||||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
@@ -154,7 +154,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
||||
elif uop == UOps.ALU:
|
||||
assert dtype is not None
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}:
|
||||
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
|
||||
else:
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
|
||||
|
||||
@@ -20,6 +20,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.CMPLT: lambda builder,x,y: builder.icmp_unsigned("<", x, y) if is_bool(x.type) else builder.icmp_signed("<", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.icmp_unsigned(">", x, y) if is_bool(x.type) else builder.icmp_signed(">", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y),
|
||||
BinaryOps.MOD: lambda builder,x,y: builder.urem(x,y) if is_bool(x.type) else builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y),
|
||||
BinaryOps.XOR: lambda builder,x,y: builder.xor(x,y),
|
||||
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
|
||||
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z
|
||||
),
|
||||
|
||||
@@ -30,7 +30,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), BinaryOps.XOR: lambda x, y: np.bitwise_xor(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
|
||||
|
||||
@@ -33,6 +33,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
||||
BinaryOps.SUB: lambda x,y: torch.sub(*match_types(x, y, disallow_bool=True)).type(output_type(x,y)),
|
||||
BinaryOps.MUL: lambda x,y: torch.mul(*match_types(x, y)).type(output_type(x,y)),
|
||||
BinaryOps.DIV: lambda x,y: torch.div(*match_types(x, y)).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.XOR: lambda x,y: torch.bitwise_xor(*match_types(x, y)),
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
||||
|
||||
@@ -735,6 +735,7 @@ class Tensor:
|
||||
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
||||
@@ -754,6 +755,7 @@ class Tensor:
|
||||
def __pow__(self, x) -> Tensor: return self.pow(x)
|
||||
def __truediv__(self, x) -> Tensor: return self.div(x)
|
||||
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
||||
def __xor__(self, x) -> Tensor: return self.xor(x)
|
||||
|
||||
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
||||
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
||||
@@ -761,6 +763,7 @@ class Tensor:
|
||||
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
||||
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
||||
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
||||
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
|
||||
|
||||
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
||||
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
||||
@@ -768,6 +771,7 @@ class Tensor:
|
||||
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
|
||||
|
||||
Reference in New Issue
Block a user