diff --git a/test/test_arange.py b/test/test_arange.py index a9bdc3b89b..ae5b6208a1 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -29,7 +29,7 @@ class TestArange(unittest.TestCase): f1 = self._get_flops(256, opts) + 1 f2 = self._get_flops(2560, opts) + 1 print(f"{f1=}, {f2=}") - assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X" + assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X" if limit is not None and not getenv("PTX"): # PTX counts index ALU in flops assert f1 <= limit, f"{f1=}, {limit=}" diff --git a/test/test_linearizer.py b/test/test_linearizer.py index b0ea04f99c..7a323e174c 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -102,7 +102,7 @@ class TestLinearizer(unittest.TestCase): stores = [u for u in lin.uops if u.op is Ops.STORE] mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL] for u in stores])) assert len(mutable_bufs) == len(stores) == 2 - assert [u.arg for u in mutable_bufs] == [0, 1] + self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1])) def _test_no_nested_ranges(self, lins, skip=None): for l in lins: diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index ebf42c4a1e..2302a4461d 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -260,7 +260,7 @@ class TestImageSimplification(unittest.TestCase): self.check(load, "((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))", "(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)", - "((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)") + "(((((idx1+((ridx1+5)//8))+1)//2)+((idx2*2)+ridx0))+-4)") def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 45377065ea..6733cd0d09 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -197,7 +197,8 @@ class TestSymbolic(unittest.TestCase): def test_mod_congruence_multiple_vars(self): self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)") - self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, "(((z+(x*-1))+(y*-1))+7)") + self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, + ("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)")) self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)") def test_div_congruence(self): @@ -289,11 +290,11 @@ class TestSymbolic(unittest.TestCase): def test_lt_sum_factor_rhs_partial(self): self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1, - "((((a*3)+(b*2))+(c*4))<2)") + ("((((a*3)+(b*2))+(c*4))<2)", "(((b*2)+((a*3)+(c*4)))<2)")) def test_lt_sum_factor_rhs_all(self): self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1, - "((((a*3)+(b*2))+(c*4))<1)") + ("((((a*3)+(b*2))+(c*4))<1)", "(((b*2)+((a*3)+(c*4)))<1)")) def test_and_fold(self): self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0") @@ -369,6 +370,7 @@ class TestSymbolic(unittest.TestCase): alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10 self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, ("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))", + "(((lidx1+((lidx2+(gidx0//5))//5))//2)+((gidx2*32)+((gidx1*8)+(lidx0*16))))", "((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))")) def test_sum_div_complex2(self): @@ -390,7 +392,8 @@ class TestSymbolic(unittest.TestCase): gidx0 = Variable("gidx0", 0, 7) lidx2 = Variable("lidx2", 0, 12) lidx3 = Variable("lidx3", 0, 1) - self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "(((gidx0*4)+(lidx2*4))+(lidx3*4))") + self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, + ("(((gidx0*4)+(lidx2*4))+(lidx3*4))","((lidx3*4)+((gidx0*4)+(lidx2*4)))")) @unittest.expectedFailure def test_variable_divmod(self): @@ -497,8 +500,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)") self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified - self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) - self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) + self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)')) + self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)')) def test_where_removal(self): cond = Variable("a", 0, 3) < 2 @@ -740,6 +743,7 @@ class TestSymbolicRealWorld(unittest.TestCase): self.assertIn(idx.render(), ("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)", '((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)', + '((lidx3+((lidx4*100352)+((gidx2*8)+((gidx1*784)+((gidx0*3211264)+((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49)))))))+2207744)', )) class TestBounds(unittest.TestCase): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 70d6e7185c..78a890da03 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -132,8 +132,8 @@ class Ops(FastEnum): WMMA = auto() # BinaryOps - ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702 - SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702 + MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); ADD = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702 + XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702 # TernaryOps WHERE = auto(); MULACC = auto() # noqa: E702