mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
use matchers to verify AST shape [pr] (#8828)
* use matchers to verify kernel AST [pr] * work * use swizzle_cnt * add comment * imports * modified_ast comment * brief
This commit is contained in:
@@ -14,12 +14,13 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec)
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if to_prerealize:
|
||||
@@ -1824,7 +1825,7 @@ class TestIndexing(unittest.TestCase):
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# this AST first needs to swizzle, but it doesn't have implicit movementops
|
||||
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
|
||||
self.assertEqual(swizzle_cnt(sink), 1)
|
||||
verify_ast(rsink)
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
|
||||
Reference in New Issue
Block a user