diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index d7409c4cc0..6997444bcf 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -93,7 +93,7 @@ class TestViz(unittest.TestCase): self.assertEqual(len(ret), 1) def test_track_rewrites_name_fxn(self): - @track_rewrites(name_fxn=lambda r: f"output_{r}") + @track_rewrites(name_fxn=lambda _,ret: f"output_{ret}") def do_rewrite(x:UOp): x = graph_rewrite(x, symbolic) return x.render() diff --git a/test/web/test_viz.js b/test/web/test_viz.js index 69d43bd96c..80475734b9 100644 --- a/test/web/test_viz.js +++ b/test/web/test_viz.js @@ -14,7 +14,7 @@ async function main() { try { browser = await puppeteer.launch({ headless: true }); const page = await browser.newPage(); - const res = await page.goto("http://localhost:8000"); + const res = await page.goto("http://localhost:8000", { waitUntil:"domcontentloaded" }); if (res.status() !== 200) throw new Error("Failed to load page"); const scheduleSelector = await page.waitForSelector("ul"); scheduleSelector.click(); diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index eba82fe9ba..7538e65428 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -501,10 +501,6 @@ do_fuse = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange), ]) -def get_name(becomes_map:dict[UOp, UOp]) -> str: - assigned_kernels = {u.base.buf_uop:u.base.src[1] for u in becomes_map.values() if u.base.op is Ops.ASSIGN}.values() - return f"Schedule {pluralize('Kernel', len(set(assigned_kernels)))}" - add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"), lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)]) @@ -538,7 +534,7 @@ finalize_gbarrier = PatternMatcher([ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) -@track_rewrites(name_fxn=get_name) +@track_rewrites(name_fxn=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}") def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]: # multi + merge_views + simplify tensor_map = graph_rewrite_map(big_sink, multi_pm+replace_allreduce+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a065d0895a..9325a2c07b 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -910,7 +910,7 @@ def track_rewrites(named=False, name_fxn:Callable|None=None): tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if count_names else args[0]) tracked_ctxs.append([]) ret = func(*args, **kwargs) - if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(ret)} n{_name_cnt[func.__name__]}" + if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}" if getenv("CAPTURE_PROCESS_REPLAY"): # find the unittest frame we're capturing in frm = sys._getframe(1)