From 2c42e9c2c6099da086e0b1dd8bbb8e833c55ae7d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 20 Aug 2024 23:36:58 -0700 Subject: [PATCH] faster rewrite, no folder in expand/reduce [run_process_replay] (#6216) * faster rewrite, no folder in expand/reduce [run_process_replay] * is removing the expander there okay * parens * don't reconstruct exact match uop * fast do_reduce * expand pyint * most of the parents gains with less lines --- test/external/external_benchmark_schedule.py | 57 ++++++++------------ tinygrad/codegen/uopgraph.py | 8 +-- tinygrad/ops.py | 10 ++-- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index cec37b518b..3dc19c64ea 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -1,57 +1,44 @@ +from typing import List from extra.models.resnet import ResNet50 from tinygrad import Tensor from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.ops import UOps from tinygrad.codegen.kernel import Kernel +from tinygrad.codegen.lowerer import ast_to_uop +from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite if __name__ == "__main__": mdl = ResNet50() img = Tensor.empty(64, 3, 224, 224) - PROFILE = getenv("PROFILE", 1) + PROFILE = getenv("PROFILE", 0) FORWARD_ONLY = getenv("FORWARD_ONLY", 0) SCHEDULE_ONLY = getenv("SCHEDULE_ONLY", 0) - with Profiling(PROFILE): - with Timing("***** model forward in "): + with Timing("all "): + with Timing("***** model tensor in "): out = mdl(img) - if not FORWARD_ONLY: - with Profiling(PROFILE): - with Timing("***** model schedule in "): + if not FORWARD_ONLY: + with Timing("***** model schedule in "): sched = out.schedule() - if not SCHEDULE_ONLY: - asts = {x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values() - kernels = [] - with Profiling(PROFILE): - with Timing("***** model uops in "): + if not SCHEDULE_ONLY: + asts = {x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values() + kernels: List[Kernel] = [] + with Timing("***** model opts in "): for ast in asts: k = Kernel(ast) k.hand_coded_optimizations() kernels.append(k) - with Profiling(PROFILE, fn="/tmp/schedule.prof"): - with Timing("***** model linearize in "): - for k in kernels: k.linearize() - - #renderer = Device[Device.DEFAULT].renderer - #with Profiling(PROFILE, fn="/tmp/schedule.prof"): - # with Timing("***** model render in "): - # for n,u in uops: renderer.render(n, u) - - # snakeviz /tmp/schedule.prof - #with Profiling(PROFILE, fn="/tmp/schedule.prof"): - # with Timing("***** model lower in "): - # eis = list(lower_schedule(sched)) - - # random makes this slow - #with Profiling(PROFILE): - # with Timing("***** model run in "): - # for ei in eis: ei.run() - - # this is all wait - #with Profiling(PROFILE): - # with Timing("***** model finish in "): - # out.data() - + with Timing("***** model lower in "): uops = [ast_to_uop(k.get_optimized_ast(), k.opts) for k in kernels] + with Profiling(PROFILE, fn="/tmp/rewrite.prof"): + with Timing("***** model rewrite in "): uops = [full_graph_rewrite(u, k.opts) for u in uops] + if getenv("LINEARIZE", 1): + with Timing("***** model linearize in "): uops = [linearize_uop(u, skip_check=False) for u in uops] + print(sum(len(u) for u in uops)) + if getenv("GRAPHUOPS", 0): + for u in uops: + from tinygrad.engine.graph import graph_uops + graph_uops(u) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 55ee013e6b..b71f0fbb0d 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -417,7 +417,7 @@ def do_expand(root:UOp): acc_number = 0 def do_reduce(root:UOp): global acc_number - reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents) + reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents) ret = root.src[0] if len(reduce_parented): assert root.dtype is not None @@ -495,7 +495,7 @@ reducer = PatternMatcher([ (UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), ]) -no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"), +no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.pyint, name="x"), lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]) # *** uop graph *** @@ -527,8 +527,8 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # expand linearize_cnt += 1 if linearize_cnt != getenv("DEBUG_EXPAND", 0): - sink = graph_rewrite(sink, folder+expander+float4_folding if opts is not None and opts.supports_float4 else folder+expander) - sink = graph_rewrite(sink, folder+expander+reducer) + sink = graph_rewrite(sink, folder+(expander+float4_folding if opts is not None and opts.supports_float4 else expander)) + sink = graph_rewrite(sink, folder+reducer) # for PTX only if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 917992e454..07b1b9418e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -5,7 +5,7 @@ import math, operator, ctypes, struct, functools, hashlib, itertools from enum import Enum, auto from dataclasses import dataclass from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType -from tinygrad.helpers import merge_dicts, pretty_print, prod +from tinygrad.helpers import pretty_print, prod from tinygrad.shape.symbolic import Variable, sint if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker @@ -166,7 +166,7 @@ class UOp: @staticmethod def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src) @functools.cached_property - def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src]) + def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}} @property # parents with self def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} @functools.cached_property @@ -301,9 +301,11 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: replace: Dict[UOp, UOp] = {} def __inner_rewrite(n:UOp) -> UOp: if rn := replace.get(n): return rn - replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg) + replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg) if found := nodes.get(replace_source): replace[n] = found - else: nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x + else: + x = UOp(*replace_source) if new_src != n.src else n + nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x return found return __inner_rewrite(sink)