update uop and tests to not use lt/gt/le/ge [pr] (#8023)

just use dunder methods, eventually remove those from ops
This commit is contained in:
chenyu
2024-12-03 21:02:52 -05:00
committed by GitHub
parent 03bf9c2985
commit 0c060fa040
10 changed files with 100 additions and 130 deletions

View File

@@ -355,7 +355,7 @@ class TestUOpGraph(unittest.TestCase):
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
alu = ld.lt(1).cast(dtypes.bool)
alu = (ld<1).cast(dtypes.bool)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
@@ -662,7 +662,7 @@ class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
idx = UOp.const(dtypes.int, 0)
@@ -681,7 +681,7 @@ class TestIFUOps(unittest.TestCase):
def test_expand_ifs_one_gate(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
@@ -700,7 +700,7 @@ class TestIFUOps(unittest.TestCase):
@unittest.expectedFailure
def test_expand_ifs_dumb(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]

View File

@@ -252,7 +252,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2)))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
gate = gidx0<UOp.const(dtypes.int, 1)
store = UOp(Ops.STORE, dtypes.void, (idx, val, gate))
uops = to_uops_list([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
@@ -271,7 +271,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
gate = gidx0<UOp.const(dtypes.int, 1)
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)]
uops = to_uops_list(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
@@ -291,7 +291,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
gate = gidx0<UOp.const(dtypes.int, 1)
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)]
uops = to_uops_list(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))

View File

@@ -118,7 +118,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
rhs = UOp.const(dtypes.int.vec(4), 2)
unopt = lhs.lt(rhs)
unopt = lhs<rhs
opt = apply_rewrite(unopt)
print(unopt)
print(opt)

View File

@@ -53,7 +53,7 @@ class TestValidIdxSimplification(unittest.TestCase):
def test_cumsum(self):
gidx0 = Special("gidx0", 5)
lidx0 = Special("lidx0", 4)
gate = (gidx0*4+lidx0).lt(19).ne(True)
gate = (gidx0*4+lidx0<19).ne(True)
idx = gidx0*4+lidx0-19
load = get_gated_load_uop(gate, idx)
self.check(load,
@@ -65,7 +65,7 @@ class TestValidIdxSimplification(unittest.TestCase):
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
ridx3 = Range(3, 4)
valid = (ridx0*3+ridx1).lt(8) & (((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4).lt(2)
valid = ((ridx0*3+ridx1)<8) & ((((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4)<2)
idx = ridx0+ridx1+ridx2+ridx3
load = get_gated_load_uop(valid, idx)
self.check(load,
@@ -76,13 +76,13 @@ class TestValidIdxSimplification(unittest.TestCase):
gidx0 = Special("gidx0", 56)
ridx0 = Range(0, 3)
alu0 = gidx0+ridx0
valid = alu0.lt(57) & alu0.ge(1)
valid = (alu0 < 57) & (alu0 >= 1)
self.assertIsNone(simplify_valid(valid))
def test_valid_order_matters1(self):
ridx0 = Range(0, 2)
v0 = ridx0.lt(1)
v1 = ((ridx0*5+1)%6).lt(5)
v0 = ridx0<1
v1 = ((ridx0*5+1)%6)<5
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
@@ -91,10 +91,10 @@ class TestValidIdxSimplification(unittest.TestCase):
gidx1 = Special("gidx1", 13)
ridx0 = Range(0, 4)
alu0 = (gidx1+(ridx0*13))
v0 = ((gidx0+11)%14).lt(11)
v1 = ((alu0+((gidx0+39)//42))%14).lt(11)
v2 = gidx0.lt(3)
v3 = alu0.lt(42)
v0 = (gidx0+11)%14<11
v1 = (alu0+((gidx0+39)//42))%14<11
v2 = gidx0<3
v3 = alu0<42
for v in itertools.permutations([v0,v1,v2,v3]):
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
@@ -116,17 +116,17 @@ class TestImageSimplification(unittest.TestCase):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), (gidx0, gidx1-1))
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-1))
self.check(load, None, "gidx0", "(gidx1+-1)")
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), (gidx0, gidx1-2))
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-2))
self.check(load, None, "gidx0", "(gidx1+-2)")
# should match any one of the AND clause and drop the matched statement from valid
valid = (gidx0).lt(1).ne(True) & (gidx1).lt(1).ne(True)
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0+1, gidx1-1))
self.check(load, "((gidx0<1)!=True)", "(gidx0+1)", "(gidx1+-1)")
valid = (gidx1).lt(1).ne(True) & (gidx1).lt(1).ne(True)
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0, gidx1-1))
self.check(load, None, "gidx0", "(gidx1+-1)")
@@ -134,15 +134,15 @@ class TestImageSimplification(unittest.TestCase):
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
load = get_load_image_uop((10, 10, 4), (gidx1).lt(10), (gidx0, gidx1))
load = get_load_image_uop((10, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# same thing, valid has a div
load = get_load_image_uop((10, 10, 4), (gidx1//2).lt(5), (gidx0, gidx1))
load = get_load_image_uop((10, 10, 4), gidx1//2<5, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# 10x20 image, not out of bound
load = get_load_image_uop((20, 10, 4), (gidx1).lt(10), (gidx0, gidx1))
load = get_load_image_uop((20, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, "(gidx1<10)", "gidx0", "gidx1")
def test_generic_idx_lt_bound(self):
@@ -150,10 +150,10 @@ class TestImageSimplification(unittest.TestCase):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1).lt(8), (gidx0, gidx1+2))
load = get_load_image_uop(shape, (gidx1<8), (gidx0, gidx1+2))
self.check(load, None, "gidx0", "(gidx1+2)")
load = get_load_image_uop(shape, (gidx1).lt(5), (gidx0, gidx1+5))
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
def test_valid_empty_set(self):
@@ -162,11 +162,11 @@ class TestImageSimplification(unittest.TestCase):
shape = (32, 32, 4)
idx = (gidx0%2, gidx1+2)
# not empty
load = get_load_image_uop(shape, (gidx0).lt(8), idx)
load = get_load_image_uop(shape, gidx0<8, idx)
self.check(load, "(gidx0<8)", "(gidx0%2)", "(gidx1+2)")
# empty -> invalid
load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
load = full_graph_rewrite(load.sink()).src[0]
self.assertEqual(load.op, Ops.VECTORIZE)
self.assertEqual(load.dtype.count, 4)
@@ -186,7 +186,7 @@ class TestImageSimplification(unittest.TestCase):
alu1 = ((idx2*2)+ridx1)
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
valid = (((idx2*2)+(ridx1)).lt(1).ne(True))&(((idx1*8)+(ridx2)).lt(1).ne(True))
valid = ((((idx2*2)+(ridx1))<1).ne(True))&((((idx1*8)+(ridx2))<1).ne(True))
shape = (128, 1536, 4)
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
@@ -207,7 +207,7 @@ class TestImageSimplification(unittest.TestCase):
alu1 = ((idx2*2)+ridx1)
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
valid = (((idx2*2)+ridx1).lt(1).ne(True))&(((idx1*8)+ridx2).lt(1).ne(True))
valid = ((((idx2*2)+ridx1)<1).ne(True))&((((idx1*8)+ridx2)<1).ne(True))
shape = (128, 768, 4)
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
@@ -226,7 +226,7 @@ class TestImageSimplification(unittest.TestCase):
alu4 = ((idx1*8)+ridx1)
alu6 = ((idx1*512)+(ridx1*64)+idx0)
valid = alu2.lt(11)&(alu4.lt(3).ne(True))
valid = (alu2<11)&(alu4<3).ne(True)
shape = (8, 1024, 4)
idx = (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4)))
@@ -240,7 +240,7 @@ class TestImageSimplification(unittest.TestCase):
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Special("gidx", 512)
valid = gidx.lt(488) & (gidx).lt(480).ne(True)
valid = (gidx<488) & (gidx<480).ne(True)
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
load = get_load_image_uop((1, 26, 4), valid, idx)
self.check(load, None, "((gidx*3)+-1438)", "0")
@@ -248,7 +248,7 @@ class TestImageSimplification(unittest.TestCase):
def test_simplify2(self):
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
lidx = Special("lidx", 4)
valid = lidx.lt(3) & lidx.lt(1).ne(True)
valid = (lidx<3) & (lidx<1).ne(True)
idx = ((lidx+1)%2, (lidx+1)//2-1)
load = get_load_image_uop((1, 2, 4), valid, idx)
self.check(load, None, "(lidx+-1)", "0")
@@ -256,7 +256,7 @@ class TestImageSimplification(unittest.TestCase):
def test_simplify3(self):
# from openpilot
idx0 = Special("idx0", 265)
valid = idx0.lt(201).ne(True)
valid = (idx0<201).ne(True)
idx = ((idx0+55)%64, (idx0+55)//64-4)
load = get_load_image_uop((1, 64, 4), valid, idx)
self.check(load, None, "(idx0+-201)", "0")
@@ -269,7 +269,7 @@ class TestImageSimplification(unittest.TestCase):
alu4 = ((idx0*4+3)%32)
alu5 = (idx0*4%32)
alu8 = (idx0//8%32//4)
alu9 = idx0.lt(256)
alu9 = idx0<256
# TODO: can this be simplified further?
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
@@ -294,7 +294,7 @@ class TestImageSimplification(unittest.TestCase):
alu2 = idx1//3
alu3 = ((alu1+1)%768)
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
valid = alu3.lt(640)
valid = alu3<640
load = get_load_image_uop(shape, valid, idx)
self.check(load, "(((idx0+(idx1*64))%192)<160)", "((idx0+((idx1//3)*16))+128)", "(((idx0+(idx1*64))%192)//16)")

View File

@@ -27,8 +27,6 @@ class Node:
@staticmethod
def ands(ops): return functools.reduce(lambda x,y: x*y, ops)
def __floordiv__(a,b,unk): return a//b
def create_lt_node(v, n): return v.lt(n)
def create_ge_node(v, n): return v.ge(n)
def SumNode(x): return Node.sum(x)
def MulNode(x, y): return x*y
@@ -48,36 +46,36 @@ class TestSymbolic(unittest.TestCase):
self.assertEqual(nmax, m)
def test_cmp_simple(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)")
def test_ge(self):
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "False")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "False")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=True)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "True")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) >= 77, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) >= 9, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)")
self.helper_test_variable(Variable("a", 3, 8) >= 4, 0, 1, "((a<4)!=True)")
self.helper_test_variable(Variable("a", 3, 8) >= 3, 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) >= 2, 1, 1, "True")
def test_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "True")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "True")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "False")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) < 77, 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) < 9, 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) < 8, 0, 1, "(a<8)")
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8) < 3, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) < 2, 0, 0, "False")
def test_ge_divides(self):
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
def test_lt_divides(self):
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
self.helper_test_variable(expr, 0, 1, "(idx<128)")
def test_ge_divides_and(self):
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
def test_lt_divides_and(self):
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
def test_lt_factors(self):
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
expr = (Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512
self.helper_test_variable(expr, 0, 1, ("(((idx1*4)+FLOAT4_INDEX)<512)", "((FLOAT4_INDEX+(idx1*4))<512)"))
def test_div_reduction(self):
@@ -223,10 +221,10 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
def test_sum_lt_fold(self):
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]), 16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]), 16), 0, 1,
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
("(((a*4)+b)<16)", "((b+(a*4))<16)"))
self.helper_test_variable(create_lt_node(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]), (4 * 67)), 0, 1, "(a<23)")
self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
def test_mul_mod_large(self):
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
@@ -244,11 +242,11 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
def test_mul_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=True)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=True)")
self.helper_test_variable(Variable("a", 0, 5)*4 < 13, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 0, 5)*4 < 16, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 0, 5)*(-2) < 0, 0, 1, "((a*-1)<0)")
self.helper_test_variable(Variable("a", 0, 5)*4 >= 12, 0, 1, "((a<3)!=True)")
self.helper_test_variable(Variable("a", 0, 5)*4 >= 13, 0, 1, "((a<4)!=True)")
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
@@ -275,26 +273,25 @@ class TestSymbolic(unittest.TestCase):
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
def test_ge_remove(self):
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "False")
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False")
def test_lt_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "False")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "True")
self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "False")
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True")
def test_lt_sum_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)")
def test_lt_simple_factor(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6+Variable("b", 0, 6)*6, 8), 0, 1,
"(((a*3)+(b*3))<4)")
self.helper_test_variable((Variable("a", 0, 6)*6+Variable("b", 0, 6)*6) < 8, 0, 1, "(((a*3)+(b*3))<4)")
def test_lt_sum_factor_rhs_partial(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1,
"((((a*3)+(b*2))+(c*4))<2)")
def test_lt_sum_factor_rhs_all(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1,
"((((a*3)+(b*2))+(c*4))<1)")
def test_and_fold(self):
@@ -469,35 +466,35 @@ class TestSymbolic(unittest.TestCase):
idx = Variable("idx", 0, 24)
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
# TODO: simplify the true branch
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
self.helper_test_variable((idx<4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
def test_idiv_lt(self):
idx = Variable("idx", 0, 24)
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//-4)<-3)")
self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)")
self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)")
def test_simplex_lt(self):
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
d = Variable("d", -3, 3)
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=True)")
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a<1).ne(True), 0, 1, "((a<1)!=True)")
self.helper_test_variable((a+b<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
def test_where_removal(self):
cond = Variable("a", 0, 3).lt(2)
cond = Variable("a", 0, 3) < 2
u1, u0 = cond.ufix(1), cond.ufix(0)
self.helper_test_variable(cond, 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
def test_where_combine(self):
cond = Variable("x", 0, 3).lt(2)
cond = Variable("x", 0, 3) < 2
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
aa = cond.where(a, a.ufix(0))
@@ -674,33 +671,6 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
mod = gidx0 % 2
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
def test_node_lt_node(self):
a = Variable("a", 1, 5)
b = Variable("b", 6, 9)
c = Variable("c", 1, 10)
d = Variable("d", 5, 10)
# if the comparison output is always the same, it folds to num
assert create_lt_node(a, b) == NumNode(1)
assert create_lt_node(b, a) == NumNode(0)
assert create_lt_node(d, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
# if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
a_lt_c = create_lt_node(a, c)
assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
assert a_lt_c
# same when comparing with a constant
a_lt_3 = create_lt_node(a, 3)
assert a_lt_3 and a_lt_3.min == 0 and a_lt_3.max == 1
def test_sumnode_mulnode_lt(self):
a = Variable("a", 1, 2)
b = Variable("b", 1, 2)
c = Variable("c", 1, 2)
x = SumNode([MulNode(a, b), c])
with self.assertRaises(AssertionError):
create_lt_node(x, 3)
def test_nested_variable_mod(self):
i = Variable("i", 1, 5)
idx0 = Variable("idx0", 0, i)

View File

@@ -66,7 +66,7 @@ class TestVminVmaxProperties(unittest.TestCase):
x = UOp.variable('x', 0, 10)
y = UOp.variable('y', 1, 11)
z = UOp.variable('z', 2, 12)
uop = x.lt(5).where(y, z)
uop = (x<5).where(y, z)
self.assertEqual(uop.vmin, 1)
self.assertEqual(uop.vmax, 12)