support symbolic reshape with non-contiguous (#4844)

* support symbolic reshape with non-contiguous

pre-requisite for symbolic arange (make symbolic ones that can be folded).

* test cases

* typo

* shorter
This commit is contained in:
chenyu
2024-06-05 16:01:19 -04:00
committed by GitHub
parent a352b6d9ce
commit 99e7a1d5e9
3 changed files with 51 additions and 5 deletions

View File

@@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase):
jf = TinyJit(f) jf = TinyJit(f)
for i in range(1, 5): for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i) vi = Variable("i", 1, 10).bind(i)
# TODO: without contiguous, the CONST shape are different in jit t = Tensor.ones(i)
t = Tensor.ones(i).contiguous()
symbolic = jf(t.reshape(vi)).item() symbolic = jf(t.reshape(vi)).item()
expected = f(t).item() expected = f(t).item()
np.testing.assert_equal(symbolic, expected) np.testing.assert_equal(symbolic, expected)

View File

@@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
assert var_val == {v: 2} assert var_val == {v: 2}
class TestSymbolicReshape(unittest.TestCase): class TestSymbolicReshapeFromContiguous(unittest.TestCase):
def test_reshape_into_symbols_simple(self): def test_reshape_into_symbols_simple(self):
for i in range(1, 6): for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i) vi = Variable("i", 1, 5).bind(i)
@@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase):
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64) new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
assert view.reshape(new_shape) is None assert view.reshape(new_shape) is None
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
def test_reshape_from_const(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).reshape(3, vi)
assert t.shape == (3, vi)
assert not t.lazydata.st.contiguous
assert len(t.lazydata.st.views) == 1
def test_reshape_not_allowed(self):
vi = Variable("i", 1, 5).bind(4)
with self.assertRaises(ValueError):
# different shape length # TODO: cases where contractions matched might be fine
Tensor.ones(3, 4, 1).reshape(3, vi)
with self.assertRaises(ValueError):
# size matched, but dimensions do not match
Tensor.ones(4, 3).reshape(3, vi)
def test_reshape_from_padded(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
st = t.lazydata.st
assert len(st.views) == 1
view = st.views[0]
assert view.shape == (4, 3, 2)
t = t.reshape(vi, 3, 2)
st2 = t.lazydata.st
assert len(st2.views) == 1
view2 = st2.views[0]
# check only shape changed. strides, offset, mask, contiguous remained the same
assert view2.shape == (vi, 3, 2)
assert view.strides == view2.strides == (0, 4, 1)
assert view.offset == view2.offset == 1
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
assert not view.contiguous and not view2.contiguous
class TestSymbolicExpand(unittest.TestCase): class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self): def test_expand_into_symbols(self):
vi = Variable("i", 1, 5).bind(3) vi = Variable("i", 1, 5).bind(3)

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools, math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.helpers import prod, all_int, argsort from tinygrad.helpers import prod, all_int, argsort
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]: def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@@ -268,7 +268,7 @@ class View:
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}" assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
return View.create(new_shape) return View.create(new_shape)
# check for the same size # check for the same size
if all_int(self.shape): if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
@@ -278,6 +278,18 @@ class View:
# after the asserts, it's okay to check contiguous # after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape) if self.contiguous: return View.create(new_shape)
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
if self_all_int and not all_int(new_shape):
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
for si, so in zip(self.shape, new_shape):
if isinstance(so, int):
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
else:
var_vals = {v: v.unbind()[1] for v in so.vars()}
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
# all dimensions matched, return the new view directly
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
strides, r_new_shape = [], reversed(new_shape) strides, r_new_shape = [], reversed(new_shape)
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)): for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
acc = 1 acc = 1