mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more actionable verify_lazyop assert (#5635)
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Union, Tuple, Any, List, Dict, Callable
|
||||
import functools, hashlib, math, operator, ctypes, struct
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import dedup, pretty_print
|
||||
from tinygrad.helpers import prod, dedup, pretty_print
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -144,13 +144,13 @@ def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return
|
||||
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
|
||||
assert ast.op is MetaOps.KERNEL, "must be SINK"
|
||||
sts: Dict[LazyOp, ShapeTracker] = {}
|
||||
def dfs(op:LazyOp, st:ShapeTracker):
|
||||
def assert_valid(op:LazyOp, st:ShapeTracker):
|
||||
if op in sts: return
|
||||
# restore globals from the two stage reduce
|
||||
if op.op is BufferOps.LOAD and op.arg.idx == -1:
|
||||
dfs(local_reduce:=op.src[0].src[0], op.arg.st)
|
||||
assert_valid(local_reduce:=op.src[0].src[0], op.arg.st)
|
||||
return sts.setdefault(op, sts[local_reduce])
|
||||
for x in op.src: dfs(x, st)
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op in ReduceOps:
|
||||
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
|
||||
@@ -158,13 +158,16 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
|
||||
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
|
||||
else:
|
||||
# movementops are pushed to the edges with LOAD
|
||||
if op.op in BufferOps: st = op.arg.st
|
||||
else: st = sts[op.src[0]]
|
||||
for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
|
||||
# elementwise inherits shape
|
||||
st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
|
||||
for x in op.src:
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {sts[x].shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {sts[x].shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[op] = st
|
||||
for i, out in enumerate(ast.src):
|
||||
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
|
||||
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
|
||||
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
|
||||
dfs(out, out.arg.st)
|
||||
assert_valid(out, out.arg.st)
|
||||
return sts
|
||||
|
||||
Reference in New Issue
Block a user