From 94effe2a7151f77b4b79ab9ba212e23bfd419805 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:58:33 +0800 Subject: [PATCH] simple VIZ=1 and get_location changes (#6599) * simpler replace * this get_location is fine? * python things * ctx location --- tinygrad/ops.py | 12 ++++++------ viz/serve.py | 15 ++++++--------- viz/test_viz.py | 1 + 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 94b730ee06..ea9d873657 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib from enum import auto, IntEnum, Enum from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType from tinygrad.helpers import pretty_print, prod, getenv, all_same from tinygrad.shape.symbolic import Variable, sint @@ -373,7 +373,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: def get_location() -> Tuple[str, int]: frm = sys._getframe(1) - # find the real frame + # find the real frame in the file that has the UPat while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}): frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @@ -486,9 +486,9 @@ TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) match_stats:Dict[UPat, List[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedRewriteContext: - loc: str # location that called graph_rewrite - sink: UOp # the sink passed into the rewrite - rewrites: List[Tuple[UOp, UOp, UPat]] # all rewrites of sparents. (before, after, UPat) + loc: str # location that called graph_rewrite + sink: UOp # the sink passed into the rewrite + rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat) contexts: List[TrackedRewriteContext] = [] class TrackedPatternMatcher(PatternMatcher): def __init__(self, patterns:List[Tuple[UPat, Callable]]): @@ -550,7 +550,7 @@ class RewriteContext: self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x return found def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: - if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(l:=get_location())[0].split('/')[-1]}:{l[1]}", sink, [])) + if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink)) return RewriteContext(pm).rewrite(sink) # ***** uop type spec ***** diff --git a/viz/serve.py b/viz/serve.py index d485846672..f3280adb45 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -34,13 +34,10 @@ class UOpRet: diffs: List[Tuple[str, str, List[str]]] # the diffs for each rewrite extra: List[List[str]] # these become code blocks in the UI -def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp: - if (found:=cache.get(base.key)): return found - if base.key == prev.key: ret = new - else: - new_srcs = tuple(replace_uop(x, prev, new, cache) for x in base.src) - ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base - cache[base.key] = ret +def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: + if (found:=replaces.get(base.key)): return found + new_srcs = tuple(replace_uop(x, replaces) for x in base.src) + replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base return ret def create_graph(ctx:TrackedRewriteContext) -> UOpRet: @@ -52,7 +49,8 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet: for i, (first, rewritten, pattern) in enumerate(ctx.rewrites): if pattern.location[0].split("/")[-1] == "ops.py": continue # first, rewrite this UOp with the current rewrite + all the seen rewrites before this - new_sink = replace_uop(uops[-1], first, rewritten, {**seen_replaces}) + seen_replaces[first.key] = rewritten + new_sink = replace_uop(uops[-1], {**seen_replaces}) # sanity check assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}" # update ret data @@ -62,7 +60,6 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet: graphs.append((new_sink, uops[-1], rewritten, first)) uops.append(new_sink) extra.append([str(new_sink)]) - seen_replaces[first.key] = rewritten return UOpRet(ctx.loc, graphs, diffs, extra) class Handler(BaseHTTPRequestHandler): diff --git a/viz/test_viz.py b/viz/test_viz.py index 7ec2572899..7216e792ce 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -88,6 +88,7 @@ class TestViz(unittest.TestCase): new_sink = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, new_sink, unified=0) self.assert_valid_ctx(contexts) + assert all(ctx.loc.split("/")[-1].split(":")[0] == __file__.split("/")[-1] for ctx in contexts) @unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules") def test_fuzz_resnet(self):