Add simple fancy indexing exceptions (#2706)

* fancy indexing raise error

* updated error message

* improved error check

* oops

* fixed onnx

* oops typo

* merge

* add full_flatten

* try

* merged and updated some tests

* more cleaning

* done

* temp fix onnx

* try

* add todo in onnx_test

* reword

* gah
This commit is contained in:
geohotstan
2023-12-20 00:23:51 +08:00
committed by GitHub
parent 417d42a363
commit fec8e9060c
5 changed files with 33 additions and 20 deletions

View File

@@ -373,7 +373,7 @@ def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
N, C, *s_dimensions = scores.shape
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels).cast(dtypes.int32)
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
y = scores.log_softmax(axis=1)
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
@@ -425,10 +425,10 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
elif nearest_mode == "floor": ret = x_resized.floor()
elif nearest_mode == "ceil": ret = x_resized.ceil()
return ret.clip(0, x_len-1)
return ret.cast(dtypes.int32).clip(0, x_len-1)
def _coordinate_transformation(x_out, y_out, output_shape, scales_, roi=None):
if coordinate_transformation_mode == "half_pixel":
x_out = (x_out + 0.5)/Tensor(scales_[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
x_out = (x_out + 0.5)/Tensor(scales_[-1]) - 0.5
y_out = (y_out + 0.5)/Tensor(scales_[-2]) - 0.5
elif coordinate_transformation_mode == "align_corners":
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)

View File

@@ -190,6 +190,10 @@ backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
if isinstance(Device[Device.DEFAULT], Compiled):
backend_test.exclude('test_MaxPool3d_stride_padding_cpu')
# TODO: this somehow passes in CI but does not pass if run locally
if Device.DEFAULT == 'METAL':
backend_test.exclude('test_maxpool_2d_same_upper_cpu')
# TODO: inaccuracy only for numpy backend. will get back to this after dtype refactor.
if Device.DEFAULT == "CPU":
backend_test.exclude('test_sce_')

View File

@@ -195,7 +195,7 @@ class TestIndexing(unittest.TestCase):
# pick a random valid indexer type
def ri(indices):
choice = random.randint(0, 2)
if choice == 0: return Tensor(indices, dtype=dtypes.int32)
if choice == 0: return Tensor(indices)
if choice == 1: return list(indices)
return tuple(indices)
@@ -1134,12 +1134,9 @@ class TestIndexing(unittest.TestCase):
x[:, [0, 1]]
'''
# TODO empty Tensor fancy index
'''
def test_empty_ndim_index_bool(self):
x = Tensor.randn(5)
self.assertRaises(IndexError, lambda: x[Tensor.empty(0, 2, dtype=dtypes.uint8)])
'''
def test_empty_slice(self):
x = Tensor.randn(2, 3, 4, 5)
@@ -1798,13 +1795,10 @@ class TestNumpy(unittest.TestCase):
self.assertIsNot(a, a[...])
self.assertIsNot(a, a[:])
# TODO shape mismatch fancy indexing error
'''
def test_broaderrors_indexing(self):
a = Tensor.zeros(5, 5)
self.assertRaisesRegex(IndexError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
self.assertRaisesRegex(IndexError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
'''
# TODO setitem
'''

View File

@@ -1294,7 +1294,7 @@ class TestOps(unittest.TestCase):
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False)
e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False)
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]]
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
return a,b,c,d,e,i,j,k,o,p
def test_slice_fancy_indexing_no_dim_collapse(self):
@@ -1340,16 +1340,20 @@ class TestOps(unittest.TestCase):
def test_slice_fancy_indexing_list_indices(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[0]]], lambda x: x[[[0]]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[0],b,c,d,:], lambda x: x[[0],j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[[0]]],b,c,d,[[1]]], lambda x: x[[[[0]]],j,k,o,[[1]]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1],b,c,d,:], lambda x: x[[1],j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0],b,c,d,:], lambda x: x[[1,0],j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p])
def test_slice_fancy_indexing_tuple_indices(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0),b,c,d,:], lambda x: x[(0),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1),b,c,d,:], lambda x: x[(1),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(((0,),),)], lambda x: x[(((0,),),)])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0,),b,c,d,:], lambda x: x[(0,),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,),b,c,d,:], lambda x: x[(1,),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0),b,c,d,:], lambda x: x[(1,0),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,(2,1,0),c,(2,1,0),e], lambda x: x[i,(2,1,0),k,(2,1,0),p])
@@ -1371,8 +1375,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,(1,1))], lambda x: x[(i,(1,1))])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,b,c,d,e)], lambda x: x[(i,j,k,o,p)])
def test_slice_fancy_indexing_errors(self): ...
# TODO: currently we not support IndexError for out of bounds idx values
def test_slice_fancy_indexing_errors(self):
a = Tensor.ones(10,11,12)
# tensors used as indices must be int or bool tensors
with self.assertRaises(IndexError): a[Tensor(1.1)]
# shape mismatch
with self.assertRaises(IndexError): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1), Tensor.randint(2,4,4,1)]
with self.assertRaises(IndexError): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1,1)]
# TODO: currently we do not support IndexError for out of bounds idx values
# any out of bounds in fancy indexing returns 0
# ex: Tensor([1,2])[Tensor([1,2,55])].numpy() -> array([2., 0., 0.], dtype=float32)
# TODO: currently we do not support tensor indexing for list of list tensor

View File

@@ -322,8 +322,8 @@ class Tensor:
# treat internal tuples and lists as Tensors and standardize indices to list type
if isinstance(indices, (tuple, list)):
# special case <indices: List[int]>, a lil ugly
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)]
else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices] # noqa: E501
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, requires_grad=False, device=self.device)]
else: indices = [Tensor(list(i), requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
else: indices = [indices]
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
@@ -386,18 +386,23 @@ class Tensor:
tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
# normalize the negative tensor indices
idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t)
# TODO uint8 and bool tensor indexing
if not (dtypes.is_int(t.dtype) or t.dtype == dtypes.bool): raise IndexError("tensors used as indices must be int or bool tensors")
# compute sum_dim, arange, and idx
max_dim = max(i.ndim for i in idx)
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
idx = first_idx + rest_idx
reshaped_idx = first_idx + rest_idx
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
# iteratively eq -> mul -> sum fancy index
for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
try:
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
except AssertionError as exc:
raise IndexError(f"shape mismatch: broadcasting not possible with index shapes {', '.join(str(i.shape) for i in idx)}") from exc
# special permute case
if tdim[0] != 0 and len(tdim) != 1 and tdim != list(range(tdim[0], tdim[-1]+1)):