mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
support symbolic reshape with non-contiguous (#4844)
* support symbolic reshape with non-contiguous pre-requisite for symbolic arange (make symbolic ones that can be folded). * test cases * typo * shorter
This commit is contained in:
@@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# TODO: without contiguous, the CONST shape are different in jit
|
||||
t = Tensor.ones(i).contiguous()
|
||||
t = Tensor.ones(i)
|
||||
symbolic = jf(t.reshape(vi)).item()
|
||||
expected = f(t).item()
|
||||
np.testing.assert_equal(symbolic, expected)
|
||||
|
||||
@@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
||||
assert var_val == {v: 2}
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
@@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase):
|
||||
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||
def test_reshape_from_const(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).reshape(3, vi)
|
||||
assert t.shape == (3, vi)
|
||||
assert not t.lazydata.st.contiguous
|
||||
assert len(t.lazydata.st.views) == 1
|
||||
|
||||
def test_reshape_not_allowed(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
with self.assertRaises(ValueError):
|
||||
# different shape length # TODO: cases where contractions matched might be fine
|
||||
Tensor.ones(3, 4, 1).reshape(3, vi)
|
||||
with self.assertRaises(ValueError):
|
||||
# size matched, but dimensions do not match
|
||||
Tensor.ones(4, 3).reshape(3, vi)
|
||||
|
||||
def test_reshape_from_padded(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
|
||||
st = t.lazydata.st
|
||||
assert len(st.views) == 1
|
||||
view = st.views[0]
|
||||
assert view.shape == (4, 3, 2)
|
||||
t = t.reshape(vi, 3, 2)
|
||||
st2 = t.lazydata.st
|
||||
assert len(st2.views) == 1
|
||||
view2 = st2.views[0]
|
||||
# check only shape changed. strides, offset, mask, contiguous remained the same
|
||||
assert view2.shape == (vi, 3, 2)
|
||||
assert view.strides == view2.strides == (0, 4, 1)
|
||||
assert view.offset == view2.offset == 1
|
||||
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
|
||||
assert not view.contiguous and not view2.contiguous
|
||||
|
||||
class TestSymbolicExpand(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 5).bind(3)
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
@@ -268,7 +268,7 @@ class View:
|
||||
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
||||
return View.create(new_shape)
|
||||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
if (self_all_int := all_int(self.shape)):
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
||||
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
@@ -278,6 +278,18 @@ class View:
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
|
||||
if self_all_int and not all_int(new_shape):
|
||||
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
for si, so in zip(self.shape, new_shape):
|
||||
if isinstance(so, int):
|
||||
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
else:
|
||||
var_vals = {v: v.unbind()[1] for v in so.vars()}
|
||||
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
# all dimensions matched, return the new view directly
|
||||
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
||||
|
||||
strides, r_new_shape = [], reversed(new_shape)
|
||||
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
||||
acc = 1
|
||||
|
||||
Reference in New Issue
Block a user