mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
viz: one name arg in track_rewrites [pr] (#10873)
* viz: one name arg in track_rewrites [pr] * other test
This commit is contained in:
@@ -66,7 +66,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return get_kernelize_map(big_sink)[big_sink]
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_viz_simple(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def test(sink): return graph_rewrite(sink, symbolic)
|
||||
test(a*1)
|
||||
ret = get_metadata(keys, contexts)
|
||||
@@ -41,7 +41,7 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_track_two_rewrites(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def test(sink): return graph_rewrite(sink, symbolic)
|
||||
test((a+a)*1)
|
||||
ret = get_metadata(keys, contexts)
|
||||
@@ -53,7 +53,7 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_track_multiple_calls_one_ctx(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def test(a, b):
|
||||
a = graph_rewrite(a, symbolic)
|
||||
b = graph_rewrite(b, symbolic)
|
||||
@@ -67,7 +67,7 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(val[1]["match_count"], 0) # no rewrites for a*5
|
||||
|
||||
def test_track_rewrites(self):
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def do_rewrite(x:UOp): return graph_rewrite(x, symbolic)
|
||||
a = UOp.variable("a", 0, 10)
|
||||
b = UOp.variable("b", 0, 4)
|
||||
@@ -93,7 +93,7 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_track_rewrites_name_fxn(self):
|
||||
@track_rewrites(name_fxn=lambda _,ret: f"output_{ret}")
|
||||
@track_rewrites(name=lambda _,ret: f"output_{ret}")
|
||||
def do_rewrite(x:UOp):
|
||||
x = graph_rewrite(x, symbolic)
|
||||
return x.render()
|
||||
@@ -109,7 +109,7 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_name_in_positional_arg(self):
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def test(sink): return graph_rewrite(sink, symbolic, None, False, "name")
|
||||
test(UOp.variable("a", 0, 1))
|
||||
self.assertEqual(contexts[0].pop().name, "name")
|
||||
@@ -125,7 +125,7 @@ class TestViz(unittest.TestCase):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
b = UOp.variable("b", 0, 10)
|
||||
c = UOp.variable("c", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def fxn(sink): return graph_rewrite(sink, substitute, ctx={a+b:c}, bottom_up=True)
|
||||
fxn(a+b)
|
||||
#UOp.substitute(a+b, {a+b:c})
|
||||
@@ -137,7 +137,7 @@ class TestViz(unittest.TestCase):
|
||||
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
|
||||
def test_rewrite_without_context(self):
|
||||
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
# test
|
||||
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
|
||||
@@ -149,7 +149,7 @@ class TestViz(unittest.TestCase):
|
||||
def test_inner_rewrite_location(self):
|
||||
# inner rewrite gets tracked in another context
|
||||
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
|
||||
# test
|
||||
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
|
||||
@@ -171,7 +171,7 @@ class TestViz(unittest.TestCase):
|
||||
y2 = graph_rewrite(y, inner_rewrite, name="inner_y")
|
||||
return None if (x2 is x and y2 is y) else x2+y2
|
||||
outer_rewrite = TrackedPatternMatcher([(UPat.cvar("x")+UPat.cvar("y"), make_float),])
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def rewrite(u:UOp): return graph_rewrite(u, outer_rewrite, name="outer")
|
||||
a = UOp.const(dtypes.int, 1)+UOp.const(dtypes.int, 2)
|
||||
rewrite(a)
|
||||
@@ -186,7 +186,7 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
|
||||
|
||||
def test_depth_level(self):
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def fxn(u:UOp): return graph_rewrite(u, l0)
|
||||
ret = fxn(UOp(Ops.CUSTOM, arg=0))
|
||||
assert ret is UOp(Ops.CUSTOM, arg=3)
|
||||
@@ -216,10 +216,10 @@ class TestViz(unittest.TestCase):
|
||||
@unittest.skip("TODO: doesn't work")
|
||||
def test_recursion_err(self):
|
||||
inf = TrackedPatternMatcher([
|
||||
(UPat.const(dtypes.int, 0).named("a"), lambda a: a.const_like(1)),
|
||||
(UPat.const(dtypes.int, 1).named("b"), lambda b: b.const_like(0)),
|
||||
(UPat.const(dtypes.int, 0).name("a"), lambda a: a.const_like(1)),
|
||||
(UPat.const(dtypes.int, 1).name("b"), lambda b: b.const_like(0)),
|
||||
])
|
||||
@track_rewrites(named=True)
|
||||
@track_rewrites(name=True)
|
||||
def func(u): return graph_rewrite(u, inf)
|
||||
with self.assertRaises(RecursionError): func(UOp.const(dtypes.int, 0))
|
||||
_ = list(get_details(keys[0], contexts[0][0]))
|
||||
|
||||
@@ -420,7 +420,7 @@ finalize_gbarrier = PatternMatcher([
|
||||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
@track_rewrites(name_fxn=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
@track_rewrites(name=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(big_sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
@@ -738,15 +738,15 @@ if getenv("CAPTURE_PROCESS_REPLAY"):
|
||||
def save_to_diskcache():
|
||||
for k,v in replay_capture.items(): diskcache_put("process_replay", k, v, prepickled=True)
|
||||
|
||||
def track_rewrites(named=False, name_fxn:Callable|None=None):
|
||||
def track_rewrites(name:Callable|bool|None=None):
|
||||
def _decorator(func):
|
||||
def __wrapper(*args, **kwargs):
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
if (count_names:=(named or name_fxn or not args)): _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
|
||||
if (count_names:=(name or not args)): _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
|
||||
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if count_names else args[0])
|
||||
tracked_ctxs.append([])
|
||||
ret = func(*args, **kwargs)
|
||||
if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}"
|
||||
if TRACK_MATCH_STATS >= 2 and callable(name): tracked_keys[-1] = f"{name(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}"
|
||||
if getenv("CAPTURE_PROCESS_REPLAY"):
|
||||
# find the unittest frame we're capturing in
|
||||
frm = sys._getframe(1)
|
||||
|
||||
Reference in New Issue
Block a user