another example of openpilot conv with valid (#6595)

This commit is contained in:
chenyu
2024-09-19 01:54:01 -04:00
committed by GitHub
parent 0c9b7c9167
commit 496806ce75

View File

@@ -74,7 +74,7 @@ class TestValidSimplification(unittest.TestCase):
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (-gidx0).lt(-7), idx))
@unittest.expectedFailure # TODO: FIXME
def test_openpilot_conv(self):
def test_openpilot_conv1(self):
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
@@ -96,6 +96,25 @@ class TestValidSimplification(unittest.TestCase):
self.assertEqual(render(shape, valid, idx),
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))")
@unittest.expectedFailure # TODO: FIXME
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
idx1 = Variable("idx1", 32)
idx2 = Variable("idx2", 64)
ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
alu1 = ((idx2*2)+ridx1)
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
shape = (128, 768, 4)
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
self.assertEqual(render(shape, valid, idx),
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Variable("gidx", 512)