From 496806ce75949766ea14e2269df1fc6a2b4ab937 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Sep 2024 01:54:01 -0400 Subject: [PATCH] another example of openpilot conv with valid (#6595) --- test/unit/test_image_valid.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index face5c8a3a..5788637776 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -74,7 +74,7 @@ class TestValidSimplification(unittest.TestCase): self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (-gidx0).lt(-7), idx)) @unittest.expectedFailure # TODO: FIXME - def test_openpilot_conv(self): + def test_openpilot_conv1(self): # first conv in openpilot # kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6 @@ -96,6 +96,25 @@ class TestValidSimplification(unittest.TestCase): self.assertEqual(render(shape, valid, idx), "read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))") + @unittest.expectedFailure # TODO: FIXME + def test_openpilot_conv2(self): + # conv in test/external/external_test_valid_remove.py + idx1 = Variable("idx1", 32) + idx2 = Variable("idx2", 64) + ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) + ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) + ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2))) + + alu1 = ((idx2*2)+ridx1) + alu3 = ((idx1*24)+(ridx2*3)+ridx0) + + valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0)) + shape = (128, 768, 4) + idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))) + + self.assertEqual(render(shape, valid, idx), + "read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))") + def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) gidx = Variable("gidx", 512)