failed test case for openpilot validhack conv (#6590)

* failed test case for openpilot validhack conv

can save 2ms once this is fixed

* fix order
This commit is contained in:
chenyu
2024-09-18 23:12:30 -04:00
committed by GitHub
parent dfcc9c9aa3
commit 1b6eee02ad

View File

@@ -73,6 +73,29 @@ class TestValidSimplification(unittest.TestCase):
# empty
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (-gidx0).lt(-7), idx))
@unittest.expectedFailure # TODO: FIXME
def test_openpilot_conv(self):
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
idx1 = Variable("idx1", 32)
idx2 = Variable("idx2", 64)
ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 5)))
ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
alu1 = ((idx2*2)+ridx1)
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
shape = (128, 1536, 4)
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
# (((((idx2*(-2))+(ridx1*(-1)))<0)&(((idx1*(-8))+(ridx2*(-1)))<0))?read_imagef(data0, smp,
# (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))
self.assertEqual(render(shape, valid, idx),
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Variable("gidx", 512)