support symbolic reshape with non-contiguous (#4844)

* support symbolic reshape with non-contiguous

pre-requisite for symbolic arange (make symbolic ones that can be folded).

* test cases

* typo

* shorter
This commit is contained in:
chenyu
2024-06-05 16:01:19 -04:00
committed by GitHub
parent a352b6d9ce
commit 99e7a1d5e9
3 changed files with 51 additions and 5 deletions

View File

@@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
# TODO: without contiguous, the CONST shape are different in jit
t = Tensor.ones(i).contiguous()
t = Tensor.ones(i)
symbolic = jf(t.reshape(vi)).item()
expected = f(t).item()
np.testing.assert_equal(symbolic, expected)

View File

@@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
assert var_val == {v: 2}
class TestSymbolicReshape(unittest.TestCase):
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
@@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase):
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
assert view.reshape(new_shape) is None
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
def test_reshape_from_const(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).reshape(3, vi)
assert t.shape == (3, vi)
assert not t.lazydata.st.contiguous
assert len(t.lazydata.st.views) == 1
def test_reshape_not_allowed(self):
vi = Variable("i", 1, 5).bind(4)
with self.assertRaises(ValueError):
# different shape length # TODO: cases where contractions matched might be fine
Tensor.ones(3, 4, 1).reshape(3, vi)
with self.assertRaises(ValueError):
# size matched, but dimensions do not match
Tensor.ones(4, 3).reshape(3, vi)
def test_reshape_from_padded(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
st = t.lazydata.st
assert len(st.views) == 1
view = st.views[0]
assert view.shape == (4, 3, 2)
t = t.reshape(vi, 3, 2)
st2 = t.lazydata.st
assert len(st2.views) == 1
view2 = st2.views[0]
# check only shape changed. strides, offset, mask, contiguous remained the same
assert view2.shape == (vi, 3, 2)
assert view.strides == view2.strides == (0, 4, 1)
assert view.offset == view2.offset == 1
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
assert not view.contiguous and not view2.contiguous
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 5).bind(3)

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.helpers import prod, all_int, argsort
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
@functools.lru_cache(maxsize=None)
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@@ -268,7 +268,7 @@ class View:
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
return View.create(new_shape)
# check for the same size
if all_int(self.shape):
if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
@@ -278,6 +278,18 @@ class View:
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
if self_all_int and not all_int(new_shape):
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
for si, so in zip(self.shape, new_shape):
if isinstance(so, int):
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
else:
var_vals = {v: v.unbind()[1] for v in so.vars()}
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
# all dimensions matched, return the new view directly
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
strides, r_new_shape = [], reversed(new_shape)
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
acc = 1