add and enable tests for indexing const folding (#4068)

* enable test in test_indexing

* added tests

* rename stuff

* del a test case cuz it's loadops.copy
This commit is contained in:
geohotstan
2024-04-04 22:46:28 +08:00
committed by GitHub
parent ba118abfec
commit 1a1dd1c1a7
2 changed files with 18 additions and 3 deletions

View File

@@ -1087,9 +1087,8 @@ class TestIndexing(unittest.TestCase):
# indexing by a scalar should slice (not copy)
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
# NOTE: skipped cuz casting in tinygrad makes _to_const_val not work
# self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
# self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
# scalar indexed with scalar
r = Tensor.randn()

View File

@@ -75,6 +75,22 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
def test_tensor_one_pow(self):
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
# folds advance indexing into basic indexing
class TestIndexingConstFolding(unittest.TestCase):
def test_scalar_index(self):
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
_check_ast_count(0, t[:,:,Tensor(1),:])
_check_ast_count(0, t[:,:,Tensor(1)+2,:])
_check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
@unittest.expectedFailure
def test_const_tensor_index(self):
# TODO: implement const tensor folded indexing
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
class TestMovedConstFolding(unittest.TestCase):
def test_add_shrunk_zero(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))