no bool in range [pr] (#7988)

* no bool in range [pr]

* fix llvm

* add arg to range spec

* fix broken test

* forgot this one

* hotfix: test_tiny jit is a real test
This commit is contained in:
George Hotz
2024-12-02 19:05:16 +08:00
committed by GitHub
parent 8909dbd82c
commit 0c7477b108
7 changed files with 38 additions and 47 deletions

View File

@@ -100,6 +100,14 @@ class TestLinearizer(unittest.TestCase):
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg for u in mutable_bufs] == [0, 1]
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_ACC])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@@ -130,11 +138,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -194,11 +198,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
def test_triple_multireduce(self):
Tensor.manual_seed(0)
@@ -218,11 +218,7 @@ class TestLinearizer(unittest.TestCase):
sink = UOp(Ops.SINK, src=(store,))
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output])
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -270,11 +266,7 @@ class TestLinearizer(unittest.TestCase):
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
]
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i < 2: continue
assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0, 1])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -301,11 +293,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -339,11 +327,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1)
lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_multiout_multireduce(self):