very ugly, but passes spec

This commit is contained in:
Christopher Milan
2026-02-03 10:59:24 -08:00
parent c7ec0a5538
commit fdc3999b65
2 changed files with 65 additions and 57 deletions

View File

@@ -11,7 +11,7 @@ from tinygrad.codegen.opt import Opt
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_unsupported_dtypes_patterns
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads
@@ -95,9 +95,11 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2, bool(DISABLE_FAST_IDIV),
tuple(EMULATED_DTYPES.tolist(dtypes)))
pm_unsupported = symbolic_simple+get_unsupported_dtypes_patterns(ren.device, tuple(EMULATED_DTYPES.tolist(dtypes)))
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2, bool(DISABLE_FAST_IDIV))
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
sink = graph_rewrite(sink, pm_unsupported, ctx=ren.device, name="unsupported dtypes", bottom_up=True)
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions 2")
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])

View File

@@ -315,8 +315,62 @@ def threefry2x32(x: UOp, key: UOp):
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
# ***** long as 2 ints *****
# ***** decomposition patterns *****
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcendental:bool, disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
if op not in ops or force_transcendental:
pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f),
(UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"),
lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# no real hardware supports THREEFRY, but NullRenderer does
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
if Ops.SHR in ops:
# no reason to check x<0 for uints
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
if not disable_fast_idiv:
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))]
if Ops.CMPLT in ops:
# These are late rewrites because simplex expects equalities to be a certain format
pat += [
((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1<x),
((UPat.cvar("c", dtypes.sints) < UPat.var("x", dtypes.sints)).logical_not(), lambda x,c: x<c+1),
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*UPat.cvar("c"), lambda x,y,c: y*(-c)<x),
(UPat.var("x", dtypes.sints)*-1 < UPat.cvar("c"), lambda x,c:-c<x),
((UPat.cvar("c1",vec=False)<UPat.var("x", dtypes.sints)) & (UPat.var("x", dtypes.sints)<UPat.cvar("c2",vec=False)),
lambda x,c1,c2: x.eq(c1+1) if c1.arg+1==c2.arg-1 else None), # (c-1)<x & x<(c+1) -> x==c
]
if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))]
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
# some backends emit FDIV for RECIP, in that case: a*(1/b) -> a/b
if Ops.FDIV in ops:
pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))]
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
return PatternMatcher(pat)
# ***** unsupported dtypes *****
# long to ints
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16
def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off))
@@ -407,65 +461,17 @@ def f2f_store(st, idx, val):
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half)))
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(dtypes.uint), dtypes.float, dtypes.half))) for i in range(n)))
# ***** decomposition patterns *****
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcendental:bool, disable_fast_idiv:bool,
emulated_dtypes:tuple[DType, ...]) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
if op not in ops or force_transcendental:
pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f),
(UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"),
lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# no real hardware supports THREEFRY, but NullRenderer does
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
if Ops.SHR in ops:
# no reason to check x<0 for uints
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
if not disable_fast_idiv:
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))]
if Ops.CMPLT in ops:
# These are late rewrites because simplex expects equalities to be a certain format
pat += [
((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1<x),
((UPat.cvar("c", dtypes.sints) < UPat.var("x", dtypes.sints)).logical_not(), lambda x,c: x<c+1),
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*UPat.cvar("c"), lambda x,y,c: y*(-c)<x),
(UPat.var("x", dtypes.sints)*-1 < UPat.cvar("c"), lambda x,c:-c<x),
((UPat.cvar("c1",vec=False)<UPat.var("x", dtypes.sints)) & (UPat.var("x", dtypes.sints)<UPat.cvar("c2",vec=False)),
lambda x,c1,c2: x.eq(c1+1) if c1.arg+1==c2.arg-1 else None), # (c-1)<x & x<(c+1) -> x==c
]
if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))]
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
# some backends emit FDIV for RECIP, in that case: a*(1/b) -> a/b
if Ops.FDIV in ops:
pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))]
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
def get_unsupported_dtypes_patterns(device, emulated_dtypes:tuple[DType, ...]):
pat = []
if dtypes.half in emulated_dtypes:
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)]
# FIXME: this is slow
pat += [(UPat(GroupOp.All, name="x"), lambda x:
x.replace(src=tuple(f2f_load(s, x.op) if s.op == Ops.LOAD and s.dtype.scalar() == dtypes.half else s for s in x.src)))]
pat += [(UPat((*GroupOp.ALU, Ops.CONST, Ops.CAST, Ops.GEP, Ops.VECTORIZE), dtypes.half, name="x"), lambda x:
x.replace(dtype=dtypes.float.vec(x.dtype.count)))]
x.replace(dtype=dtypes.float.vec(x.dtype.count) if x.dtype.scalar() == dtypes.half else x.dtype,
src=tuple((f2f_load(s, x.op) if s.op == Ops.LOAD else s.cast(dtypes.float) if x.op is not Ops.CAST else s)
if s.dtype.scalar() == dtypes.half else s for s in x.src)))]
pat += [(UPat(Ops.BITCAST, (dtypes.ushort, dtypes.short, dtypes.bfloat16), src=(UPat.var("x", dtypes.float),), name="bc"), lambda bc,x:
bc.replace(src=(f2f(x.bitcast(dtypes.uint), dtypes.float, dtypes.half),)))]
pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val: