mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
enable test_index and test_advancedindex (#2648)
* enable test_index and test_advancedindex with pretty diff * removed contig * created set_ helper function * comment change * del empty line --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -4,6 +4,9 @@ import math, unittest, random
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
random.seed(42)
|
||||
|
||||
@@ -15,6 +18,15 @@ def numpy_testing_assert_equal_helper(a, b):
|
||||
def consec(shape, start=1):
|
||||
return Tensor(np.arange(math.prod(shape)).reshape(shape)+start)
|
||||
|
||||
# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
|
||||
def set_(reference: Tensor, shape, strides, offset):
|
||||
if reference.lazydata.base.realized is None: reference.realize()
|
||||
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
|
||||
strided = Tensor(LazyBuffer(device=reference.device, st=ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)), optype=None, op=None, dtype=reference.dtype, src=None, base=reference.lazydata.base))
|
||||
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
assert strided.lazydata in reference.lazydata.base.views, "base.views should contain strided.lazydata"
|
||||
return strided
|
||||
|
||||
def make_tensor(shape, dtype:dtypes, noncontiguous):
|
||||
r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
|
||||
values uniformly drawn from ``[low, high)``.
|
||||
@@ -39,7 +51,7 @@ def make_tensor(shape, dtype:dtypes, noncontiguous):
|
||||
+---------------------------+------------+----------+
|
||||
"""
|
||||
contiguous = not noncontiguous # lol
|
||||
if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, dtype=dtype, contiguous=contiguous)
|
||||
if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, dtype=dtypes.bool, contiguous=contiguous)
|
||||
elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, dtype=dtype, contiguous=contiguous)
|
||||
elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, dtype=dtype, contiguous=contiguous) # signed int
|
||||
elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
|
||||
@@ -118,20 +130,23 @@ class TestIndexing(unittest.TestCase):
|
||||
tensor_indexed = tensor[idx1]
|
||||
numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed))
|
||||
|
||||
# self.assertRaises(ValueError, lambda: reference[1:9:0])
|
||||
self.assertRaises(ValueError, lambda: reference[1:9:0])
|
||||
# NOTE torch doesn't support this but numpy does so we should too. Torch raises ValueError
|
||||
# see test_slice_negative_strides in test_ops.py
|
||||
# self.assertRaises(ValueError, lambda: reference[1:9:-1])
|
||||
|
||||
# self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
|
||||
# self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
|
||||
# self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
|
||||
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
|
||||
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
|
||||
self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
|
||||
|
||||
# self.assertRaises(IndexError, lambda: reference[0.0])
|
||||
# self.assertRaises(TypeError, lambda: reference[0.0:2.0])
|
||||
# self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
|
||||
# self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
|
||||
# self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
|
||||
# self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0])
|
||||
self.assertRaises(TypeError, lambda: reference[0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
|
||||
|
||||
# TODO: delitem
|
||||
# def delitem(): del reference[0]
|
||||
# self.assertRaises(TypeError, delitem)
|
||||
|
||||
@@ -140,8 +155,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
# pick a random valid indexer type
|
||||
def ri(indices):
|
||||
choice = random.randint(0, 1)
|
||||
# TODO: we do not support tuple of list for index now
|
||||
choice = random.randint(0, 2)
|
||||
if choice == 0: return Tensor(indices)
|
||||
if choice == 1: return list(indices)
|
||||
return tuple(indices)
|
||||
@@ -156,17 +170,19 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
def validate_setting(x):
|
||||
pass
|
||||
# # TODO: we don't support setitem now
|
||||
# x[[0]] = -2
|
||||
# numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
|
||||
# x[[0]] = -1
|
||||
# numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
|
||||
# x[[2, 3, 4]] = 4
|
||||
# numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
|
||||
# x[ri([2, 3, 4]), ] = 3
|
||||
# numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
|
||||
# x[ri([0, 2, 4]), ] = np.array([5, 4, 3])
|
||||
# numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
|
||||
# TODO: setitem
|
||||
'''
|
||||
x[[0]] = -2
|
||||
numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
|
||||
x[[0]] = -1
|
||||
numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
|
||||
x[[2, 3, 4]] = 4
|
||||
numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
|
||||
x[ri([2, 3, 4]), ] = 3
|
||||
numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
|
||||
x[ri([0, 2, 4]), ] = np.array([5, 4, 3])
|
||||
numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))
|
||||
'''
|
||||
|
||||
# Case 1: Purely Integer Array Indexing
|
||||
reference = consec((10,))
|
||||
@@ -175,32 +191,31 @@ class TestIndexing(unittest.TestCase):
|
||||
# setting values
|
||||
validate_setting(reference)
|
||||
|
||||
# # Tensor with stride != 1
|
||||
# # strided is [1, 3, 5, 7]
|
||||
# reference = consec((10,))
|
||||
# strided = np.array(())
|
||||
# strided.set_(reference.storage(), storage_offset=0,
|
||||
# size=torch.Size([4]), stride=[2])
|
||||
# Tensor with stride != 1
|
||||
# strided is [1, 3, 5, 7]
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
|
||||
# numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
|
||||
# np.array([[5, 3], [1, 7]]))
|
||||
reference = consec((10,))
|
||||
strided = set_(reference, (4,), (2,), 0)
|
||||
|
||||
# # stride is [4, 8]
|
||||
# strided = np.array(())
|
||||
# strided.set_(reference.storage(), storage_offset=4,
|
||||
# size=torch.Size([2]), stride=[4])
|
||||
# numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
|
||||
# numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
|
||||
# np.array([[5, 9], [9, 5]]))
|
||||
numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
|
||||
numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
|
||||
np.array([[5, 3], [1, 7]]))
|
||||
|
||||
# stride is [4, 8]
|
||||
|
||||
strided = set_(reference, (2,), (4,), offset=4)
|
||||
|
||||
numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
|
||||
numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
|
||||
np.array([[5, 9], [9, 5]]))
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
@@ -210,16 +225,15 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
|
||||
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
|
||||
# # TODO: we don't support list of Tensors as index
|
||||
# numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
|
||||
# numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
|
||||
# numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))
|
||||
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 2]])
|
||||
# columns = [0],
|
||||
# numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
|
||||
# [3, 5]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
|
||||
[3, 5]]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
@@ -233,17 +247,20 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
|
||||
[4, 5]]))
|
||||
|
||||
# # setting values
|
||||
# reference[ri([0]), ri([1])] = -1
|
||||
# numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
|
||||
# reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
|
||||
# numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
# np.array([-1, 2, -4]))
|
||||
# reference[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
# numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
# np.array([[4, 6], [2, 3]]))
|
||||
# TODO: setitem
|
||||
'''
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
|
||||
reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
np.array([-1, 2, -4]))
|
||||
reference[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# Verify still works with Transposed (i.e. non-contiguous) Tensors
|
||||
# Verify still works with Transposed (i.e. non-contiguous) Tensors
|
||||
|
||||
reference = Tensor([[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
@@ -258,15 +275,14 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
|
||||
numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
|
||||
# # TODO: we don't support list of Tensors as index
|
||||
# numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
|
||||
# numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
|
||||
# numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
|
||||
numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))
|
||||
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 2]])
|
||||
# columns = [0],
|
||||
# numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
@@ -278,99 +294,106 @@ class TestIndexing(unittest.TestCase):
|
||||
[1, 2]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))
|
||||
|
||||
# # setting values
|
||||
# reference[ri([0]), ri([1])] = -1
|
||||
# numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
|
||||
# np.array([-1]))
|
||||
# reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
|
||||
# numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
# np.array([-1, 2, -4]))
|
||||
# reference[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
# numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
# np.array([[4, 6], [2, 3]]))
|
||||
# TODO: setitem
|
||||
'''
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
|
||||
np.array([-1]))
|
||||
reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
|
||||
numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
|
||||
np.array([-1, 2, -4]))
|
||||
reference[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(reference[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# # stride != 1
|
||||
# stride != 1
|
||||
|
||||
# # strided is [[1 3 5 7],
|
||||
# # [9 11 13 15]]
|
||||
# strided is [[1 3 5 7],
|
||||
# [9 11 13 15]]
|
||||
|
||||
# reference = torch.arange(0., 24).view(3, 8)
|
||||
# strided = np.array(())
|
||||
# strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
|
||||
# stride=[8, 2])
|
||||
reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
strided = set_(reference, (2,4), (8,2), 1)
|
||||
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
|
||||
# np.array([1, 9]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
|
||||
# np.array([3, 11]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
|
||||
# np.array([1]))
|
||||
# numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
|
||||
# np.array([15]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
|
||||
# np.array([1, 7]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
|
||||
# np.array([9, 11, 11, 9, 15]))
|
||||
# numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
# np.array([1, 3, 9, 9]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
|
||||
np.array([1, 9]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
|
||||
np.array([3, 11]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
|
||||
np.array([1]))
|
||||
numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
|
||||
np.array([15]))
|
||||
numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
|
||||
np.array([1, 7]))
|
||||
numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
|
||||
np.array([9, 11, 11, 9, 15]))
|
||||
numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
np.array([1, 3, 9, 9]))
|
||||
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 1]])
|
||||
# columns = [0],
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
# np.array([[1, 1], [9, 9]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 1]])
|
||||
columns = [0],
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
np.array([[1, 1], [9, 9]]))
|
||||
|
||||
# rows = ri([[0, 1],
|
||||
# [1, 0]])
|
||||
# columns = ri([1, 2])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
# np.array([[3, 13], [11, 5]]))
|
||||
# rows = ri([[0, 0],
|
||||
# [1, 1]])
|
||||
# columns = ri([[0, 1],
|
||||
# [1, 2]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
# np.array([[1, 3], [11, 13]]))
|
||||
rows = ri([[0, 1],
|
||||
[1, 0]])
|
||||
columns = ri([1, 2])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
np.array([[3, 13], [11, 5]]))
|
||||
rows = ri([[0, 0],
|
||||
[1, 1]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 2]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
np.array([[1, 3], [11, 13]]))
|
||||
|
||||
# # setting values
|
||||
|
||||
# # strided is [[10, 11],
|
||||
# # [17, 18]]
|
||||
# setting values
|
||||
|
||||
# reference = torch.arange(0., 24).view(3, 8)
|
||||
# strided = np.array(())
|
||||
# strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
|
||||
# stride=[7, 1])
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
# np.array([11]))
|
||||
# strided[ri([0]), ri([1])] = -1
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
# np.array([-1]))
|
||||
# strided is [[10, 11],
|
||||
# [17, 18]]
|
||||
|
||||
# reference = torch.arange(0., 24).view(3, 8)
|
||||
# strided = np.array(())
|
||||
# strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
|
||||
# stride=[7, 1])
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
# np.array([11, 17]))
|
||||
# strided[ri([0, 1]), ri([1, 0])] = np.array([-1, 2])
|
||||
# numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
# np.array([-1, 2]))
|
||||
reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
# reference = torch.arange(0., 24).view(3, 8)
|
||||
# strided = np.array(())
|
||||
# strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
|
||||
# stride=[7, 1])
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
np.array([11]))
|
||||
# TODO setitem
|
||||
'''
|
||||
strided[ri([0]), ri([1])] = -1
|
||||
numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
|
||||
np.array([-1]))
|
||||
'''
|
||||
|
||||
# rows = ri([[0],
|
||||
# [1]])
|
||||
# columns = ri([[0, 1],
|
||||
# [0, 1]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
# np.array([[10, 11], [17, 18]]))
|
||||
# strided[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
# numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
# np.array([[4, 6], [2, 3]]))
|
||||
reference = Tensor.arange(0., 24).reshape(3, 8)
|
||||
strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
np.array([11, 17]))
|
||||
# TODO setitem
|
||||
'''
|
||||
strided[ri([0, 1]), ri([1, 0])] = np.array([-1, 2])
|
||||
numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
|
||||
np.array([-1, 2]))
|
||||
'''
|
||||
|
||||
reference = Tensor.arange(0., 24).realize().reshape(3, 8)
|
||||
strided = set_(reference, (2,2), (7,1), 10)
|
||||
|
||||
rows = ri([[0],
|
||||
[1]])
|
||||
columns = ri([[0, 1],
|
||||
[0, 1]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
np.array([[10, 11], [17, 18]]))
|
||||
# TODO setitem
|
||||
'''
|
||||
strided[rows, columns] = np.array([[4, 6], [2, 3]])
|
||||
numpy_testing_assert_equal_helper(strided[rows, columns],
|
||||
np.array([[4, 6], [2, 3]]))
|
||||
'''
|
||||
|
||||
# Tests using less than the number of dims, and ellipsis
|
||||
|
||||
@@ -385,40 +408,26 @@ class TestIndexing(unittest.TestCase):
|
||||
# verify too many indices fails
|
||||
with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]
|
||||
|
||||
# # test invalid index fails
|
||||
# reference = torch.empty(10)
|
||||
# # can't test cuda because it is a device assert
|
||||
# if not reference.is_cuda:
|
||||
# for err_idx in (10, -11):
|
||||
# with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
# reference[err_idx]
|
||||
# with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
# reference[torch.LongTensor([err_idx]).to(device)]
|
||||
# with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
# reference[[err_idx]]
|
||||
# test invalid index fails
|
||||
reference = Tensor.empty(10)
|
||||
for err_idx in (10, -11):
|
||||
with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
reference[err_idx]
|
||||
# TODO: cannot check for out of bounds with Tensor indexing
|
||||
# see test_ops.py: test_slice_fancy_indexing_errors()
|
||||
'''
|
||||
with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
reference[Tensor([err_idx], dtype=dtypes.int64)]
|
||||
with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
reference[[err_idx]]
|
||||
'''
|
||||
|
||||
'''
|
||||
def tensor_indices_to_np(tensor, indices):
|
||||
# convert the Torch Tensor to a numpy array
|
||||
tensor = tensor.to(device='cpu')
|
||||
npt = tensor.numpy()
|
||||
# convert indices
|
||||
idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else
|
||||
i for i in indices)
|
||||
return npt, idxs
|
||||
'''
|
||||
def tensor_indices_to_np(tensor: Tensor, indices):
|
||||
npt = tensor.numpy()
|
||||
idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype is dtypes.int64 else
|
||||
i for i in indices)
|
||||
return npt, idxs
|
||||
|
||||
'''
|
||||
def get_numpy(tensor, indices):
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
# index and return as a Torch Tensor
|
||||
return np.array(npt[idxs])
|
||||
'''
|
||||
def get_numpy(tensor, indices):
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
return Tensor(npt[idxs])
|
||||
@@ -440,10 +449,6 @@ class TestIndexing(unittest.TestCase):
|
||||
npt[idxs] = value
|
||||
return npt
|
||||
|
||||
'''
|
||||
def assert_get_eq(tensor, indexer):
|
||||
numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
|
||||
'''
|
||||
def assert_get_eq(tensor, indexer):
|
||||
numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))
|
||||
|
||||
@@ -459,20 +464,9 @@ class TestIndexing(unittest.TestCase):
|
||||
pyt = tensor.detach()
|
||||
numt = tensor.detach()
|
||||
pyt[indexer] = val
|
||||
numt = np.array(set_numpy(numt, indexer, val)) #TODO: shouldn't this already be a numpy array? Why wrap numpy array again???
|
||||
numt = set_numpy(numt, indexer, val)
|
||||
numpy_testing_assert_equal_helper(pyt, numt)
|
||||
|
||||
'''
|
||||
def assert_backward_eq(tensor, indexer):
|
||||
cpu = tensor.float().clone().detach().requires_grad_(True)
|
||||
outcpu = cpu[indexer]
|
||||
gOcpu = torch.rand_like(outcpu)
|
||||
outcpu.backward(gOcpu)
|
||||
dev = cpu.to(device).detach().requires_grad_(True)
|
||||
outdev = dev[indexer]
|
||||
outdev.backward(gOcpu.to(device))
|
||||
numpy_testing_assert_equal_helper(cpu.grad, dev.grad)
|
||||
'''
|
||||
# NOTE: torch initiates the gradients using g0cpu (rand as gradients)
|
||||
def assert_backward_eq(tensor: Tensor, indexer):
|
||||
cpu = tensor.float().detach()
|
||||
@@ -498,177 +492,180 @@ class TestIndexing(unittest.TestCase):
|
||||
set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size).cast(dtypes.float64)
|
||||
return set_tensor
|
||||
|
||||
# # Tensor is 0 1 2 3 4
|
||||
# # 5 6 7 8 9
|
||||
# # 10 11 12 13 14
|
||||
# # 15 16 17 18 19
|
||||
# reference = torch.arange(0., 20).view(4, 5)
|
||||
# Tensor is 0 1 2 3 4
|
||||
# 5 6 7 8 9
|
||||
# 10 11 12 13 14
|
||||
# 15 16 17 18 19
|
||||
reference = Tensor.arange(0., 20).reshape(4, 5)
|
||||
|
||||
# indices_to_test = [
|
||||
# # grab the second, fourth columns
|
||||
# [slice(None), [1, 3]],
|
||||
indices_to_test = [
|
||||
# grab the second, fourth columns
|
||||
[slice(None), [1, 3]],
|
||||
|
||||
# # first, third rows,
|
||||
# [[0, 2], slice(None)],
|
||||
# first, third rows,
|
||||
[[0, 2], slice(None)],
|
||||
|
||||
# # weird shape
|
||||
# [slice(None), [[0, 1],
|
||||
# [2, 3]]],
|
||||
# # negatives
|
||||
# [[-1], [0]],
|
||||
# [[0, 2], [-1]],
|
||||
# [slice(None), [-1]],
|
||||
# ]
|
||||
# weird shape
|
||||
[slice(None), [[0, 1],
|
||||
[2, 3]]],
|
||||
# negatives
|
||||
[[-1], [0]],
|
||||
[[0, 2], [-1]],
|
||||
[slice(None), [-1]],
|
||||
]
|
||||
|
||||
# # only test dupes on gets
|
||||
# get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
|
||||
# only test dupes on gets
|
||||
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
|
||||
|
||||
# for indexer in get_indices_to_test:
|
||||
# assert_get_eq(reference, indexer)
|
||||
# if self.device_type != 'cpu':
|
||||
# assert_backward_eq(reference, indexer)
|
||||
for indexer in get_indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
# for indexer in indices_to_test:
|
||||
# assert_set_eq(reference, indexer, 44)
|
||||
# assert_set_eq(reference,
|
||||
# indexer,
|
||||
# get_set_tensor(reference, indexer))
|
||||
# TODO setitem
|
||||
'''
|
||||
for indexer in indices_to_test:
|
||||
assert_set_eq(reference, indexer, 44)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
'''
|
||||
|
||||
# reference = torch.arange(0., 160).view(4, 8, 5)
|
||||
reference = Tensor.arange(0., 160).reshape(4, 8, 5)
|
||||
|
||||
# indices_to_test = [
|
||||
# [slice(None), slice(None), [0, 3, 4]],
|
||||
# [slice(None), [2, 4, 5, 7], slice(None)],
|
||||
# [[2, 3], slice(None), slice(None)],
|
||||
# [slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
# [slice(None), [0], [1, 2, 4]],
|
||||
# [slice(None), [0, 1, 3], [4]],
|
||||
# [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
# [slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
# [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
# [[0, 2, 3], [1, 3, 4], slice(None)],
|
||||
# [[0], [1, 2, 4], slice(None)],
|
||||
# [[0, 1, 3], [4], slice(None)],
|
||||
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
# [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
|
||||
# [[[0, 1], [2, 3]], [[0]], slice(None)],
|
||||
# [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
|
||||
# [[[2]], [[0, 3], [4, 1]], slice(None)],
|
||||
# # non-contiguous indexing subspace
|
||||
# [[0, 2, 3], slice(None), [1, 3, 4]],
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[[2, 3], slice(None), slice(None)],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[[0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0, 1, 3], [4], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 3]], slice(None)],
|
||||
[[[0, 1], [2, 3]], [[0]], slice(None)],
|
||||
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
|
||||
[[[2]], [[0, 3], [4, 1]], slice(None)],
|
||||
# non-contiguous indexing subspace
|
||||
[[0, 2, 3], slice(None), [1, 3, 4]],
|
||||
|
||||
# # less dim, ellipsis
|
||||
# [[0, 2], ],
|
||||
# [[0, 2], slice(None)],
|
||||
# [[0, 2], Ellipsis],
|
||||
# [[0, 2], slice(None), Ellipsis],
|
||||
# [[0, 2], Ellipsis, slice(None)],
|
||||
# [[0, 2], [1, 3]],
|
||||
# [[0, 2], [1, 3], Ellipsis],
|
||||
# [Ellipsis, [1, 3], [2, 3]],
|
||||
# [Ellipsis, [2, 3, 4]],
|
||||
# [Ellipsis, slice(None), [2, 3, 4]],
|
||||
# [slice(None), Ellipsis, [2, 3, 4]],
|
||||
# less dim, ellipsis
|
||||
[[0, 2], ],
|
||||
[[0, 2], slice(None)],
|
||||
[[0, 2], Ellipsis],
|
||||
[[0, 2], slice(None), Ellipsis],
|
||||
[[0, 2], Ellipsis, slice(None)],
|
||||
[[0, 2], [1, 3]],
|
||||
[[0, 2], [1, 3], Ellipsis],
|
||||
[Ellipsis, [1, 3], [2, 3]],
|
||||
[Ellipsis, [2, 3, 4]],
|
||||
[Ellipsis, slice(None), [2, 3, 4]],
|
||||
[slice(None), Ellipsis, [2, 3, 4]],
|
||||
|
||||
# # ellipsis counts for nothing
|
||||
# [Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
# [slice(None), Ellipsis, slice(None), [0, 3, 4]],
|
||||
# [slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
# [slice(None), slice(None), [0, 3, 4], Ellipsis],
|
||||
# [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
|
||||
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
|
||||
# ]
|
||||
# ellipsis counts for nothing
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), [0, 3, 4], Ellipsis],
|
||||
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
|
||||
]
|
||||
|
||||
# for indexer in indices_to_test:
|
||||
# assert_get_eq(reference, indexer)
|
||||
# assert_set_eq(reference, indexer, 212)
|
||||
# assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
# if torch.cuda.is_available():
|
||||
# assert_backward_eq(reference, indexer)
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
# TODO setitem
|
||||
'''
|
||||
assert_set_eq(reference, indexer, 212)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
'''
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
# reference = torch.arange(0., 1296).view(3, 9, 8, 6)
|
||||
reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)
|
||||
|
||||
# indices_to_test = [
|
||||
# [slice(None), slice(None), slice(None), [0, 3, 4]],
|
||||
# [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
|
||||
# [slice(None), [2, 3], slice(None), slice(None)],
|
||||
# [[1, 2], slice(None), slice(None), slice(None)],
|
||||
# [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
# [slice(None), slice(None), [0], [1, 2, 4]],
|
||||
# [slice(None), slice(None), [0, 1, 3], [4]],
|
||||
# [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
# [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
# [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
# [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
# [slice(None), [0], [1, 2, 4], slice(None)],
|
||||
# [slice(None), [0, 1, 3], [4], slice(None)],
|
||||
# [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
|
||||
# [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
|
||||
# [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
|
||||
# [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
|
||||
# [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
|
||||
# [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
|
||||
# [[0], [1, 2, 4], slice(None), slice(None)],
|
||||
# [[0, 1, 2], [4], slice(None), slice(None)],
|
||||
# [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
|
||||
# [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
|
||||
# [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
# [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
# [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
|
||||
# [slice(None), [2, 3, 4], [1, 3, 4], [4]],
|
||||
# [slice(None), [0, 1, 3], [4], [1, 3, 4]],
|
||||
# [slice(None), [6], [0, 2, 3], [1, 3, 4]],
|
||||
# [slice(None), [2, 3, 5], [3], [4]],
|
||||
# [slice(None), [0], [4], [1, 3, 4]],
|
||||
# [slice(None), [6], [0, 2, 3], [1]],
|
||||
# [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
|
||||
# [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
# [[2, 0, 1], [1, 2, 3], [4], slice(None)],
|
||||
# [[0, 1, 2], [4], [1, 3, 4], slice(None)],
|
||||
# [[0], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
# [[0, 2, 1], [3], [4], slice(None)],
|
||||
# [[0], [4], [1, 3, 4], slice(None)],
|
||||
# [[1], [0, 2, 3], [1], slice(None)],
|
||||
# [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[slice(None), [2, 3], slice(None), slice(None)],
|
||||
[[1, 2], slice(None), slice(None), slice(None)],
|
||||
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[slice(None), [0], [1, 2, 4], slice(None)],
|
||||
[slice(None), [0, 1, 3], [4], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
|
||||
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
|
||||
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
|
||||
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
|
||||
[[0], [1, 2, 4], slice(None), slice(None)],
|
||||
[[0, 1, 2], [4], slice(None), slice(None)],
|
||||
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
|
||||
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
|
||||
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 4], [1, 3, 4], [4]],
|
||||
[slice(None), [0, 1, 3], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 5], [3], [4]],
|
||||
[slice(None), [0], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1]],
|
||||
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
|
||||
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[2, 0, 1], [1, 2, 3], [4], slice(None)],
|
||||
[[0, 1, 2], [4], [1, 3, 4], slice(None)],
|
||||
[[0], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0], [4], [1, 3, 4], slice(None)],
|
||||
[[1], [0, 2, 3], [1], slice(None)],
|
||||
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
|
||||
|
||||
# # less dim, ellipsis
|
||||
# [Ellipsis, [0, 3, 4]],
|
||||
# [Ellipsis, slice(None), [0, 3, 4]],
|
||||
# [Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
# [slice(None), Ellipsis, [0, 3, 4]],
|
||||
# [slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
# [slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
# [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
|
||||
# [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
# [[0], [1, 2, 4]],
|
||||
# [[0], [1, 2, 4], slice(None)],
|
||||
# [[0], [1, 2, 4], Ellipsis],
|
||||
# [[0], [1, 2, 4], Ellipsis, slice(None)],
|
||||
# [[1], ],
|
||||
# [[0, 2, 1], [3], [4]],
|
||||
# [[0, 2, 1], [3], [4], slice(None)],
|
||||
# [[0, 2, 1], [3], [4], Ellipsis],
|
||||
# [Ellipsis, [0, 2, 1], [3], [4]],
|
||||
# ]
|
||||
# less dim, ellipsis
|
||||
[Ellipsis, [0, 3, 4]],
|
||||
[Ellipsis, slice(None), [0, 3, 4]],
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4]],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0], [1, 2, 4], Ellipsis],
|
||||
[[0], [1, 2, 4], Ellipsis, slice(None)],
|
||||
[[1], ],
|
||||
[[0, 2, 1], [3], [4]],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 1], [3], [4]],
|
||||
]
|
||||
|
||||
# for indexer in indices_to_test:
|
||||
# assert_get_eq(reference, indexer)
|
||||
# assert_set_eq(reference, indexer, 1333)
|
||||
# assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
# indices_to_test += [
|
||||
# [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
|
||||
# [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
|
||||
# ]
|
||||
# for indexer in indices_to_test:
|
||||
# assert_get_eq(reference, indexer)
|
||||
# assert_set_eq(reference, indexer, 1333)
|
||||
# if self.device_type != 'cpu':
|
||||
# assert_backward_eq(reference, indexer)
|
||||
|
||||
# def test_advancedindex_big(self):
|
||||
# reference = Tensor.arange(123344)
|
||||
# numpy_testing_assert_equal_helper(reference[[0, 123, 44488, 68807, 123343],], np.array([0, 123, 44488, 68807, 123343]))
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
# TODO setitem
|
||||
'''
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
'''
|
||||
indices_to_test += [
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
|
||||
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
|
||||
]
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
# TODO setitem
|
||||
'''
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
'''
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
# def test_set_item_to_scalar_tensor(self):
|
||||
# m = random.randint(1, 10)
|
||||
|
||||
Reference in New Issue
Block a user