mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix lt_folding VCONST issue [run_process_replay] (#6424)
* le and ge [run_process_replay] * bugfix * fix divides bug * fix lt_folding issue
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, exec_alu
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
|
||||
@@ -111,6 +112,17 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
self.assertEqual(optimized_div_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
|
||||
def test_graph_rewrite_div_folding_bug(self):
|
||||
lhs = UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
|
||||
UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
|
||||
rhs = UOp.const(dtypes.int.vec(4), 2)
|
||||
unopt = lhs.lt(rhs)
|
||||
opt = apply_rewrite(unopt)
|
||||
print(unopt)
|
||||
print(opt)
|
||||
if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src))
|
||||
|
||||
def test_full_graph_rewrite_modulo_large_divisor(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5)
|
||||
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop)
|
||||
|
||||
@@ -61,13 +61,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_cmp_simple(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))", '((a<8)!=1)'})
|
||||
|
||||
def test_ge(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, {"((a*-1)<-7)", "((a*(-1))<(-7))", '((a<8)!=1)'})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))", '((a<4)!=1)'})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
|
||||
|
||||
@@ -255,8 +255,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, {"((a*-1)<0)", "((a*(-1))<0)"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))", '((a<3)!=1)'})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))", '((a<4)!=1)'})
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
|
||||
Reference in New Issue
Block a user