fix bmnist torch with RANGEIFY=1 (#12442)

* fix bmnist torch with RANGEIFY=1

* alt

* test and comment

* this was always wrong

* simple failing test for rangeify

* simple upat to match the old behavior
This commit is contained in:
qazal
2025-10-05 12:34:27 +03:00
committed by GitHub
parent b5f31d7505
commit 4b60121498
2 changed files with 7 additions and 2 deletions

View File

@@ -130,6 +130,11 @@ class TestAssign(unittest.TestCase):
@unittest.expectedFailure
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
def test_assign_changes_buffer_alt(self):
a, b = [Tensor(Tensor(0).contiguous().realize().uop.as_buf()) for _ in range(2)]
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
self.assertEqual((a + b).item(), 3)
def test_assign_diamond_cycle(self):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaisesRegex(RuntimeError, "cycle"):

View File

@@ -81,8 +81,8 @@ earliest_rewrites = PatternMatcher([
# copy only to different device
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None),
# contiguous/buffer/copy/assign is already contiguous
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
# contiguous buffer is buffer, this is for *correctness* of assign, not just speed
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.BUFFER),)), lambda root: root.src[0].forced_reshape(root.shape).rtag(root.tag)),
])
# *****************