mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update uop and tests to not use lt/gt/le/ge [pr] (#8023)
just use dunder methods, eventually remove those from ops
This commit is contained in:
@@ -355,7 +355,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
|
||||
alu = ld.lt(1).cast(dtypes.bool)
|
||||
alu = (ld<1).cast(dtypes.bool)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
@@ -662,7 +662,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
@@ -681,7 +681,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid&(lidx.ne(2))
|
||||
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
|
||||
@@ -700,7 +700,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_expand_ifs_dumb(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
|
||||
@@ -252,7 +252,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val, gate))
|
||||
uops = to_uops_list([store])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
@@ -271,7 +271,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
@@ -291,7 +291,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)]
|
||||
uops = to_uops_list(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
|
||||
@@ -118,7 +118,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
|
||||
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
|
||||
rhs = UOp.const(dtypes.int.vec(4), 2)
|
||||
unopt = lhs.lt(rhs)
|
||||
unopt = lhs<rhs
|
||||
opt = apply_rewrite(unopt)
|
||||
print(unopt)
|
||||
print(opt)
|
||||
|
||||
@@ -53,7 +53,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
def test_cumsum(self):
|
||||
gidx0 = Special("gidx0", 5)
|
||||
lidx0 = Special("lidx0", 4)
|
||||
gate = (gidx0*4+lidx0).lt(19).ne(True)
|
||||
gate = (gidx0*4+lidx0<19).ne(True)
|
||||
idx = gidx0*4+lidx0-19
|
||||
load = get_gated_load_uop(gate, idx)
|
||||
self.check(load,
|
||||
@@ -65,7 +65,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
ridx1 = Range(1, 4)
|
||||
ridx2 = Range(2, 4)
|
||||
ridx3 = Range(3, 4)
|
||||
valid = (ridx0*3+ridx1).lt(8) & (((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4).lt(2)
|
||||
valid = ((ridx0*3+ridx1)<8) & ((((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4)<2)
|
||||
idx = ridx0+ridx1+ridx2+ridx3
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
@@ -76,13 +76,13 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
gidx0 = Special("gidx0", 56)
|
||||
ridx0 = Range(0, 3)
|
||||
alu0 = gidx0+ridx0
|
||||
valid = alu0.lt(57) & alu0.ge(1)
|
||||
valid = (alu0 < 57) & (alu0 >= 1)
|
||||
self.assertIsNone(simplify_valid(valid))
|
||||
|
||||
def test_valid_order_matters1(self):
|
||||
ridx0 = Range(0, 2)
|
||||
v0 = ridx0.lt(1)
|
||||
v1 = ((ridx0*5+1)%6).lt(5)
|
||||
v0 = ridx0<1
|
||||
v1 = ((ridx0*5+1)%6)<5
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
|
||||
|
||||
@@ -91,10 +91,10 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
gidx1 = Special("gidx1", 13)
|
||||
ridx0 = Range(0, 4)
|
||||
alu0 = (gidx1+(ridx0*13))
|
||||
v0 = ((gidx0+11)%14).lt(11)
|
||||
v1 = ((alu0+((gidx0+39)//42))%14).lt(11)
|
||||
v2 = gidx0.lt(3)
|
||||
v3 = alu0.lt(42)
|
||||
v0 = (gidx0+11)%14<11
|
||||
v1 = (alu0+((gidx0+39)//42))%14<11
|
||||
v2 = gidx0<3
|
||||
v3 = alu0<42
|
||||
|
||||
for v in itertools.permutations([v0,v1,v2,v3]):
|
||||
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
|
||||
@@ -116,17 +116,17 @@ class TestImageSimplification(unittest.TestCase):
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
shape = (10, 10, 4)
|
||||
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), (gidx0, gidx1-1))
|
||||
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-1))
|
||||
self.check(load, None, "gidx0", "(gidx1+-1)")
|
||||
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), (gidx0, gidx1-2))
|
||||
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-2))
|
||||
self.check(load, None, "gidx0", "(gidx1+-2)")
|
||||
|
||||
# should match any one of the AND clause and drop the matched statement from valid
|
||||
valid = (gidx0).lt(1).ne(True) & (gidx1).lt(1).ne(True)
|
||||
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
|
||||
load = get_load_image_uop(shape, valid, (gidx0+1, gidx1-1))
|
||||
self.check(load, "((gidx0<1)!=True)", "(gidx0+1)", "(gidx1+-1)")
|
||||
|
||||
valid = (gidx1).lt(1).ne(True) & (gidx1).lt(1).ne(True)
|
||||
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
|
||||
load = get_load_image_uop(shape, valid, (gidx0, gidx1-1))
|
||||
self.check(load, None, "gidx0", "(gidx1+-1)")
|
||||
|
||||
@@ -134,15 +134,15 @@ class TestImageSimplification(unittest.TestCase):
|
||||
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
load = get_load_image_uop((10, 10, 4), (gidx1).lt(10), (gidx0, gidx1))
|
||||
load = get_load_image_uop((10, 10, 4), gidx1<10, (gidx0, gidx1))
|
||||
self.check(load, None, "gidx0", "gidx1")
|
||||
|
||||
# same thing, valid has a div
|
||||
load = get_load_image_uop((10, 10, 4), (gidx1//2).lt(5), (gidx0, gidx1))
|
||||
load = get_load_image_uop((10, 10, 4), gidx1//2<5, (gidx0, gidx1))
|
||||
self.check(load, None, "gidx0", "gidx1")
|
||||
|
||||
# 10x20 image, not out of bound
|
||||
load = get_load_image_uop((20, 10, 4), (gidx1).lt(10), (gidx0, gidx1))
|
||||
load = get_load_image_uop((20, 10, 4), gidx1<10, (gidx0, gidx1))
|
||||
self.check(load, "(gidx1<10)", "gidx0", "gidx1")
|
||||
|
||||
def test_generic_idx_lt_bound(self):
|
||||
@@ -150,10 +150,10 @@ class TestImageSimplification(unittest.TestCase):
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
shape = (10, 10, 4)
|
||||
load = get_load_image_uop(shape, (gidx1).lt(8), (gidx0, gidx1+2))
|
||||
load = get_load_image_uop(shape, (gidx1<8), (gidx0, gidx1+2))
|
||||
self.check(load, None, "gidx0", "(gidx1+2)")
|
||||
|
||||
load = get_load_image_uop(shape, (gidx1).lt(5), (gidx0, gidx1+5))
|
||||
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
|
||||
self.check(load, None, "gidx0", "(gidx1+5)")
|
||||
|
||||
def test_valid_empty_set(self):
|
||||
@@ -162,11 +162,11 @@ class TestImageSimplification(unittest.TestCase):
|
||||
shape = (32, 32, 4)
|
||||
idx = (gidx0%2, gidx1+2)
|
||||
# not empty
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8), idx)
|
||||
load = get_load_image_uop(shape, gidx0<8, idx)
|
||||
self.check(load, "(gidx0<8)", "(gidx0%2)", "(gidx1+2)")
|
||||
|
||||
# empty -> invalid
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)
|
||||
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
self.assertEqual(load.op, Ops.VECTORIZE)
|
||||
self.assertEqual(load.dtype.count, 4)
|
||||
@@ -186,7 +186,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu1 = ((idx2*2)+ridx1)
|
||||
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
|
||||
|
||||
valid = (((idx2*2)+(ridx1)).lt(1).ne(True))&(((idx1*8)+(ridx2)).lt(1).ne(True))
|
||||
valid = ((((idx2*2)+(ridx1))<1).ne(True))&((((idx1*8)+(ridx2))<1).ne(True))
|
||||
shape = (128, 1536, 4)
|
||||
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
|
||||
@@ -207,7 +207,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu1 = ((idx2*2)+ridx1)
|
||||
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
|
||||
|
||||
valid = (((idx2*2)+ridx1).lt(1).ne(True))&(((idx1*8)+ridx2).lt(1).ne(True))
|
||||
valid = ((((idx2*2)+ridx1)<1).ne(True))&((((idx1*8)+ridx2)<1).ne(True))
|
||||
shape = (128, 768, 4)
|
||||
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
@@ -226,7 +226,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu4 = ((idx1*8)+ridx1)
|
||||
alu6 = ((idx1*512)+(ridx1*64)+idx0)
|
||||
|
||||
valid = alu2.lt(11)&(alu4.lt(3).ne(True))
|
||||
valid = (alu2<11)&(alu4<3).ne(True)
|
||||
shape = (8, 1024, 4)
|
||||
idx = (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4)))
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
gidx = Special("gidx", 512)
|
||||
valid = gidx.lt(488) & (gidx).lt(480).ne(True)
|
||||
valid = (gidx<488) & (gidx<480).ne(True)
|
||||
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
|
||||
load = get_load_image_uop((1, 26, 4), valid, idx)
|
||||
self.check(load, None, "((gidx*3)+-1438)", "0")
|
||||
@@ -248,7 +248,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
def test_simplify2(self):
|
||||
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
|
||||
lidx = Special("lidx", 4)
|
||||
valid = lidx.lt(3) & lidx.lt(1).ne(True)
|
||||
valid = (lidx<3) & (lidx<1).ne(True)
|
||||
idx = ((lidx+1)%2, (lidx+1)//2-1)
|
||||
load = get_load_image_uop((1, 2, 4), valid, idx)
|
||||
self.check(load, None, "(lidx+-1)", "0")
|
||||
@@ -256,7 +256,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
def test_simplify3(self):
|
||||
# from openpilot
|
||||
idx0 = Special("idx0", 265)
|
||||
valid = idx0.lt(201).ne(True)
|
||||
valid = (idx0<201).ne(True)
|
||||
idx = ((idx0+55)%64, (idx0+55)//64-4)
|
||||
load = get_load_image_uop((1, 64, 4), valid, idx)
|
||||
self.check(load, None, "(idx0+-201)", "0")
|
||||
@@ -269,7 +269,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu4 = ((idx0*4+3)%32)
|
||||
alu5 = (idx0*4%32)
|
||||
alu8 = (idx0//8%32//4)
|
||||
alu9 = idx0.lt(256)
|
||||
alu9 = idx0<256
|
||||
|
||||
# TODO: can this be simplified further?
|
||||
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
|
||||
@@ -294,7 +294,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
alu2 = idx1//3
|
||||
alu3 = ((alu1+1)%768)
|
||||
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
|
||||
valid = alu3.lt(640)
|
||||
valid = alu3<640
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.check(load, "(((idx0+(idx1*64))%192)<160)", "((idx0+((idx1//3)*16))+128)", "(((idx0+(idx1*64))%192)//16)")
|
||||
|
||||
@@ -27,8 +27,6 @@ class Node:
|
||||
@staticmethod
|
||||
def ands(ops): return functools.reduce(lambda x,y: x*y, ops)
|
||||
def __floordiv__(a,b,unk): return a//b
|
||||
def create_lt_node(v, n): return v.lt(n)
|
||||
def create_ge_node(v, n): return v.ge(n)
|
||||
def SumNode(x): return Node.sum(x)
|
||||
def MulNode(x, y): return x*y
|
||||
|
||||
@@ -48,36 +46,36 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.assertEqual(nmax, m)
|
||||
|
||||
def test_cmp_simple(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)")
|
||||
|
||||
def test_ge(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "False")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "False")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=True)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "True")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "True")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 77, 0, 0, "False")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 9, 0, 0, "False")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 4, 0, 1, "((a<4)!=True)")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 3, 1, 1, "True")
|
||||
self.helper_test_variable(Variable("a", 3, 8) >= 2, 1, 1, "True")
|
||||
|
||||
def test_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "True")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "True")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "False")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "False")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 77, 1, 1, "True")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 9, 1, 1, "True")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 8, 0, 1, "(a<8)")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 3, 0, 0, "False")
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 2, 0, 0, "False")
|
||||
|
||||
def test_ge_divides(self):
|
||||
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
|
||||
def test_lt_divides(self):
|
||||
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
|
||||
self.helper_test_variable(expr, 0, 1, "(idx<128)")
|
||||
|
||||
def test_ge_divides_and(self):
|
||||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
||||
def test_lt_divides_and(self):
|
||||
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
|
||||
expr = (Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512
|
||||
self.helper_test_variable(expr, 0, 1, ("(((idx1*4)+FLOAT4_INDEX)<512)", "((FLOAT4_INDEX+(idx1*4))<512)"))
|
||||
|
||||
def test_div_reduction(self):
|
||||
@@ -223,10 +221,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
|
||||
|
||||
def test_sum_lt_fold(self):
|
||||
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1,
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
|
||||
("(((a*4)+b)<16)", "((b+(a*4))<16)"))
|
||||
self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)")
|
||||
self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
|
||||
|
||||
def test_mul_mod_large(self):
|
||||
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
|
||||
@@ -244,11 +242,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
|
||||
|
||||
def test_mul_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=True)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=True)")
|
||||
self.helper_test_variable(Variable("a", 0, 5)*4 < 13, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable("a", 0, 5)*4 < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable("a", 0, 5)*(-2) < 0, 0, 1, "((a*-1)<0)")
|
||||
self.helper_test_variable(Variable("a", 0, 5)*4 >= 12, 0, 1, "((a<3)!=True)")
|
||||
self.helper_test_variable(Variable("a", 0, 5)*4 >= 13, 0, 1, "((a<4)!=True)")
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
@@ -275,26 +273,25 @@ class TestSymbolic(unittest.TestCase):
|
||||
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
def test_ge_remove(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "False")
|
||||
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False")
|
||||
|
||||
def test_lt_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "False")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "True")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "False")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True")
|
||||
|
||||
def test_lt_sum_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)")
|
||||
|
||||
def test_lt_simple_factor(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6+Variable("b", 0, 6)*6, 8), 0, 1,
|
||||
"(((a*3)+(b*3))<4)")
|
||||
self.helper_test_variable((Variable("a", 0, 6)*6+Variable("b", 0, 6)*6) < 8, 0, 1, "(((a*3)+(b*3))<4)")
|
||||
|
||||
def test_lt_sum_factor_rhs_partial(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
|
||||
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1,
|
||||
"((((a*3)+(b*2))+(c*4))<2)")
|
||||
|
||||
def test_lt_sum_factor_rhs_all(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
|
||||
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1,
|
||||
"((((a*3)+(b*2))+(c*4))<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
@@ -469,35 +466,35 @@ class TestSymbolic(unittest.TestCase):
|
||||
idx = Variable("idx", 0, 24)
|
||||
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
|
||||
# TODO: simplify the true branch
|
||||
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
|
||||
self.helper_test_variable((idx<4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
|
||||
|
||||
def test_idiv_lt(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
|
||||
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//-4)<-3)")
|
||||
self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)")
|
||||
self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)")
|
||||
|
||||
def test_simplex_lt(self):
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
c = Variable("c", 0, 3)
|
||||
d = Variable("d", -3, 3)
|
||||
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=True)")
|
||||
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a<1).ne(True), 0, 1, "((a<1)!=True)")
|
||||
self.helper_test_variable((a+b<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
|
||||
def test_where_removal(self):
|
||||
cond = Variable("a", 0, 3).lt(2)
|
||||
cond = Variable("a", 0, 3) < 2
|
||||
u1, u0 = cond.ufix(1), cond.ufix(0)
|
||||
self.helper_test_variable(cond, 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
|
||||
|
||||
def test_where_combine(self):
|
||||
cond = Variable("x", 0, 3).lt(2)
|
||||
cond = Variable("x", 0, 3) < 2
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
aa = cond.where(a, a.ufix(0))
|
||||
@@ -674,33 +671,6 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
mod = gidx0 % 2
|
||||
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
|
||||
|
||||
def test_node_lt_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = Variable("b", 6, 9)
|
||||
c = Variable("c", 1, 10)
|
||||
d = Variable("d", 5, 10)
|
||||
# if the comparison output is always the same, it folds to num
|
||||
assert create_lt_node(a, b) == NumNode(1)
|
||||
assert create_lt_node(b, a) == NumNode(0)
|
||||
assert create_lt_node(d, a) == NumNode(0)
|
||||
assert create_lt_node(a, a) == NumNode(0)
|
||||
assert create_lt_node(a, a) == NumNode(0)
|
||||
# if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
|
||||
a_lt_c = create_lt_node(a, c)
|
||||
assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
|
||||
assert a_lt_c
|
||||
# same when comparing with a constant
|
||||
a_lt_3 = create_lt_node(a, 3)
|
||||
assert a_lt_3 and a_lt_3.min == 0 and a_lt_3.max == 1
|
||||
|
||||
def test_sumnode_mulnode_lt(self):
|
||||
a = Variable("a", 1, 2)
|
||||
b = Variable("b", 1, 2)
|
||||
c = Variable("c", 1, 2)
|
||||
x = SumNode([MulNode(a, b), c])
|
||||
with self.assertRaises(AssertionError):
|
||||
create_lt_node(x, 3)
|
||||
|
||||
def test_nested_variable_mod(self):
|
||||
i = Variable("i", 1, 5)
|
||||
idx0 = Variable("idx0", 0, i)
|
||||
|
||||
@@ -66,7 +66,7 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
x = UOp.variable('x', 0, 10)
|
||||
y = UOp.variable('y', 1, 11)
|
||||
z = UOp.variable('z', 2, 12)
|
||||
uop = x.lt(5).where(y, z)
|
||||
uop = (x<5).where(y, z)
|
||||
self.assertEqual(uop.vmin, 1)
|
||||
self.assertEqual(uop.vmax, 12)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user