mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove CStyleLanguage from test_uop_symbolic (#7142)
This commit is contained in:
@@ -6,11 +6,10 @@ from typing import Tuple
|
||||
|
||||
# *** fake symobilc uops ***
|
||||
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||
from tinygrad import Variable
|
||||
import functools
|
||||
|
||||
@@ -18,13 +17,8 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
|
||||
if DEBUG>=5: print_uops(uops)
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
class TestRenderer(CStyleLanguage):
|
||||
code_for_op = {**CStyleLanguage.code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1]
|
||||
fxn = TestRenderer().render("", uops)
|
||||
return fxn.split("*(data0+0) = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
|
||||
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
|
||||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
class Node:
|
||||
@@ -54,23 +48,23 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_cmp_simple(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=1)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
|
||||
|
||||
def test_ge(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=1)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=1)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "False")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "False")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a<8)!=True)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a<4)!=True)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "True")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "True")
|
||||
|
||||
def test_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "True")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "True")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "False")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "False")
|
||||
|
||||
def test_ge_divides(self):
|
||||
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
|
||||
@@ -181,7 +175,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
|
||||
def test_sum_div_trim_const(self):
|
||||
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((a+b+1)//4)")
|
||||
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
|
||||
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
@@ -229,8 +223,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=1)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=1)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a<3)!=True)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a<4)!=True)")
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
@@ -257,12 +251,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
def test_ge_remove(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "False")
|
||||
|
||||
def test_lt_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "False")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "True")
|
||||
|
||||
def test_lt_sum_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
|
||||
@@ -273,11 +267,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_lt_sum_factor_rhs_partial(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
|
||||
"(((a*3)+(b*2)+(c*4))<2)")
|
||||
"((((a*3)+(b*2))+(c*4))<2)")
|
||||
|
||||
def test_lt_sum_factor_rhs_all(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
|
||||
"(((a*3)+(b*2)+(c*4))<1)")
|
||||
"((((a*3)+(b*2))+(c*4))<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
@@ -351,7 +345,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
lidx1 = Variable("lidx1", 0, 15)
|
||||
lidx2 = Variable("lidx2", 0, 3)
|
||||
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
|
||||
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(gidx2*32)+(gidx1*8)+(lidx0*16))")
|
||||
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, "((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))")
|
||||
|
||||
def test_sum_div_complex2(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
@@ -372,7 +366,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "(((gidx0*4)+(lidx2*4))+(lidx3*4))")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_variable_divmod(self):
|
||||
@@ -387,10 +381,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_neg_all_range(self):
|
||||
gidx = Variable("gidx", 0, 124)
|
||||
lidx = Variable("lidx", 0, 7)
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "((((gidx*-8)+(lidx*-1)+999)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "((((gidx*-8)+(lidx*-1)+1000)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((((gidx*-8)+(lidx*-1)+1001)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((((gidx*-8)+(lidx*-1)+1002)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((gidx*-8)+(lidx*-1))+999)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1000)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1001)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1002)//-4)+250)")
|
||||
|
||||
# NOTE: tests are not correct in symbolic
|
||||
def test_div_neg_then_neg(self):
|
||||
@@ -400,7 +394,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
alu2 = -lidx0-lidx1
|
||||
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
|
||||
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4")
|
||||
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((lidx0*-1)+(lidx1*-1)+134)//-32)+4)")
|
||||
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((((lidx0*-1)+(lidx1*-1))+134)//-32)+4)")
|
||||
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
|
||||
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
|
||||
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
|
||||
@@ -439,7 +433,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
idx = Variable("idx", 0, 24)
|
||||
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
|
||||
# TODO: simplify the true branch
|
||||
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):-1)")
|
||||
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
|
||||
|
||||
def test_idiv_lt(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
@@ -451,13 +445,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
b = Variable("b", 0, 3)
|
||||
c = Variable("c", 0, 3)
|
||||
d = Variable("d", -3, 3)
|
||||
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)")
|
||||
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
|
||||
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
|
||||
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=1)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=True)")
|
||||
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
|
||||
@@ -969,6 +969,8 @@ renderer = PatternMatcher([
|
||||
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
|
||||
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user