rangeify: fix kernelize (#12357)

This commit is contained in:
qazal
2025-09-30 10:10:08 +03:00
committed by GitHub
parent 86c5c969ea
commit 4ff7f20b9d
2 changed files with 6 additions and 9 deletions

View File

@@ -698,7 +698,6 @@ class TestSchedule(unittest.TestCase):
c = (a.sum(2).contiguous() + b).contiguous()
check_schedule(c, 2)
@expect_rangeify_fails
def test_kernelize(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@@ -706,20 +705,20 @@ class TestSchedule(unittest.TestCase):
d = c+2
check_schedule(d, 2)
@expect_rangeify_fails
def test_kernelize_view(self):
a = Tensor.empty(4,1)
b = a*2
c = b.kernelize()+Tensor.empty(4,4)
check_schedule(c, 2)
@expect_rangeify_fails
def test_kernelize_diamond(self):
a = Tensor([0]).realize()
prev_a = (a+1).contiguous()
a.assign(Tensor([2]))
a.kernelize(prev_a)
assert prev_a.uop in a.uop.src, "contiguous usage must run before assign"
# RANGEIFY doesn't apply the post diamond graph, it's fine since we can always apply the fixup on each kernelize call
if not RANGEIFY:
assert prev_a.uop in a.uop.src, "contiguous usage must run before assign"
self.assertEqual((prev_a+a*3).item(), 1+2*3)
@expect_rangeify_fails
@@ -734,7 +733,6 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(b.buffer.numpy(), [12])
# unlike schedule, kernelize can be called multiple times on a Tensor
@expect_rangeify_fails
def test_double_kerenlize(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@@ -743,7 +741,6 @@ class TestSchedule(unittest.TestCase):
e = c.kernelize()+d.kernelize()
check_schedule(e, 3)
@expect_rangeify_fails
def test_kernelize_bw(self):
a = Tensor.full((3,), 2.0, requires_grad=True).contiguous()
b = Tensor.full((3,), 3.0, requires_grad=True).contiguous()
@@ -754,7 +751,6 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(z.item(), 18.0)
self.assertEqual(z.grad.item(), 1.0)
@expect_rangeify_fails
def test_kernelize_bw_view(self):
a = Tensor.full((3,1), 2.0, requires_grad=True).contiguous()
b = Tensor.full((3,1), 3.0, requires_grad=True).contiguous()

View File

@@ -13,7 +13,7 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, si
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
double_reshape = PatternMatcher([
# RESHAPE on RESHAPE is the second reshape
@@ -343,7 +343,8 @@ pm_rangeify = pm_mops+PatternMatcher([
# handle assign
(UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"),
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],))),
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],)) \
if assign.src[1].op is not Ops.KERNEL else None),
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(