This commit is contained in:
George Hotz
2025-10-27 15:29:26 +08:00
parent af3211f73c
commit b987b8b22a
12 changed files with 92 additions and 55 deletions

View File

@@ -759,7 +759,7 @@ class TestSchedule(unittest.TestCase):
def test_pow_neg_05_is_rsqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT])
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIPROCAL, Ops.SQRT])
def test_pow_2_has_1_mul(self):
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)

View File

@@ -115,7 +115,7 @@ class TestFloatUOps(TestUOps):
def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
@unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop')
def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a))
def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf'))
def test_recip(self): self._test_uop_fxn(Ops.RECIPROCAL, lambda a: 1/a if a != 0 else float('inf'))
def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
def test_add(self): self._test_bop_fxn(Ops.ADD, lambda a,b: a+b)
@@ -218,18 +218,18 @@ class TestExecALU(TestUOps):
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8)
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
def test_recip(self):
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (8,)), 1/8)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (7,)), 1/7)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-3,)), 1/-3)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (-50,)), 1/-50)
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (8,)), 1/8)
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (7,)), 1/7)
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3,)), 1/-3)
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (-50,)), 1/-50)
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
np.testing.assert_allclose(exec_alu(Ops.RECIP, dtypes.float32, (10,)), 1/10)
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, ((34**2),)), 1/(34**2))
np.testing.assert_allclose(exec_alu(Ops.RECIPROCAL, dtypes.float32, (10,)), 1/10)
def test_bool_cmplt(self):
self.assertEqual(exec_alu(Ops.CMPLT, dtypes.bool, (False, False)), False)

View File

@@ -137,7 +137,7 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
if BEAM_DEBUG:
print("BEAM_SEARCH:")
print('\n'.join(pyrender(lin.ast.replace(arg=None))))
print(pyrender(lin.ast.replace(arg=None)))
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
try:

View File

@@ -26,7 +26,7 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None)
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print('\n'.join(pyrender(ast)))
if DEBUG >= 5: print(pyrender(ast))
# linearize
if renderer is None: renderer = Device.default.renderer
@@ -38,7 +38,7 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None)
except RuntimeError as e:
print("***** LINEARIZE FAILURE *****")
print(e)
print('\n'.join(pyrender(ast)))
print(pyrender(ast))
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"

View File

@@ -15,7 +15,7 @@ def reduce_gradient(ctx:UOp, ret:UOp):
# ctx is grad_output
pm_gradient = PatternMatcher([
(UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
(UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
(UPat(Ops.RECIPROCAL, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
(UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
(UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
(UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),

View File

@@ -95,7 +95,7 @@ class CStyleLanguage(Renderer):
infinity: str = "INFINITY"
nan: str = "NAN"
code_for_op: dict = {
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIPROCAL: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
Ops.TRUNC: lambda x,dtype: f"trunc({x})",
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
@@ -208,7 +208,7 @@ class ClangRenderer(CStyleLanguage):
# language options
buffer_suffix = " restrict"
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC, Ops.RECIP]}),
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC, Ops.RECIPROCAL]}),
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})",
Ops.FDIV: lambda a,b,dtype: f"({a}/{b})"}
@@ -365,7 +365,7 @@ class CUDARenderer(CStyleLanguage):
Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
Ops.RECIPROCAL: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"}
extra_matcher = PatternMatcher([
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),

View File

@@ -21,7 +21,7 @@ def glsl_type(t:DType) -> mesa.struct_glsl_type:
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax"}
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax"}
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIP: "frcp",
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIPROCAL: "frcp",
Ops.MAX: "fmax", Ops.TRUNC: "ftrunc", Ops.SIN: "fsin", Ops.EXP2: "fexp2", Ops.LOG2: "flog2"}
aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dtypes.sints}, **{x:f_aop for x in dtypes.floats}}

View File

@@ -16,7 +16,7 @@ def render_val(x, dtype):
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
asm_for_op: dict[Ops, Callable] = {
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
Ops.RECIPROCAL: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};",

View File

@@ -319,13 +319,13 @@ def bufferize_to_store(x:UOp, allow_locals=True):
mops.append((walk.op, walk.marg))
walk = walk.src[0]
for m in mops[::-1]: ret = ret._mop(*m)
return ret.forced_reshape(shape).replace(tag=x.tag)
return ret.reshape(shape, can_fold=False).replace(tag=x.tag)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], tag=x.tag).end(*[x for x in rngs if x.op is Ops.RANGE])
ret = buf.after(do_store).forced_reshape(shape)
ret = buf.after(do_store).reshape(shape, can_fold=False)
# TODO: is this right? what if it's offset
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])

View File

@@ -48,7 +48,7 @@ class Ops(FastEnum):
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
# load/store before math
LOAD = auto(); STORE = auto() # noqa: E702
@@ -79,7 +79,7 @@ class Ops(FastEnum):
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
Ternary = {Ops.WHERE, Ops.MULACC}
@@ -108,6 +108,6 @@ class GroupOp:
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
All = set(Ops)

View File

@@ -114,7 +114,7 @@ class MathTrait:
return self._binop(Ops.IDIV, x, reverse)
def mod(self:TMT, x:TMT|ConstType, reverse:bool=False): return self._binop(Ops.MOD, x, reverse)
def sub(self:TMT, x:TMT|ConstType, reverse:bool=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
def div(self:TMT, x:TMT|ConstType, reverse:bool=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
def div(self:TMT, x:TMT|ConstType, reverse:bool=False): return (self.ufix(x)*self.alu(Ops.RECIPROCAL)) if reverse else (self*self.ufix(x).alu(Ops.RECIPROCAL))
def __neg__(self): return self.neg()
@@ -162,7 +162,7 @@ class MathTrait:
if isinstance(y, type(self)): return self.alu(Ops.WHERE, y.ufix(x), y)
raise RuntimeError("where needs at least one UOp arg")
def threefry(self:TMT, seed:TMT): return self.alu(Ops.THREEFRY, seed)
def reciprocal(self): return self.alu(Ops.RECIP)
def reciprocal(self): return self.alu(Ops.RECIPROCAL)
def trunc(self): return self.alu(Ops.TRUNC)
def sqrt(self): return self.alu(Ops.SQRT)
def sin(self): return self.alu(Ops.SIN)

View File

@@ -53,17 +53,23 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
def test_pyrender(u:UOp):
def test_pyrender(test_ast:UOp):
from tinygrad.dtype import AddrSpace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.schedule.rangeify import BufferizeOpts, Kernel
code = pyrender(u)
print(code)
code = pyrender(test_ast)
print("\n\n"+code)
lcls:dict[str, Any] = {"inf": math.inf, "nan": math.nan,
"KernelInfo": KernelInfo, "Kernel": Kernel,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace}
"KernelInfo": KernelInfo, "Kernel": Kernel,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace}
exec(code, None, lcls)
if lcls['ast'] is not u: raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nUOP:\n{u}\nPRODUCED:\n{lcls['ast']}")
ast:UOp = lcls['ast']
if ast is not test_ast:
if str(test_ast) == str(ast):
for u1,u2 in zip(list(test_ast.toposort()), list(ast.toposort())):
if u1 is not u2:
raise RuntimeError("STRING SAME, UOP MISMATCH", u1, u2, id(u1), id(u2), id(u1.arg), id(u2.arg))
raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nSTR MATCH: {str(test_ast) == str(ast)}\nUOP:\n{test_ast}\nPRODUCED:\n{ast}")
return code
class UOpMetaClass(type):
@@ -372,7 +378,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs)
def end(self, *src:UOp):
if len(src) == 0: return self
assert all(x.op is Ops.RANGE for x in src), "end only ends ranges"
return UOp(Ops.END, src=(self,)+src)
def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
@@ -390,10 +395,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(end:sint, *arg, dtype=dtypes.index):
def range(end:sint, *arg, dtype=dtypes.index, **kwargs):
if len(arg) == 0: raise RuntimeError("range needs an arg")
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg, **kwargs)
@staticmethod
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
def r(self, op:Ops, axis:tuple[int, ...]):
@@ -526,27 +531,29 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
# in these four, if the shape doesn't change we can return self
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
def reshape(self, arg:tuple[sint, ...], can_fold=True): return self._mop(Ops.RESHAPE, arg, same_shape_noop=can_fold)
def expand(self, arg:tuple[sint, ...], can_fold=True): return self._mop(Ops.EXPAND, arg, same_shape_noop=can_fold)
def shrink(self, arg:tuple[tuple[sint, sint], ...], can_fold=True): return self._mop(Ops.SHRINK, arg, same_shape_noop=can_fold)
def pad(self, arg:tuple[tuple[sint, sint], ...], can_fold=True): return self._mop(Ops.PAD, arg, same_shape_noop=can_fold)
# in these two, we have custom logic to check if they are a no-op
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
def permute(self, arg:tuple[int, ...], can_fold=True):
return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) and can_fold else self
def flip(self, arg:tuple[bool, ...], can_fold=True):
return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) and can_fold else self
# *** uop UNIQUE ***
# TODO: use this in Buffer
unique_num = itertools.count(0)
@staticmethod
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
def unique(arg:int|None=None): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num) if arg is None else arg)
# *** uop Buffer stuff ***
@staticmethod
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp.unique(), UOp(Ops.DEVICE, arg=device)), size)
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@property
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
@recursive_property
@@ -764,7 +771,7 @@ def safe_pow(x, y):
python_alu: dict[Ops, Callable] = {
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
@@ -1228,7 +1235,7 @@ renderer = PatternMatcher([
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
#(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")),
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(Ops.RECIP, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")),
(UPat(Ops.RECIPROCAL, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(1/{x.src[0].arg})")),
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
@@ -1267,28 +1274,58 @@ pm_pyrender = PatternMatcher([
])
"""
sugar = { Ops.SINK: "sink" } #, Ops.STORE: "store", Ops.LOAD: "load" }
#, Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
# Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
sugar = { Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE,
Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS}
pm_pyrender = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.SINK, src=()), lambda: UOp.sink()),
(UPat(set(sugar.keys()), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{sugar[x.op]}("+', '.join([ctx[y] for y in x.src[1:]] + \
#(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
# f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})"),
#(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
# f"UOp.new_buffer(\"{d.arg}\", {x.size}, {x.dtype}, {u.arg})"),
#(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"),
#(UPat(Ops.RANGE, name="x"), lambda ctx,x:
# "UOp.range("+', '.join([ctx[x.src[0]]] + [str(y) for y in x.arg])+
# (', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+\
# (', tag='+str(x.tag) if x.tag is not None else '')+")"),
# simplest
#(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.marg}, can_fold=False)"),
(UPat(set(syms.keys()), name="x"), lambda ctx,x: f"({ctx[x.src[0]]}{syms[x.op]}{ctx[x.src[1]]})"),
(UPat(sugar, src=(), name="x"), lambda ctx,x: f"UOp.{x.op.name.lower()}("+', '.join( \
([f'arg={repr(x.arg)}'] if x.arg is not None else []) + ([f'tag={repr(x.tag)}'] if x.tag is not None else []))+")"),
(UPat(sugar, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}("+', '.join([ctx[y] for y in x.src[1:]] + \
([f'arg={repr(x.arg)}'] if x.arg is not None else []) + ([f'tag={repr(x.tag)}'] if x.tag is not None else []))+")"),
(UPat(GroupOp.All, name="u"), lambda ctx,u: "UOp("+', '.join([str(u.op), str(u.dtype)] + \
[f"({ctx[u.src[0]]},)" if len(u.src) == 1 else f"({','.join([ctx[x] for x in u.src])})"] + \
([f"({ctx[u.src[0]]},)"] if len(u.src) == 1 else ([f"({', '.join([ctx[x] for x in u.src])})"] if len(u.src) > 1 else [])) + \
([f"arg={repr(u.arg)}"] if u.arg is not None else []) + ([f"tag={repr(u.tag)}"] if u.tag is not None else []))+")"),
])
@Context(SPEC=0)
def pyrender(ast:UOp) -> str:
cmap = ast.get_consumer_map()
uops = list(ast.toposort())
ret: dict[str, str] = {}
r: dict[UOp, str] = {}
not_rendered = {Ops.CONST}
always_rendered = {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY}
to_render = {ast}
for u in uops:
if u.op is Ops.STORE: to_render.add(u.src[1])
if len(cmap[u]) == 1 and u.op not in always_rendered or u.op in not_rendered: continue
if u.op in {Ops.SINK}:
for s in u.src: to_render.add(s)
to_render.add(u)
for i,u in enumerate(uops):
r[u] = f"c{i}" if u is not uops[-1] else "ast"
ret[r[u]] = cast(str, pm_pyrender.rewrite(u, ctx=r))
ren = pm_pyrender.rewrite(u, ctx=r)
assert isinstance(ren, str)
if u not in to_render:
r[u] = ren
else:
r[u] = f"c{i}" if u is not uops[-1] else "ast"
ret[r[u]] = ren
return '\n'.join([f"{k} = {v}" for k,v in ret.items()])
"""