mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
no bool in range [pr] (#7988)
* no bool in range [pr] * fix llvm * add arg to range spec * fix broken test * forgot this one * hotfix: test_tiny jit is a real test
This commit is contained in:
@@ -100,6 +100,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg for u in mutable_bufs] == [0, 1]
|
||||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_ACC])
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
|
||||
for i,u in enumerate(ranges):
|
||||
if skip and i in skip: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -130,11 +138,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -194,11 +198,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
def test_triple_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -218,11 +218,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output])
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -270,11 +266,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
|
||||
]
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i < 2: continue
|
||||
assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0, 1])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -301,11 +293,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -339,11 +327,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1)
|
||||
lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_multiout_multireduce(self):
|
||||
|
||||
Reference in New Issue
Block a user