From c08521e8237afa7c1f62dce6fe1197dfb76b4db1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:19:10 +0800 Subject: [PATCH] minor cleanups from toonygrad (#6990) --- test/external/external_benchmark_schedule.py | 4 ++-- test/test_uop_graph.py | 4 ++-- tinygrad/codegen/kernel.py | 4 ++-- tinygrad/codegen/lowerer.py | 2 +- tinygrad/engine/graph.py | 8 ++------ tinygrad/helpers.py | 1 + tinygrad/ops.py | 1 + tinygrad/viz/serve.py | 10 +++++++--- 8 files changed, 18 insertions(+), 16 deletions(-) diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 86ac60a92a..782b5c9a57 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -4,7 +4,7 @@ from tinygrad import Tensor, Device from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen from tinygrad.ops import UOps from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.lowerer import ast_to_uop +from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite from tinygrad.engine.search import beam_search, bufs_from_lin @@ -37,7 +37,7 @@ if __name__ == "__main__": else: k.hand_coded_optimizations() kernels.append(k) - with Timing("***** model lower in "): uops = [ast_to_uop(k.get_optimized_ast(), k.opts) for k in kernels] + with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels] with Profiling(PROFILE, fn="/tmp/rewrite.prof"): with Timing("***** model rewrite in "): rewritten_uops = [] diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 0885784575..48e5d6404d 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -6,7 +6,7 @@ from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo from tinygrad.ops import UPat, PatternMatcher -from tinygrad.codegen.lowerer import ast_to_uop +from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding from tinygrad.shape.shapetracker import ShapeTracker, View @@ -50,7 +50,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase): UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - lower_sink = ast_to_uop(sink, Device[Device.DEFAULT].renderer) + lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer) cnt = [0] old_init = UOp.__init__ def uop_hook(self, *args, **kwargs): diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 85345213b1..90ea7fa05d 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -16,7 +16,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import strides_for_shape from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite -from tinygrad.codegen.lowerer import ast_to_uop, get_contraction +from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction class OptOps(Enum): TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 @@ -715,7 +715,7 @@ class Kernel: print(self.applied_opts) verify_ast(modified_ast) - self.uops:List[UOp] = linearize_uop(full_graph_rewrite(ast_to_uop(modified_ast, self.opts), self.opts)) + self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts)) if DEBUG >= 5: print_uops(self.uops) if getenv("GRAPHUOPS"): from tinygrad.engine.graph import graph_uops diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 2022f1c1f0..3cb804769b 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -132,4 +132,4 @@ pm_lowerer = PatternMatcher([ (UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store), ]) -def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts)) +def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts)) diff --git a/tinygrad/engine/graph.py b/tinygrad/engine/graph.py index 230fefe6d3..bb2a401eb4 100644 --- a/tinygrad/engine/graph.py +++ b/tinygrad/engine/graph.py @@ -3,8 +3,9 @@ from collections import defaultdict from typing import List, Any, DefaultDict from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp from tinygrad.device import Device -from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters +from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, word_wrap from tinygrad.engine.lazy import LazyBuffer +from tinygrad.viz.serve import uops_colors with contextlib.suppress(ImportError): import networkx as nx @@ -70,12 +71,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False): # realized but unseen? G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080") -uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0", - UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484", - UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff", - UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"} graph_uops_cnt = 0 -def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap)) def graph_uops(uops:List[UOp]): global graph_uops_cnt G = nx.DiGraph() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 69f047efec..2217c1c34f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -67,6 +67,7 @@ def get_child(obj, key): elif isinstance(obj, dict): obj = obj[k] else: obj = getattr(obj, k) return obj +def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap)) @functools.lru_cache(maxsize=None) def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 35ba0ab102..6ce289d804 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -99,6 +99,7 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps. class UOps(FastEnum): # uops that aren't rendered SINK = auto() + CONTIGUOUS = auto() # metaops CUSTOM = auto() diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1fde84acee..e74c57a1bc 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -4,11 +4,15 @@ from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from dataclasses import asdict, dataclass from typing import Any, Dict, List, Tuple, Optional -from tinygrad.helpers import colored, getenv, to_function_name, tqdm +from tinygrad.helpers import colored, getenv, to_function_name, tqdm, word_wrap from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines -from tinygrad.engine.graph import word_wrap, uops_colors from tinygrad.codegen.kernel import Kernel +uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0", + UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484", + UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff", + UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"} + # ** API spec @dataclass @@ -112,7 +116,6 @@ class Handler(BaseHTTPRequestHandler): # ** main loop -stop_reloader = threading.Event() def reloader(): mtime = os.stat(__file__).st_mtime while not stop_reloader.is_set(): @@ -122,6 +125,7 @@ def reloader(): time.sleep(0.1) if __name__ == "__main__": + stop_reloader = threading.Event() multiprocessing.current_process().name = "VizProcess" # disallow opening of devices st = time.perf_counter() print("*** viz is starting")