mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
work
This commit is contained in:
@@ -371,6 +371,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
||||
if isinstance(b, float) and math.isnan(b): b = math.nan
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
|
||||
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
@@ -484,7 +486,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
case Ops.CONST: return self.arg
|
||||
case Ops.VCONST: return self.arg[i]
|
||||
case Ops.VECTORIZE: return cast(sint, self.src[i].ssimplify())
|
||||
case Ops.VECTORIZE: return self.src[i]
|
||||
case _: raise RuntimeError(f"no sgep on {self.op}")
|
||||
|
||||
@functools.cached_property
|
||||
@@ -1231,52 +1233,44 @@ renderer_infer = PatternMatcher([
|
||||
*renderer.patterns
|
||||
])
|
||||
|
||||
"""
|
||||
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
|
||||
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
|
||||
pm_pyrender = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
|
||||
(UPat(Ops.END, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.end({', '.join([y.arg for y in x.src[1:]])})")),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")),
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")),
|
||||
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg= f"UOp.special({x.src[0].arg}, \"{x.arg}\", dtype={x.dtype})")),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
|
||||
(UPat(set(sugar.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP,
|
||||
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")),
|
||||
])
|
||||
"""
|
||||
# *** pyrender ***
|
||||
|
||||
sugar = { Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER}
|
||||
def render_marg(ctx,x:UOp):
|
||||
if x.op in {Ops.PERMUTE, Ops.FLIP}: return str(x.marg)
|
||||
pieces = []
|
||||
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
|
||||
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]
|
||||
if x.op in {Ops.PAD, Ops.SHRINK}:
|
||||
pieces = [f"({ctx[a[0]] if isinstance(a[0], UOp) else str(a[0])}, {ctx[a[1]] if isinstance(a[1], UOp) else str(a[1])})" for a in x.marg]
|
||||
return f"({','.join(pieces)})" if len(pieces) != 1 else f"({pieces[0]},)"
|
||||
|
||||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})"),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
|
||||
f"UOp.new_buffer(\"{d.arg}\", {x.size}, {x.dtype}, {u.arg})"),
|
||||
f"UOp.new_buffer({repr(d.arg)}, {x.size}, {x.dtype}, {u.arg})"),
|
||||
(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"),
|
||||
# NOTE: range has srcs sometimes after control flow
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), name="x"), lambda ctx,x,c:
|
||||
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
|
||||
(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+\
|
||||
(', tag='+str(x.tag) if x.tag is not None else '')+")"),
|
||||
# simplest
|
||||
# TODO: index shouldn't mismatch dtype
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, dtype={x.dtype})" if x.src[0].dtype != x.dtype else None),
|
||||
# TODO: fix forced_reshape
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({x.marg})"),
|
||||
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.marg})"),
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})"),
|
||||
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
|
||||
# TODO: why don't these work?
|
||||
#(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, name="z")), name="x"), lambda ctx,x,y,z: f"({ctx[y]}{syms[x.op]}{z.arg})"),
|
||||
#(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(Ops.CONST, name="y"), UPat(name="z")), name="x"), lambda ctx,x,y,z: f"({y.arg}{syms[x.op]}{ctx[z]})"),
|
||||
# NOTE: sub doesn't work cause it's written as add/mul
|
||||
(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, name="z")), name="x"), lambda ctx,x,y,z: f"({ctx[y]}{syms[x.op]}{z.arg})"),
|
||||
(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(Ops.CONST, name="y"), UPat(name="z")), name="x"), lambda ctx,x,y,z: f"({y.arg}{syms[x.op]}{ctx[z]})"),
|
||||
(UPat(set(syms.keys())-{Ops.SUB}, name="x"), lambda ctx,x: f"({ctx[x.src[0]]}{syms[x.op]}{ctx[x.src[1]]})"),
|
||||
(UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join( \
|
||||
([f'arg={repr(x.arg)}'] if x.arg is not None else []) + ([f'tag={repr(x.tag)}'] if x.tag is not None else []))+")"),
|
||||
@@ -1284,13 +1278,13 @@ pm_pyrender_extra = PatternMatcher([
|
||||
([f'arg={repr(x.arg)}'] if x.arg is not None else []) + ([f'tag={repr(x.tag)}'] if x.tag is not None else []))+")"),
|
||||
])
|
||||
|
||||
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
|
||||
pm_pyrender = pm_pyrender_extra+PatternMatcher([
|
||||
(UPat(GroupOp.All, name="u"), lambda ctx,u: "UOp("+', '.join([str(u.op), str(u.dtype)] + \
|
||||
([f"({ctx[u.src[0]]},)"] if len(u.src) == 1 else ([f"({', '.join([ctx[x] for x in u.src])})"] if len(u.src) > 1 else [])) + \
|
||||
([f"arg={repr(u.arg)}"] if u.arg is not None else []) + ([f"tag={repr(u.tag)}"] if u.tag is not None else []))+")"),
|
||||
])
|
||||
|
||||
@Context(SPEC=0)
|
||||
def pyrender(ast:UOp) -> str:
|
||||
cmap = ast.get_consumer_map()
|
||||
uops = list(ast.toposort())
|
||||
@@ -1298,7 +1292,7 @@ def pyrender(ast:UOp) -> str:
|
||||
r: dict[UOp, str] = {}
|
||||
|
||||
not_rendered = {Ops.CONST}
|
||||
always_rendered = {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY}
|
||||
always_rendered = {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.WHERE}
|
||||
to_render = {ast}
|
||||
for u in uops:
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
@@ -1328,11 +1322,13 @@ def eval_pyrender(code:str) -> UOp:
|
||||
exec(code, None, lcls)
|
||||
return lcls['ast']
|
||||
|
||||
def test_pyrender(test_ast:UOp):
|
||||
def test_pyrender(test_ast:UOp, check_parents=True):
|
||||
code = pyrender(test_ast)
|
||||
print("\n\n"+code)
|
||||
ast:UOp = eval_pyrender(code)
|
||||
if ast is not test_ast:
|
||||
if check_parents:
|
||||
for u in test_ast.toposort(): test_pyrender(u, check_parents=False)
|
||||
raise RuntimeError(f"PYRENDER ISSUE:\nSTR MATCH: {str(test_ast) == str(ast)}\nUOP:\n{test_ast}\nPRODUCED:\n{ast}\nCODE:\n{code}")
|
||||
return code
|
||||
|
||||
|
||||
Reference in New Issue
Block a user