From 5977a3d8a671b95142bdf5fc154a46eec3f5c5bf Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 17 Dec 2024 23:23:27 +0200 Subject: [PATCH] regression test viz failure when there's no tracked context (#8297) * regression test viz failure when there's no tracked context * test inner rewrite locations, keep notes --- test/test_viz.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_viz.py b/test/test_viz.py index 1a94e95cd0..c5a4bace74 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional import unittest, decimal, json from tinygrad.dtype import dtypes -from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites +from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto @@ -113,6 +113,33 @@ class TestViz(unittest.TestCase): self.assertEqual(len(ret), 1) self.assertIs(ret[0], a.sqrt().sin()) # only rewrite + # NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ + @unittest.expectedFailure + def test_rewrite_without_context(self): + def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic) + @track_rewrites(named=True) + def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic) + # test + add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1) + untracked_graph_rewrite(add) + self.assertEqual(len(contexts), 0) + tracked_graph_rewrite(add) + self.assertEqual(len(contexts), 1) + + def test_inner_rewrite_location(self): + # inner rewrite gets tracked in another context + def inner_rewrite(sink): return graph_rewrite(sink, symbolic) + @track_rewrites(named=True) + def tracked_graph_rewrite(sink): return inner_rewrite(sink) + # test + add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1) + tracked_graph_rewrite(add) + self.assertEqual(len(contexts), 1) + # location of context is inner_rewrite + fp, lineno = contexts[0][0].loc + self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno) + self.assertEqual(fp, inner_rewrite.__code__.co_filename) + class TextVizProfiler(unittest.TestCase): def test_perfetto_node(self): prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),