diff --git a/test/test_ops.py b/test/test_ops.py index 2da830b7be..11653323c1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2699,6 +2699,9 @@ class TestOps(unittest.TestCase): a = Tensor(3.14) np.testing.assert_allclose(Tensor.stack(a, a).numpy(), Tensor([3.14, 3.14]).numpy()) + def test_stack_max(self): + helper_test_op(None, lambda x, y: torch.stack((x, y)).max(axis=0)[0], lambda x, y: Tensor.stack(x, y).max(axis=0), vals=[[1.], [2.]]) + def test_repeat(self): x = Tensor.randn(4, 6, 3) base_repeats = [2, 4, 3] diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 3e3514e7cb..d503bfb68e 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -142,7 +142,7 @@ pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([ # remove REDUCE on load, comes from indexing a tensor with another tensor def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward_slice_with_self) pm_load_collapse = PatternMatcher([ - (UPat(Ops.REDUCE, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse), + (UPat(Ops.REDUCE, arg=Ops.ADD, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse), # we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse ((UPat.var("x", dtypes.index)+UPat.var("y"))