diff --git a/test/test_schedule.py b/test/test_schedule.py index c346690c03..6b090d911e 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -759,7 +759,7 @@ class TestSchedule(unittest.TestCase): def test_pow_neg_05_is_rsqrt(self): t = Tensor([1.0, 2.0, 3.0]) ** -0.5 - self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT]) + self.assertEqual(self._alu_from_tensor(t), [Ops.RECIPROCAL, Ops.SQRT]) def test_pow_2_has_1_mul(self): t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0) diff --git a/test/test_uops.py b/test/test_uops.py index 6749115624..4b3102b78f 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -115,7 +115,7 @@ class TestFloatUOps(TestUOps): def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) @unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop') def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a)) - def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf')) + def test_recip(self): self._test_uop_fxn(Ops.RECIPROCAL, lambda a: 1/a if a != 0 else float('inf')) def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b) @@ -218,18 +218,18 @@ class TestExecALU(TestUOps): self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2) self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8) - np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0)) - np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0)) + np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (3.0,)))), 2+(1.0/3.0)) + np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3.0,)))), -2-(1.0/3.0)) def test_recip(self): - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (8,)), 1/8) - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (7,)), 1/7) - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-3,)), 1/-3) - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-50,)), 1/-50) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (8,)), 1/8) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (7,)), 1/7) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3,)), 1/-3) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (-50,)), 1/-50) - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3)) - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((34**2),)), 1/(34**2)) - np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (10,)), 1/10) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, ((32+521+3),)), 1/(32+521+3)) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, ((34**2),)), 1/(34**2)) + np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (10,)), 1/10) def test_bool_cmplt(self): self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, False)), False) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 2f51209755..e0419256b7 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -15,7 +15,7 @@ def reduce_gradient(ctx:UOp, ret:UOp): # ctx is grad_output pm_gradient = PatternMatcher([ (UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)), - (UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)), + (UPat(Ops.RECIPROCAL, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)), (UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)), (UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)), (UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index ee6c0d387a..0ab215fdb3 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -95,7 +95,7 @@ class CStyleLanguage(Renderer): infinity: str = "INFINITY" nan: str = "NAN" code_for_op: dict = { - Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}", + Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIPROCAL: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}", Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})", Ops.TRUNC: lambda x,dtype: f"trunc({x})", Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})", @@ -208,7 +208,7 @@ class ClangRenderer(CStyleLanguage): # language options buffer_suffix = " restrict" type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"} - code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC, Ops.RECIP]}), + code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC, Ops.RECIPROCAL]}), Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})", Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})", Ops.FDIV: lambda a,b,dtype: f"({a}/{b})"} @@ -365,7 +365,7 @@ class CUDARenderer(CStyleLanguage): Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", - Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } + Ops.RECIPROCAL: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"} extra_matcher = PatternMatcher([ (UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None), diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 116004fa05..9282e2034e 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -21,7 +21,7 @@ def glsl_type(t:DType) -> mesa.struct_glsl_type: u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior", Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax"} s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax"} -f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIP: "frcp", +f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIPROCAL: "frcp", Ops.MAX: "fmax", Ops.TRUNC: "ftrunc", Ops.SIN: "fsin", Ops.EXP2: "fexp2", Ops.LOG2: "flog2"} aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dtypes.sints}, **{x:f_aop for x in dtypes.floats}} diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 6882b736ab..2cb6ef683f 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -16,7 +16,7 @@ def render_val(x, dtype): return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") asm_for_op: dict[Ops, Callable] = { - Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};", + Ops.RECIPROCAL: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};", Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};", diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 5ac56b25c8..55d0f9e996 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -47,7 +47,7 @@ class Ops(FastEnum): UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702 # UnaryOps - CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702 + CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702 # load/store before math LOAD = auto(); STORE = auto() # noqa: E702 @@ -78,7 +78,7 @@ class Ops(FastEnum): CUSTOM = auto(); CUSTOMI = auto() # noqa: E702 class GroupOp: - Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC} + Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC} Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW} Ternary = {Ops.WHERE, Ops.MULACC} @@ -107,6 +107,6 @@ class GroupOp: Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ} # do not preserve f(0) = 0 - UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW} + UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW} All = set(Ops) diff --git a/tinygrad/uop/mathtraits.py b/tinygrad/uop/mathtraits.py index a1f5d7eca2..27c3beeb45 100644 --- a/tinygrad/uop/mathtraits.py +++ b/tinygrad/uop/mathtraits.py @@ -114,7 +114,8 @@ class MathTrait: return self._binop(Ops.IDIV, x, reverse) def mod(self:TMT, x:TMT|ConstType, reverse:bool=False): return self._binop(Ops.MOD, x, reverse) def sub(self:TMT, x:TMT|ConstType, reverse:bool=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) - def div(self:TMT, x:TMT|ConstType, reverse:bool=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) + def div(self:TMT, x:TMT|ConstType, reverse:bool=False): + return (self.ufix(x)*self.alu(Ops.RECIPROCAL)) if reverse else (self*self.ufix(x).alu(Ops.RECIPROCAL)) def __neg__(self): return self.neg() @@ -162,7 +163,7 @@ class MathTrait: if isinstance(y, type(self)): return self.alu(Ops.WHERE, y.ufix(x), y) raise RuntimeError("where needs at least one UOp arg") def threefry(self:TMT, seed:TMT): return self.alu(Ops.THREEFRY, seed) - def reciprocal(self): return self.alu(Ops.RECIP) + def reciprocal(self): return self.alu(Ops.RECIPROCAL) def trunc(self): return self.alu(Ops.TRUNC) def sqrt(self): return self.alu(Ops.SQRT) def sin(self): return self.alu(Ops.SIN) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index df9c196d3d..5c40fc7927 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -750,7 +750,7 @@ def safe_pow(x, y): python_alu: dict[Ops, Callable] = { Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2, - Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), + Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc, Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, @@ -1214,7 +1214,7 @@ renderer = PatternMatcher([ (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), #(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")), (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), - (UPat(Ops.RECIP, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")), + (UPat(Ops.RECIPROCAL, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")), (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), @@ -1231,7 +1231,7 @@ renderer_infer = PatternMatcher([ ]) sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce", - Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"} + Ops.WHERE: "where", Ops.RECIPROCAL: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"} pm_pyrender = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")), (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),