remove CStyleLanguage from test_uop_symbolic (#7142)

This commit is contained in:
chenyu
2024-10-17 19:39:34 -04:00
committed by GitHub
parent 72ed66205d
commit 0cd4b93441
2 changed files with 38 additions and 42 deletions

View File

@@ -6,11 +6,10 @@ from typing import Tuple
# *** fake symobilc uops ***
from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad import Variable
import functools
@@ -18,13 +17,8 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
if DEBUG>=5: print_uops(uops)
from tinygrad.renderer.cstyle import CStyleLanguage
class TestRenderer(CStyleLanguage):
code_for_op = {**CStyleLanguage.code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1]
fxn = TestRenderer().render("", uops)
return fxn.split("*(data0+0) = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
def NumNode(val): return UOp.const(dtypes.int, val)
class Node:
@@ -54,23 +48,23 @@ class TestSymbolic(unittest.TestCase):
def test_cmp_simple(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=1)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
def test_ge(self):
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=1)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=1)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "False")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "False")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=True)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "True")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "True")
def test_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "True")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "True")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "False")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "False")
def test_ge_divides(self):
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
@@ -181,7 +175,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((a+b+1)//4)")
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
@@ -229,8 +223,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=1)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=1)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=True)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=True)")
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
@@ -257,12 +251,12 @@ class TestSymbolic(unittest.TestCase):
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
def test_ge_remove(self):
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "False")
def test_lt_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "False")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "True")
def test_lt_sum_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
@@ -273,11 +267,11 @@ class TestSymbolic(unittest.TestCase):
def test_lt_sum_factor_rhs_partial(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
"(((a*3)+(b*2)+(c*4))<2)")
"((((a*3)+(b*2))+(c*4))<2)")
def test_lt_sum_factor_rhs_all(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
"(((a*3)+(b*2)+(c*4))<1)")
"((((a*3)+(b*2))+(c*4))<1)")
def test_and_fold(self):
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
@@ -351,7 +345,7 @@ class TestSymbolic(unittest.TestCase):
lidx1 = Variable("lidx1", 0, 15)
lidx2 = Variable("lidx2", 0, 3)
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx2*32)+(gidx1*8)+(lidx0*16))")
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))")
def test_sum_div_complex2(self):
gidx0 = Variable("gidx0", 0, 7)
@@ -372,7 +366,7 @@ class TestSymbolic(unittest.TestCase):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "(((gidx0*4)+(lidx2*4))+(lidx3*4))")
@unittest.expectedFailure
def test_variable_divmod(self):
@@ -387,10 +381,10 @@ class TestSymbolic(unittest.TestCase):
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
lidx = Variable("lidx", 0, 7)
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "((((gidx*-8)+(lidx*-1)+999)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "((((gidx*-8)+(lidx*-1)+1000)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((((gidx*-8)+(lidx*-1)+1001)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((((gidx*-8)+(lidx*-1)+1002)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((gidx*-8)+(lidx*-1))+999)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1000)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1001)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1002)//-4)+250)")
# NOTE: tests are not correct in symbolic
def test_div_neg_then_neg(self):
@@ -400,7 +394,7 @@ class TestSymbolic(unittest.TestCase):
alu2 = -lidx0-lidx1
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4")
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((lidx0*-1)+(lidx1*-1)+134)//-32)+4)")
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((((lidx0*-1)+(lidx1*-1))+134)//-32)+4)")
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
@@ -439,7 +433,7 @@ class TestSymbolic(unittest.TestCase):
idx = Variable("idx", 0, 24)
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
# TODO: simplify the true branch
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):-1)")
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
def test_idiv_lt(self):
idx = Variable("idx", 0, 24)
@@ -451,13 +445,13 @@ class TestSymbolic(unittest.TestCase):
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
d = Variable("d", -3, 3)
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)")
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=1)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=True)")
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):

View File

@@ -969,6 +969,8 @@ renderer = PatternMatcher([
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE),
lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
])