diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index ca9f761925..e4e3823c3e 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -254,8 +254,9 @@ def threefry2x32(x: UOp, seed: UOp): # ***** main rewriter ***** -def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None, vec=None): +def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx3=None, extra=None, vec=None): if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE + loop_start, loop_end = rng.src if mval.arg >= 0 or loop_start.arg != 0: # TODO: support and test this with other mvals and loop_starts if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}") @@ -340,31 +341,12 @@ constant_folder = PatternMatcher([ lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), # threefry (UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32), - # extra arange loop folding because we don't fold adds. TODO: fold adds - (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") + - UPat.var("idx2") + UPat.var("idx3")).lt(UPat.cvar("compval")) - .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") + - UPat.var("idx2")).lt(UPat.cvar("compval")) - .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - # arange loop folding (reduce) - (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng")) - .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - # arange loop folding (unrolled) - (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng")) - .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - # arange loop folding (vectorized) - (UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) * - UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng"))) - .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - # arange loop folding (unrolled, vectorized) - (UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) * - UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng"))) - .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + # arange loop folding + (UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any( + m1:=(UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, name="rng")), + m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1)) + .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # unrolled arange div folding (UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs), # indexing (with a multiply offset)! diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ea9d873657..c54f154bea 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -382,14 +382,14 @@ def lines(fn) -> List[str]: with open(fn) as f: return f.readlines() class UPat(MathTrait): - __slots__ = ["op", "dtype", "arg", "name", "src"] + __slots__ = ["op", "dtype", "arg", "name", "src", "_any"] def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None, - name:Optional[str]=None, allow_any_len:bool=False, location=None, + name:Optional[str]=None, allow_any_len:bool=False, location=None, _any=False, custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None): self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype - self.arg, self.name = arg, name + self.arg, self.name, self._any = arg, name, _any self.src: Any = None # try all permutations if it's a list @@ -407,6 +407,9 @@ class UPat(MathTrait): upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0]) self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1) + @staticmethod + def any(*src): return UPat(src=src, _any=True) + @staticmethod @functools.lru_cache(None) def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name) @@ -446,6 +449,10 @@ class UPat(MathTrait): return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0]) def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: + if pat._any: + for x in pat.src[0]: + if (match:=_match(uop, x, store.copy())): return match + return [] if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \ (pat.dtype is not None and uop.dtype not in pat.dtype) or \ (pat.arg is not None and pat.arg != uop.arg) or \