mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
canonicalize 0 in shape in View.create (#4815)
set strides to 0, offset to 0, mask to None, and contiguous to True with size 0 view.
This commit is contained in:
@@ -379,17 +379,17 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
t = Tensor.empty(3, 2, 0)
|
||||
assert t.shape == (3, 2, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 1)
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
|
||||
t = Tensor.empty(3, 0, 2)
|
||||
assert t.shape == (3, 0, 2)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 2, 1)
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
|
||||
t = Tensor.empty(0, 0, 0)
|
||||
assert t.shape == (0, 0, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 1)
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
|
||||
def test_rand(self):
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
|
||||
@@ -96,6 +96,8 @@ class View:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
||||
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
# canonicalize 0 in shape
|
||||
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
||||
# canonicalize empty mask
|
||||
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
||||
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
||||
|
||||
Reference in New Issue
Block a user