mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
failed test case repro for openpilot model (#12475)
* failed test case repro for openpilot model * assertEqual
This commit is contained in:
@@ -2,6 +2,7 @@ import unittest
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp
|
||||
from test.helpers import expect_rangeify_fails
|
||||
|
||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
@@ -300,5 +301,17 @@ class TestOuterworld(unittest.TestCase):
|
||||
o.contiguous(i).realize()
|
||||
self.assertTrue((t==o).all().item())
|
||||
|
||||
class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
@expect_rangeify_fails # TODO: fix
|
||||
def test_matmul_relu_cat(self):
|
||||
a = Tensor.ones(100, 512).contiguous().realize()
|
||||
c = Tensor.ones(1, 512).contiguous().realize()
|
||||
cm = Tensor.ones(512, 512)
|
||||
c = c @ cm
|
||||
c = c.relu()
|
||||
|
||||
res = Tensor.cat(a, c, dim=0)
|
||||
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user