simple VIZ=1 and get_location changes (#6599)

* simpler replace

* this get_location is fine?

* python things

* ctx location
This commit is contained in:
qazal
2024-09-19 15:58:33 +08:00
committed by GitHub
parent eeee032b14
commit 94effe2a71
3 changed files with 13 additions and 15 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast,
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
from enum import auto, IntEnum, Enum
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
from tinygrad.helpers import pretty_print, prod, getenv, all_same
from tinygrad.shape.symbolic import Variable, sint
@@ -373,7 +373,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
def get_location() -> Tuple[str, int]:
frm = sys._getframe(1)
# find the real frame
# find the real frame in the file that has the UPat
while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}):
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@@ -486,9 +486,9 @@ TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedRewriteContext:
loc: str # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
rewrites: List[Tuple[UOp, UOp, UPat]] # all rewrites of sparents. (before, after, UPat)
loc: str # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
contexts: List[TrackedRewriteContext] = []
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
@@ -550,7 +550,7 @@ class RewriteContext:
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(l:=get_location())[0].split('/')[-1]}:{l[1]}", sink, []))
if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink))
return RewriteContext(pm).rewrite(sink)
# ***** uop type spec *****

View File

@@ -34,13 +34,10 @@ class UOpRet:
diffs: List[Tuple[str, str, List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp:
if (found:=cache.get(base.key)): return found
if base.key == prev.key: ret = new
else:
new_srcs = tuple(replace_uop(x, prev, new, cache) for x in base.src)
ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
cache[base.key] = ret
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
if (found:=replaces.get(base.key)): return found
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
return ret
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
@@ -52,7 +49,8 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
if pattern.location[0].split("/")[-1] == "ops.py": continue
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
new_sink = replace_uop(uops[-1], first, rewritten, {**seen_replaces})
seen_replaces[first.key] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
# update ret data
@@ -62,7 +60,6 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
graphs.append((new_sink, uops[-1], rewritten, first))
uops.append(new_sink)
extra.append([str(new_sink)])
seen_replaces[first.key] = rewritten
return UOpRet(ctx.loc, graphs, diffs, extra)
class Handler(BaseHTTPRequestHandler):

View File

@@ -88,6 +88,7 @@ class TestViz(unittest.TestCase):
new_sink = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
self.assert_valid_ctx(contexts)
assert all(ctx.loc.split("/")[-1].split(":")[0] == __file__.split("/")[-1] for ctx in contexts)
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
def test_fuzz_resnet(self):