regression test viz failure when there's no tracked context (#8297)

* regression test viz failure when there's no tracked context

* test inner rewrite locations, keep notes
This commit is contained in:
qazal
2024-12-17 23:23:27 +02:00
committed by GitHub
parent 777d2aec05
commit 5977a3d8a6

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, Optional
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
@@ -113,6 +113,33 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(ret), 1)
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
@unittest.expectedFailure
def test_rewrite_without_context(self):
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
untracked_graph_rewrite(add)
self.assertEqual(len(contexts), 0)
tracked_graph_rewrite(add)
self.assertEqual(len(contexts), 1)
def test_inner_rewrite_location(self):
# inner rewrite gets tracked in another context
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
tracked_graph_rewrite(add)
self.assertEqual(len(contexts), 1)
# location of context is inner_rewrite
fp, lineno = contexts[0][0].loc
self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno)
self.assertEqual(fp, inner_rewrite.__code__.co_filename)
class TextVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),