Fix load collapse MAX to ADD (#13406)

* add Ops.ADD to pattern

* add test
This commit is contained in:
Sieds Lykles
2025-11-21 12:26:14 +01:00
committed by GitHub
parent 87c248eafa
commit 114bb94c55
2 changed files with 4 additions and 1 deletions

View File

@@ -2699,6 +2699,9 @@ class TestOps(unittest.TestCase):
a = Tensor(3.14)
np.testing.assert_allclose(Tensor.stack(a, a).numpy(), Tensor([3.14, 3.14]).numpy())
def test_stack_max(self):
helper_test_op(None, lambda x, y: torch.stack((x, y)).max(axis=0)[0], lambda x, y: Tensor.stack(x, y).max(axis=0), vals=[[1.], [2.]])
def test_repeat(self):
x = Tensor.randn(4, 6, 3)
base_repeats = [2, 4, 3]

View File

@@ -142,7 +142,7 @@ pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([
# remove REDUCE on load, comes from indexing a tensor with another tensor
def no_load(u:UOp) -> bool: return not any(x.op is Ops.INDEX for x in u.backward_slice_with_self)
pm_load_collapse = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse),
(UPat(Ops.REDUCE, arg=Ops.ADD, src=(UPat.var("u"), UPat()), name="red"), reduce_load_collapse),
# we want to make sure we dont do math on a loaded index since that can cause overflow, this undoes the rule in pm_reduce_load_collapse
((UPat.var("x", dtypes.index)+UPat.var("y"))<UPat.var("c"), lambda x,y,c: x < c-y if no_load(y) and no_load(c) and not no_load(x) else None),
])