fix FUSE pushing through SHRINK (#10271)

This commit is contained in:
qazal
2025-05-13 11:38:53 +03:00
committed by GitHub
parent 1c4ab6b991
commit a2d6b0afe0
2 changed files with 9 additions and 1 deletions

View File

@@ -1726,6 +1726,14 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(out, 1)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_shrink(self):
Tensor.manual_seed(0)
with Context(TRACK_MATCH_STATS=0):
x = Tensor.randn(11).realize()
a = Tensor.arange(22)
out = (x + a[:11]).sum()
self.check_schedule(out, 1)
@unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()")
def test_arange_index_contiguous(self):
Tensor.manual_seed(0)

View File

@@ -308,7 +308,7 @@ view_left = merge_views+PatternMatcher([
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
if (st:=unwrap(view.st)).contiguous: return None
if (st:=unwrap(view.st)).contiguous and st.size == r.size: return None
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])