From e5fb08acbcc493ce96e73332cce63f843fbf1114 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 27 Jul 2024 20:20:56 +0800 Subject: [PATCH] simpler expand UOps acc [run_process_replay] (#5754) --- tinygrad/codegen/uopgraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 6d6a66739e..2edd096720 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -373,7 +373,7 @@ def do_reduce_with_expand(root): expands = [x for x in root.src[1:] if x.op is UOps.EXPAND] expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)] expands_non_reduce = [x for x in expands if x not in expands_reduce] - const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar())) + const = UOp.const(root.dtype.scalar(), 0 if root.arg is ReduceOps.SUM else dtypes.min(root.dtype)) ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,)) acc_number += 1 alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)]