clean up test_uop_symbolic [pr] (#8165)

removed old `Node` references
This commit is contained in:
chenyu
2024-12-11 14:13:19 -05:00
committed by GitHub
parent 5eadae204b
commit 0e57152dbb

View File

@@ -2,10 +2,6 @@
import unittest, pickle
from typing import Tuple
# TODO: fix all the @unittest.expectedFailure
# *** fake symobilc uops ***
from tinygrad.dtype import dtypes, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
@@ -20,15 +16,9 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
def NumNode(val): return UOp.const(dtypes.int, val)
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
@staticmethod
def ands(ops): return functools.reduce(lambda x,y: x*y, ops)
def __floordiv__(a,b,unk): return a//b
def SumNode(x): return Node.sum(x)
def MulNode(x, y): return x*y
def uconst(val): return UOp.const(dtypes.int, val)
def usum(ops): return functools.reduce(lambda x,y: x+y, ops)
def uand(ops): return functools.reduce(lambda x,y: x*y, ops)
# *** leave tests the same
@@ -74,8 +64,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(expr, 0, 1, "(idx<128)")
def test_lt_divides_and(self):
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
expr = uand([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
def test_lt_factors(self):
@@ -113,15 +103,9 @@ class TestSymbolic(unittest.TestCase):
def test_add_1(self):
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)")
def test_add_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(a+1)")
def test_sub_1(self):
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)")
def test_sub_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(a+-1)")
def test_add_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a+a, 0, 16, "(a*2)")
@@ -165,31 +149,31 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)")
def test_sum_div_remove(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
def test_sum_div_min_max(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
def test_sum_div_mod_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
def test_sum_div_some_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"))
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"))
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
def test_sum_div_no_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
def test_mod_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
def test_mod_to_sub(self):
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
@@ -219,16 +203,16 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)")
def test_sum_div_const(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 4, 0, 7, "a")
def test_sum_div_const_big(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 16, 0, 1, "(a//4)")
def test_sum_lt_fold(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
("(((a*4)+b)<16)", "((b+(a*4))<16)"))
self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
self.helper_test_variable(usum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
def test_mul_mod_large(self):
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
@@ -260,14 +244,14 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)")
def test_distribute_mul(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)")
def test_mod_mul_sum(self):
self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)"))
self.helper_test_variable(usum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)"))
def test_sum_0(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
self.helper_test_variable(usum([Variable("a", 0, 7)]), 0, 7, "a")
def test_mod_remove(self):
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
@@ -303,26 +287,26 @@ class TestSymbolic(unittest.TestCase):
"((((a*3)+(b*2))+(c*4))<1)")
def test_and_fold(self):
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0")
def test_and_remove(self):
self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
self.helper_test_variable(uand([uconst(1), Variable("a", 0, 1)]), 0, 1, "a")
def test_mod_factor_negative(self):
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
def test_sum_combine_num(self):
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(a+6)")
self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)")
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
def test_div_cancel(self):
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
def test_mod_cancel(self):
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
@@ -345,7 +329,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
def test_sum_div_partial_remove(self):
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
@unittest.expectedFailure
def test_div_numerator_negative(self):
@@ -541,7 +525,7 @@ class TestSymbolicNumeric(unittest.TestCase):
MIN, MAX = 0, 10
# one number
for i in range(MIN, MAX):
v = graph_rewrite(f(NumNode(i)), sym)
v = graph_rewrite(f(uconst(i)), sym)
self.assertEqual(v.vmin, v.vmax)
self.assertEqual(v.vmin, f(i))
for kmin in range(MIN, MAX):
@@ -565,16 +549,16 @@ class TestSymbolicNumeric(unittest.TestCase):
class TestSymbolicVars(unittest.TestCase):
def test_simple(self):
z = NumNode(0)
z = uconst(0)
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == set()
print(a.vars())
assert a.vars() == a.vars() == {a}
m = MulNode(a, 3)
m = a * 3
assert m.vars() == {a}
s = SumNode([a, b, c])
s = usum([a, b, c])
assert s.vars() == {a, b, c}
def test_compound(self):
@@ -608,29 +592,19 @@ class TestSymInfer(unittest.TestCase):
assert sym_infer(a*b+c, var_vals) == 10
"""
@unittest.skip("not supported on uops yet")
class TestSymRender(unittest.TestCase):
def test_sym_render(self):
a = Variable("a", 1, 8)
b = Variable("b", 1, 10)
assert sym_render(a) == "a"
assert sym_render(1) == "1"
assert sym_render(a+1) == "(1+a)"
assert sym_render(a*b) == "(a*b)"
@unittest.skip("not supported on uops yet")
class TestSymbolicSymbolicOps(unittest.TestCase):
def test_node_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
assert uconst(0) // (Variable("i", 1, 10)*128) == 0
assert uconst(0) % (Variable("i", 1, 10)*128) == 0
assert uconst(127) // (Variable("i", 1, 10)*128) == 0
assert uconst(127) % (Variable("i", 1, 10)*128) == 127
assert 127 // (Variable("i", 1, 10)*128) == 0
assert 127 % (Variable("i", 1, 10)*128) == 127
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert uconst(128) // (Variable("i", 1, 10)*128 + 128) == 0
assert uconst(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
assert 0 // (Variable("i", 1, 10)*128) == 0
@@ -639,10 +613,10 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
assert idx0 % (i*3) == idx0
assert i // i == 1
assert i % i == 0
assert 128 // NumNode(4) == 32
assert 128 % NumNode(4) == 0
assert NumNode(128) // NumNode(4) == 32
assert NumNode(128) % NumNode(4) == 0
assert 128 // uconst(4) == 32
assert 128 % uconst(4) == 0
assert uconst(128) // uconst(4) == 32
assert uconst(128) % uconst(4) == 0
def test_mulnode_divmod_node(self):
i = Variable("i", 1, 10)
@@ -667,12 +641,12 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
# assert (i*128+128)*2 // (i*128+128) == 2
# assert (i*128+128)*2 % (i*128+128) == 0
def test_sumnode_div_numnode_no_factoring(self):
def test_sumnode_div_uconst_no_factoring(self):
gid = Variable("gid", 0, 1023)
lid = Variable("lid", 0, 3)
expr_before_div = NumNode(-1019)-4*lid-gid
unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
expr_before_div = uconst(-1019)-4*lid-gid
unfactored_expr = Node.__floordiv__(expr_before_div, uconst(-16), False)
factored_expr = Node.__floordiv__(expr_before_div, uconst(-16), True)
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
@@ -698,21 +672,21 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
def test_num_node_mul_node(self):
a = Variable("a", 1, 5)
b = NumNode(2) * a
b = uconst(2) * a
assert b == a * 2
assert isinstance(b, MulNode)
b = NumNode(1) * a
b = uconst(1) * a
assert b == a
assert isinstance(b, Variable)
b = NumNode(0) * a
b = uconst(0) * a
assert b == 0
assert isinstance(b, NumNode)
assert isinstance(b, uconst)
def test_substitute(self):
a = Variable("idx0", 1, 3)
b = a + 1
c = b.substitute({a: NumNode(1)})
assert c == NumNode(2)
c = b.substitute({a: uconst(1)})
assert c == uconst(2)
"""
class TestSymbolicRealWorld(unittest.TestCase):