mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rename RECIP to RECIPROCAL (#12939)
This commit is contained in:
@@ -759,7 +759,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_pow_neg_05_is_rsqrt(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT])
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIPROCAL, Ops.SQRT])
|
||||
|
||||
def test_pow_2_has_1_mul(self):
|
||||
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestFloatUOps(TestUOps):
|
||||
def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop')
|
||||
def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a))
|
||||
def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf'))
|
||||
def test_recip(self): self._test_uop_fxn(Ops.RECIPROCAL, lambda a: 1/a if a != 0 else float('inf'))
|
||||
def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
|
||||
def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b)
|
||||
@@ -218,18 +218,18 @@ class TestExecALU(TestUOps):
|
||||
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2)
|
||||
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8)
|
||||
|
||||
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
|
||||
|
||||
def test_recip(self):
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (8,)), 1/8)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (7,)), 1/7)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-3,)), 1/-3)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-50,)), 1/-50)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (8,)), 1/8)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (7,)), 1/7)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3,)), 1/-3)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (-50,)), 1/-50)
|
||||
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (10,)), 1/10)
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, ((34**2),)), 1/(34**2))
|
||||
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (10,)), 1/10)
|
||||
|
||||
def test_bool_cmplt(self):
|
||||
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, False)), False)
|
||||
|
||||
@@ -15,7 +15,7 @@ def reduce_gradient(ctx:UOp, ret:UOp):
|
||||
# ctx is grad_output
|
||||
pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
|
||||
(UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
|
||||
(UPat(Ops.RECIPROCAL, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
|
||||
(UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
|
||||
(UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
|
||||
(UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
|
||||
|
||||
@@ -95,7 +95,7 @@ class CStyleLanguage(Renderer):
|
||||
infinity: str = "INFINITY"
|
||||
nan: str = "NAN"
|
||||
code_for_op: dict = {
|
||||
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
||||
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIPROCAL: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
||||
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
||||
Ops.TRUNC: lambda x,dtype: f"trunc({x})",
|
||||
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
@@ -208,7 +208,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC, Ops.RECIP]}),
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC, Ops.RECIPROCAL]}),
|
||||
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
|
||||
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})",
|
||||
Ops.FDIV: lambda a,b,dtype: f"({a}/{b})"}
|
||||
@@ -365,7 +365,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
|
||||
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
||||
Ops.RECIPROCAL: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"}
|
||||
extra_matcher = PatternMatcher([
|
||||
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),
|
||||
|
||||
@@ -21,7 +21,7 @@ def glsl_type(t:DType) -> mesa.struct_glsl_type:
|
||||
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
|
||||
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax"}
|
||||
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax"}
|
||||
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIP: "frcp",
|
||||
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIPROCAL: "frcp",
|
||||
Ops.MAX: "fmax", Ops.TRUNC: "ftrunc", Ops.SIN: "fsin", Ops.EXP2: "fexp2", Ops.LOG2: "flog2"}
|
||||
aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dtypes.sints}, **{x:f_aop for x in dtypes.floats}}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ def render_val(x, dtype):
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
asm_for_op: dict[Ops, Callable] = {
|
||||
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
Ops.RECIPROCAL: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};",
|
||||
|
||||
@@ -47,7 +47,7 @@ class Ops(FastEnum):
|
||||
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
||||
|
||||
# load/store before math
|
||||
LOAD = auto(); STORE = auto() # noqa: E702
|
||||
@@ -78,7 +78,7 @@ class Ops(FastEnum):
|
||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
|
||||
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
@@ -107,6 +107,6 @@ class GroupOp:
|
||||
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
||||
UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
||||
|
||||
All = set(Ops)
|
||||
|
||||
@@ -114,7 +114,8 @@ class MathTrait:
|
||||
return self._binop(Ops.IDIV, x, reverse)
|
||||
def mod(self:TMT, x:TMT|ConstType, reverse:bool=False): return self._binop(Ops.MOD, x, reverse)
|
||||
def sub(self:TMT, x:TMT|ConstType, reverse:bool=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
def div(self:TMT, x:TMT|ConstType, reverse:bool=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
||||
def div(self:TMT, x:TMT|ConstType, reverse:bool=False):
|
||||
return (self.ufix(x)*self.alu(Ops.RECIPROCAL)) if reverse else (self*self.ufix(x).alu(Ops.RECIPROCAL))
|
||||
|
||||
def __neg__(self): return self.neg()
|
||||
|
||||
@@ -162,7 +163,7 @@ class MathTrait:
|
||||
if isinstance(y, type(self)): return self.alu(Ops.WHERE, y.ufix(x), y)
|
||||
raise RuntimeError("where needs at least one UOp arg")
|
||||
def threefry(self:TMT, seed:TMT): return self.alu(Ops.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(Ops.RECIP)
|
||||
def reciprocal(self): return self.alu(Ops.RECIPROCAL)
|
||||
def trunc(self): return self.alu(Ops.TRUNC)
|
||||
def sqrt(self): return self.alu(Ops.SQRT)
|
||||
def sin(self): return self.alu(Ops.SIN)
|
||||
|
||||
@@ -750,7 +750,7 @@ def safe_pow(x, y):
|
||||
|
||||
python_alu: dict[Ops, Callable] = {
|
||||
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
||||
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
|
||||
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
||||
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
||||
@@ -1214,7 +1214,7 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
#(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")),
|
||||
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.RECIP, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")),
|
||||
(UPat(Ops.RECIPROCAL, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")),
|
||||
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
@@ -1231,7 +1231,7 @@ renderer_infer = PatternMatcher([
|
||||
])
|
||||
|
||||
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
|
||||
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
|
||||
Ops.WHERE: "where", Ops.RECIPROCAL: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
|
||||
pm_pyrender = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
|
||||
|
||||
Reference in New Issue
Block a user