use matchers to verify AST shape [pr] (#8828)

* use matchers to verify kernel AST [pr]

* work

* use swizzle_cnt

* add comment

* imports

* modified_ast comment

* brief
This commit is contained in:
qazal
2025-01-31 02:17:42 -05:00
committed by GitHub
parent 643c09a6c6
commit af4f9d1aa9
4 changed files with 32 additions and 53 deletions

View File

@@ -14,12 +14,13 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec)
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if to_prerealize:
@@ -1824,7 +1825,7 @@ class TestIndexing(unittest.TestCase):
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
self.assertEqual(swizzle_cnt(sink), 1)
verify_ast(rsink)
def test_no_reshape_reduceop(self):

View File

@@ -5,7 +5,7 @@ from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import DEBUG
from tinygrad.ops import UOp, Ops, print_uops
from tinygrad.codegen.kernel import verify_ast
from tinygrad.spec import type_verify, shape_spec
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad import dtypes
from tinygrad.shape.view import View
@@ -15,8 +15,8 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
sink = UOp(Ops.SINK, dtypes.void, stores)
if DEBUG >= 3:
for op in stores: print(op)
try: verify_ast(sink)
except AssertionError as e: raise InvalidASTException(e.args)
try: type_verify(list(sink.toposort), shape_spec)
except RuntimeError as e: raise InvalidASTException(e.args)
k = Kernel(sink)
k.linearize()
if DEBUG >= 6: print_uops(k.uops)
@@ -64,23 +64,24 @@ class TestVerifyAST(unittest.TestCase):
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_reduce_add_store(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_buffer_uops_st(self):
a = Tensor.randn(4, 4)+2
verify_ast(ast:=a.schedule()[-1].ast)
helper_test_verify_ast(ast:=a.schedule()[-1].ast)
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [u.st for u in ast.toposort if u.op is Ops.CONST][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
@unittest.skip("questionable if we want this")
def test_assert_swizzle(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))

View File

@@ -6,7 +6,7 @@ from typing import Optional, cast, Final, Callable, Sequence
from enum import Enum, auto
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
from tinygrad.spec import type_verify
from tinygrad.spec import type_verify, shape_spec
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
from tinygrad.dtype import ImageDType
@@ -57,11 +57,8 @@ class Kernel:
if ast.op is Ops.SINK: self.ast = ast
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
try: verify_ast(self.ast)
except AssertionError as e:
print("INVALID AST")
print(self.ast)
raise e
# verify AST matches the spec
if __debug__: type_verify(list(self.ast.toposort), shape_spec)
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
@@ -673,7 +670,10 @@ class Kernel:
if getenv("RAWAST"): print(self.ast)
print(modified_ast)
print(self.applied_opts)
verify_ast(modified_ast)
# verify AST matches the spec after applying opts
if __debug__: type_verify(list(modified_ast.toposort))
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
if DEBUG >= 5: print_uops(self.uops)
@@ -693,39 +693,3 @@ class Kernel:
key=lambda x: (x.op, x.src[0].arg)))
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of intermediate UOps
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None:
if uop in sts: return
# restore globals from the two stage reduce
# this is because this LOAD has an implicit movement op
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
sts[uop] = sts[local_reduce]
return
for x in uop.src: _assert_valid_uop(x, st, sts)
# only reduceuop is allowed to change shape, limited to turning n to 1
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
# movementops are pushed to VIEW
elif uop.op is Ops.VIEW:
# NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}"
st = uop.arg
# everything else inherits shape
else:
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
st = src_sts[0]
if not all_same(shapes:=[x.shape for x in src_sts]):
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
raise AssertionError(f"found implicit expand {sizes} {shapes}")
sts[uop] = st
def verify_ast(ast:UOp) -> None:
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
sts: dict[UOp, ShapeTracker] = {}
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
type_verify(list(sts))

View File

@@ -1,7 +1,7 @@
from typing import cast
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
from tinygrad.helpers import all_int, prod
from tinygrad.helpers import all_int, all_same, dedup, prod
# *** this is the spec of a Tensor in UOp ***
@@ -61,7 +61,7 @@ spec = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# early LOAD has a <buf, shapetracker, store?>
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
@@ -121,6 +121,19 @@ spec = PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
])
# *** this is the UOp shape spec ***
def verify_sink_dims(sink:UOp):
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
shape_spec = PatternMatcher([
# shapes must have either 1 or n in each dimension
(UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
# all parent UOps must have the same shape
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
])
# ***** uop helpers *****
def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):