mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
add UPat.any support [run_process_replay] (#6602)
* add UPat.any support [run_process_replay] * single arange pattern * no loop_start and loop_end
This commit is contained in:
@@ -254,8 +254,9 @@ def threefry2x32(x: UOp, seed: UOp):
|
||||
|
||||
# ***** main rewriter *****
|
||||
|
||||
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None, vec=None):
|
||||
def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx3=None, extra=None, vec=None):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
|
||||
loop_start, loop_end = rng.src
|
||||
if mval.arg >= 0 or loop_start.arg != 0:
|
||||
# TODO: support and test this with other mvals and loop_starts
|
||||
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
||||
@@ -340,31 +341,12 @@ constant_folder = PatternMatcher([
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# extra arange loop folding because we don't fold adds. TODO: fold adds
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
|
||||
UPat.var("idx2") + UPat.var("idx3")).lt(UPat.cvar("compval"))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
|
||||
UPat.var("idx2")).lt(UPat.cvar("compval"))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (reduce)
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (unrolled)
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (vectorized)
|
||||
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
|
||||
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (unrolled, vectorized)
|
||||
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
|
||||
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding
|
||||
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
|
||||
m1:=(UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, name="rng")),
|
||||
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# unrolled arange div folding
|
||||
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
|
||||
# indexing (with a multiply offset)!
|
||||
|
||||
@@ -382,14 +382,14 @@ def lines(fn) -> List[str]:
|
||||
with open(fn) as f: return f.readlines()
|
||||
|
||||
class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src", "_any"]
|
||||
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None, _any=False,
|
||||
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
|
||||
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
|
||||
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.arg, self.name = arg, name
|
||||
self.arg, self.name, self._any = arg, name, _any
|
||||
self.src: Any = None
|
||||
|
||||
# try all permutations if it's a list
|
||||
@@ -407,6 +407,9 @@ class UPat(MathTrait):
|
||||
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
||||
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
@staticmethod
|
||||
def any(*src): return UPat(src=src, _any=True)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name)
|
||||
@@ -446,6 +449,10 @@ class UPat(MathTrait):
|
||||
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if pat._any:
|
||||
for x in pat.src[0]:
|
||||
if (match:=_match(uop, x, store.copy())): return match
|
||||
return []
|
||||
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
|
||||
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
|
||||
(pat.arg is not None and pat.arg != uop.arg) or \
|
||||
|
||||
Reference in New Issue
Block a user