mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
another example of openpilot conv with valid (#6595)
This commit is contained in:
@@ -74,7 +74,7 @@ class TestValidSimplification(unittest.TestCase):
|
||||
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (-gidx0).lt(-7), idx))
|
||||
|
||||
@unittest.expectedFailure # TODO: FIXME
|
||||
def test_openpilot_conv(self):
|
||||
def test_openpilot_conv1(self):
|
||||
# first conv in openpilot
|
||||
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
|
||||
|
||||
@@ -96,6 +96,25 @@ class TestValidSimplification(unittest.TestCase):
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
@unittest.expectedFailure # TODO: FIXME
|
||||
def test_openpilot_conv2(self):
|
||||
# conv in test/external/external_test_valid_remove.py
|
||||
idx1 = Variable("idx1", 32)
|
||||
idx2 = Variable("idx2", 64)
|
||||
ridx0 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx0", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
|
||||
ridx1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
|
||||
ridx2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("ridx2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 2)))
|
||||
|
||||
alu1 = ((idx2*2)+ridx1)
|
||||
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
|
||||
|
||||
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
|
||||
shape = (128, 768, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
gidx = Variable("gidx", 512)
|
||||
|
||||
Reference in New Issue
Block a user