add UPat.any support [run_process_replay] (#6602)

* add UPat.any support [run_process_replay]

* single arange pattern

* no loop_start and loop_end
This commit is contained in:
George Hotz
2024-09-19 17:11:24 +08:00
committed by GitHub
parent d06b36e527
commit 718ecad2ee
2 changed files with 18 additions and 29 deletions

View File

@@ -254,8 +254,9 @@ def threefry2x32(x: UOp, seed: UOp):
# ***** main rewriter *****
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None, vec=None):
def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx3=None, extra=None, vec=None):
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
loop_start, loop_end = rng.src
if mval.arg >= 0 or loop_start.arg != 0:
# TODO: support and test this with other mvals and loop_starts
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
@@ -340,31 +341,12 @@ constant_folder = PatternMatcher([
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
# extra arange loop folding because we don't fold adds. TODO: fold adds
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
UPat.var("idx2") + UPat.var("idx3")).lt(UPat.cvar("compval"))
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
UPat.var("idx2")).lt(UPat.cvar("compval"))
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (reduce)
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (unrolled)
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (vectorized)
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (unrolled, vectorized)
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
m1:=(UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, name="rng")),
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# unrolled arange div folding
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
# indexing (with a multiply offset)!

View File

@@ -382,14 +382,14 @@ def lines(fn) -> List[str]:
with open(fn) as f: return f.readlines()
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src"]
__slots__ = ["op", "dtype", "arg", "name", "src", "_any"]
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None, _any=False,
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name = arg, name
self.arg, self.name, self._any = arg, name, _any
self.src: Any = None
# try all permutations if it's a list
@@ -407,6 +407,9 @@ class UPat(MathTrait):
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
@staticmethod
def any(*src): return UPat(src=src, _any=True)
@staticmethod
@functools.lru_cache(None)
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name)
@@ -446,6 +449,10 @@ class UPat(MathTrait):
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if pat._any:
for x in pat.src[0]:
if (match:=_match(uop, x, store.copy())): return match
return []
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
(pat.arg is not None and pat.arg != uop.arg) or \