mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
setitem with arange fusion 1 (#5898)
This commit is contained in:
@@ -1385,17 +1385,29 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
|
||||
|
||||
def test_arange_childless(self):
|
||||
def test_arange_childless_base(self):
|
||||
a = Tensor.arange(4)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), np.arange(4))
|
||||
|
||||
def test_arange_group_childless(self):
|
||||
def test_arange_childless_view(self):
|
||||
a = Tensor.arange(4).reshape(2, 2)
|
||||
a[0] = 4
|
||||
np.testing.assert_equal(a.numpy(), [[4, 4], [2, 3]])
|
||||
|
||||
def test_arange_group_childless_base(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4)
|
||||
a = Tensor.arange(4)+x
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy())
|
||||
|
||||
def test_arange_group_childless_view(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.ones(4).contiguous().realize()
|
||||
a = Tensor.arange(4)+x
|
||||
a[0] = 6
|
||||
np.testing.assert_equal(a.numpy(), [6., 2., 3., 4.])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -301,7 +301,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}")
|
||||
if any(tr.forced_realize or tr in outs for tr in group): continue
|
||||
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
|
||||
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
|
||||
for tr in group: del realizes[tr]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user