mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
disallow some uops at different levels [run_process_replay] (#6186)
* assert intermediate ones * assert low-level uops
This commit is contained in:
@@ -784,6 +784,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[1][-1] if arg[0] is ReduceOps.WMMA else arg[1]))
|
||||
else:
|
||||
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[-1]]
|
||||
|
||||
@@ -589,7 +589,7 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check:
|
||||
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
|
||||
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE, UOps.REDUCE_AXIS, UOps.SHAPETRACKER}])
|
||||
try:
|
||||
type_verify(_uops)
|
||||
assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}"
|
||||
|
||||
Reference in New Issue
Block a user