clean up apply OptOps.PADTO [run_process_replay] (#6694)

This commit is contained in:
chenyu
2024-09-23 23:13:50 -04:00
committed by GitHub
parent f703180356
commit 4a2fa0b627

View File

@@ -441,10 +441,9 @@ class Kernel:
elif opt.op is OptOps.PADTO:
check(not self.vars, "does not work with symbolic shape")
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ops have f(0) = 0
if self.first_reduce <= axis:
check((r:=cast(UOp, self.reduceop)).arg[0] is BinaryOps.ADD and \
all(not isinstance(op.arg, Enum) or op.arg not in UNSAFE_PAD_OPS for sop in r.src for op in sop.parents), "cannot pad")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and self.first_reduce <= axis:
check(r.arg[0] is BinaryOps.ADD and all(not (u.op is UOps.ALU and u.arg in UNSAFE_PAD_OPS) for u in r.parents), "cannot pad UNSAFE_PAD_OPS")
padded = False
for i,st in enumerate(self.sts):
if self.sts[i].shape[axis] == 1: continue # reduced