import unittest from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, is_increasing from tinygrad.dtype import dtypes from tinygrad.ops import UOp, UOps, BinaryOps def render(image_shape, valid:UOp, idx:UOp) -> str: uops = linearize_uop(full_graph_rewrite(UOp(UOps.LOAD, dtypes.float.vec(4), ( UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0), idx, UOp(UOps.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4), valid )).sink())) from tinygrad.renderer.cstyle import OpenCLRenderer class TestRenderer(OpenCLRenderer): code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"} fxn = TestRenderer().render("", uops) # print(fxn) return fxn.split("float4 val0 = ")[1].split(";")[0] def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) def Variable(expr, nmin, nmax): return UOp(UOps.DEFINE_VAR, dtypes.int, (), (expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax))) class TestHelpers(unittest.TestCase): def test_is_increasing(self): idx1 = Special("idx1", 32) idx2 = Special("idx2", 64) ridx0 = Variable("ridx0", 0, 5) ridx1 = Variable("ridx1", 0, 2) ridx2 = Variable("ridx2", 0, 2) # (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))) f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768 f1 = ridx0+(idx1*48)+(ridx2*6)+(-6) f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2) f3 = (idx2*2)+ridx1+(-1) self.assertFalse(is_increasing(f0)) self.assertTrue(is_increasing(f1)) self.assertTrue(is_increasing(f2)) self.assertTrue(is_increasing(f3)) class TestValidSimplification(unittest.TestCase): def test_idx_neg_lt_c(self): # (idx1 * (-1) < c) ? (..., idx1-1+c) : 0 can drop the valid gidx0 = Special("gidx0", 32) gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(0), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))") self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(-1), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-2))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-2))))") self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(1), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), "read_imagef(data0, smp, (int2)(gidx0,gidx1))") # should match any one of the AND clause and drop the matched statement from valid valid = (gidx1*(-1)).lt(0) and (gidx0*(-1)).lt(0) self.assertEqual(render((10, 10, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))), "(((gidx0*(-1))<0)?read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1)))):(float4)(0.0f,0.0f,0.0f,0.0f))") valid = (gidx1*(-1)).lt(0) and (gidx1*(-1)).lt(0) self.assertEqual(render((10, 10, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))") def test_idx_lt_bound(self): # (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid gidx0 = Special("gidx0", 32) gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), "read_imagef(data0, smp, (int2)(gidx0,gidx1))") # 10x20 image, not out of bound self.assertEqual(render((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), "((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))") def test_generic_idx_lt_bound(self): # (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid gidx0 = Special("gidx0", 32) gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1).lt(8), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+2))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+2)))") self.assertEqual(render((10, 10, 4), (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5))), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))") def test_valid_empty_set(self): gidx0 = Special("gidx0", 32) gidx1 = Special("gidx1", 32) shape = (1, 2, 4) idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0%2, gidx1+2)) # not empty self.assertEqual(render(shape, (gidx0).lt(8), idx), "((gidx0<8)?read_imagef(data0, smp, (int2)((gidx0%2),(gidx1+2))):(float4)(0.0f,0.0f,0.0f,0.0f))") # empty self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (-gidx0).lt(-7), idx)) @unittest.expectedFailure # TODO: FIXME def test_openpilot_conv1(self): # first conv in openpilot # kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6 idx1 = Special("idx1", 32) idx2 = Special("idx2", 64) ridx0 = Variable("ridx0", 0, 5) ridx1 = Variable("ridx1", 0, 2) ridx2 = Variable("ridx2", 0, 2) alu1 = ((idx2*2)+ridx1) alu4 = ((idx1*48)+(ridx2*6)+ridx0) valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0)) shape = (128, 1536, 4) idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))) # (((((idx2*(-2))+(ridx1*(-1)))<0)&(((idx1*(-8))+(ridx2*(-1)))<0))?read_imagef(data0, smp, # (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f)) self.assertEqual(render(shape, valid, idx), "read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))") @unittest.expectedFailure # TODO: FIXME def test_openpilot_conv2(self): # conv in test/external/external_test_valid_remove.py idx1 = Special("idx1", 32) idx2 = Special("idx2", 64) ridx0 = Variable("ridx0", 0, 2) ridx1 = Variable("ridx1", 0, 2) ridx2 = Variable("ridx2", 0, 2) alu1 = ((idx2*2)+ridx1) alu3 = ((idx1*24)+(ridx2*3)+ridx0) valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0)) shape = (128, 768, 4) idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))) self.assertEqual(render(shape, valid, idx), "read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))") def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) gidx = Special("gidx", 512) valid = gidx.lt(488) & (-gidx).lt(-479) idx = ((gidx*3+18)%26, (gidx*3+18)//26-56) # alu0 is ((gidx*3)+18) self.assertEqual(render((1, 26, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), "read_imagef(data0, smp, (int2)(((gidx*3)+(-1438)),0))") def test_simplify2(self): # from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d lidx = Special("lidx", 4) valid = lidx.lt(3) & (-lidx).lt(0) idx = ((lidx+1)%2, (lidx+1)//2-1) self.assertEqual(render((1, 2, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), "read_imagef(data0, smp, (int2)((lidx+(-1)),0))") def test_simplify3(self): # from openpilot idx0 = Special("idx0", 265) valid = (-idx0).lt(-200) idx = ((idx0+55)%64, (idx0+55)//64-4) self.assertEqual(render((1, 64, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), "read_imagef(data0, smp, (int2)((idx0+(-201)),0))") @unittest.expectedFailure # TODO: not ready yet def test_simplify4(self): idx0 = Special("idx0", 512) data1_shape = (4, 64, 4) alu2 = ((idx0*4+1)%32) alu3 = ((idx0*4+2)%32) alu4 = ((idx0*4+3)%32) alu5 = (idx0*4%32) alu8 = (idx0//8%32//4) alu9 = idx0.lt(256) # TODO: simplify these, manual parsing is not going to work # alu0 = (((idx0*4)+1)%32) self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu2*8))%64),(alu2//8)))), "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") # alu0 = (((idx0*4)+2)%32) self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu3*8))%64),(alu3//8)))), "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") # alu0 = (((idx0*4)+3)%32) self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu4*8))%64),(alu4//8)))), "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") # alu0 = ((idx0*4)%32) self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))), "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") if __name__ == '__main__': unittest.main()