mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
regression test viz failure when there's no tracked context (#8297)
* regression test viz failure when there's no tracked context * test inner rewrite locations, keep notes
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
|
||||
@@ -113,6 +113,33 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(ret), 1)
|
||||
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
||||
|
||||
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
|
||||
@unittest.expectedFailure
|
||||
def test_rewrite_without_context(self):
|
||||
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
@track_rewrites(named=True)
|
||||
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
# test
|
||||
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
|
||||
untracked_graph_rewrite(add)
|
||||
self.assertEqual(len(contexts), 0)
|
||||
tracked_graph_rewrite(add)
|
||||
self.assertEqual(len(contexts), 1)
|
||||
|
||||
def test_inner_rewrite_location(self):
|
||||
# inner rewrite gets tracked in another context
|
||||
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
@track_rewrites(named=True)
|
||||
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
|
||||
# test
|
||||
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
|
||||
tracked_graph_rewrite(add)
|
||||
self.assertEqual(len(contexts), 1)
|
||||
# location of context is inner_rewrite
|
||||
fp, lineno = contexts[0][0].loc
|
||||
self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno)
|
||||
self.assertEqual(fp, inner_rewrite.__code__.co_filename)
|
||||
|
||||
class TextVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
||||
|
||||
Reference in New Issue
Block a user