From f1f863c953f38c212be496f274b4a1d67bdd1094 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 15 Nov 2023 13:12:21 -0500 Subject: [PATCH] allow 0-dim array to broadcast into zero shape tensor (#2315) * allow 0-dim array to broadcast into zero shape tensor * not in --- test/test_ops.py | 12 +++++------- test/test_tensor.py | 4 ++-- tinygrad/tensor.py | 4 ++-- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index ebf946e5c5..6c0036d00a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -626,24 +626,22 @@ class TestOps(unittest.TestCase): np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy()) - @unittest.skip("No suppport for tensors with 0s in shape") def test_slice_both_endpoints_out_of_bounds(self): helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True) helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True) - @unittest.skip("No suppport for tensors with 0s in shape") def test_slice_start_gt_end(self): helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) + # TODO: bug in getitem? + # helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) - @unittest.skip("No suppport for tensors with 0s in shape") def test_slice_empty(self): helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) - @unittest.skip("No suppport for tensors with 0s in shape") def test_slice_zero_in_shape(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1]) # x.shape = (0, 10) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5]) # x.shape = (0, 3, 3) + helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) # x.shape = (0, 10) + # TODO: bug in getitem? + # helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) # x.shape = (0, 3, 3) def test_slice_errors(self): a = Tensor.ones(4, 3) diff --git a/test/test_tensor.py b/test/test_tensor.py index a2415f6945..940bc54dde 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -358,8 +358,7 @@ class TestZeroShapeTensor(unittest.TestCase): assert ab.shape == (3, 2, 0) np.testing.assert_equal(ab.numpy(), a.numpy() * b.numpy()) - # NOTE: cannot compare with a constant to construct the mask because 0-dim tensor is not broadcastable - mask = (Tensor.rand(3, 2, 0) > Tensor.rand(3, 2, 0)) + mask = (Tensor.rand(3, 2, 0) > 0.5) assert mask.shape == (3, 2, 0) c = mask.where(a, b) assert c.shape == (3, 2, 0) @@ -383,6 +382,7 @@ class TestZeroShapeTensor(unittest.TestCase): np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf")) np.testing.assert_equal(Tensor([]).min().numpy(), float("inf")) np.testing.assert_equal(Tensor([]).sum().numpy(), 0) + np.testing.assert_equal(Tensor([]).mean().numpy(), 0) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2d580f2939..d96a949900 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -441,7 +441,7 @@ class Tensor: def mean(self, axis=None, keepdim=False): assert all_int(self.shape), "does not support symbolic shape" out = self.sum(axis=axis, keepdim=keepdim) - return out.mul(prod(out.shape)/prod(self.shape)) + return out.mul(prod(out.shape)/prod(self.shape)) if 0 not in self.shape else out def std(self, axis=None, keepdim=False, correction=1): assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) @@ -625,6 +625,7 @@ class Tensor: def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]: x: Tensor = self if not isinstance(y, Tensor): + if 0 in x.shape: return x, x.full_like(y) y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) if reverse: x, y = y, x if (xshape:=x.shape) == (yshape:=y.shape): return (x, y) @@ -684,7 +685,6 @@ class Tensor: def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x)) def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]): - if 0 in self.shape: return self x_,y = self._broadcasted(input_) x,z = x_._broadcasted(other) return mlops.Where.apply(x, *y._broadcasted(z))