simpler expand UOps acc [run_process_replay] (#5754)

This commit is contained in:
qazal
2024-07-27 20:20:56 +08:00
committed by GitHub
parent de66d93859
commit e5fb08acbc

View File

@@ -373,7 +373,7 @@ def do_reduce_with_expand(root):
expands = [x for x in root.src[1:] if x.op is UOps.EXPAND]
expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)]
expands_non_reduce = [x for x in expands if x not in expands_reduce]
const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar()))
const = UOp.const(root.dtype.scalar(), 0 if root.arg is ReduceOps.SUM else dtypes.min(root.dtype))
ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
acc_number += 1
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)]