PatternMatcher add (#5532)

* PatternMatcher add [run_process_replay]

* f4 dynamic

* test_failure_36 is fixed

* fix PTX
This commit is contained in:
George Hotz
2024-07-17 12:44:42 -07:00
committed by GitHub
parent d3c137d478
commit 1a68854766
2 changed files with 10 additions and 12 deletions

View File

@@ -300,7 +300,7 @@ class TestLinearizerFailures(unittest.TestCase):
# UOps.UNMUL left after linearize
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["METAL", "GPU", "CUDA", "AMD", "NV", "CLANG", "LLVM"])
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
# BEGIN METAL=1 ./examples/beautiful_mnist.py failures
# log : PYTHONPATH=. LOGKERNS=/tmp/beautiful_mnist.kernels.txt METAL=1 python3 ./examples/beautiful_mnist.py

View File

@@ -63,6 +63,9 @@ class PatternMatcher:
assert p.op is not None
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
@@ -347,8 +350,6 @@ constant_folder = PatternMatcher([
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
])
constant_folder_w_f4 = PatternMatcher(constant_folder.patterns + float4_folding.patterns)
# *** uop expander ***
def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]):
@@ -386,7 +387,7 @@ def do_expand(root:UOp):
lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)])
if len(dont_expand_args):
if root.op is UOps.WMMA:
new_src.append(lnew_src[0]) # TODO: is this always right?
new_src.append(lnew_src[0]) # TODO: is this always right? all lnew_src should be the same
else:
new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
else:
@@ -494,10 +495,9 @@ class UOpGraph:
# used by linearizer
self._uops: Optional[List[UOp]] = None
self.opts = opts
self.folder = constant_folder if opts is None or not opts.supports_float4 else constant_folder_w_f4
self.folder = constant_folder if opts is None or not opts.supports_float4 else (constant_folder+float4_folding)
if TRANSCENDENTAL >= 2 or (opts is not None and TRANSCENDENTAL >= 1 and opts.device in {"CLANG", "LLVM"}):
# TODO: slow to rebuild this...
self.folder = PatternMatcher(self.folder.patterns + transcendental_folding.patterns)
self.folder = self.folder + transcendental_folding
def __reduce__(self): return self.__class__, (self.sink, self.opts)
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
@@ -546,15 +546,13 @@ class UOpGraph:
# do graph rewrite
sink = graph_rewrite(sink, self.folder)
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns))
# expand
UOpGraph.cnt += 1
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander)
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander+self.folder)
# do graph rewrite (2)
sink = graph_rewrite(sink, self.folder)
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns))
# for PTX only
if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
# filter nodes that don't link to a sink
# BFS toposort