mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix bmnist torch with RANGEIFY=1 (#12442)
* fix bmnist torch with RANGEIFY=1 * alt * test and comment * this was always wrong * simple failing test for rangeify * simple upat to match the old behavior
This commit is contained in:
@@ -130,6 +130,11 @@ class TestAssign(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
|
||||
|
||||
def test_assign_changes_buffer_alt(self):
|
||||
a, b = [Tensor(Tensor(0).contiguous().realize().uop.as_buf()) for _ in range(2)]
|
||||
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
|
||||
self.assertEqual((a + b).item(), 3)
|
||||
|
||||
def test_assign_diamond_cycle(self):
|
||||
# NOTE: should *not* raise AssertionError from numpy
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
|
||||
@@ -81,8 +81,8 @@ earliest_rewrites = PatternMatcher([
|
||||
# copy only to different device
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None),
|
||||
|
||||
# contiguous/buffer/copy/assign is already contiguous
|
||||
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
||||
# contiguous buffer is buffer, this is for *correctness* of assign, not just speed
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.BUFFER),)), lambda root: root.src[0].forced_reshape(root.shape).rtag(root.tag)),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user