more actionable verify_lazyop assert (#5635)

This commit is contained in:
qazal
2024-07-23 00:06:11 +08:00
committed by GitHub
parent edc58e6b6e
commit fe6f9b2048

View File

@@ -3,7 +3,7 @@ from typing import Union, Tuple, Any, List, Dict, Callable
import functools, hashlib, math, operator, ctypes, struct
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import dedup, pretty_print
from tinygrad.helpers import prod, dedup, pretty_print
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
@@ -144,13 +144,13 @@ def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
assert ast.op is MetaOps.KERNEL, "must be SINK"
sts: Dict[LazyOp, ShapeTracker] = {}
def dfs(op:LazyOp, st:ShapeTracker):
def assert_valid(op:LazyOp, st:ShapeTracker):
if op in sts: return
# restore globals from the two stage reduce
if op.op is BufferOps.LOAD and op.arg.idx == -1:
dfs(local_reduce:=op.src[0].src[0], op.arg.st)
assert_valid(local_reduce:=op.src[0].src[0], op.arg.st)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: dfs(x, st)
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op in ReduceOps:
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
@@ -158,13 +158,16 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
else:
# movementops are pushed to the edges with LOAD
if op.op in BufferOps: st = op.arg.st
else: st = sts[op.src[0]]
for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
# elementwise inherits shape
st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
for x in op.src:
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {sts[x].shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {sts[x].shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for i, out in enumerate(ast.src):
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
dfs(out, out.arg.st)
assert_valid(out, out.arg.st)
return sts