disallow some uops at different levels [run_process_replay] (#6186)

* assert intermediate ones

* assert low-level uops
This commit is contained in:
qazal
2024-08-20 02:23:44 +08:00
committed by GitHub
parent 5d742f7fe3
commit ee5fe12630
2 changed files with 2 additions and 1 deletions

View File

@@ -784,6 +784,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
# only reduceuop is allowed to change shape, limited to turning n to 1
if op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[1][-1] if arg[0] is ReduceOps.WMMA else arg[1]))
else:
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[-1]]

View File

@@ -589,7 +589,7 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
if not skip_check:
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE, UOps.REDUCE_AXIS, UOps.SHAPETRACKER}])
try:
type_verify(_uops)
assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}"