enable test_index and test_advancedindex (#2648)

* enable test_index and test_advancedindex with pretty diff

* removed contig

* created set_ helper function

* comment change

* del empty line

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2023-12-08 08:44:39 +08:00
committed by GitHub
parent 00d9eda961
commit d02ff21f1a

View File

@@ -4,6 +4,9 @@ import math, unittest, random
import numpy as np
from tinygrad.tensor import Tensor, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
random.seed(42)
@@ -15,6 +18,15 @@ def numpy_testing_assert_equal_helper(a, b):
def consec(shape, start=1):
return Tensor(np.arange(math.prod(shape)).reshape(shape)+start)
# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
def set_(reference: Tensor, shape, strides, offset):
if reference.lazydata.base.realized is None: reference.realize()
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
strided = Tensor(LazyBuffer(device=reference.device, st=ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)), optype=None, op=None, dtype=reference.dtype, src=None, base=reference.lazydata.base))
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
assert strided.lazydata in reference.lazydata.base.views, "base.views should contain strided.lazydata"
return strided
def make_tensor(shape, dtype:dtypes, noncontiguous):
r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
values uniformly drawn from ``[low, high)``.
@@ -39,7 +51,7 @@ def make_tensor(shape, dtype:dtypes, noncontiguous):
+---------------------------+------------+----------+
"""
contiguous = not noncontiguous # lol
if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, dtype=dtype, contiguous=contiguous)
if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, dtype=dtypes.bool, contiguous=contiguous)
elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, dtype=dtype, contiguous=contiguous)
elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, dtype=dtype, contiguous=contiguous) # signed int
elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
@@ -118,20 +130,23 @@ class TestIndexing(unittest.TestCase):
tensor_indexed = tensor[idx1]
numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed))
# self.assertRaises(ValueError, lambda: reference[1:9:0])
self.assertRaises(ValueError, lambda: reference[1:9:0])
# NOTE torch doesn't support this but numpy does so we should too. Torch raises ValueError
# see test_slice_negative_strides in test_ops.py
# self.assertRaises(ValueError, lambda: reference[1:9:-1])
# self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
# self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
# self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
# self.assertRaises(IndexError, lambda: reference[0.0])
# self.assertRaises(TypeError, lambda: reference[0.0:2.0])
# self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
# self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
# self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
# self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
self.assertRaises(IndexError, lambda: reference[0.0])
self.assertRaises(TypeError, lambda: reference[0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
# TODO: delitem
# def delitem(): del reference[0]
# self.assertRaises(TypeError, delitem)
@@ -140,8 +155,7 @@ class TestIndexing(unittest.TestCase):
# pick a random valid indexer type
def ri(indices):
choice = random.randint(0, 1)
# TODO: we do not support tuple of list for index now
choice = random.randint(0, 2)
if choice == 0: return Tensor(indices)
if choice == 1: return list(indices)
return tuple(indices)
@@ -156,17 +170,19 @@ class TestIndexing(unittest.TestCase):
def validate_setting(x):
pass
# # TODO: we don't support setitem now
# x[[0]] = -2
# numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
# x[[0]] = -1
# numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
# x[[2, 3, 4]] = 4
# numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
# x[ri([2, 3, 4]), ] = 3
# numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
# x[ri([0, 2, 4]), ] = np.array([5, 4, 3])
# numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
# TODO: setitem
'''
x[[0]] = -2
numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
x[[0]] = -1
numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
x[[2, 3, 4]] = 4
numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
x[ri([2, 3, 4]), ] = 3
numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
x[ri([0, 2, 4]), ] = np.array([5, 4, 3])
numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
'''
# Case 1: Purely Integer Array Indexing
reference = consec((10,))
@@ -175,32 +191,31 @@ class TestIndexing(unittest.TestCase):
# setting values
validate_setting(reference)
# # Tensor with stride != 1
# # strided is [1, 3, 5, 7]
# reference = consec((10,))
# strided = np.array(())
# strided.set_(reference.storage(), storage_offset=0,
# size=torch.Size([4]), stride=[2])
# Tensor with stride != 1
# strided is [1, 3, 5, 7]
# numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
# numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
# numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
# numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
# numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
# np.array([[5, 3], [1, 7]]))
reference = consec((10,))
strided = set_(reference, (4,), (2,), 0)
# # stride is [4, 8]
# strided = np.array(())
# strided.set_(reference.storage(), storage_offset=4,
# size=torch.Size([2]), stride=[4])
# numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
# numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
# numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
# numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
# np.array([[5, 9], [9, 5]]))
numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
np.array([[5, 3], [1, 7]]))
# stride is [4, 8]
strided = set_(reference, (2,), (4,), offset=4)
numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
np.array([[5, 9], [9, 5]]))
# reference is 1 2
# 3 4
@@ -210,16 +225,15 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
# # TODO: we don't support list of Tensors as index
# numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
# numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
# numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
# rows = ri([[0, 0],
# [1, 2]])
# columns = [0],
# numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
# [3, 5]]))
rows = ri([[0, 0],
[1, 2]])
columns = [0],
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
[3, 5]]))
rows = ri([[0, 0],
[1, 2]])
@@ -233,17 +247,20 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
[4, 5]]))
# # setting values
# reference[ri([0]), ri([1])] = -1
# numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
# reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
# numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
# np.array([-1, 2, -4]))
# reference[rows, columns] = np.array([[4, 6], [2, 3]])
# numpy_testing_assert_equal_helper(reference[rows, columns],
# np.array([[4, 6], [2, 3]]))
# TODO: setitem
'''
# setting values
reference[ri([0]), ri([1])] = -1
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
np.array([-1, 2, -4]))
reference[rows, columns] = np.array([[4, 6], [2, 3]])
numpy_testing_assert_equal_helper(reference[rows, columns],
np.array([[4, 6], [2, 3]]))
'''
# Verify still works with Transposed (i.e. non-contiguous) Tensors
# Verify still works with Transposed (i.e. non-contiguous) Tensors
reference = Tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
@@ -258,15 +275,14 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
# # TODO: we don't support list of Tensors as index
# numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
# numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
# numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
# rows = ri([[0, 0],
# [1, 2]])
# columns = [0],
# numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
rows = ri([[0, 0],
[1, 2]])
columns = [0],
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
rows = ri([[0, 0],
[1, 2]])
@@ -278,99 +294,106 @@ class TestIndexing(unittest.TestCase):
[1, 2]])
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
# # setting values
# reference[ri([0]), ri([1])] = -1
# numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
# np.array([-1]))
# reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
# numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
# np.array([-1, 2, -4]))
# reference[rows, columns] = np.array([[4, 6], [2, 3]])
# numpy_testing_assert_equal_helper(reference[rows, columns],
# np.array([[4, 6], [2, 3]]))
# TODO: setitem
'''
# setting values
reference[ri([0]), ri([1])] = -1
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
np.array([-1]))
reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
np.array([-1, 2, -4]))
reference[rows, columns] = np.array([[4, 6], [2, 3]])
numpy_testing_assert_equal_helper(reference[rows, columns],
np.array([[4, 6], [2, 3]]))
'''
# # stride != 1
# stride != 1
# # strided is [[1 3 5 7],
# # [9 11 13 15]]
# strided is [[1 3 5 7],
# [9 11 13 15]]
# reference = torch.arange(0., 24).view(3, 8)
# strided = np.array(())
# strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
# stride=[8, 2])
reference = Tensor.arange(0., 24).reshape(3, 8)
strided = set_(reference, (2,4), (8,2), 1)
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
# np.array([1, 9]))
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
# np.array([3, 11]))
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
# np.array([1]))
# numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
# np.array([15]))
# numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
# np.array([1, 7]))
# numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
# np.array([9, 11, 11, 9, 15]))
# numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
# np.array([1, 3, 9, 9]))
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
np.array([1, 9]))
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
np.array([3, 11]))
numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
np.array([1]))
numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
np.array([15]))
numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
np.array([1, 7]))
numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
np.array([9, 11, 11, 9, 15]))
numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
np.array([1, 3, 9, 9]))
# rows = ri([[0, 0],
# [1, 1]])
# columns = [0],
# numpy_testing_assert_equal_helper(strided[rows, columns],
# np.array([[1, 1], [9, 9]]))
rows = ri([[0, 0],
[1, 1]])
columns = [0],
numpy_testing_assert_equal_helper(strided[rows, columns],
np.array([[1, 1], [9, 9]]))
# rows = ri([[0, 1],
# [1, 0]])
# columns = ri([1, 2])
# numpy_testing_assert_equal_helper(strided[rows, columns],
# np.array([[3, 13], [11, 5]]))
# rows = ri([[0, 0],
# [1, 1]])
# columns = ri([[0, 1],
# [1, 2]])
# numpy_testing_assert_equal_helper(strided[rows, columns],
# np.array([[1, 3], [11, 13]]))
rows = ri([[0, 1],
[1, 0]])
columns = ri([1, 2])
numpy_testing_assert_equal_helper(strided[rows, columns],
np.array([[3, 13], [11, 5]]))
rows = ri([[0, 0],
[1, 1]])
columns = ri([[0, 1],
[1, 2]])
numpy_testing_assert_equal_helper(strided[rows, columns],
np.array([[1, 3], [11, 13]]))
# # setting values
# # strided is [[10, 11],
# # [17, 18]]
# setting values
# reference = torch.arange(0., 24).view(3, 8)
# strided = np.array(())
# strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
# stride=[7, 1])
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
# np.array([11]))
# strided[ri([0]), ri([1])] = -1
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
# np.array([-1]))
# strided is [[10, 11],
# [17, 18]]
# reference = torch.arange(0., 24).view(3, 8)
# strided = np.array(())
# strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
# stride=[7, 1])
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
# np.array([11, 17]))
# strided[ri([0, 1]), ri([1, 0])] = np.array([-1, 2])
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
# np.array([-1, 2]))
reference = Tensor.arange(0., 24).reshape(3, 8)
strided = set_(reference, (2,2), (7,1), 10)
# reference = torch.arange(0., 24).view(3, 8)
# strided = np.array(())
# strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
# stride=[7, 1])
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
np.array([11]))
# TODO setitem
'''
strided[ri([0]), ri([1])] = -1
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
np.array([-1]))
'''
# rows = ri([[0],
# [1]])
# columns = ri([[0, 1],
# [0, 1]])
# numpy_testing_assert_equal_helper(strided[rows, columns],
# np.array([[10, 11], [17, 18]]))
# strided[rows, columns] = np.array([[4, 6], [2, 3]])
# numpy_testing_assert_equal_helper(strided[rows, columns],
# np.array([[4, 6], [2, 3]]))
reference = Tensor.arange(0., 24).reshape(3, 8)
strided = set_(reference, (2,2), (7,1), 10)
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
np.array([11, 17]))
# TODO setitem
'''
strided[ri([0, 1]), ri([1, 0])] = np.array([-1, 2])
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
np.array([-1, 2]))
'''
reference = Tensor.arange(0., 24).realize().reshape(3, 8)
strided = set_(reference, (2,2), (7,1), 10)
rows = ri([[0],
[1]])
columns = ri([[0, 1],
[0, 1]])
numpy_testing_assert_equal_helper(strided[rows, columns],
np.array([[10, 11], [17, 18]]))
# TODO setitem
'''
strided[rows, columns] = np.array([[4, 6], [2, 3]])
numpy_testing_assert_equal_helper(strided[rows, columns],
np.array([[4, 6], [2, 3]]))
'''
# Tests using less than the number of dims, and ellipsis
@@ -385,40 +408,26 @@ class TestIndexing(unittest.TestCase):
# verify too many indices fails
with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]
# # test invalid index fails
# reference = torch.empty(10)
# # can't test cuda because it is a device assert
# if not reference.is_cuda:
# for err_idx in (10, -11):
# with self.assertRaisesRegex(IndexError, r'out of'):
# reference[err_idx]
# with self.assertRaisesRegex(IndexError, r'out of'):
# reference[torch.LongTensor([err_idx]).to(device)]
# with self.assertRaisesRegex(IndexError, r'out of'):
# reference[[err_idx]]
# test invalid index fails
reference = Tensor.empty(10)
for err_idx in (10, -11):
with self.assertRaisesRegex(IndexError, r'out of'):
reference[err_idx]
# TODO: cannot check for out of bounds with Tensor indexing
# see test_ops.py: test_slice_fancy_indexing_errors()
'''
with self.assertRaisesRegex(IndexError, r'out of'):
reference[Tensor([err_idx], dtype=dtypes.int64)]
with self.assertRaisesRegex(IndexError, r'out of'):
reference[[err_idx]]
'''
'''
def tensor_indices_to_np(tensor, indices):
# convert the Torch Tensor to a numpy array
tensor = tensor.to(device='cpu')
npt = tensor.numpy()
# convert indices
idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else
i for i in indices)
return npt, idxs
'''
def tensor_indices_to_np(tensor: Tensor, indices):
npt = tensor.numpy()
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype is dtypes.int64 else
i for i in indices)
return npt, idxs
'''
def get_numpy(tensor, indices):
npt, idxs = tensor_indices_to_np(tensor, indices)
# index and return as a Torch Tensor
return np.array(npt[idxs])
'''
def get_numpy(tensor, indices):
npt, idxs = tensor_indices_to_np(tensor, indices)
return Tensor(npt[idxs])
@@ -440,10 +449,6 @@ class TestIndexing(unittest.TestCase):
npt[idxs] = value
return npt
'''
def assert_get_eq(tensor, indexer):
numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
'''
def assert_get_eq(tensor, indexer):
numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
@@ -459,20 +464,9 @@ class TestIndexing(unittest.TestCase):
pyt = tensor.detach()
numt = tensor.detach()
pyt[indexer] = val
numt = np.array(set_numpy(numt, indexer, val)) #TODO: shouldn't this already be a numpy array? Why wrap numpy array again???
numt = set_numpy(numt, indexer, val)
numpy_testing_assert_equal_helper(pyt, numt)
'''
def assert_backward_eq(tensor, indexer):
cpu = tensor.float().clone().detach().requires_grad_(True)
outcpu = cpu[indexer]
gOcpu = torch.rand_like(outcpu)
outcpu.backward(gOcpu)
dev = cpu.to(device).detach().requires_grad_(True)
outdev = dev[indexer]
outdev.backward(gOcpu.to(device))
numpy_testing_assert_equal_helper(cpu.grad, dev.grad)
'''
# NOTE: torch initiates the gradients using g0cpu (rand as gradients)
def assert_backward_eq(tensor: Tensor, indexer):
cpu = tensor.float().detach()
@@ -498,177 +492,180 @@ class TestIndexing(unittest.TestCase):
set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size).cast(dtypes.float64)
return set_tensor
# # Tensor is 0 1 2 3 4
# # 5 6 7 8 9
# # 10 11 12 13 14
# # 15 16 17 18 19
# reference = torch.arange(0., 20).view(4, 5)
# Tensor is 0 1 2 3 4
# 5 6 7 8 9
# 10 11 12 13 14
# 15 16 17 18 19
reference = Tensor.arange(0., 20).reshape(4, 5)
# indices_to_test = [
# # grab the second, fourth columns
# [slice(None), [1, 3]],
indices_to_test = [
# grab the second, fourth columns
[slice(None), [1, 3]],
# # first, third rows,
# [[0, 2], slice(None)],
# first, third rows,
[[0, 2], slice(None)],
# # weird shape
# [slice(None), [[0, 1],
# [2, 3]]],
# # negatives
# [[-1], [0]],
# [[0, 2], [-1]],
# [slice(None), [-1]],
# ]
# weird shape
[slice(None), [[0, 1],
[2, 3]]],
# negatives
[[-1], [0]],
[[0, 2], [-1]],
[slice(None), [-1]],
]
# # only test dupes on gets
# get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
# only test dupes on gets
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
# for indexer in get_indices_to_test:
# assert_get_eq(reference, indexer)
# if self.device_type != 'cpu':
# assert_backward_eq(reference, indexer)
for indexer in get_indices_to_test:
assert_get_eq(reference, indexer)
assert_backward_eq(reference, indexer)
# for indexer in indices_to_test:
# assert_set_eq(reference, indexer, 44)
# assert_set_eq(reference,
# indexer,
# get_set_tensor(reference, indexer))
# TODO setitem
'''
for indexer in indices_to_test:
assert_set_eq(reference, indexer, 44)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
'''
# reference = torch.arange(0., 160).view(4, 8, 5)
reference = Tensor.arange(0., 160).reshape(4, 8, 5)
# indices_to_test = [
# [slice(None), slice(None), [0, 3, 4]],
# [slice(None), [2, 4, 5, 7], slice(None)],
# [[2, 3], slice(None), slice(None)],
# [slice(None), [0, 2, 3], [1, 3, 4]],
# [slice(None), [0], [1, 2, 4]],
# [slice(None), [0, 1, 3], [4]],
# [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
# [slice(None), [[0, 1], [2, 3]], [[0]]],
# [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
# [[0, 2, 3], [1, 3, 4], slice(None)],
# [[0], [1, 2, 4], slice(None)],
# [[0, 1, 3], [4], slice(None)],
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
# [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
# [[[0, 1], [2, 3]], [[0]], slice(None)],
# [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
# [[[2]], [[0, 3], [4, 1]], slice(None)],
# # non-contiguous indexing subspace
# [[0, 2, 3], slice(None), [1, 3, 4]],
indices_to_test = [
[slice(None), slice(None), [0, 3, 4]],
[slice(None), [2, 4, 5, 7], slice(None)],
[[2, 3], slice(None), slice(None)],
[slice(None), [0, 2, 3], [1, 3, 4]],
[slice(None), [0], [1, 2, 4]],
[slice(None), [0, 1, 3], [4]],
[slice(None), [[0, 1], [1, 0]], [[2, 3]]],
[slice(None), [[0, 1], [2, 3]], [[0]]],
[slice(None), [[5, 6]], [[0, 3], [4, 4]]],
[[0, 2, 3], [1, 3, 4], slice(None)],
[[0], [1, 2, 4], slice(None)],
[[0, 1, 3], [4], slice(None)],
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
[[[0, 1], [1, 0]], [[2, 3]], slice(None)],
[[[0, 1], [2, 3]], [[0]], slice(None)],
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
[[[2]], [[0, 3], [4, 1]], slice(None)],
# non-contiguous indexing subspace
[[0, 2, 3], slice(None), [1, 3, 4]],
# # less dim, ellipsis
# [[0, 2], ],
# [[0, 2], slice(None)],
# [[0, 2], Ellipsis],
# [[0, 2], slice(None), Ellipsis],
# [[0, 2], Ellipsis, slice(None)],
# [[0, 2], [1, 3]],
# [[0, 2], [1, 3], Ellipsis],
# [Ellipsis, [1, 3], [2, 3]],
# [Ellipsis, [2, 3, 4]],
# [Ellipsis, slice(None), [2, 3, 4]],
# [slice(None), Ellipsis, [2, 3, 4]],
# less dim, ellipsis
[[0, 2], ],
[[0, 2], slice(None)],
[[0, 2], Ellipsis],
[[0, 2], slice(None), Ellipsis],
[[0, 2], Ellipsis, slice(None)],
[[0, 2], [1, 3]],
[[0, 2], [1, 3], Ellipsis],
[Ellipsis, [1, 3], [2, 3]],
[Ellipsis, [2, 3, 4]],
[Ellipsis, slice(None), [2, 3, 4]],
[slice(None), Ellipsis, [2, 3, 4]],
# # ellipsis counts for nothing
# [Ellipsis, slice(None), slice(None), [0, 3, 4]],
# [slice(None), Ellipsis, slice(None), [0, 3, 4]],
# [slice(None), slice(None), Ellipsis, [0, 3, 4]],
# [slice(None), slice(None), [0, 3, 4], Ellipsis],
# [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
# ]
# ellipsis counts for nothing
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
[slice(None), Ellipsis, slice(None), [0, 3, 4]],
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
[slice(None), slice(None), [0, 3, 4], Ellipsis],
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
]
# for indexer in indices_to_test:
# assert_get_eq(reference, indexer)
# assert_set_eq(reference, indexer, 212)
# assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
# if torch.cuda.is_available():
# assert_backward_eq(reference, indexer)
for indexer in indices_to_test:
assert_get_eq(reference, indexer)
# TODO setitem
'''
assert_set_eq(reference, indexer, 212)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
'''
assert_backward_eq(reference, indexer)
# reference = torch.arange(0., 1296).view(3, 9, 8, 6)
reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)
# indices_to_test = [
# [slice(None), slice(None), slice(None), [0, 3, 4]],
# [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
# [slice(None), [2, 3], slice(None), slice(None)],
# [[1, 2], slice(None), slice(None), slice(None)],
# [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
# [slice(None), slice(None), [0], [1, 2, 4]],
# [slice(None), slice(None), [0, 1, 3], [4]],
# [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
# [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
# [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
# [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
# [slice(None), [0], [1, 2, 4], slice(None)],
# [slice(None), [0, 1, 3], [4], slice(None)],
# [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
# [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
# [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
# [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
# [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
# [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
# [[0], [1, 2, 4], slice(None), slice(None)],
# [[0, 1, 2], [4], slice(None), slice(None)],
# [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
# [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
# [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
# [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
# [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
# [slice(None), [2, 3, 4], [1, 3, 4], [4]],
# [slice(None), [0, 1, 3], [4], [1, 3, 4]],
# [slice(None), [6], [0, 2, 3], [1, 3, 4]],
# [slice(None), [2, 3, 5], [3], [4]],
# [slice(None), [0], [4], [1, 3, 4]],
# [slice(None), [6], [0, 2, 3], [1]],
# [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
# [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
# [[2, 0, 1], [1, 2, 3], [4], slice(None)],
# [[0, 1, 2], [4], [1, 3, 4], slice(None)],
# [[0], [0, 2, 3], [1, 3, 4], slice(None)],
# [[0, 2, 1], [3], [4], slice(None)],
# [[0], [4], [1, 3, 4], slice(None)],
# [[1], [0, 2, 3], [1], slice(None)],
# [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
indices_to_test = [
[slice(None), slice(None), slice(None), [0, 3, 4]],
[slice(None), slice(None), [2, 4, 5, 7], slice(None)],
[slice(None), [2, 3], slice(None), slice(None)],
[[1, 2], slice(None), slice(None), slice(None)],
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
[slice(None), slice(None), [0], [1, 2, 4]],
[slice(None), slice(None), [0, 1, 3], [4]],
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
[slice(None), [0], [1, 2, 4], slice(None)],
[slice(None), [0, 1, 3], [4], slice(None)],
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
[[0], [1, 2, 4], slice(None), slice(None)],
[[0, 1, 2], [4], slice(None), slice(None)],
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
[slice(None), [2, 3, 4], [1, 3, 4], [4]],
[slice(None), [0, 1, 3], [4], [1, 3, 4]],
[slice(None), [6], [0, 2, 3], [1, 3, 4]],
[slice(None), [2, 3, 5], [3], [4]],
[slice(None), [0], [4], [1, 3, 4]],
[slice(None), [6], [0, 2, 3], [1]],
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
[[2, 0, 1], [1, 2, 3], [4], slice(None)],
[[0, 1, 2], [4], [1, 3, 4], slice(None)],
[[0], [0, 2, 3], [1, 3, 4], slice(None)],
[[0, 2, 1], [3], [4], slice(None)],
[[0], [4], [1, 3, 4], slice(None)],
[[1], [0, 2, 3], [1], slice(None)],
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
# # less dim, ellipsis
# [Ellipsis, [0, 3, 4]],
# [Ellipsis, slice(None), [0, 3, 4]],
# [Ellipsis, slice(None), slice(None), [0, 3, 4]],
# [slice(None), Ellipsis, [0, 3, 4]],
# [slice(None), slice(None), Ellipsis, [0, 3, 4]],
# [slice(None), [0, 2, 3], [1, 3, 4]],
# [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
# [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
# [[0], [1, 2, 4]],
# [[0], [1, 2, 4], slice(None)],
# [[0], [1, 2, 4], Ellipsis],
# [[0], [1, 2, 4], Ellipsis, slice(None)],
# [[1], ],
# [[0, 2, 1], [3], [4]],
# [[0, 2, 1], [3], [4], slice(None)],
# [[0, 2, 1], [3], [4], Ellipsis],
# [Ellipsis, [0, 2, 1], [3], [4]],
# ]
# less dim, ellipsis
[Ellipsis, [0, 3, 4]],
[Ellipsis, slice(None), [0, 3, 4]],
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
[slice(None), Ellipsis, [0, 3, 4]],
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
[slice(None), [0, 2, 3], [1, 3, 4]],
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
[[0], [1, 2, 4]],
[[0], [1, 2, 4], slice(None)],
[[0], [1, 2, 4], Ellipsis],
[[0], [1, 2, 4], Ellipsis, slice(None)],
[[1], ],
[[0, 2, 1], [3], [4]],
[[0, 2, 1], [3], [4], slice(None)],
[[0, 2, 1], [3], [4], Ellipsis],
[Ellipsis, [0, 2, 1], [3], [4]],
]
# for indexer in indices_to_test:
# assert_get_eq(reference, indexer)
# assert_set_eq(reference, indexer, 1333)
# assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
# indices_to_test += [
# [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
# [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
# ]
# for indexer in indices_to_test:
# assert_get_eq(reference, indexer)
# assert_set_eq(reference, indexer, 1333)
# if self.device_type != 'cpu':
# assert_backward_eq(reference, indexer)
# def test_advancedindex_big(self):
# reference = Tensor.arange(123344)
# numpy_testing_assert_equal_helper(reference[[0, 123, 44488, 68807, 123343],], np.array([0, 123, 44488, 68807, 123343]))
for indexer in indices_to_test:
assert_get_eq(reference, indexer)
# TODO setitem
'''
assert_set_eq(reference, indexer, 1333)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
'''
indices_to_test += [
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
]
for indexer in indices_to_test:
assert_get_eq(reference, indexer)
# TODO setitem
'''
assert_set_eq(reference, indexer, 1333)
'''
assert_backward_eq(reference, indexer)
# def test_set_item_to_scalar_tensor(self):
# m = random.randint(1, 10)