diff --git a/docs/abstractions.py b/docs/abstractions.py index fb86246f44..a1e1364b55 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -100,7 +100,7 @@ class LazyOp: # there's currently 28 Ops you have to implement for an accelerator. class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto() -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPEQ = auto(); MAX = auto() +class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto() class ReduceOps(Enum): SUM = auto(); MAX = auto() class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() class TernaryOps(Enum): MULACC = auto(); WHERE = auto() diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index f128413f1d..978119d031 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -138,7 +138,7 @@ class AssemblyCodegen(Linearizer): elif uop == UOps.ALU and newvar is not None: out = newreg(newvar) if newvar not in tor else tor[newvar] # this is the only thing that can violate SSA - if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]: + if args in [BinaryOps.CMPLT]: pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool) ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args)) ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) diff --git a/extra/assembly/assembly_rdna.py b/extra/assembly/assembly_rdna.py index 069ec0b54c..25984e3a18 100644 --- a/extra/assembly/assembly_rdna.py +++ b/extra/assembly/assembly_rdna.py @@ -46,7 +46,7 @@ class RDNACodegen(AssemblyCodegen): alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma", BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp", UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp", - BinaryOps.CMPEQ: "cmp_eq", BinaryOps.CMPLT: "cmp_lt"} + BinaryOps.CMPLT: "cmp_lt"} pend_regs:Set[Register] = set() rtor:Dict[Register, str] = {} @@ -115,7 +115,7 @@ class RDNACodegen(AssemblyCodegen): else: ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}") elif uop == UOps.ALU: - if arg in [BinaryOps.CMPLT, BinaryOps.CMPEQ]: + if arg in [BinaryOps.CMPLT]: ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") else: alu_arg = alu[arg] diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 73da08341d..cc71a0d0f4 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -26,7 +26,7 @@ class TestLazyBuffer(unittest.TestCase): helper(a[(slice(start, None, stride),)*ndims]) def test_shuffle_pad_ops_cmpeq(self): - y = Tensor([1]).cat(Tensor([1]).eq(0)).numpy() + y = Tensor([1]).cat(Tensor([1]) == 0).numpy() z = Tensor([1, 0]).numpy() np.testing.assert_allclose(y, z) diff --git a/test/test_ops.py b/test/test_ops.py index b037e897b5..2bd285eeba 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -173,7 +173,7 @@ class TestOps(unittest.TestCase): self.assertRaises(RuntimeError, (t1 == t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) - self.assertRaises(RuntimeError, (tt1.eq(tt2)).sum().backward) + self.assertRaises(RuntimeError, (tt1 == tt2).sum().backward) def test_cmp_lt_backwards(self): t1 = torch.ones(4, requires_grad=True) diff --git a/test/test_uops.py b/test/test_uops.py index 3b75a1b934..96d7b9efba 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -61,7 +61,7 @@ class TestUOps(unittest.TestCase): def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b) def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmpeq(self): self._test_bop_fxn(BinaryOps.CMPEQ, lambda a,b: float(a==b)) + def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a Tuple[LazyBuffer, ...]: mops.append((bx.op.op, bx.op.arg)) bx = cast(LazyBuffer, bx.op.src[0]) # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 - unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} + unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())): new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1])) else: diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 3cf762c0bf..0fe7a3b756 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -35,7 +35,7 @@ class Relu(Function): return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - mask = self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret.binary_op(BinaryOps.CMPEQ, self.ret.const_like(0))) + mask = self.ret.const_like(0).binary_op(BinaryOps.CMPLT, self.ret) return mask.binary_op(BinaryOps.MUL, grad_output) class Log(Function): @@ -96,7 +96,7 @@ class Max(Function): def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)) + max_is_1s = self.x.const_like(1).binary_op(BinaryOps.SUB, self.x.binary_op(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) # sum of locations, averaged div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) @@ -107,24 +107,9 @@ class Max(Function): # ************* binary ops ************* -class Equal(Function): +class Less(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.binary_op(BinaryOps.CMPEQ, y) - -class Maximum(Function): - __slots__ = "x", "y", "ret" - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x, self.y = x, y - self.ret = x.binary_op(BinaryOps.MAX, y) - return self.ret - - def backward(self, grad_output:LazyBuffer): - mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret) - eq = self.x.binary_op(BinaryOps.CMPEQ, self.y) - splitter = eq.const_like(2).binary_op(BinaryOps.SUB, eq).binary_op(BinaryOps.DIV, eq.const_like(2)) - - return grad_output.binary_op(BinaryOps.MUL, mask.const_like(1).binary_op(BinaryOps.SUB, mask).binary_op(BinaryOps.ADD, eq)).binary_op(BinaryOps.MUL, splitter) if self.needs_input_grad[0] else None, \ - grad_output.binary_op(BinaryOps.MUL, mask).binary_op(BinaryOps.MUL, splitter) if self.needs_input_grad[1] else None + return x.binary_op(BinaryOps.CMPLT, y) class Add(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b2ded5a94b..0f8007edee 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars # NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702 -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 +class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index b5335ea49e..9499bb0d4e 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -36,7 +36,7 @@ class CStyleLanguage(NamedTuple): BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", - BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})", + BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})" } diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 8adb995b0c..833b5d772c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -26,7 +26,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), - BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), + BinaryOps.CMPLT: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()), BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)), TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)), diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 6eb7604033..180cb1ad7b 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -17,7 +17,7 @@ class WGSLLanguage(CStyleLanguage): code_for_op = { UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})", - BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPEQ: lambda x,y: f"f32({x}=={y})", + BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" } diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 648648ac60..f6a4c3fd2b 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -34,7 +34,7 @@ def einsum_mulacc(einsum, get_strides, expand): numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), - BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)), + BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x Tensor: return x.dot(self) if reverse else self.dot(x) - def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x) + def maximum(self, x:Union[Tensor, float]) -> Tensor: return (selfx).detach().where(self, (self+x)/2)) def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x)) - def eq(self, x) -> Tensor: return self._broadcasted(mlops.Equal, x, False) # ***** broadcasted trinary mlops ***** @@ -651,12 +650,12 @@ class Tensor: def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) - def __ge__(self, x) -> Tensor: return self.maximum(x).eq(self) - def __le__(self, x) -> Tensor: return self.maximum(x).eq(x) - def __lt__(self, x) -> Tensor: return 1.0-(self>=x) - def __gt__(self, x) -> Tensor: return 1.0-(self<=x) - def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore # mypy things this should be a bool - def __ne__(self, x) -> Tensor: return 1.0-self.eq(x) # type: ignore + def __lt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, False) + def __gt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, True) + def __ge__(self, x) -> Tensor: return 1.0-(self Tensor: return 1.0-(self>x) + def __ne__(self, x) -> Tensor: return (selfx) # type: ignore + def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore # ***** functional nn ops *****