This commit is contained in:
George Hotz
2025-10-27 16:40:50 +08:00
parent 97a4296d8b
commit 4c63cf3914

View File

@@ -371,6 +371,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(b, float) and math.isnan(b): b = math.nan
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
@@ -484,7 +486,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
case Ops.CONST: return self.arg
case Ops.VCONST: return self.arg[i]
case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify())
case Ops.VECTORIZE: return self.src[i]
case _: raise RuntimeError(f"no sgep on {self.op}")
@functools.cached_property
@@ -1231,52 +1233,44 @@ renderer_infer = PatternMatcher([
*renderer.patterns
])
"""
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
pm_pyrender = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
(UPat(Ops.END, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.end({', '.join([y.arg for y in x.src[1:]])})")),
(UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")),
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
(UPat(Ops.SPECIAL, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg= f"UOp.special({x.src[0].arg}, \"{x.arg}\", dtype={x.dtype})")),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: UOp(Ops.NOOP, arg=
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
(UPat(set(sugar.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP,
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")),
])
"""
# *** pyrender ***
sugar = { Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER}
def render_marg(ctx,x:UOp):
if x.op in {Ops.PERMUTE, Ops.FLIP}: return str(x.marg)
pieces = []
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]
if x.op in {Ops.PAD, Ops.SHRINK}:
pieces = [f"({ctx[a[0]] if isinstance(a[0], UOp) else str(a[0])}, {ctx[a[1]] if isinstance(a[1], UOp) else str(a[1])})" for a in x.marg]
return f"({','.join(pieces)})" if len(pieces) != 1 else f"({pieces[0]},)"
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
f"UOp.new_buffer(\"{d.arg}\", {x.size}, {x.dtype}, {u.arg})"),
f"UOp.new_buffer({repr(d.arg)}, {x.size}, {x.dtype}, {u.arg})"),
(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"),
# NOTE: range has srcs sometimes after control flow
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), name="x"), lambda ctx,x,c:
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+\
(', tag='+str(x.tag) if x.tag is not None else '')+")"),
# simplest
# TODO: index shouldn't mismatch dtype
(UPat(Ops.INDEX, src=(UPat(), UPat()), name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, dtype={x.dtype})" if x.src[0].dtype != x.dtype else None),
# TODO: fix forced_reshape
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({x.marg})"),
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.marg})"),
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})"),
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
# TODO: why don't these work?
#(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, name="z")), name="x"), lambda ctx,x,y,z: f"({ctx[y]}{syms[x.op]}{z.arg})"),
#(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(Ops.CONST, name="y"), UPat(name="z")), name="x"), lambda ctx,x,y,z: f"({y.arg}{syms[x.op]}{ctx[z]})"),
# NOTE: sub doesn't work cause it's written as add/mul
(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, name="z")), name="x"), lambda ctx,x,y,z: f"({ctx[y]}{syms[x.op]}{z.arg})"),
(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(Ops.CONST, name="y"), UPat(name="z")), name="x"), lambda ctx,x,y,z: f"({y.arg}{syms[x.op]}{ctx[z]})"),
(UPat(set(syms.keys())-{Ops.SUB}, name="x"), lambda ctx,x: f"({ctx[x.src[0]]}{syms[x.op]}{ctx[x.src[1]]})"),
(UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join( \
([f'arg={repr(x.arg)}'] if x.arg is not None else []) + ([f'tag={repr(x.tag)}'] if x.tag is not None else []))+")"),
@@ -1284,13 +1278,13 @@ pm_pyrender_extra = PatternMatcher([
([f'arg={repr(x.arg)}'] if x.arg is not None else []) + ([f'tag={repr(x.tag)}'] if x.tag is not None else []))+")"),
])
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
pm_pyrender = pm_pyrender_extra+PatternMatcher([
(UPat(GroupOp.All, name="u"), lambda ctx,u: "UOp("+', '.join([str(u.op), str(u.dtype)] + \
([f"({ctx[u.src[0]]},)"] if len(u.src) == 1 else ([f"({', '.join([ctx[x] for x in u.src])})"] if len(u.src) > 1 else [])) + \
([f"arg={repr(u.arg)}"] if u.arg is not None else []) + ([f"tag={repr(u.tag)}"] if u.tag is not None else []))+")"),
])
@Context(SPEC=0)
def pyrender(ast:UOp) -> str:
cmap = ast.get_consumer_map()
uops = list(ast.toposort())
@@ -1298,7 +1292,7 @@ def pyrender(ast:UOp) -> str:
r: dict[UOp, str] = {}
not_rendered = {Ops.CONST}
always_rendered = {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY}
always_rendered = {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.WHERE}
to_render = {ast}
for u in uops:
if u.op is Ops.STORE: to_render.add(u.src[1])
@@ -1328,11 +1322,13 @@ def eval_pyrender(code:str) -> UOp:
exec(code, None, lcls)
return lcls['ast']
def test_pyrender(test_ast:UOp):
def test_pyrender(test_ast:UOp, check_parents=True):
code = pyrender(test_ast)
print("\n\n"+code)
ast:UOp = eval_pyrender(code)
if ast is not test_ast:
if check_parents:
for u in test_ast.toposort(): test_pyrender(u, check_parents=False)
raise RuntimeError(f"PYRENDER ISSUE:\nSTR MATCH: {str(test_ast) == str(ast)}\nUOP:\n{test_ast}\nPRODUCED:\n{ast}\nCODE:\n{code}")
return code