mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
only match with op, not arg [pr] (#7534)
This commit is contained in:
@@ -404,7 +404,7 @@ expander = PatternMatcher([
|
||||
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([Ops.EXPAND])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
|
||||
@@ -531,8 +531,7 @@ class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...], Set[Ops]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None,
|
||||
custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None):
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Ops]]=None):
|
||||
assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
|
||||
self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
||||
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
@@ -553,7 +552,7 @@ class UPat(MathTrait):
|
||||
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
||||
else:
|
||||
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
||||
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
self.early_reject = set(pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
|
||||
|
||||
@@ -597,11 +596,11 @@ class UPat(MathTrait):
|
||||
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
||||
|
||||
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
||||
(self.arg is not None and self.arg != uop.arg) or \
|
||||
(self.op is not None and uop.op not in self.op) or \
|
||||
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
||||
if (self.op is not None and uop.op not in self.op) or \
|
||||
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
||||
(self.arg is not None and self.arg != uop.arg) or \
|
||||
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
||||
if self.src is None: return [store]
|
||||
res: List[Dict[str, UOp]] = []
|
||||
for vp in self.src:
|
||||
@@ -632,14 +631,14 @@ class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
|
||||
self.pdict: Dict[Ops, List[Tuple[UPat, Callable, Set, bool]]] = {}
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
|
||||
tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends
|
||||
real_fxn = types.FunctionType(*tuple_fxn)
|
||||
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
|
||||
for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
|
||||
|
||||
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
||||
|
||||
@@ -647,8 +646,8 @@ class PatternMatcher:
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject,has_ctx in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
ler = set(u.op for u in uop.src)
|
||||
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
||||
if not early_reject.issubset(ler): continue
|
||||
for match in p.match(uop, {}):
|
||||
if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: return ret
|
||||
@@ -688,8 +687,8 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject,has_ctx in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
ler = set(u.op for u in uop.src)
|
||||
for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
|
||||
Reference in New Issue
Block a user