fix bug, and add unit test to catch failure

This commit is contained in:
George Hotz
2023-03-11 16:57:25 -08:00
parent 3ec457248c
commit 61071f881a
4 changed files with 10 additions and 48 deletions

View File

@@ -546,6 +546,13 @@ class TestOps(unittest.TestCase):
def test_matvec(self):
helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4)
# this was the failure in llama early realizing freqs_cis
def test_double_slice(self):
helper_test_op([(4,4)], lambda x: x[:, 1:2][1:2])
helper_test_op([(4,4)], lambda x: x[1:3][1:2])
helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1])
helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1])
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)