From cba508c8c333af94dc20939d9f8e73126bd388ef Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 9 May 2025 01:55:53 -0400 Subject: [PATCH] update uop symbolic tests (#10228) clean up TODOs and update tests --- test/unit/test_uop_symbolic.py | 29 ++++++++++++----------------- test/unit/test_uop_vmin_vmax.py | 7 +++---- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index a67a1c1d7c..367e18e0bf 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -1,12 +1,11 @@ #!/usr/bin/env python -import unittest, pickle +import unittest, pickle, functools from tinygrad.dtype import dtypes, ConstType from tinygrad.codegen import full_rewrite from tinygrad.codegen.devectorizer import sym from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad import Variable -import functools def render(self) -> tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children @@ -84,9 +83,9 @@ class TestSymbolic(unittest.TestCase): assert idx1*4 is not idx1+4 assert idx1*4 is not idx2*4 assert idx1+idx2 is idx1+idx2 - # assert idx1+idx2 is idx2+idx1 + assert idx1+idx2 is not idx2+idx1 assert idx1+idx2 is not idx2 - # assert idx1*idx2 is idx2*idx1 + assert idx1*idx2 is not idx2*idx1 def test_factorize(self): a = Variable("a", 0, 8) @@ -130,9 +129,9 @@ class TestSymbolic(unittest.TestCase): def test_mul_1(self): self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a") - @unittest.expectedFailure def test_mul_neg_1(self): - self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)") + self.helper_test_variable((Variable("a", 0, 2)*-1)//3, 0, 0, "0") + self.helper_test_variable((Variable("a", 2, 7)*-1)//3, -2, 0, "((a//3)*-1)") def test_mul_2(self): self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)") @@ -256,7 +255,7 @@ class TestSymbolic(unittest.TestCase): a = Variable("a", 0, 124) self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)") self.helper_test_variable(((-a)//2-1)//2, -31, 0, "(((a+2)//4)*-1)") - # self.helper_test_variable(((-a)//2+10)//2, -26, 5, "(((a*-1)+20)//4)") + self.helper_test_variable(((-a)//2+10)//2, -26, 5, "((((a//2)*-1)+10)//2)") def test_div_const_div_wrong_sign(self): a = Variable("a", 0, 124) @@ -276,12 +275,11 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") def test_big_mod(self): - # NOTE: we no longer support negative variables - #self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") - #self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)") - #self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") + self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") + self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(((a*-1)%10)*-1)") + self.helper_test_variable(Variable("a", -20, 1)%10, -9, 9, "(a%10)") # TODO: tighter max self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)") - #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") + self.helper_test_variable(Variable("a", -1, 20)%10, -9, 9, "(a%10)") # TODO: tighter min def test_ge_remove(self): self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False") @@ -353,14 +351,13 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_partial_remove(self): self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") - @unittest.expectedFailure + # TODO: this is wrong def test_div_numerator_negative(self): - self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)") + self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "(idx*-1)") def test_div_into_mod(self): self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)") - # TODO: simplify the expression def test_div_neg_cancel(self): self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((idx//4)+1)") self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx+3)//4)") @@ -427,7 +424,6 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)") self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)") - # TODO: simplify the expression def test_div_neg_all_range(self): gidx = Variable("gidx", 0, 124) lidx = Variable("lidx", 0, 7) @@ -436,7 +432,6 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((gidx*2)+((lidx+2)//4))") self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((gidx*2)+((lidx+1)//4))") - # NOTE: tests are not correct in symbolic def test_div_neg_then_neg(self): # taken from arange opts lidx0 = Variable("lidx0", 0, 7) diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index 96f4821264..c4e71bd06c 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -235,10 +235,9 @@ class TestVminVmaxVConst(unittest.TestCase): def test_vmin_vmax_vconst_with_bools(self): # vmin and vmax for a vector constant of bool values - uop = UOp.const(dtypes.float32.vec(3), (True, False, False)) - # TODO: these return floats, not bool - self.assertEqual(uop.vmin, 0.0) - self.assertEqual(uop.vmax, 1.0) + uop = UOp.const(dtypes.bool.vec(3), (True, False, False)) + self.assertIs(uop.vmin, False) + self.assertIs(uop.vmax, True) class TestConstFactor(unittest.TestCase): def test_const_factor_constant(self):