mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
verify_ast prep refactor for intermediate uops type spec (#6135)
* refactor to ops * refactor to two functions * the uop's shape become local_reduce
This commit is contained in:
@@ -760,30 +760,35 @@ class Kernel:
|
||||
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
# the living definition of UOps.SHAPETRACKER
|
||||
# the living definition of UOp st_arg
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
||||
if uop in sts: return
|
||||
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
return
|
||||
for x in src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[1][-1] if arg[0] is ReduceOps.WMMA else arg[1]))
|
||||
else:
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[-1]]
|
||||
for x in (src[1:] if op in BUFFER_UOPS else src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[uop] = st
|
||||
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
def assert_valid(op:UOp, st:ShapeTracker):
|
||||
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
|
||||
assert_valid(local_reduce:=op.src[2].src[2], op.st_arg)
|
||||
return sts.setdefault(op, sts[local_reduce])
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
|
||||
else:
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = op.arg if op.op is UOps.SHAPETRACKER else sts[op.src[-1]]
|
||||
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[op] = st
|
||||
for out in ast.src: assert_valid(out, out.st_arg)
|
||||
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
return sts
|
||||
|
||||
Reference in New Issue
Block a user