use matchers to verify AST shape [pr] (#8828)

* use matchers to verify kernel AST [pr]

* work

* use swizzle_cnt

* add comment

* imports

* modified_ast comment

* brief
This commit is contained in:
qazal
2025-01-31 02:17:42 -05:00
committed by GitHub
parent 643c09a6c6
commit af4f9d1aa9
4 changed files with 32 additions and 53 deletions

View File

@@ -14,12 +14,13 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec)
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if to_prerealize:
@@ -1824,7 +1825,7 @@ class TestIndexing(unittest.TestCase):
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
self.assertEqual(swizzle_cnt(sink), 1)
verify_ast(rsink)
def test_no_reshape_reduceop(self):