From fe0724eebfca3ef5f49fca225457652e13dd9159 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 4 May 2025 16:01:18 -0400 Subject: [PATCH] prebuild all rewrites [pr] (#10154) * prebuild all rewrites [pr] * fix that * tests pass with linearizer --- .github/workflows/test.yml | 4 +- tinygrad/codegen/devectorizer.py | 3 +- tinygrad/codegen/flow.py | 69 ++++++++++++++++++++++++++++++++ tinygrad/codegen/kernel.py | 10 ++--- tinygrad/codegen/linearize.py | 69 +++++++++++++++++++++----------- tinygrad/upat.py | 12 +++--- 6 files changed, 131 insertions(+), 36 deletions(-) create mode 100644 tinygrad/codegen/flow.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 45e7c37e12..3aa868785b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -368,8 +368,8 @@ jobs: # run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing - name: Run GC tests run: PYTHONPATH="." python test/external/external_uop_gc.py - - name: Repo line count < 12800 lines - run: MAX_LINE_COUNT=12800 python sz.py + - name: Repo line count < 13000 lines + run: MAX_LINE_COUNT=13000 python sz.py fuzzing: name: Fuzzing diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index e98f2d4b68..558baf8d76 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -450,5 +450,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher) # final rules for the renderer (without sym) - sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher, ctx=opts) + sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher, + ctx=opts, name="final rewrite") return sink diff --git a/tinygrad/codegen/flow.py b/tinygrad/codegen/flow.py new file mode 100644 index 0000000000..6a95bb1360 --- /dev/null +++ b/tinygrad/codegen/flow.py @@ -0,0 +1,69 @@ +from typing import Any, Callable +import functools +from dataclasses import dataclass +from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL +from tinygrad.ops import PatternMatcher, graph_rewrite, UOp +from tinygrad.renderer import Renderer + +# import all pattern matchers here +from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index +from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing +from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander +from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \ + pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns +from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext + +@dataclass +class RewriteStep: + pm: PatternMatcher + ctx: Callable[[UOp], Any]|None = None + name: str|None = None + bottom_up: bool = False + def __call__(self, sink:UOp): + return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up) + +def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) + +def get_rewrites_for_renderer(opts:Renderer, linearizer=True) -> list[RewriteStep]: + # ** lowerer (rewrite_shapetracker_with_index) ** + ret: list[RewriteStep] = [] + if QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) + ret.append(RewriteStep(pm_lowerer, lambda ast: get_index(ast, opts), name="lowerer")) + + # ** expander (expand_rewrite) ** + ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic")) + + # ignore (for masked stores) + ret.append(RewriteStep(pm_store_ignore, name="store_ignore")) + ret.append(RewriteStep(pm_move_ignore, name="move_ignore")) + + # expand + remove surviving ignores + ret.append(RewriteStep(pm_delete_ignore+sym+expander, name="expander")) + + # ** devectorizer (full_graph_rewrite) ** + # remove reduce + ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce")) + + # devectorize (TODO: does this need opts?) + if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing + elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing + else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing + ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize")) + + supported_ops = tuple(opts.code_for_op.keys()) + extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([]) + + # optional pre matcher + if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher")) + + # final rules for the renderer (without sym) + pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher + ret.append(RewriteStep(pm_final_rewrite, lambda _: opts, name="final rewrite")) + + # ** linearizer ** + if linearizer: + ret.append(RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True)) + ret.append(RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends")) + ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks")) + ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize")) + return ret diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index d69d278c90..11ac1ea320 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -14,10 +14,9 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, ro from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX, CAPTURE_PROCESS_REPLAY from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape -from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.devectorizer import full_graph_rewrite -from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction +from tinygrad.codegen.lowerer import get_contraction from tinygrad.engine.grouper import view_left +from tinygrad.codegen.flow import get_rewrites_for_renderer, apply_rewrites class KernelOptError(Exception): pass @@ -528,7 +527,7 @@ class Kernel: return ret fixed_ast = fixup_ast(self.ast) del fixup_ast - return graph_rewrite(fixed_ast, view_left) + return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST") # **** this is the lowerer **** @@ -554,7 +553,8 @@ class Kernel: #if __debug__: type_verify(list(modified_ast.toposort()), shape_spec) try: - self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts)) + rewrite_list = get_rewrites_for_renderer(self.opts) + self.uops:list[UOp] = list(apply_rewrites(modified_ast, rewrite_list).arg.lst) except RuntimeError: print("***** LINEARIZE FAILURE *****") print(f"ast = {self.ast}") diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 3316bf3cec..05ed7c7809 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -52,7 +52,7 @@ def disp(y:UOp) -> str: @dataclass(frozen=True, eq=False) class BasicBlock2: lst: tuple[UOp, ...] - ctx: tuple[UOp, ...] + ctx: tuple[UOp, ...] = () end: UOp|None = None cnt: int = 0 child_ctx: tuple[UOp, ...]|None = None @@ -162,10 +162,32 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp): bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb) -block_create = PatternMatcher([( - UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up) +block_create = PatternMatcher([ + (UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up), ]) +# ***** blockend merging **** + +def merge_blockends(sink:UOp) -> UOp|None: + # only run on the final BLOCK with the SINK in it + if sink.arg.lst[-1].op is not Ops.SINK: return None + # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs + blockends_to_arg: dict[UOp, list[UOp]] = {} + for be in sink.toposort(): + if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) + new_forks = {} + for k,v in blockends_to_arg.items(): + # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails + if len(v) > 1: + bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v)) + out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb) + # NOTE: bb.ctx != u.arg.ctx can cause problems here + for u in v: new_forks[u] = out + if len(new_forks) == 0: return None + return sink.substitute(new_forks) + +pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)]) + # ***** block merging **** def merge_block(x:UOp): @@ -209,6 +231,19 @@ block_merge = PatternMatcher([ # ****** finalize ****** +def finalize(sink:UOp) -> UOp: + if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src): + raise RuntimeError("linearize failure") + + # place the early things + lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst) + + if __debug__: type_verify(lst) + + return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst))) + +pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)]) + def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" @@ -218,28 +253,16 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: # wrap all uops in blocks, already reordered sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True) - # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs - blockends_to_arg: dict[UOp, list[UOp]] = {} - for be in sink.toposort(): - if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) - new_forks = {} - for k,v in blockends_to_arg.items(): - # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails - if len(v) > 1: - bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v)) - out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb) - # NOTE: bb.ctx != u.arg.ctx can cause problems here - for u in v: new_forks[u] = out - sink = sink.substitute(new_forks) - # merge blockends + sink = graph_rewrite(sink, pm_blockend_merge, name="Linearizer: Merge Blockends") + + # merge blocks sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks") - if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src): raise RuntimeError("linearize failure") - # place the early things - lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst) + # finalize + sink = graph_rewrite(sink, pm_finalize, name="Linearizer: Finalize") - # sanity checks (NOTE: these can cause things to be skipped in BEAM) - if not skip_check: type_verify(lst) + from tinygrad.ops import print_uops + print_uops(sink.arg.lst) - return lst + return list(sink.arg.lst) \ No newline at end of file diff --git a/tinygrad/upat.py b/tinygrad/upat.py index 0ab3dfe7e6..17007f6492 100644 --- a/tinygrad/upat.py +++ b/tinygrad/upat.py @@ -1,6 +1,6 @@ from typing import Any, Callable import itertools, inspect, functools, types -from tinygrad.helpers import partition, dedup +from tinygrad.helpers import partition, dedup, Context from tinygrad.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function class UPatCompileError(Exception): pass @@ -139,10 +139,12 @@ def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]: def _get_code(self:UPat, has_ctx:bool): ret = _get_clause(self, UOp(Ops.NOOP, arg="uop")) try: - ret = graph_rewrite(ret, pm_proc, name="process UPat") - dyn_lookup: dict[str, Any] = {} - out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat") - rendered = _final_render(out, has_ctx) + # TODO: this should be tracked in a "system" rewrite, not untracked or tracked with kernel + with Context(TRACK_MATCH_STATS=0): + ret = graph_rewrite(ret, pm_proc, name="process UPat") + dyn_lookup: dict[str, Any] = {} + out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat") + rendered = _final_render(out, has_ctx) except UPatCompileError: #print("FAILED", self, self.location) return None