mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -2,10 +2,6 @@
|
||||
import unittest, pickle
|
||||
from typing import Tuple
|
||||
|
||||
# TODO: fix all the @unittest.expectedFailure
|
||||
|
||||
# *** fake symobilc uops ***
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
@@ -20,15 +16,9 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
|
||||
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
|
||||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
class Node:
|
||||
@staticmethod
|
||||
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
@staticmethod
|
||||
def ands(ops): return functools.reduce(lambda x,y: x*y, ops)
|
||||
def __floordiv__(a,b,unk): return a//b
|
||||
def SumNode(x): return Node.sum(x)
|
||||
def MulNode(x, y): return x*y
|
||||
def uconst(val): return UOp.const(dtypes.int, val)
|
||||
def usum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
def uand(ops): return functools.reduce(lambda x,y: x*y, ops)
|
||||
|
||||
# *** leave tests the same
|
||||
|
||||
@@ -74,8 +64,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(expr, 0, 1, "(idx<128)")
|
||||
|
||||
def test_lt_divides_and(self):
|
||||
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
expr = uand([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
|
||||
|
||||
def test_lt_factors(self):
|
||||
@@ -113,15 +103,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_add_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)")
|
||||
|
||||
def test_add_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(a+1)")
|
||||
|
||||
def test_sub_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)")
|
||||
|
||||
def test_sub_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(a+-1)")
|
||||
|
||||
def test_add_self(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a, 0, 16, "(a*2)")
|
||||
@@ -165,31 +149,31 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)")
|
||||
|
||||
def test_sum_div_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
def test_sum_div_mod_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
|
||||
|
||||
def test_sum_div_some_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"))
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"))
|
||||
|
||||
def test_sum_div_trim_const(self):
|
||||
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
|
||||
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
|
||||
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
def test_mod_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
|
||||
|
||||
def test_mod_to_sub(self):
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
|
||||
@@ -219,16 +203,16 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)")
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 4, 0, 7, "a")
|
||||
|
||||
def test_sum_div_const_big(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 16, 0, 1, "(a//4)")
|
||||
|
||||
def test_sum_lt_fold(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
|
||||
("(((a*4)+b)<16)", "((b+(a*4))<16)"))
|
||||
self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
|
||||
self.helper_test_variable(usum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
|
||||
|
||||
def test_mul_mod_large(self):
|
||||
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
|
||||
@@ -260,14 +244,14 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)")
|
||||
|
||||
def test_distribute_mul(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)")
|
||||
|
||||
def test_mod_mul_sum(self):
|
||||
self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)"))
|
||||
self.helper_test_variable(usum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)"))
|
||||
|
||||
def test_sum_0(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
|
||||
def test_mod_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
|
||||
@@ -303,26 +287,26 @@ class TestSymbolic(unittest.TestCase):
|
||||
"((((a*3)+(b*2))+(c*4))<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
|
||||
def test_and_remove(self):
|
||||
self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
self.helper_test_variable(uand([uconst(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
|
||||
def test_mod_factor_negative(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
|
||||
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
|
||||
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
|
||||
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
|
||||
|
||||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(a+6)")
|
||||
self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)")
|
||||
|
||||
def test_sum_num_hoisted_and_factors_cancel_out(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
def test_div_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
|
||||
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
|
||||
|
||||
def test_mod_cancel(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
|
||||
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
|
||||
|
||||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
@@ -345,7 +329,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
def test_sum_div_partial_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_div_numerator_negative(self):
|
||||
@@ -541,7 +525,7 @@ class TestSymbolicNumeric(unittest.TestCase):
|
||||
MIN, MAX = 0, 10
|
||||
# one number
|
||||
for i in range(MIN, MAX):
|
||||
v = graph_rewrite(f(NumNode(i)), sym)
|
||||
v = graph_rewrite(f(uconst(i)), sym)
|
||||
self.assertEqual(v.vmin, v.vmax)
|
||||
self.assertEqual(v.vmin, f(i))
|
||||
for kmin in range(MIN, MAX):
|
||||
@@ -565,16 +549,16 @@ class TestSymbolicNumeric(unittest.TestCase):
|
||||
|
||||
class TestSymbolicVars(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
z = NumNode(0)
|
||||
z = uconst(0)
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == set()
|
||||
print(a.vars())
|
||||
assert a.vars() == a.vars() == {a}
|
||||
m = MulNode(a, 3)
|
||||
m = a * 3
|
||||
assert m.vars() == {a}
|
||||
s = SumNode([a, b, c])
|
||||
s = usum([a, b, c])
|
||||
assert s.vars() == {a, b, c}
|
||||
|
||||
def test_compound(self):
|
||||
@@ -608,29 +592,19 @@ class TestSymInfer(unittest.TestCase):
|
||||
assert sym_infer(a*b+c, var_vals) == 10
|
||||
|
||||
"""
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymRender(unittest.TestCase):
|
||||
def test_sym_render(self):
|
||||
a = Variable("a", 1, 8)
|
||||
b = Variable("b", 1, 10)
|
||||
assert sym_render(a) == "a"
|
||||
assert sym_render(1) == "1"
|
||||
assert sym_render(a+1) == "(1+a)"
|
||||
assert sym_render(a*b) == "(a*b)"
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
def test_node_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
|
||||
assert uconst(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert uconst(0) % (Variable("i", 1, 10)*128) == 0
|
||||
assert uconst(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert uconst(127) % (Variable("i", 1, 10)*128) == 127
|
||||
assert 127 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 127 % (Variable("i", 1, 10)*128) == 127
|
||||
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert uconst(128) // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert uconst(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 0 // (Variable("i", 1, 10)*128) == 0
|
||||
@@ -639,10 +613,10 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
assert idx0 % (i*3) == idx0
|
||||
assert i // i == 1
|
||||
assert i % i == 0
|
||||
assert 128 // NumNode(4) == 32
|
||||
assert 128 % NumNode(4) == 0
|
||||
assert NumNode(128) // NumNode(4) == 32
|
||||
assert NumNode(128) % NumNode(4) == 0
|
||||
assert 128 // uconst(4) == 32
|
||||
assert 128 % uconst(4) == 0
|
||||
assert uconst(128) // uconst(4) == 32
|
||||
assert uconst(128) % uconst(4) == 0
|
||||
|
||||
def test_mulnode_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
@@ -667,12 +641,12 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
# assert (i*128+128)*2 // (i*128+128) == 2
|
||||
# assert (i*128+128)*2 % (i*128+128) == 0
|
||||
|
||||
def test_sumnode_div_numnode_no_factoring(self):
|
||||
def test_sumnode_div_uconst_no_factoring(self):
|
||||
gid = Variable("gid", 0, 1023)
|
||||
lid = Variable("lid", 0, 3)
|
||||
expr_before_div = NumNode(-1019)-4*lid-gid
|
||||
unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
|
||||
factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
|
||||
expr_before_div = uconst(-1019)-4*lid-gid
|
||||
unfactored_expr = Node.__floordiv__(expr_before_div, uconst(-16), False)
|
||||
factored_expr = Node.__floordiv__(expr_before_div, uconst(-16), True)
|
||||
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
|
||||
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
|
||||
|
||||
@@ -698,21 +672,21 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = NumNode(2) * a
|
||||
b = uconst(2) * a
|
||||
assert b == a * 2
|
||||
assert isinstance(b, MulNode)
|
||||
b = NumNode(1) * a
|
||||
b = uconst(1) * a
|
||||
assert b == a
|
||||
assert isinstance(b, Variable)
|
||||
b = NumNode(0) * a
|
||||
b = uconst(0) * a
|
||||
assert b == 0
|
||||
assert isinstance(b, NumNode)
|
||||
assert isinstance(b, uconst)
|
||||
|
||||
def test_substitute(self):
|
||||
a = Variable("idx0", 1, 3)
|
||||
b = a + 1
|
||||
c = b.substitute({a: NumNode(1)})
|
||||
assert c == NumNode(2)
|
||||
c = b.substitute({a: uconst(1)})
|
||||
assert c == uconst(2)
|
||||
"""
|
||||
|
||||
class TestSymbolicRealWorld(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user