mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 17:45:38 -05:00
failed test case for openpilot validhack conv (#6590)
* failed test case for openpilot validhack conv can save 2ms once this is fixed * fix order
This commit is contained in:
@@ -73,6 +73,29 @@ class TestValidSimplification(unittest.TestCase):
|
||||
# empty
|
||||
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (-gidx0).lt(-7), idx))
|
||||
|
||||
@unittest.expectedFailure # TODO: FIXME
|
||||
def test_openpilot_conv(self):
|
||||
# first conv in openpilot
|
||||
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
|
||||
|
||||
idx1 = Variable("idx1", 32)
|
||||
idx2 = Variable("idx2", 64)
|
||||
ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 5)))
|
||||
ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
|
||||
ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
|
||||
|
||||
alu1 = ((idx2*2)+ridx1)
|
||||
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
|
||||
|
||||
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
|
||||
shape = (128, 1536, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
|
||||
# (((((idx2*(-2))+(ridx1*(-1)))<0)&(((idx1*(-8))+(ridx2*(-1)))<0))?read_imagef(data0, smp,
|
||||
# (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
gidx = Variable("gidx", 512)
|
||||
|
||||
Reference in New Issue
Block a user