fix lt_folding VCONST issue [run_process_replay] (#6424)

* le and ge [run_process_replay]

* bugfix

* fix divides bug

* fix lt_folding issue
This commit is contained in:
George Hotz
2024-09-19 14:59:20 +08:00
committed by GitHub
parent 309ea63c03
commit 012a2c449a
4 changed files with 27 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import UOp, UOps, BinaryOps, exec_alu
from tinygrad.codegen.uopgraph import full_graph_rewrite
@@ -111,6 +112,17 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
self.assertEqual(optimized_div_uop.op, UOps.CONST)
self.assertEqual(optimized_div_uop.arg, 1)
def test_graph_rewrite_div_folding_bug(self):
lhs = UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
rhs = UOp.const(dtypes.int.vec(4), 2)
unopt = lhs.lt(rhs)
opt = apply_rewrite(unopt)
print(unopt)
print(opt)
if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src))
def test_full_graph_rewrite_modulo_large_divisor(self):
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5)
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop)

View File

@@ -61,13 +61,13 @@ class TestSymbolic(unittest.TestCase):
def test_cmp_simple(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))"})
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))", '((a<8)!=1)'})
def test_ge(self):
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))"})
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"})
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))", '((a<8)!=1)'})
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))", '((a<4)!=1)'})
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
@@ -255,8 +255,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, {"((a*-1)<0)", "((a*(-1))<0)"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))", '((a<3)!=1)'})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))", '((a<4)!=1)'})
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")