diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index b3273b481f..e52bd6b607 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -11,7 +11,7 @@ from tinygrad.codegen.opt import Opt # import all pattern matchers here from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load -from tinygrad.uop.decompositions import get_late_rewrite_patterns +from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_unsupported_dtypes_patterns from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads @@ -95,9 +95,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # decompositions supported_ops = tuple(ren.code_for_op.keys()) - pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2, bool(DISABLE_FAST_IDIV), - tuple(EMULATED_DTYPES.tolist(dtypes))) + pm_unsupported = symbolic_simple+get_unsupported_dtypes_patterns(ren.device, tuple(EMULATED_DTYPES.tolist(dtypes))) + pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2, bool(DISABLE_FAST_IDIV)) sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions") + sink = graph_rewrite(sink, pm_unsupported, ctx=ren.device, name="unsupported dtypes", bottom_up=True) + sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions 2") # final rules for the renderer (without sym) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 00dad3ea36..db5e8dbcca 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -315,8 +315,62 @@ def threefry2x32(x: UOp, key: UOp): return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64) -# ***** long as 2 ints ***** +# ***** decomposition patterns ***** + +powers_of_two: dict[int, int] = {2**i:i for i in range(64)} +@functools.cache +def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcendental:bool, disable_fast_idiv:bool) -> PatternMatcher: + pat: list[tuple[UPat, Callable]] = [] + for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)): + if op not in ops or force_transcendental: + pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f), + (UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"), + lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))] + # rewrite SQRT to xpow 0.5 + if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) + # no real hardware supports THREEFRY, but NullRenderer does + if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) + # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) + if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) + # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) + if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] + if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), + lambda x,y: (x | y).logical_not())] + # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) + if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)] + if Ops.SHR in ops: + # no reason to check x<0 for uints + pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] + pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where( + c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v + if not disable_fast_idiv: + pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))] + pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))] + if Ops.NEG in ops: + pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))] + if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))] + if Ops.CMPLT in ops: + # These are late rewrites because simplex expects equalities to be a certain format + pat += [ + ((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1 x==c + ] + if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))] + if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))] + # some backends emit FDIV for RECIP, in that case: a*(1/b) -> a/b + if Ops.FDIV in ops: + pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))] + pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))] + return PatternMatcher(pat) + +# ***** unsupported dtypes ***** + +# long to ints l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint} def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16 def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off)) @@ -407,65 +461,17 @@ def f2f_store(st, idx, val): if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half))) return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(dtypes.uint), dtypes.float, dtypes.half))) for i in range(n))) -# ***** decomposition patterns ***** - -powers_of_two: dict[int, int] = {2**i:i for i in range(64)} @functools.cache -def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcendental:bool, disable_fast_idiv:bool, - emulated_dtypes:tuple[DType, ...]) -> PatternMatcher: - pat: list[tuple[UPat, Callable]] = [] - for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)): - if op not in ops or force_transcendental: - pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f), - (UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"), - lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))] - # rewrite SQRT to xpow 0.5 - if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) - # no real hardware supports THREEFRY, but NullRenderer does - if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) - # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) - if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) - # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) - if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] - if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), - lambda x,y: (x | y).logical_not())] - # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) - if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)] - if Ops.SHR in ops: - # no reason to check x<0 for uints - pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] - pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where( - c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v - if not disable_fast_idiv: - pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))] - pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))] - if Ops.NEG in ops: - pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))] - if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))] - if Ops.CMPLT in ops: - # These are late rewrites because simplex expects equalities to be a certain format - pat += [ - ((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1 x==c - ] - if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))] - if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))] - # some backends emit FDIV for RECIP, in that case: a*(1/b) -> a/b - if Ops.FDIV in ops: - pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))] - pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))] +def get_unsupported_dtypes_patterns(device, emulated_dtypes:tuple[DType, ...]): + pat = [] if dtypes.half in emulated_dtypes: pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x: x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)] # FIXME: this is slow pat += [(UPat(GroupOp.All, name="x"), lambda x: - x.replace(src=tuple(f2f_load(s, x.op) if s.op == Ops.LOAD and s.dtype.scalar() == dtypes.half else s for s in x.src)))] - pat += [(UPat((*GroupOp.ALU, Ops.CONST, Ops.CAST, Ops.GEP, Ops.VECTORIZE), dtypes.half, name="x"), lambda x: - x.replace(dtype=dtypes.float.vec(x.dtype.count)))] + x.replace(dtype=dtypes.float.vec(x.dtype.count) if x.dtype.scalar() == dtypes.half else x.dtype, + src=tuple((f2f_load(s, x.op) if s.op == Ops.LOAD else s.cast(dtypes.float) if x.op is not Ops.CAST else s) + if s.dtype.scalar() == dtypes.half else s for s in x.src)))] pat += [(UPat(Ops.BITCAST, (dtypes.ushort, dtypes.short, dtypes.bfloat16), src=(UPat.var("x", dtypes.float),), name="bc"), lambda bc,x: bc.replace(src=(f2f(x.bitcast(dtypes.uint), dtypes.float, dtypes.half),)))] pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val: