diff --git a/test/external/external_benchmark_ast.py b/test/external/external_benchmark_ast.py deleted file mode 100644 index 9270f8faa1..0000000000 --- a/test/external/external_benchmark_ast.py +++ /dev/null @@ -1,59 +0,0 @@ -import time, pickle -import plotly.graph_objects as go -from typing import Dict, List, Tuple -from extra.models.resnet import ResNet50 -from tinygrad import Tensor -from tinygrad.codegen.kernel import Kernel -from tinygrad.helpers import Context, getenv, to_function_name -from tinygrad.engine.schedule import _get_output_groups, _lower_lazybuffer -from tinygrad.engine.lazy import LazyBuffer -from tinygrad.ops import UOp, UOps - -if __name__ == "__main__": - mdl = ResNet50() - img = Tensor.empty(64, 3, 224, 224) - out = mdl(img) - output_groups, realizes, _ = _get_output_groups(out.lazydata.lbs, set()) - - asts: List[UOp] = [] - no_rewrite: List[float] = [] - for k,v in output_groups.items(): - st = time.perf_counter_ns() - lsi = _lower_lazybuffer(v, realizes)[0] - et = time.perf_counter_ns() - st - if lsi.ast.op is UOps.EXT: continue - no_rewrite.append(et*1e-6) - asts.append(lsi.ast) - - rewrite: List[float] = [] - bufs: List[List[LazyBuffer]] = [] - with Context(AST_REWRITE=1): - for k,v in output_groups.items(): - st = time.perf_counter_ns() - lsi = _lower_lazybuffer(v, realizes)[0] - bufs.append(v) - et = time.perf_counter_ns() - st - if lsi.ast.op is UOps.EXT: continue - rewrite.append(et*1e-6) - - assert len(rewrite) == len(no_rewrite) == len(asts) - - kernel_tms: Dict[bytes, Tuple[UOp, float, float, List[LazyBuffer]]] = {k.key:(k, no_rewrite[i], rewrite[i], bufs[i]) for i,k in enumerate(asts)} - pct_change: Dict[bytes, float] = {k:((x-y)/x)*100 for k,(_,x,y,_) in kernel_tms.items()} - slowest_kernels = list(sorted(pct_change.items(), key=lambda x:x[1])) - names = {ast.key:Kernel(ast).name for ast,_,_,_ in kernel_tms.values()} - print("slowest ast rewrites:") - for k,pct in slowest_kernels[:10]: - _, no_rw, rw, outs = kernel_tms[k] - print(f"{names[k]:10s} {no_rw:4.2f} ms -> {rw:4.2f} ms {pct:4.2f}%") - with open("/tmp/kernel_tms", "wb") as f: pickle.dump(kernel_tms, f) - - if getenv("GRAPH_TIMING"): - sample = slowest_kernels[:20] - x: List[str] = [to_function_name(names[k]) for k,_ in sample] - y1, y2 = [kernel_tms[k][1] for k,_ in sample], [kernel_tms[k][2] for k,_ in sample] - fig = go.Figure(data=[go.Bar(name="no graph_rewrite", x=x, y=y1, marker=dict(color="#524eed", line=dict(color='rgba(0,0,0,0)'))), - go.Bar(name="graph_rewrite", x=x, y=y2, marker=dict(color="#6fcf97", line=dict(color='rgba(0,0,0,0)')))]) - fig.update_layout(barmode="group", paper_bgcolor="black", plot_bgcolor="black", - font={"color":"white"}, yaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"}) - fig.show() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2974f60a9f..2126c150bd 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,15 +2,15 @@ import sys, pickle, atexit from collections import defaultdict, deque from dataclasses import dataclass from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast -from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \ - graph_rewrite, track_rewrites, Variable, sint +from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \ + graph_rewrite, track_rewrites, sint from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \ colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap from tinygrad.dtype import ImageDType, dtypes -from tinygrad.engine.lazy import LazyBuffer from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.device import Buffer from tinygrad.shape.view import View, strides_for_shape +from tinygrad.engine.lazy import LazyBuffer +from tinygrad.device import Buffer # creation can recurse a lot sys.setrecursionlimit(10000) @@ -275,11 +275,12 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={}) return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) -def _get_output_groups(outs:List[LazyBuffer]) -> \ - Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups - Dict[Buffer, UOp], # this is a map of realized Buffers to UOps.BUFFER - Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule - """find all the realizes in the graph, group the output LazyBuffers into kernels.""" +SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = [] +def _graph_schedule(outs:List[LazyBuffer]) -> \ + Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph + DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph + Dict[Variable, int]]: # this has all the var values of the schedule + """create a graph for realizing the outputs""" # start by just realizing the buffers passed in realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None} allbufs: Dict[LazyBuffer, None] = {} @@ -373,20 +374,12 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \ # NOTE: UOps.BUFFER creation must come after the ImageDType fixup else: uop = UOp(UOps.BUFFER, buf.buffer.dtype.ptr(), (), (len(buf_uops), (buf.buffer.device, buf.buffer.size, buf.buffer.dtype))) buf_uops.setdefault(buf.buffer, uop) - return output_groups, buf_uops, assign_targets -SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = [] -def _graph_schedule(outs:List[LazyBuffer]) -> \ - Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph - DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph - Dict[Variable, int]]: # this has all the var values of the schedule - """create a graph for realizing the outputs""" - output_groups, buf_uops, assign_targets = _get_output_groups(outs) # preschedule all buffers in realizes prescheduled: List[LBScheduleItem] = [] var_vals: Dict[Variable, int] = {} - for group in output_groups.values(): - prescheduled.append((ret:=_lower_lazybuffer(group, buf_uops))[0]) + for outs in output_groups.values(): + prescheduled.append((ret:=_lower_lazybuffer(outs, buf_uops))[0]) var_vals = merge_dicts([var_vals, ret[1]]) schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}