mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix FUSE pushing through SHRINK (#10271)
This commit is contained in:
@@ -1726,6 +1726,14 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_shrink(self):
|
||||
Tensor.manual_seed(0)
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
x = Tensor.randn(11).realize()
|
||||
a = Tensor.arange(22)
|
||||
out = (x + a[:11]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
|
||||
@unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()")
|
||||
def test_arange_index_contiguous(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
@@ -308,7 +308,7 @@ view_left = merge_views+PatternMatcher([
|
||||
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
||||
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
if (st:=unwrap(view.st)).contiguous: return None
|
||||
if (st:=unwrap(view.st)).contiguous and st.size == r.size: return None
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
|
||||
Reference in New Issue
Block a user