mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
split test_advancedindex
This commit is contained in:
@@ -180,474 +180,6 @@ class TestIndexing(unittest.TestCase):
|
||||
# def delitem(): del reference[0]
|
||||
# self.assertRaises(TypeError, delitem)
|
||||
|
||||
# TODO: LLVM is quite fast, why are other compiled backends slow?
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CPU", "CL", "METAL", "NV", "AMD"], "slow")
|
||||
def test_advancedindex(self):
|
||||
# integer array indexing
|
||||
|
||||
# pick a random valid indexer type
|
||||
def ri(indices):
|
||||
choice = random.randint(0, 2)
|
||||
if choice == 0: return Tensor(indices)
|
||||
if choice == 1: return list(indices)
|
||||
return tuple(indices)
|
||||
|
||||
def validate_indexing(x):
|
||||
numpy_testing_assert_equal_helper(x[[0]], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4))
|
||||
numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3))
|
||||
numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3))
|
||||
numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5]))
|
||||
|
||||
def validate_setting(x):
|
||||
x[[0]] = -2
|
||||
numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
|
||||
x[[0]] = -1
|
||||
numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
|
||||
x[[2, 3, 4]] = 4
|
||||
numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
|
||||
x[ri([2, 3, 4]), ] = 3
|
||||
numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
|
||||
x[ri([0, 2, 4]), ] = Tensor([5, 4, 3])
|
||||
numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
|
||||
|
||||
# Case 1: Purely Integer Array Indexing
|
||||
reference = consec((10,))
|
||||
validate_indexing(reference)
|
||||
# setting values
|
||||
validate_setting(reference)
|
||||
|
||||
# Tensor with stride != 1
|
||||
# strided is [1, 3, 5, 7]
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = consec((10,))
|
||||
# strided = set_(reference, (4,), (2,), 0)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
|
||||
# numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
|
||||
# np.array([[5, 3], [1, 7]]))
|
||||
|
||||
# stride is [4, 8]
|
||||
|
||||
# strided = set_(reference, (2,), (4,), offset=4)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
|
||||
# numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
|
||||
# np.array([[5, 9], [9, 5]]))
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
# 5 6
|
||||
reference = consec((3, 2))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
|
||||
[3, 5]]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([1, 0])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[2, 1],
|
||||
[4, 5]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 0]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
|
||||
[4, 5]]))
|
||||
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
|
||||
reference[ri([0, 1, 2]), ri([0])] = Tensor([-1, 2, -4])
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
np.array([-1, 2, -4]))
|
||||
reference[rows, columns] = Tensor([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
|
||||
# Verify still works with Transposed (i.e. non-contiguous) Tensors
|
||||
reference = Tensor([[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]]).T
|
||||
|
||||
# Transposed: [[0, 4, 8],
|
||||
# [1, 5, 9],
|
||||
# [2, 6, 10],
|
||||
# [3, 7, 11]]
|
||||
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([1, 0])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[4, 0], [5, 2]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 3]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 2]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
|
||||
|
||||
# TODO: non contiguous setitem
|
||||
'''
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
|
||||
np.array([-1]))
|
||||
reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
np.array([-1, 2, -4]))
|
||||
reference[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# stride != 1
|
||||
|
||||
# strided is [[1 3 5 7],
|
||||
# [9 11 13 15]]
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
# strided = set_(reference, (2,4), (8,2), 1)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])], np.array([1, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])], np.array([3, 11]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])], np.array([15]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]], np.array([1, 7]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], np.array([9, 11, 11, 9, 15]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 3, 9, 9]))
|
||||
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 1]])
|
||||
# columns = [0],
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[1, 1], [9, 9]]))
|
||||
|
||||
# rows = ri([[0, 1],
|
||||
# [1, 0]])
|
||||
# columns = ri([1, 2])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[3, 13], [11, 5]]))
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 1]])
|
||||
# columns = ri([[0, 1],
|
||||
# [1, 2]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[1, 3], [11, 13]]))
|
||||
|
||||
# setting values
|
||||
|
||||
# strided is [[10, 11],
|
||||
# [17, 18]]
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
# strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])], np.array([11]))
|
||||
|
||||
# TODO non contiguous setitem
|
||||
'''
|
||||
strided[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
Tensor([-1]))
|
||||
'''
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
# strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])], np.array([11, 17]))
|
||||
|
||||
# TODO non contiguous setitem
|
||||
'''
|
||||
strided[ri([0, 1]), ri([1, 0])] = Tensor([-1, 2])
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
Tensor([-1, 2]))
|
||||
'''
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).realize().reshape(3, 8)
|
||||
# strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# rows = ri([[0],
|
||||
# [1]])
|
||||
# columns = ri([[0, 1],
|
||||
# [0, 1]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[10, 11], [17, 18]]))
|
||||
|
||||
# TODO non contiguous setitem
|
||||
'''
|
||||
strided[rows, columns] = Tensor([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
Tensor([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# Tests using less than the number of dims, and ellipsis
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
# 5 6
|
||||
reference = consec((3, 2))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 2]),], np.array([[1, 2], [5, 6]]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]]))
|
||||
numpy_testing_assert_equal_helper(reference[..., ri([1])], np.array([[2], [4], [6]]))
|
||||
|
||||
# verify too many indices fails
|
||||
with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]
|
||||
|
||||
# test invalid index fails
|
||||
reference = Tensor.empty(10)
|
||||
for err_idx in (10, -11):
|
||||
with self.assertRaises(IndexError):
|
||||
reference[err_idx]
|
||||
# NOTE cannot check for out of bounds with Tensor indexing
|
||||
# see tensor.py: __getitem__ (Tiny Things)
|
||||
'''
|
||||
with self.assertRaises(IndexError):
|
||||
reference[Tensor([err_idx], dtype=dtypes.int64)]
|
||||
with self.assertRaises(IndexError):
|
||||
reference[[err_idx]]
|
||||
'''
|
||||
|
||||
def tensor_indices_to_np(tensor: Tensor, indices):
|
||||
npt = tensor.numpy()
|
||||
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
|
||||
i for i in indices)
|
||||
return npt, idxs
|
||||
|
||||
def get_numpy(tensor, indices):
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
return Tensor(npt[idxs])
|
||||
|
||||
def set_numpy(tensor:Tensor, indices, value):
|
||||
if not isinstance(value, int):
|
||||
value = value.numpy()
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
npt[idxs] = value
|
||||
return npt
|
||||
|
||||
def assert_get_eq(tensor, indexer):
|
||||
numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
|
||||
|
||||
def assert_set_eq(tensor: Tensor, indexer, val):
|
||||
pyt = clone(tensor)
|
||||
numt = clone(tensor)
|
||||
pyt[indexer] = val
|
||||
numt = set_numpy(numt, indexer, val)
|
||||
numpy_testing_assert_equal_helper(pyt, numt)
|
||||
|
||||
# NOTE: torch initiates the gradients using g0cpu (rand as gradients)
|
||||
def assert_backward_eq(tensor: Tensor, indexer):
|
||||
cpu = clone(tensor.float())
|
||||
cpu.requires_grad = True
|
||||
outcpu = cpu[indexer].sum()
|
||||
outcpu.backward()
|
||||
dev = cpu.detach()
|
||||
dev.requires_grad = True
|
||||
outdev = dev[indexer].sum()
|
||||
outdev.backward()
|
||||
numpy_testing_assert_equal_helper(cpu.grad, dev.grad)
|
||||
|
||||
def get_set_tensor(indexed: Tensor, indexer):
|
||||
set_size = indexed[indexer].shape
|
||||
set_count = indexed[indexer].numel()
|
||||
set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size) #.cast(dtypes.float64)
|
||||
return set_tensor
|
||||
|
||||
# Tensor is 0 1 2 3 4
|
||||
# 5 6 7 8 9
|
||||
# 10 11 12 13 14
|
||||
# 15 16 17 18 19
|
||||
reference = Tensor.arange(0., 20).reshape(4, 5)
|
||||
|
||||
indices_to_test = [
|
||||
# grab the second, fourth columns
|
||||
[slice(None), [1, 3]],
|
||||
|
||||
# first, third rows,
|
||||
[[0, 2], slice(None)],
|
||||
|
||||
# weird shape
|
||||
[slice(None), [[0, 1],
|
||||
[2, 3]]],
|
||||
# negatives
|
||||
[[-1], [0]],
|
||||
[[0, 2], [-1]],
|
||||
[slice(None), [-1]],
|
||||
]
|
||||
|
||||
# only test dupes on gets
|
||||
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
|
||||
|
||||
for indexer in get_indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_set_eq(reference, indexer, 44)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
|
||||
reference = Tensor.arange(0., 160).reshape(4, 8, 5)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[[2, 3], slice(None), slice(None)],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[[0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0, 1, 3], [4], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 3]], slice(None)],
|
||||
[[[0, 1], [2, 3]], [[0]], slice(None)],
|
||||
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
|
||||
[[[2]], [[0, 3], [4, 1]], slice(None)],
|
||||
# non-contiguous indexing subspace
|
||||
[[0, 2, 3], slice(None), [1, 3, 4]],
|
||||
|
||||
# less dim, ellipsis
|
||||
[[0, 2], ],
|
||||
[[0, 2], slice(None)],
|
||||
[[0, 2], Ellipsis],
|
||||
[[0, 2], slice(None), Ellipsis],
|
||||
[[0, 2], Ellipsis, slice(None)],
|
||||
[[0, 2], [1, 3]],
|
||||
[[0, 2], [1, 3], Ellipsis],
|
||||
[Ellipsis, [1, 3], [2, 3]],
|
||||
[Ellipsis, [2, 3, 4]],
|
||||
[Ellipsis, slice(None), [2, 3, 4]],
|
||||
[slice(None), Ellipsis, [2, 3, 4]],
|
||||
|
||||
# ellipsis counts for nothing
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), [0, 3, 4], Ellipsis],
|
||||
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
|
||||
assert_set_eq(reference, indexer, 212)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[slice(None), [2, 3], slice(None), slice(None)],
|
||||
[[1, 2], slice(None), slice(None), slice(None)],
|
||||
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[slice(None), [0], [1, 2, 4], slice(None)],
|
||||
[slice(None), [0, 1, 3], [4], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
|
||||
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
|
||||
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
|
||||
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
|
||||
[[0], [1, 2, 4], slice(None), slice(None)],
|
||||
[[0, 1, 2], [4], slice(None), slice(None)],
|
||||
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
|
||||
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
|
||||
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 4], [1, 3, 4], [4]],
|
||||
[slice(None), [0, 1, 3], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 5], [3], [4]],
|
||||
[slice(None), [0], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1]],
|
||||
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
|
||||
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[2, 0, 1], [1, 2, 3], [4], slice(None)],
|
||||
[[0, 1, 2], [4], [1, 3, 4], slice(None)],
|
||||
[[0], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0], [4], [1, 3, 4], slice(None)],
|
||||
[[1], [0, 2, 3], [1], slice(None)],
|
||||
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
|
||||
|
||||
# less dim, ellipsis
|
||||
[Ellipsis, [0, 3, 4]],
|
||||
[Ellipsis, slice(None), [0, 3, 4]],
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4]],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0], [1, 2, 4], Ellipsis],
|
||||
[[0], [1, 2, 4], Ellipsis, slice(None)],
|
||||
[[1], ],
|
||||
[[0, 2, 1], [3], [4]],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 1], [3], [4]],
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
|
||||
indices_to_test += [
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
|
||||
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
|
||||
]
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
# TODO setitem backward
|
||||
'''
|
||||
def test_set_item_to_scalar_tensor(self):
|
||||
@@ -1568,5 +1100,474 @@ class TestNumpy(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(kernel, kernel2)
|
||||
'''
|
||||
|
||||
def tensor_indices_to_np(tensor: Tensor, indices):
|
||||
npt = tensor.numpy()
|
||||
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
|
||||
i for i in indices)
|
||||
return npt, idxs
|
||||
|
||||
def get_numpy(tensor, indices):
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
return Tensor(npt[idxs])
|
||||
|
||||
def set_numpy(tensor:Tensor, indices, value):
|
||||
if not isinstance(value, int):
|
||||
value = value.numpy()
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
npt[idxs] = value
|
||||
return npt
|
||||
|
||||
def assert_get_eq(tensor, indexer):
|
||||
numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
|
||||
|
||||
def assert_set_eq(tensor: Tensor, indexer, val):
|
||||
pyt = clone(tensor)
|
||||
numt = clone(tensor)
|
||||
pyt[indexer] = val
|
||||
numt = set_numpy(numt, indexer, val)
|
||||
numpy_testing_assert_equal_helper(pyt, numt)
|
||||
|
||||
# NOTE: torch initiates the gradients using g0cpu (rand as gradients)
|
||||
def assert_backward_eq(tensor: Tensor, indexer):
|
||||
cpu = clone(tensor.float())
|
||||
cpu.requires_grad = True
|
||||
outcpu = cpu[indexer].sum()
|
||||
outcpu.backward()
|
||||
dev = cpu.detach()
|
||||
dev.requires_grad = True
|
||||
outdev = dev[indexer].sum()
|
||||
outdev.backward()
|
||||
numpy_testing_assert_equal_helper(cpu.grad, dev.grad)
|
||||
|
||||
def get_set_tensor(indexed: Tensor, indexer):
|
||||
set_size = indexed[indexer].shape
|
||||
set_count = indexed[indexer].numel()
|
||||
set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size) #.cast(dtypes.float64)
|
||||
return set_tensor
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CPU", "CL", "METAL", "NV", "AMD"], "slow")
|
||||
class TestAdvancedIndexing(unittest.TestCase):
|
||||
def test_integer_array_indexing(self):
|
||||
# pick a random valid indexer type
|
||||
def ri(indices):
|
||||
choice = random.randint(0, 2)
|
||||
if choice == 0: return Tensor(indices)
|
||||
if choice == 1: return list(indices)
|
||||
return tuple(indices)
|
||||
|
||||
def validate_indexing(x):
|
||||
numpy_testing_assert_equal_helper(x[[0]], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4))
|
||||
numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3))
|
||||
numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3))
|
||||
numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5]))
|
||||
|
||||
def validate_setting(x):
|
||||
x[[0]] = -2
|
||||
numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
|
||||
x[[0]] = -1
|
||||
numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
|
||||
x[[2, 3, 4]] = 4
|
||||
numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
|
||||
x[ri([2, 3, 4]), ] = 3
|
||||
numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
|
||||
x[ri([0, 2, 4]), ] = Tensor([5, 4, 3])
|
||||
numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
|
||||
|
||||
# Case 1: Purely Integer Array Indexing
|
||||
reference = consec((10,))
|
||||
validate_indexing(reference)
|
||||
# setting values
|
||||
validate_setting(reference)
|
||||
|
||||
# Tensor with stride != 1
|
||||
# strided is [1, 3, 5, 7]
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = consec((10,))
|
||||
# strided = set_(reference, (4,), (2,), 0)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
|
||||
# numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
|
||||
# np.array([[5, 3], [1, 7]]))
|
||||
|
||||
# stride is [4, 8]
|
||||
|
||||
# strided = set_(reference, (2,), (4,), offset=4)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
|
||||
# numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
|
||||
# np.array([[5, 9], [9, 5]]))
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
# 5 6
|
||||
reference = consec((3, 2))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
|
||||
[3, 5]]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([1, 0])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[2, 1],
|
||||
[4, 5]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 0]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
|
||||
[4, 5]]))
|
||||
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
|
||||
reference[ri([0, 1, 2]), ri([0])] = Tensor([-1, 2, -4])
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
np.array([-1, 2, -4]))
|
||||
reference[rows, columns] = Tensor([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
|
||||
# Verify still works with Transposed (i.e. non-contiguous) Tensors
|
||||
reference = Tensor([[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]]).T
|
||||
|
||||
# Transposed: [[0, 4, 8],
|
||||
# [1, 5, 9],
|
||||
# [2, 6, 10],
|
||||
# [3, 7, 11]]
|
||||
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([1, 0])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[4, 0], [5, 2]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 3]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 2]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
|
||||
|
||||
# TODO: non contiguous setitem
|
||||
'''
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
|
||||
np.array([-1]))
|
||||
reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
np.array([-1, 2, -4]))
|
||||
reference[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# stride != 1
|
||||
|
||||
# strided is [[1 3 5 7],
|
||||
# [9 11 13 15]]
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
# strided = set_(reference, (2,4), (8,2), 1)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])], np.array([1, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])], np.array([3, 11]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])], np.array([15]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]], np.array([1, 7]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], np.array([9, 11, 11, 9, 15]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 3, 9, 9]))
|
||||
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 1]])
|
||||
# columns = [0],
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[1, 1], [9, 9]]))
|
||||
|
||||
# rows = ri([[0, 1],
|
||||
# [1, 0]])
|
||||
# columns = ri([1, 2])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[3, 13], [11, 5]]))
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 1]])
|
||||
# columns = ri([[0, 1],
|
||||
# [1, 2]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[1, 3], [11, 13]]))
|
||||
|
||||
# setting values
|
||||
|
||||
# strided is [[10, 11],
|
||||
# [17, 18]]
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
# strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])], np.array([11]))
|
||||
|
||||
# TODO non contiguous setitem
|
||||
'''
|
||||
strided[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
Tensor([-1]))
|
||||
'''
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
# strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])], np.array([11, 17]))
|
||||
|
||||
# TODO non contiguous setitem
|
||||
'''
|
||||
strided[ri([0, 1]), ri([1, 0])] = Tensor([-1, 2])
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
Tensor([-1, 2]))
|
||||
'''
|
||||
|
||||
# # TODO: set stride
|
||||
# reference = Tensor.arange(0., 24).realize().reshape(3, 8)
|
||||
# strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# rows = ri([[0],
|
||||
# [1]])
|
||||
# columns = ri([[0, 1],
|
||||
# [0, 1]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns], np.array([[10, 11], [17, 18]]))
|
||||
|
||||
# TODO non contiguous setitem
|
||||
'''
|
||||
strided[rows, columns] = Tensor([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
Tensor([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# Tests using less than the number of dims, and ellipsis
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
# 5 6
|
||||
reference = consec((3, 2))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 2]),], np.array([[1, 2], [5, 6]]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]]))
|
||||
numpy_testing_assert_equal_helper(reference[..., ri([1])], np.array([[2], [4], [6]]))
|
||||
|
||||
# verify too many indices fails
|
||||
with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]
|
||||
|
||||
# test invalid index fails
|
||||
reference = Tensor.empty(10)
|
||||
for err_idx in (10, -11):
|
||||
with self.assertRaises(IndexError):
|
||||
reference[err_idx]
|
||||
# NOTE cannot check for out of bounds with Tensor indexing
|
||||
# see tensor.py: __getitem__ (Tiny Things)
|
||||
'''
|
||||
with self.assertRaises(IndexError):
|
||||
reference[Tensor([err_idx], dtype=dtypes.int64)]
|
||||
with self.assertRaises(IndexError):
|
||||
reference[[err_idx]]
|
||||
'''
|
||||
|
||||
def test_numpy_parity_and_backward_2d(self):
|
||||
# Tensor is 0 1 2 3 4
|
||||
# 5 6 7 8 9
|
||||
# 10 11 12 13 14
|
||||
# 15 16 17 18 19
|
||||
reference = Tensor.arange(0., 20).reshape(4, 5)
|
||||
|
||||
indices_to_test = [
|
||||
# grab the second, fourth columns
|
||||
[slice(None), [1, 3]],
|
||||
|
||||
# first, third rows,
|
||||
[[0, 2], slice(None)],
|
||||
|
||||
# weird shape
|
||||
[slice(None), [[0, 1],
|
||||
[2, 3]]],
|
||||
# negatives
|
||||
[[-1], [0]],
|
||||
[[0, 2], [-1]],
|
||||
[slice(None), [-1]],
|
||||
]
|
||||
|
||||
# only test dupes on gets
|
||||
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
|
||||
|
||||
for indexer in get_indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_set_eq(reference, indexer, 44)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
|
||||
def test_numpy_parity_and_backward_3d(self):
|
||||
reference = Tensor.arange(0., 160).reshape(4, 8, 5)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[[2, 3], slice(None), slice(None)],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[[0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0, 1, 3], [4], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 3]], slice(None)],
|
||||
[[[0, 1], [2, 3]], [[0]], slice(None)],
|
||||
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
|
||||
[[[2]], [[0, 3], [4, 1]], slice(None)],
|
||||
# non-contiguous indexing subspace
|
||||
[[0, 2, 3], slice(None), [1, 3, 4]],
|
||||
|
||||
# less dim, ellipsis
|
||||
[[0, 2], ],
|
||||
[[0, 2], slice(None)],
|
||||
[[0, 2], Ellipsis],
|
||||
[[0, 2], slice(None), Ellipsis],
|
||||
[[0, 2], Ellipsis, slice(None)],
|
||||
[[0, 2], [1, 3]],
|
||||
[[0, 2], [1, 3], Ellipsis],
|
||||
[Ellipsis, [1, 3], [2, 3]],
|
||||
[Ellipsis, [2, 3, 4]],
|
||||
[Ellipsis, slice(None), [2, 3, 4]],
|
||||
[slice(None), Ellipsis, [2, 3, 4]],
|
||||
|
||||
# ellipsis counts for nothing
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), [0, 3, 4], Ellipsis],
|
||||
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
|
||||
assert_set_eq(reference, indexer, 212)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
def test_numpy_parity_and_backward_4d(self):
|
||||
reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[slice(None), [2, 3], slice(None), slice(None)],
|
||||
[[1, 2], slice(None), slice(None), slice(None)],
|
||||
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[slice(None), [0], [1, 2, 4], slice(None)],
|
||||
[slice(None), [0, 1, 3], [4], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
|
||||
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
|
||||
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
|
||||
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
|
||||
[[0], [1, 2, 4], slice(None), slice(None)],
|
||||
[[0, 1, 2], [4], slice(None), slice(None)],
|
||||
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
|
||||
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
|
||||
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 4], [1, 3, 4], [4]],
|
||||
[slice(None), [0, 1, 3], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 5], [3], [4]],
|
||||
[slice(None), [0], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1]],
|
||||
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
|
||||
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[2, 0, 1], [1, 2, 3], [4], slice(None)],
|
||||
[[0, 1, 2], [4], [1, 3, 4], slice(None)],
|
||||
[[0], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0], [4], [1, 3, 4], slice(None)],
|
||||
[[1], [0, 2, 3], [1], slice(None)],
|
||||
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
|
||||
|
||||
# less dim, ellipsis
|
||||
[Ellipsis, [0, 3, 4]],
|
||||
[Ellipsis, slice(None), [0, 3, 4]],
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4]],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0], [1, 2, 4], Ellipsis],
|
||||
[[0], [1, 2, 4], Ellipsis, slice(None)],
|
||||
[[1], ],
|
||||
[[0, 2, 1], [3], [4]],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 1], [3], [4]],
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
|
||||
indices_to_test += [
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
|
||||
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
|
||||
]
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user