mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove RewriteStep premature optimization (#12840)
* remove RewriteStep premature optimization * fix ebs * core line count
This commit is contained in:
9
sz.py
9
sz.py
@@ -54,6 +54,8 @@ def gen_diff(table_old, table_new):
|
||||
|
||||
def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff)
|
||||
|
||||
NONCORE_DIRS = {"tinygrad/apps", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) == 3:
|
||||
headers = ["Name", "Lines", "Diff", "Tokens/Line", "Diff"]
|
||||
@@ -76,9 +78,12 @@ if __name__ == "__main__":
|
||||
else:
|
||||
print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", floatfmt=".1f")+"\n")
|
||||
groups = sorted([('/'.join(x[0].rsplit("/", 1)[0].split("/")[0:2]), x[1], x[2]) for x in table])
|
||||
dir_sizes = {}
|
||||
for dir_name, group in itertools.groupby(groups, key=lambda x:x[0]):
|
||||
print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}")
|
||||
dir_sizes[dir_name] = sum([x[1] for x in group])
|
||||
print(f"{dir_name:30s} : {dir_sizes[dir_name]:6d}")
|
||||
print(f"\n core line count: {sum([v for k,v in dir_sizes.items() if k not in NONCORE_DIRS])}")
|
||||
total_lines = sum([x[1] for x in table])
|
||||
print(f"\ntotal line count: {total_lines}")
|
||||
print(f"total line count: {total_lines}")
|
||||
max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1"))
|
||||
assert max_line_count == -1 or total_lines <= max_line_count, f"OVER {max_line_count} LINES"
|
||||
|
||||
5
test/external/external_benchmark_schedule.py
vendored
5
test/external/external_benchmark_schedule.py
vendored
@@ -2,7 +2,7 @@ from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor, nn, Device
|
||||
from tinygrad.helpers import Profiling, Timing, getenv
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
from tinygrad.codegen.late.control_flow import linearize
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
@@ -29,12 +29,11 @@ if __name__ == "__main__":
|
||||
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
|
||||
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
|
||||
|
||||
rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=False)
|
||||
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
||||
with Timing("***** model rewrite in "):
|
||||
rewritten_uops = []
|
||||
for u in asts:
|
||||
rewritten_uops.append(apply_rewrites(u, rewrites))
|
||||
rewritten_uops.append(full_rewrite_to_sink(u, opts=Device.default.renderer))
|
||||
|
||||
if LINEARIZE:
|
||||
with Timing("***** model linearize in "):
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify
|
||||
@@ -19,91 +16,69 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
pm: PatternMatcher
|
||||
ctx: Callable[[UOp], Any]|None = None
|
||||
name: str|None = None
|
||||
bottom_up: bool = False
|
||||
def __call__(self, sink:UOp):
|
||||
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
|
||||
|
||||
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
||||
|
||||
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
|
||||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer **
|
||||
ret: list[RewriteStep] = []
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if opts is None: opts = Renderer()
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
|
||||
# lowerer first
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
if QUANTIZE and opts.device in {"CPU", "DSP"}: sink = graph_rewrite(sink, pm_quant, name="quantize")
|
||||
|
||||
# split ranges
|
||||
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
|
||||
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
|
||||
|
||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))
|
||||
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
|
||||
|
||||
# optimize (schedule) the AST
|
||||
ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges"))
|
||||
ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces"))
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
||||
sink = graph_rewrite(sink, pm_reduce_simplify, name="simplify reduces")
|
||||
sink = graph_rewrite(sink, pm_postrange_opt, ctx=opts, name="post optimize ast")
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic"))
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic")
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander"))
|
||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
sink = graph_rewrite(sink, pm_add_buffers+rangeify_codegen, name="add local buffers")
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=opts, name="add gpudims")
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
|
||||
|
||||
supported_ops = tuple(opts.code_for_op.keys())
|
||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||
sink = graph_rewrite(sink, pm_devectorize, ctx=opts, name="devectorize")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
|
||||
ret.append(RewriteStep(symbolic, name="post index symbolic"))
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing, ctx=opts.device, name="lower all index dtypes")
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
# optional pre matcher
|
||||
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
||||
if opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher, name="pre_matcher")
|
||||
|
||||
# decompositions
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
|
||||
ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions"))
|
||||
supported_ops = tuple(opts.code_for_op.keys())
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=opts.device, name="decompositions")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
|
||||
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=opts.device, name="final rewrite")
|
||||
|
||||
# this was the linearizer
|
||||
ret.append(RewriteStep(pm_merge_ends, name="merge ends"))
|
||||
ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True))
|
||||
sink = graph_rewrite(sink, pm_merge_ends, name="merge ends")
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow starts", bottom_up=True)
|
||||
|
||||
# return the list
|
||||
return ret
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize))
|
||||
# return the rewritten sink
|
||||
return sink
|
||||
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user