mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
rangeify: fix kernelize (#12357)
This commit is contained in:
@@ -698,7 +698,6 @@ class TestSchedule(unittest.TestCase):
|
||||
c = (a.sum(2).contiguous() + b).contiguous()
|
||||
check_schedule(c, 2)
|
||||
|
||||
@expect_rangeify_fails
|
||||
def test_kernelize(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
@@ -706,20 +705,20 @@ class TestSchedule(unittest.TestCase):
|
||||
d = c+2
|
||||
check_schedule(d, 2)
|
||||
|
||||
@expect_rangeify_fails
|
||||
def test_kernelize_view(self):
|
||||
a = Tensor.empty(4,1)
|
||||
b = a*2
|
||||
c = b.kernelize()+Tensor.empty(4,4)
|
||||
check_schedule(c, 2)
|
||||
|
||||
@expect_rangeify_fails
|
||||
def test_kernelize_diamond(self):
|
||||
a = Tensor([0]).realize()
|
||||
prev_a = (a+1).contiguous()
|
||||
a.assign(Tensor([2]))
|
||||
a.kernelize(prev_a)
|
||||
assert prev_a.uop in a.uop.src, "contiguous usage must run before assign"
|
||||
# RANGEIFY doesn't apply the post diamond graph, it's fine since we can always apply the fixup on each kernelize call
|
||||
if not RANGEIFY:
|
||||
assert prev_a.uop in a.uop.src, "contiguous usage must run before assign"
|
||||
self.assertEqual((prev_a+a*3).item(), 1+2*3)
|
||||
|
||||
@expect_rangeify_fails
|
||||
@@ -734,7 +733,6 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(b.buffer.numpy(), [12])
|
||||
|
||||
# unlike schedule, kernelize can be called multiple times on a Tensor
|
||||
@expect_rangeify_fails
|
||||
def test_double_kerenlize(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
@@ -743,7 +741,6 @@ class TestSchedule(unittest.TestCase):
|
||||
e = c.kernelize()+d.kernelize()
|
||||
check_schedule(e, 3)
|
||||
|
||||
@expect_rangeify_fails
|
||||
def test_kernelize_bw(self):
|
||||
a = Tensor.full((3,), 2.0, requires_grad=True).contiguous()
|
||||
b = Tensor.full((3,), 3.0, requires_grad=True).contiguous()
|
||||
@@ -754,7 +751,6 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(z.item(), 18.0)
|
||||
self.assertEqual(z.grad.item(), 1.0)
|
||||
|
||||
@expect_rangeify_fails
|
||||
def test_kernelize_bw_view(self):
|
||||
a = Tensor.full((3,1), 2.0, requires_grad=True).contiguous()
|
||||
b = Tensor.full((3,1), 3.0, requires_grad=True).contiguous()
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, si
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
|
||||
|
||||
double_reshape = PatternMatcher([
|
||||
# RESHAPE on RESHAPE is the second reshape
|
||||
@@ -343,7 +343,8 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
|
||||
# handle assign
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"),
|
||||
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],))),
|
||||
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],)) \
|
||||
if assign.src[1].op is not Ops.KERNEL else None),
|
||||
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(
|
||||
|
||||
Reference in New Issue
Block a user