diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 6d6a66739e..2edd096720 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -373,7 +373,7 @@ def do_reduce_with_expand(root): expands = [x for x in root.src[1:] if x.op is UOps.EXPAND] expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)] expands_non_reduce = [x for x in expands if x not in expands_reduce] - const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar())) + const = UOp.const(root.dtype.scalar(), 0 if root.arg is ReduceOps.SUM else dtypes.min(root.dtype)) ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,)) acc_number += 1 alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)]