mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
PatternMatcher add (#5532)
* PatternMatcher add [run_process_replay] * f4 dynamic * test_failure_36 is fixed * fix PTX
This commit is contained in:
@@ -300,7 +300,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
# UOps.UNMUL left after linearize
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["METAL", "GPU", "CUDA", "AMD", "NV", "CLANG", "LLVM"])
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
# BEGIN METAL=1 ./examples/beautiful_mnist.py failures
|
||||
# log : PYTHONPATH=. LOGKERNS=/tmp/beautiful_mnist.kernels.txt METAL=1 python3 ./examples/beautiful_mnist.py
|
||||
|
||||
@@ -63,6 +63,9 @@ class PatternMatcher:
|
||||
assert p.op is not None
|
||||
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
@@ -347,8 +350,6 @@ constant_folder = PatternMatcher([
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
|
||||
])
|
||||
|
||||
constant_folder_w_f4 = PatternMatcher(constant_folder.patterns + float4_folding.patterns)
|
||||
|
||||
# *** uop expander ***
|
||||
|
||||
def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]):
|
||||
@@ -386,7 +387,7 @@ def do_expand(root:UOp):
|
||||
lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)])
|
||||
if len(dont_expand_args):
|
||||
if root.op is UOps.WMMA:
|
||||
new_src.append(lnew_src[0]) # TODO: is this always right?
|
||||
new_src.append(lnew_src[0]) # TODO: is this always right? all lnew_src should be the same
|
||||
else:
|
||||
new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
|
||||
else:
|
||||
@@ -494,10 +495,9 @@ class UOpGraph:
|
||||
# used by linearizer
|
||||
self._uops: Optional[List[UOp]] = None
|
||||
self.opts = opts
|
||||
self.folder = constant_folder if opts is None or not opts.supports_float4 else constant_folder_w_f4
|
||||
self.folder = constant_folder if opts is None or not opts.supports_float4 else (constant_folder+float4_folding)
|
||||
if TRANSCENDENTAL >= 2 or (opts is not None and TRANSCENDENTAL >= 1 and opts.device in {"CLANG", "LLVM"}):
|
||||
# TODO: slow to rebuild this...
|
||||
self.folder = PatternMatcher(self.folder.patterns + transcendental_folding.patterns)
|
||||
self.folder = self.folder + transcendental_folding
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.sink, self.opts)
|
||||
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
||||
@@ -546,15 +546,13 @@ class UOpGraph:
|
||||
|
||||
# do graph rewrite
|
||||
sink = graph_rewrite(sink, self.folder)
|
||||
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns))
|
||||
|
||||
# expand
|
||||
UOpGraph.cnt += 1
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander)
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander+self.folder)
|
||||
|
||||
# do graph rewrite (2)
|
||||
sink = graph_rewrite(sink, self.folder)
|
||||
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns))
|
||||
# for PTX only
|
||||
if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
|
||||
|
||||
# filter nodes that don't link to a sink
|
||||
# BFS toposort
|
||||
|
||||
Reference in New Issue
Block a user