diff --git a/test/test_schedule.py b/test/test_schedule.py index 0b8351fc5b..4f8fa2542d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -698,7 +698,6 @@ class TestSchedule(unittest.TestCase): c = (a.sum(2).contiguous() + b).contiguous() check_schedule(c, 2) - @expect_rangeify_fails def test_kernelize(self): a = Tensor.empty(10) b = Tensor.empty(10) @@ -706,20 +705,20 @@ class TestSchedule(unittest.TestCase): d = c+2 check_schedule(d, 2) - @expect_rangeify_fails def test_kernelize_view(self): a = Tensor.empty(4,1) b = a*2 c = b.kernelize()+Tensor.empty(4,4) check_schedule(c, 2) - @expect_rangeify_fails def test_kernelize_diamond(self): a = Tensor([0]).realize() prev_a = (a+1).contiguous() a.assign(Tensor([2])) a.kernelize(prev_a) - assert prev_a.uop in a.uop.src, "contiguous usage must run before assign" + # RANGEIFY doesn't apply the post diamond graph, it's fine since we can always apply the fixup on each kernelize call + if not RANGEIFY: + assert prev_a.uop in a.uop.src, "contiguous usage must run before assign" self.assertEqual((prev_a+a*3).item(), 1+2*3) @expect_rangeify_fails @@ -734,7 +733,6 @@ class TestSchedule(unittest.TestCase): self.assertEqual(b.buffer.numpy(), [12]) # unlike schedule, kernelize can be called multiple times on a Tensor - @expect_rangeify_fails def test_double_kerenlize(self): a = Tensor.empty(10) b = Tensor.empty(10) @@ -743,7 +741,6 @@ class TestSchedule(unittest.TestCase): e = c.kernelize()+d.kernelize() check_schedule(e, 3) - @expect_rangeify_fails def test_kernelize_bw(self): a = Tensor.full((3,), 2.0, requires_grad=True).contiguous() b = Tensor.full((3,), 3.0, requires_grad=True).contiguous() @@ -754,7 +751,6 @@ class TestSchedule(unittest.TestCase): self.assertEqual(z.item(), 18.0) self.assertEqual(z.grad.item(), 1.0) - @expect_rangeify_fails def test_kernelize_bw_view(self): a = Tensor.full((3,1), 2.0, requires_grad=True).contiguous() b = Tensor.full((3,1), 3.0, requires_grad=True).contiguous() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index a6aadb674a..a93cd2a8e2 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -13,7 +13,7 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, si ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, - Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD} + Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL} double_reshape = PatternMatcher([ # RESHAPE on RESHAPE is the second reshape @@ -343,7 +343,8 @@ pm_rangeify = pm_mops+PatternMatcher([ # handle assign (UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"), - lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],))), + lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],)) \ + if assign.src[1].op is not Ops.KERNEL else None), # move MAP through elementwise ALU / reduce. these are the items with cost (UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(