mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
use matchers to verify AST shape [pr] (#8828)
* use matchers to verify kernel AST [pr] * work * use swizzle_cnt * add comment * imports * modified_ast comment * brief
This commit is contained in:
@@ -14,12 +14,13 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec)
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if to_prerealize:
|
||||
@@ -1824,7 +1825,7 @@ class TestIndexing(unittest.TestCase):
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# this AST first needs to swizzle, but it doesn't have implicit movementops
|
||||
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
|
||||
self.assertEqual(swizzle_cnt(sink), 1)
|
||||
verify_ast(rsink)
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad import Tensor
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import UOp, Ops, print_uops
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.shape.view import View
|
||||
@@ -15,8 +15,8 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||
sink = UOp(Ops.SINK, dtypes.void, stores)
|
||||
if DEBUG >= 3:
|
||||
for op in stores: print(op)
|
||||
try: verify_ast(sink)
|
||||
except AssertionError as e: raise InvalidASTException(e.args)
|
||||
try: type_verify(list(sink.toposort), shape_spec)
|
||||
except RuntimeError as e: raise InvalidASTException(e.args)
|
||||
k = Kernel(sink)
|
||||
k.linearize()
|
||||
if DEBUG >= 6: print_uops(k.uops)
|
||||
@@ -64,23 +64,24 @@ class TestVerifyAST(unittest.TestCase):
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_buffer_uops_st(self):
|
||||
a = Tensor.randn(4, 4)+2
|
||||
verify_ast(ast:=a.schedule()[-1].ast)
|
||||
helper_test_verify_ast(ast:=a.schedule()[-1].ast)
|
||||
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [u.st for u in ast.toposort if u.op is Ops.CONST][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
@unittest.skip("questionable if we want this")
|
||||
def test_assert_swizzle(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, cast, Final, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
|
||||
from tinygrad.spec import type_verify
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -57,11 +57,8 @@ class Kernel:
|
||||
if ast.op is Ops.SINK: self.ast = ast
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
try: verify_ast(self.ast)
|
||||
except AssertionError as e:
|
||||
print("INVALID AST")
|
||||
print(self.ast)
|
||||
raise e
|
||||
# verify AST matches the spec
|
||||
if __debug__: type_verify(list(self.ast.toposort), shape_spec)
|
||||
|
||||
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
||||
|
||||
@@ -673,7 +670,10 @@ class Kernel:
|
||||
if getenv("RAWAST"): print(self.ast)
|
||||
print(modified_ast)
|
||||
print(self.applied_opts)
|
||||
verify_ast(modified_ast)
|
||||
# verify AST matches the spec after applying opts
|
||||
if __debug__: type_verify(list(modified_ast.toposort))
|
||||
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
||||
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
|
||||
|
||||
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
||||
if DEBUG >= 5: print_uops(self.uops)
|
||||
@@ -693,39 +693,3 @@ class Kernel:
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
# the living definition of intermediate UOps
|
||||
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None:
|
||||
if uop in sts: return
|
||||
# restore globals from the two stage reduce
|
||||
# this is because this LOAD has an implicit movement op
|
||||
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
return
|
||||
for x in uop.src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
||||
# movementops are pushed to VIEW
|
||||
elif uop.op is Ops.VIEW:
|
||||
# NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
|
||||
assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}"
|
||||
st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
|
||||
st = src_sts[0]
|
||||
if not all_same(shapes:=[x.shape for x in src_sts]):
|
||||
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
||||
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
||||
sts[uop] = st
|
||||
|
||||
def verify_ast(ast:UOp) -> None:
|
||||
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
|
||||
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
||||
sts: dict[UOp, ShapeTracker] = {}
|
||||
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
type_verify(list(sts))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
||||
from tinygrad.helpers import all_int, prod
|
||||
from tinygrad.helpers import all_int, all_same, dedup, prod
|
||||
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
|
||||
@@ -61,7 +61,7 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
|
||||
|
||||
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# early LOAD has a <buf, shapetracker, store?>
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
||||
@@ -121,6 +121,19 @@ spec = PatternMatcher([
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
])
|
||||
|
||||
# *** this is the UOp shape spec ***
|
||||
|
||||
def verify_sink_dims(sink:UOp):
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
|
||||
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
|
||||
|
||||
shape_spec = PatternMatcher([
|
||||
# shapes must have either 1 or n in each dimension
|
||||
(UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
|
||||
# all parent UOps must have the same shape
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
||||
])
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):
|
||||
|
||||
Reference in New Issue
Block a user