diff --git a/test/test_ops.py b/test/test_ops.py index 46ad7072ad..d842787ab1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1408,6 +1408,11 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]]) + def test_const_reduce(self): + helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).sum(), lambda x: (x.full_like(2)).sum(), forward_only=True) + helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).prod(), lambda x: (x.full_like(2)).prod(), forward_only=True) + helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).max(), lambda x: (x.full_like(2)).max(), forward_only=True) + @unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)") def test_any(self): helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index bdd56d848d..710bab28b8 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -102,12 +102,14 @@ def reduce_collapse(red:UOp): return sink.substitute({v:k for k,v in replaces.items()}) def reduce_unparented(red:UOp): - if red.arg not in {Ops.ADD, Ops.MAX}: return None + if red.arg not in {Ops.ADD, Ops.MAX, Ops.MUL}: return None reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents) if len(reduce_unparented) == 0: return None ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0] if red.arg is Ops.ADD: for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + if red.arg is Ops.MUL: + for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret pm_reduce_simplify = PatternMatcher([ diff --git a/tinygrad/uop/mathtraits.py b/tinygrad/uop/mathtraits.py index 05bf89d8be..0de976c90b 100644 --- a/tinygrad/uop/mathtraits.py +++ b/tinygrad/uop/mathtraits.py @@ -167,3 +167,4 @@ class MathTrait: def log2(self): return self.alu(Ops.LOG2) def exp2(self): return self.alu(Ops.EXP2) def pow(self, x): return self.alu(Ops.POW, self.ufix(x)) + def __pow__(self, x): return self.pow(x)