mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Add simple fancy indexing exceptions (#2706)
* fancy indexing raise error * updated error message * improved error check * oops * fixed onnx * oops typo * merge * add full_flatten * try * merged and updated some tests * more cleaning * done * temp fix onnx * try * add todo in onnx_test * reword * gah
This commit is contained in:
@@ -373,7 +373,7 @@ def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore
|
||||
|
||||
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
|
||||
N, C, *s_dimensions = scores.shape
|
||||
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
|
||||
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels).cast(dtypes.int32)
|
||||
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
|
||||
y = scores.log_softmax(axis=1)
|
||||
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
|
||||
@@ -425,10 +425,10 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
||||
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
|
||||
elif nearest_mode == "floor": ret = x_resized.floor()
|
||||
elif nearest_mode == "ceil": ret = x_resized.ceil()
|
||||
return ret.clip(0, x_len-1)
|
||||
return ret.cast(dtypes.int32).clip(0, x_len-1)
|
||||
def _coordinate_transformation(x_out, y_out, output_shape, scales_, roi=None):
|
||||
if coordinate_transformation_mode == "half_pixel":
|
||||
x_out = (x_out + 0.5)/Tensor(scales_[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
|
||||
x_out = (x_out + 0.5)/Tensor(scales_[-1]) - 0.5
|
||||
y_out = (y_out + 0.5)/Tensor(scales_[-2]) - 0.5
|
||||
elif coordinate_transformation_mode == "align_corners":
|
||||
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
|
||||
|
||||
4
test/external/external_test_onnx_backend.py
vendored
4
test/external/external_test_onnx_backend.py
vendored
@@ -190,6 +190,10 @@ backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
|
||||
if isinstance(Device[Device.DEFAULT], Compiled):
|
||||
backend_test.exclude('test_MaxPool3d_stride_padding_cpu')
|
||||
|
||||
# TODO: this somehow passes in CI but does not pass if run locally
|
||||
if Device.DEFAULT == 'METAL':
|
||||
backend_test.exclude('test_maxpool_2d_same_upper_cpu')
|
||||
|
||||
# TODO: inaccuracy only for numpy backend. will get back to this after dtype refactor.
|
||||
if Device.DEFAULT == "CPU":
|
||||
backend_test.exclude('test_sce_')
|
||||
|
||||
@@ -195,7 +195,7 @@ class TestIndexing(unittest.TestCase):
|
||||
# pick a random valid indexer type
|
||||
def ri(indices):
|
||||
choice = random.randint(0, 2)
|
||||
if choice == 0: return Tensor(indices, dtype=dtypes.int32)
|
||||
if choice == 0: return Tensor(indices)
|
||||
if choice == 1: return list(indices)
|
||||
return tuple(indices)
|
||||
|
||||
@@ -1134,12 +1134,9 @@ class TestIndexing(unittest.TestCase):
|
||||
x[:, [0, 1]]
|
||||
'''
|
||||
|
||||
# TODO empty Tensor fancy index
|
||||
'''
|
||||
def test_empty_ndim_index_bool(self):
|
||||
x = Tensor.randn(5)
|
||||
self.assertRaises(IndexError, lambda: x[Tensor.empty(0, 2, dtype=dtypes.uint8)])
|
||||
'''
|
||||
|
||||
def test_empty_slice(self):
|
||||
x = Tensor.randn(2, 3, 4, 5)
|
||||
@@ -1798,13 +1795,10 @@ class TestNumpy(unittest.TestCase):
|
||||
self.assertIsNot(a, a[...])
|
||||
self.assertIsNot(a, a[:])
|
||||
|
||||
# TODO shape mismatch fancy indexing error
|
||||
'''
|
||||
def test_broaderrors_indexing(self):
|
||||
a = Tensor.zeros(5, 5)
|
||||
self.assertRaisesRegex(IndexError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
|
||||
self.assertRaisesRegex(IndexError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
'''
|
||||
|
||||
# TODO setitem
|
||||
'''
|
||||
|
||||
@@ -1294,7 +1294,7 @@ class TestOps(unittest.TestCase):
|
||||
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False)
|
||||
e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False)
|
||||
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]]
|
||||
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
|
||||
return a,b,c,d,e,i,j,k,o,p
|
||||
|
||||
def test_slice_fancy_indexing_no_dim_collapse(self):
|
||||
@@ -1340,16 +1340,20 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_slice_fancy_indexing_list_indices(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[0]]], lambda x: x[[[0]]])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[0],b,c,d,:], lambda x: x[[0],j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[[0]]],b,c,d,[[1]]], lambda x: x[[[[0]]],j,k,o,[[1]]])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1],b,c,d,:], lambda x: x[[1],j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0],b,c,d,:], lambda x: x[[1,0],j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p])
|
||||
|
||||
def test_slice_fancy_indexing_tuple_indices(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0),b,c,d,:], lambda x: x[(0),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1),b,c,d,:], lambda x: x[(1),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(((0,),),)], lambda x: x[(((0,),),)])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0,),b,c,d,:], lambda x: x[(0,),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,),b,c,d,:], lambda x: x[(1,),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0),b,c,d,:], lambda x: x[(1,0),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,(2,1,0),c,(2,1,0),e], lambda x: x[i,(2,1,0),k,(2,1,0),p])
|
||||
@@ -1371,8 +1375,14 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,(1,1))], lambda x: x[(i,(1,1))])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,b,c,d,e)], lambda x: x[(i,j,k,o,p)])
|
||||
|
||||
def test_slice_fancy_indexing_errors(self): ...
|
||||
# TODO: currently we not support IndexError for out of bounds idx values
|
||||
def test_slice_fancy_indexing_errors(self):
|
||||
a = Tensor.ones(10,11,12)
|
||||
# tensors used as indices must be int or bool tensors
|
||||
with self.assertRaises(IndexError): a[Tensor(1.1)]
|
||||
# shape mismatch
|
||||
with self.assertRaises(IndexError): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1), Tensor.randint(2,4,4,1)]
|
||||
with self.assertRaises(IndexError): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1,1)]
|
||||
# TODO: currently we do not support IndexError for out of bounds idx values
|
||||
# any out of bounds in fancy indexing returns 0
|
||||
# ex: Tensor([1,2])[Tensor([1,2,55])].numpy() -> array([2., 0., 0.], dtype=float32)
|
||||
# TODO: currently we do not support tensor indexing for list of list tensor
|
||||
|
||||
@@ -322,8 +322,8 @@ class Tensor:
|
||||
# treat internal tuples and lists as Tensors and standardize indices to list type
|
||||
if isinstance(indices, (tuple, list)):
|
||||
# special case <indices: List[int]>, a lil ugly
|
||||
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)]
|
||||
else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices] # noqa: E501
|
||||
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, requires_grad=False, device=self.device)]
|
||||
else: indices = [Tensor(list(i), requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
|
||||
else: indices = [indices]
|
||||
|
||||
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
||||
@@ -386,18 +386,23 @@ class Tensor:
|
||||
tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
|
||||
# normalize the negative tensor indices
|
||||
idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t)
|
||||
# TODO uint8 and bool tensor indexing
|
||||
if not (dtypes.is_int(t.dtype) or t.dtype == dtypes.bool): raise IndexError("tensors used as indices must be int or bool tensors")
|
||||
|
||||
# compute sum_dim, arange, and idx
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
|
||||
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
|
||||
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
|
||||
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
|
||||
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
||||
idx = first_idx + rest_idx
|
||||
reshaped_idx = first_idx + rest_idx
|
||||
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
|
||||
|
||||
# iteratively eq -> mul -> sum fancy index
|
||||
for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
||||
try:
|
||||
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
||||
except AssertionError as exc:
|
||||
raise IndexError(f"shape mismatch: broadcasting not possible with index shapes {', '.join(str(i.shape) for i in idx)}") from exc
|
||||
|
||||
# special permute case
|
||||
if tdim[0] != 0 and len(tdim) != 1 and tdim != list(range(tdim[0], tdim[-1]+1)):
|
||||
|
||||
Reference in New Issue
Block a user