remove RewriteStep premature optimization (#12840)

* remove RewriteStep premature optimization

* fix ebs

* core line count
This commit is contained in:
George Hotz
2025-10-21 21:45:20 +08:00
committed by GitHub
parent 7f798a9630
commit 8960ac54f3
3 changed files with 38 additions and 59 deletions

9
sz.py
View File

@@ -54,6 +54,8 @@ def gen_diff(table_old, table_new):
def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff)
NONCORE_DIRS = {"tinygrad/apps", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"}
if __name__ == "__main__":
if len(sys.argv) == 3:
headers = ["Name", "Lines", "Diff", "Tokens/Line", "Diff"]
@@ -76,9 +78,12 @@ if __name__ == "__main__":
else:
print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", floatfmt=".1f")+"\n")
groups = sorted([('/'.join(x[0].rsplit("/", 1)[0].split("/")[0:2]), x[1], x[2]) for x in table])
dir_sizes = {}
for dir_name, group in itertools.groupby(groups, key=lambda x:x[0]):
print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}")
dir_sizes[dir_name] = sum([x[1] for x in group])
print(f"{dir_name:30s} : {dir_sizes[dir_name]:6d}")
print(f"\n core line count: {sum([v for k,v in dir_sizes.items() if k not in NONCORE_DIRS])}")
total_lines = sum([x[1] for x in table])
print(f"\ntotal line count: {total_lines}")
print(f"total line count: {total_lines}")
max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1"))
assert max_line_count == -1 or total_lines <= max_line_count, f"OVER {max_line_count} LINES"

View File

@@ -2,7 +2,7 @@ from extra.models.resnet import ResNet50
from tinygrad import Tensor, nn, Device
from tinygrad.helpers import Profiling, Timing, getenv
from tinygrad.uop.ops import Ops
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
from tinygrad.codegen import full_rewrite_to_sink
from tinygrad.codegen.late.control_flow import linearize
from tinygrad.uop.spec import type_verify
@@ -29,12 +29,11 @@ if __name__ == "__main__":
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=False)
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
with Timing("***** model rewrite in "):
rewritten_uops = []
for u in asts:
rewritten_uops.append(apply_rewrites(u, rewrites))
rewritten_uops.append(full_rewrite_to_sink(u, opts=Device.default.renderer))
if LINEARIZE:
with Timing("***** model linearize in "):

View File

@@ -1,6 +1,3 @@
from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify
@@ -19,91 +16,69 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize
@dataclass
class RewriteStep:
pm: PatternMatcher
ctx: Callable[[UOp], Any]|None = None
name: str|None = None
bottom_up: bool = False
def __call__(self, sink:UOp):
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer **
ret: list[RewriteStep] = []
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
if opts is None: opts = Renderer()
# first we optimize
if optimize:
# lowerer first
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
if QUANTIZE and opts.device in {"CPU", "DSP"}: sink = graph_rewrite(sink, pm_quant, name="quantize")
# split ranges
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
# optimize (schedule) the AST
ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges"))
ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces"))
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
sink = graph_rewrite(sink, pm_reduce_simplify, name="simplify reduces")
sink = graph_rewrite(sink, pm_postrange_opt, ctx=opts, name="post optimize ast")
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic"))
sink = graph_rewrite(sink, sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic")
# expand
ret.append(RewriteStep(sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander"))
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
# add locals
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
sink = graph_rewrite(sink, pm_add_buffers+rangeify_codegen, name="add local buffers")
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims (late). this works after devectorize, but it's faster here
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
sink = graph_rewrite(sink, pm_add_gpudims, ctx=opts, name="add gpudims")
# devectorize (TODO: does this need opts?)
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
supported_ops = tuple(opts.code_for_op.keys())
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
sink = graph_rewrite(sink, pm_devectorize, ctx=opts, name="devectorize")
# lower the index dtype to a concrete int
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
ret.append(RewriteStep(symbolic, name="post index symbolic"))
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing, ctx=opts.device, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
if opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher, name="pre_matcher")
# decompositions
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions"))
supported_ops = tuple(opts.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=opts.device, name="decompositions")
# final rules for the renderer (without sym)
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
sink = graph_rewrite(sink, pm_final_rewrite, ctx=opts.device, name="final rewrite")
# this was the linearizer
ret.append(RewriteStep(pm_merge_ends, name="merge ends"))
ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True))
sink = graph_rewrite(sink, pm_merge_ends, name="merge ends")
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow starts", bottom_up=True)
# return the list
return ret
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize))
# return the rewritten sink
return sink
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
"""