mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
simple VIZ=1 and get_location changes (#6599)
* simpler replace * this get_location is fine? * python things * ctx location
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast,
|
||||
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
|
||||
from enum import auto, IntEnum, Enum
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
|
||||
from tinygrad.helpers import pretty_print, prod, getenv, all_same
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
@@ -373,7 +373,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
|
||||
def get_location() -> Tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame
|
||||
# find the real frame in the file that has the UPat
|
||||
while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}):
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@@ -486,9 +486,9 @@ TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
||||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
@dataclass(frozen=True)
|
||||
class TrackedRewriteContext:
|
||||
loc: str # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] # all rewrites of sparents. (before, after, UPat)
|
||||
loc: str # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
|
||||
contexts: List[TrackedRewriteContext] = []
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
@@ -550,7 +550,7 @@ class RewriteContext:
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
|
||||
return found
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(l:=get_location())[0].split('/')[-1]}:{l[1]}", sink, []))
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink))
|
||||
return RewriteContext(pm).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
15
viz/serve.py
15
viz/serve.py
@@ -34,13 +34,10 @@ class UOpRet:
|
||||
diffs: List[Tuple[str, str, List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
|
||||
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp:
|
||||
if (found:=cache.get(base.key)): return found
|
||||
if base.key == prev.key: ret = new
|
||||
else:
|
||||
new_srcs = tuple(replace_uop(x, prev, new, cache) for x in base.src)
|
||||
ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
|
||||
cache[base.key] = ret
|
||||
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
|
||||
if (found:=replaces.get(base.key)): return found
|
||||
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
|
||||
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
|
||||
return ret
|
||||
|
||||
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
@@ -52,7 +49,8 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
|
||||
if pattern.location[0].split("/")[-1] == "ops.py": continue
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
new_sink = replace_uop(uops[-1], first, rewritten, {**seen_replaces})
|
||||
seen_replaces[first.key] = rewritten
|
||||
new_sink = replace_uop(uops[-1], {**seen_replaces})
|
||||
# sanity check
|
||||
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
|
||||
# update ret data
|
||||
@@ -62,7 +60,6 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
graphs.append((new_sink, uops[-1], rewritten, first))
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
seen_replaces[first.key] = rewritten
|
||||
return UOpRet(ctx.loc, graphs, diffs, extra)
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
@@ -88,6 +88,7 @@ class TestViz(unittest.TestCase):
|
||||
new_sink = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc.split("/")[-1].split(":")[0] == __file__.split("/")[-1] for ctx in contexts)
|
||||
|
||||
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
|
||||
def test_fuzz_resnet(self):
|
||||
|
||||
Reference in New Issue
Block a user