more expand prereqs [run_process_replay] (#6499)

This commit is contained in:
George Hotz
2024-09-12 17:46:12 +08:00
committed by GitHub
parent 327eb12600
commit 9543e4c92e

View File

@@ -234,6 +234,8 @@ constant_folder = PatternMatcher([
lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# GEP add push
(UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"), lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg)),
# tensor core with a 0 input is acc
*[(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
*[(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
@@ -294,8 +296,8 @@ constant_folder = PatternMatcher([
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c"), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c"), lambda x,c: c if c.arg else x),
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
# ** zero folding **
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
@@ -314,7 +316,7 @@ constant_folder = PatternMatcher([
((UPat.cvar("c0", vec=False)*UPat.var("x")).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0")*UPat.var("x")).lt(UPat.cvar("c1")),
((UPat.cvar("c0", vec=False)*UPat.var("x")).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# mul add lt
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
@@ -363,6 +365,10 @@ constant_folder = PatternMatcher([
# remove NOOPs from SINK
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
# remove EXPANDs from SINK
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, tuple(flatten(x.src if x.op in {UOps.SINK, UOps.EXPAND} else (x,) for x in root.src)), root.arg)
if any(x.op in {UOps.SINK, UOps.EXPAND} for x in root.src) else None),
# ** move add consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))), UPat.var("y")]),