diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index 5788637776..8b7dd9394b 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -18,14 +18,14 @@ def render(image_shape, valid:UOp, idx:UOp) -> str: # print(fxn) return fxn.split("float4 val0 = ")[1].split(";")[0] -def Variable(expr, nmax): - return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) +def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) +def Variable(expr, nmin, nmax): return UOp(UOps.DEFINE_VAR, dtypes.int, (), (expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax))) class TestValidSimplification(unittest.TestCase): def test_idx_neg_lt_c(self): # (idx1 * (-1) < c) ? (..., idx1-1+c) : 0 can drop the valid - gidx0 = Variable("gidx0", 32) - gidx1 = Variable("gidx1", 32) + gidx0 = Special("gidx0", 32) + gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(0), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))") self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(-1), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-2))), @@ -44,8 +44,8 @@ class TestValidSimplification(unittest.TestCase): def test_idx_lt_bound(self): # (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid - gidx0 = Variable("gidx0", 32) - gidx1 = Variable("gidx1", 32) + gidx0 = Special("gidx0", 32) + gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), "read_imagef(data0, smp, (int2)(gidx0,gidx1))") # 10x20 image, not out of bound @@ -54,16 +54,16 @@ class TestValidSimplification(unittest.TestCase): def test_generic_idx_lt_bound(self): # (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid - gidx0 = Variable("gidx0", 32) - gidx1 = Variable("gidx1", 32) + gidx0 = Special("gidx0", 32) + gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1).lt(8), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+2))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+2)))") self.assertEqual(render((10, 10, 4), (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))") def test_valid_empty_set(self): - gidx0 = Variable("gidx0", 32) - gidx1 = Variable("gidx1", 32) + gidx0 = Special("gidx0", 32) + gidx1 = Special("gidx1", 32) shape = (1, 2, 4) idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0%2, gidx1+2)) # not empty @@ -77,12 +77,11 @@ class TestValidSimplification(unittest.TestCase): def test_openpilot_conv1(self): # first conv in openpilot # kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6 - - idx1 = Variable("idx1", 32) - idx2 = Variable("idx2", 64) - ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 5))) - ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) - ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) + idx1 = Special("idx1", 32) + idx2 = Special("idx2", 64) + ridx0 = Variable("ridx0", 0, 5) + ridx1 = Variable("ridx1", 0, 2) + ridx2 = Variable("ridx2", 0, 2) alu1 = ((idx2*2)+ridx1) alu4 = ((idx1*48)+(ridx2*6)+ridx0) @@ -99,11 +98,11 @@ class TestValidSimplification(unittest.TestCase): @unittest.expectedFailure # TODO: FIXME def test_openpilot_conv2(self): # conv in test/external/external_test_valid_remove.py - idx1 = Variable("idx1", 32) - idx2 = Variable("idx2", 64) - ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) - ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) - ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) + idx1 = Special("idx1", 32) + idx2 = Special("idx2", 64) + ridx0 = Variable("ridx0", 0, 2) + ridx1 = Variable("ridx1", 0, 2) + ridx2 = Variable("ridx2", 0, 2) alu1 = ((idx2*2)+ridx1) alu3 = ((idx1*24)+(ridx2*3)+ridx0) @@ -117,7 +116,7 @@ class TestValidSimplification(unittest.TestCase): def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) - gidx = Variable("gidx", 512) + gidx = Special("gidx", 512) valid = gidx.lt(488) & (-gidx).lt(-479) idx = ((gidx*3+18)%26, (gidx*3+18)//26-56) # alu0 is ((gidx*3)+18) @@ -126,7 +125,7 @@ class TestValidSimplification(unittest.TestCase): def test_simplify2(self): # from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d - lidx = Variable("lidx", 4) + lidx = Special("lidx", 4) valid = lidx.lt(3) & (-lidx).lt(0) idx = ((lidx+1)%2, (lidx+1)//2-1) self.assertEqual(render((1, 2, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), @@ -134,14 +133,14 @@ class TestValidSimplification(unittest.TestCase): def test_simplify3(self): # from openpilot - idx0 = Variable("idx0", 265) + idx0 = Special("idx0", 265) valid = (-idx0).lt(-200) idx = ((idx0+55)%64, (idx0+55)//64-4) self.assertEqual(render((1, 64, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), "read_imagef(data0, smp, (int2)((idx0+(-201)),0))") def test_simplify4(self): - idx0 = Variable("idx0", 512) + idx0 = Special("idx0", 512) data1_shape = (4, 64, 4) alu2 = ((idx0*4+1)%32) alu3 = ((idx0*4+2)%32)