diff --git a/sz.py b/sz.py index 3bd5abf5cd..0d7e5be273 100755 --- a/sz.py +++ b/sz.py @@ -54,6 +54,8 @@ def gen_diff(table_old, table_new): def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff) +NONCORE_DIRS = {"tinygrad/apps", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"} + if __name__ == "__main__": if len(sys.argv) == 3: headers = ["Name", "Lines", "Diff", "Tokens/Line", "Diff"] @@ -76,9 +78,12 @@ if __name__ == "__main__": else: print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", floatfmt=".1f")+"\n") groups = sorted([('/'.join(x[0].rsplit("/", 1)[0].split("/")[0:2]), x[1], x[2]) for x in table]) + dir_sizes = {} for dir_name, group in itertools.groupby(groups, key=lambda x:x[0]): - print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}") + dir_sizes[dir_name] = sum([x[1] for x in group]) + print(f"{dir_name:30s} : {dir_sizes[dir_name]:6d}") + print(f"\n core line count: {sum([v for k,v in dir_sizes.items() if k not in NONCORE_DIRS])}") total_lines = sum([x[1] for x in table]) - print(f"\ntotal line count: {total_lines}") + print(f"total line count: {total_lines}") max_line_count = int(os.getenv("MAX_LINE_COUNT", "-1")) assert max_line_count == -1 or total_lines <= max_line_count, f"OVER {max_line_count} LINES" diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 1d0b223506..ac29bb4a2a 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -2,7 +2,7 @@ from extra.models.resnet import ResNet50 from tinygrad import Tensor, nn, Device from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.uop.ops import Ops -from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites +from tinygrad.codegen import full_rewrite_to_sink from tinygrad.codegen.late.control_flow import linearize from tinygrad.uop.spec import type_verify @@ -29,12 +29,11 @@ if __name__ == "__main__": asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values()) if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1] - rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=False) with Profiling(PROFILE, fn="/tmp/rewrite.prof"): with Timing("***** model rewrite in "): rewritten_uops = [] for u in asts: - rewritten_uops.append(apply_rewrites(u, rewrites)) + rewritten_uops.append(full_rewrite_to_sink(u, opts=Device.default.renderer)) if LINEARIZE: with Timing("***** model linearize in "): diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index bd3d771e2a..af8df1c583 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,6 +1,3 @@ -from typing import Any, Callable -import functools -from dataclasses import dataclass from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype from tinygrad.uop.spec import type_verify @@ -19,91 +16,69 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize -@dataclass -class RewriteStep: - pm: PatternMatcher - ctx: Callable[[UOp], Any]|None = None - name: str|None = None - bottom_up: bool = False - def __call__(self, sink:UOp): - return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up) - -def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) - -def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]: - # cache with the values of the context vars - return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value) - -@functools.cache -def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: - # ** lowerer ** - ret: list[RewriteStep] = [] +def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp: + if opts is None: opts = Renderer() + # first we optimize if optimize: - - # lowerer first - if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) + if QUANTIZE and opts.device in {"CPU", "DSP"}: sink = graph_rewrite(sink, pm_quant, name="quantize") # split ranges - ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges")) + sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges") # symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct) - ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic")) + sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic") # optimize (schedule) the AST - ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges")) - ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces")) - ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast")) + sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges") + sink = graph_rewrite(sink, pm_reduce_simplify, name="simplify reduces") + sink = graph_rewrite(sink, pm_postrange_opt, ctx=opts, name="post optimize ast") # ** expander (expand_rewrite) ** - ret.append(RewriteStep(sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic")) + sink = graph_rewrite(sink, sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic") # expand - ret.append(RewriteStep(sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")) + sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") # add locals - ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers")) + sink = graph_rewrite(sink, pm_add_buffers+rangeify_codegen, name="add local buffers") # ** devectorizer (full_graph_rewrite) ** # remove reduce - ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce")) + sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce") # add gpu dims (late). this works after devectorize, but it's faster here - ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims")) + sink = graph_rewrite(sink, pm_add_gpudims, ctx=opts, name="add gpudims") # devectorize (TODO: does this need opts?) - if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing - elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing + if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing + elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing - ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize")) - - supported_ops = tuple(opts.code_for_op.keys()) - extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([]) + sink = graph_rewrite(sink, pm_devectorize, ctx=opts, name="devectorize") # lower the index dtype to a concrete int - ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes")) - ret.append(RewriteStep(symbolic, name="post index symbolic")) + sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing, ctx=opts.device, name="lower all index dtypes") + sink = graph_rewrite(sink, symbolic, name="post index symbolic") # optional pre matcher - if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher")) + if opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher, name="pre_matcher") # decompositions - pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2) - ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions")) + supported_ops = tuple(opts.code_for_op.keys()) + pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2) + sink = graph_rewrite(sink, pm_decomp, ctx=opts.device, name="decompositions") # final rules for the renderer (without sym) + extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([]) pm_final_rewrite = pm_decomp+pm_render+extra_matcher - ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite")) + sink = graph_rewrite(sink, pm_final_rewrite, ctx=opts.device, name="final rewrite") # this was the linearizer - ret.append(RewriteStep(pm_merge_ends, name="merge ends")) - ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True)) + sink = graph_rewrite(sink, pm_merge_ends, name="merge ends") + sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow starts", bottom_up=True) - # return the list - return ret - -def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp: - return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize)) + # return the rewritten sink + return sink def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: """