mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add and enable tests for indexing const folding (#4068)
* enable test in test_indexing * added tests * rename stuff * del a test case cuz it's loadops.copy
This commit is contained in:
@@ -1087,9 +1087,8 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
# indexing by a scalar should slice (not copy)
|
||||
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
|
||||
# NOTE: skipped cuz casting in tinygrad makes _to_const_val not work
|
||||
# self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
|
||||
# self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
|
||||
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
|
||||
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
|
||||
|
||||
# scalar indexed with scalar
|
||||
r = Tensor.randn()
|
||||
|
||||
@@ -75,6 +75,22 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
def test_tensor_one_pow(self):
|
||||
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
# folds advance indexing into basic indexing
|
||||
class TestIndexingConstFolding(unittest.TestCase):
|
||||
def test_scalar_index(self):
|
||||
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
||||
_check_ast_count(0, t[:,:,Tensor(1),:])
|
||||
_check_ast_count(0, t[:,:,Tensor(1)+2,:])
|
||||
_check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_const_tensor_index(self):
|
||||
# TODO: implement const tensor folded indexing
|
||||
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
||||
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
|
||||
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
|
||||
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
|
||||
|
||||
class TestMovedConstFolding(unittest.TestCase):
|
||||
def test_add_shrunk_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
|
||||
|
||||
Reference in New Issue
Block a user