From 45083ccb43557c09c697c8fe13f5cdb4f9e1f38f Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 3 Jun 2024 13:37:37 -0400 Subject: [PATCH] canonicalize 0 in shape in View.create (#4815) set strides to 0, offset to 0, mask to None, and contiguous to True with size 0 view. --- test/test_tensor.py | 6 +++--- tinygrad/shape/view.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 05c3d6b1a6..025731037d 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -379,17 +379,17 @@ class TestZeroShapeTensor(unittest.TestCase): t = Tensor.empty(3, 2, 0) assert t.shape == (3, 2, 0) # numpy has stride 0, 0, 0; torch has stride 2, 1, 1 - assert t.lazydata.st.real_strides() == (0, 0, 1) + assert t.lazydata.st.real_strides() == (0, 0, 0) t = Tensor.empty(3, 0, 2) assert t.shape == (3, 0, 2) # numpy has stride 0, 0, 0; torch has stride 2, 2, 1 - assert t.lazydata.st.real_strides() == (0, 2, 1) + assert t.lazydata.st.real_strides() == (0, 0, 0) t = Tensor.empty(0, 0, 0) assert t.shape == (0, 0, 0) # numpy has stride 0, 0, 0; torch has stride 1, 1, 1 - assert t.lazydata.st.real_strides() == (0, 0, 1) + assert t.lazydata.st.real_strides() == (0, 0, 0) def test_rand(self): t = Tensor.rand(3, 2, 0) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 92da2b5a35..077eb1a65b 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -96,6 +96,8 @@ class View: @functools.lru_cache(maxsize=None) def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None): strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape) + # canonicalize 0 in shape + if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True) # canonicalize empty mask if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked