diff --git a/test/test_edgecases.py b/test/test_edgecases.py index ae4e26b489..d675442162 100644 --- a/test/test_edgecases.py +++ b/test/test_edgecases.py @@ -62,7 +62,6 @@ class TestNaNEdgeCases(unittest.TestCase): class TestEmptyTensorEdgeCases(unittest.TestCase): # we don't need more of these - @unittest.expectedFailure def test_sort_empty(self): # Sorting an empty tensor works in PyTorch and should return empty # values and indices. tinygrad raises an error instead. diff --git a/test/test_ops.py b/test/test_ops.py index fdd81c7b9e..f8d5fe280a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1093,6 +1093,9 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]]) def test_sort(self): + for shape in [(0,), (0,5), (1,), (1,5)]: + helper_test_op([shape], lambda x: x.sort(0).values, lambda x: x.sort(0)[0], forward_only=True) + helper_test_op([shape], lambda x: x.sort(0).indices.type(torch.int32), lambda x: x.sort(0)[1], forward_only=True) for dim in [-1, 0, 1]: for descending in [True, False]: helper_test_op([(8,8,6)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f66102bc16..b9d71305d0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2845,8 +2845,8 @@ class Tensor(MathTrait): ``` """ x, dim = self, self._resolve_dim(dim) + if (orig_len:= x.shape[dim]) <= 1: return x, x.zeros_like(dtype=dtypes.default_int) # pad to power of 2 - orig_len = x.shape[dim] n_stages = math.ceil(math.log2(orig_len)) fill_value = dtypes.min(x.dtype) if descending else dtypes.max(x.dtype) pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))