diff --git a/test/test_uops.py b/test/test_uops.py index 9dedc99198..3b7d77f375 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -64,6 +64,7 @@ class TestUOps(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends") class TestFloatUOps(TestUOps): + def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a) def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) @@ -85,6 +86,7 @@ class TestFloatUOps(TestUOps): # TODO: fix this on all the backends @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some") class TestNonFloatUOps(TestUOps): + def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, dtypes.int32) def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32) def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32) def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 8c8515b2fa..a9b5fa3486 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -27,6 +27,10 @@ class Zero(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0) def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.const(0) +class Neg(Function): + def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG) + def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.e(UnaryOps.NEG) + class Sin(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x @@ -120,7 +124,7 @@ class Sub(Function): def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ - grad_output.const(0).e(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None + grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None class Mul(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: @@ -138,7 +142,7 @@ class Div(Function): def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \ - grad_output.const(0).e(BinaryOps.SUB, grad_output).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None + grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # ************* ternary ops ************* diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d67325d183..26bf8b4f97 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer # the Enum class doesn't work with mypy, this is static. sorry it's ugly # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars # NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block -class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702 +class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 diff --git a/tinygrad/renderer/assembly_arm64.py b/tinygrad/renderer/assembly_arm64.py index 43a0a92acf..6a4dc48fba 100644 --- a/tinygrad/renderer/assembly_arm64.py +++ b/tinygrad/renderer/assembly_arm64.py @@ -25,7 +25,8 @@ def specialize_to_arm64(fn_nm, asm): type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", - UnaryOps.NOOP: "mov", UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), + UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", + UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"} def mov_imm(value, reg): diff --git a/tinygrad/renderer/assembly_ptx.py b/tinygrad/renderer/assembly_ptx.py index cb78d8d961..69e6105279 100644 --- a/tinygrad/renderer/assembly_ptx.py +++ b/tinygrad/renderer/assembly_ptx.py @@ -33,7 +33,8 @@ def specialize_to_ptx(lang, function_name): ins = [] alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx", - UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", + UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", + UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"} for uop, out, vin, arg in lang.ins: if uop == UOps.ENDLOOP: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 40fe67ab25..0e418c5a7d 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -32,6 +32,7 @@ class CStyleLanguage(NamedTuple): uses_ptr_arithmetic: bool = False launch_bounds: bool = False code_for_op: Dict = { + UnaryOps.NEG: lambda x: f"(-{x})", UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index af6867b768..51f9446057 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -18,6 +18,7 @@ render_llvm = { } code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=('fast',)), UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)), diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index a1203ae877..48cae863f6 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -15,7 +15,9 @@ class WGSLLanguage(CStyleLanguage): generic_var_prefix = "var " external_local_bufs = True code_for_op = { - UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", + UnaryOps.NEG: lambda x: f"(-{x})", + UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", + UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})", BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index bbf92fb71b..ce73520f4b 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -10,7 +10,7 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b) base_fxn_for_op: Dict[Op, Callable] = { - BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, + UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)], diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2cdb4dd3ab..b86406cb59 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -534,6 +534,7 @@ class Tensor: # ***** mlops (unary) ***** + def __neg__(self): return mlops.Neg.apply(self) def contiguous(self): return mlops.Contiguous.apply(self) def contiguous_backward(self): return mlops.ContiguousBackward.apply(self) def log(self): return mlops.Log.apply(self) @@ -557,7 +558,6 @@ class Tensor: def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b) def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b) - def __neg__(self): return 0.0-self def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) def abs(self): return self.relu() + (-self).relu()