verify_ast prep refactor for intermediate uops type spec (#6135)

* refactor to ops

* refactor to two functions

* the uop's shape become local_reduce
This commit is contained in:
qazal
2024-08-17 20:34:18 +08:00
committed by GitHub
parent d9ce664350
commit 41ac8bdd63

View File

@@ -760,30 +760,35 @@ class Kernel:
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of UOps.SHAPETRACKER
# the living definition of UOp st_arg
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if uop in sts: return
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
sts[uop] = sts[local_reduce]
return
for x in src: _assert_valid_uop(x, st, sts)
# only reduceuop is allowed to change shape, limited to turning n to 1
if op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[1][-1] if arg[0] is ReduceOps.WMMA else arg[1]))
else:
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[-1]]
for x in (src[1:] if op in BUFFER_UOPS else src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[uop] = st
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[2].src[2], op.st_arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = op.arg if op.op is UOps.SHAPETRACKER else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.st_arg)
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts