mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
support symbolic reshape with non-contiguous (#4844)
* support symbolic reshape with non-contiguous pre-requisite for symbolic arange (make symbolic ones that can be folded). * test cases * typo * shorter
This commit is contained in:
@@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase):
|
|||||||
jf = TinyJit(f)
|
jf = TinyJit(f)
|
||||||
for i in range(1, 5):
|
for i in range(1, 5):
|
||||||
vi = Variable("i", 1, 10).bind(i)
|
vi = Variable("i", 1, 10).bind(i)
|
||||||
# TODO: without contiguous, the CONST shape are different in jit
|
t = Tensor.ones(i)
|
||||||
t = Tensor.ones(i).contiguous()
|
|
||||||
symbolic = jf(t.reshape(vi)).item()
|
symbolic = jf(t.reshape(vi)).item()
|
||||||
expected = f(t).item()
|
expected = f(t).item()
|
||||||
np.testing.assert_equal(symbolic, expected)
|
np.testing.assert_equal(symbolic, expected)
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
|
|||||||
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
||||||
assert var_val == {v: 2}
|
assert var_val == {v: 2}
|
||||||
|
|
||||||
class TestSymbolicReshape(unittest.TestCase):
|
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
||||||
def test_reshape_into_symbols_simple(self):
|
def test_reshape_into_symbols_simple(self):
|
||||||
for i in range(1, 6):
|
for i in range(1, 6):
|
||||||
vi = Variable("i", 1, 5).bind(i)
|
vi = Variable("i", 1, 5).bind(i)
|
||||||
@@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase):
|
|||||||
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
|
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
|
||||||
assert view.reshape(new_shape) is None
|
assert view.reshape(new_shape) is None
|
||||||
|
|
||||||
|
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||||
|
def test_reshape_from_const(self):
|
||||||
|
vi = Variable("i", 1, 5).bind(4)
|
||||||
|
t = Tensor.ones(3, 4).reshape(3, vi)
|
||||||
|
assert t.shape == (3, vi)
|
||||||
|
assert not t.lazydata.st.contiguous
|
||||||
|
assert len(t.lazydata.st.views) == 1
|
||||||
|
|
||||||
|
def test_reshape_not_allowed(self):
|
||||||
|
vi = Variable("i", 1, 5).bind(4)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# different shape length # TODO: cases where contractions matched might be fine
|
||||||
|
Tensor.ones(3, 4, 1).reshape(3, vi)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# size matched, but dimensions do not match
|
||||||
|
Tensor.ones(4, 3).reshape(3, vi)
|
||||||
|
|
||||||
|
def test_reshape_from_padded(self):
|
||||||
|
vi = Variable("i", 1, 5).bind(4)
|
||||||
|
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
|
||||||
|
st = t.lazydata.st
|
||||||
|
assert len(st.views) == 1
|
||||||
|
view = st.views[0]
|
||||||
|
assert view.shape == (4, 3, 2)
|
||||||
|
t = t.reshape(vi, 3, 2)
|
||||||
|
st2 = t.lazydata.st
|
||||||
|
assert len(st2.views) == 1
|
||||||
|
view2 = st2.views[0]
|
||||||
|
# check only shape changed. strides, offset, mask, contiguous remained the same
|
||||||
|
assert view2.shape == (vi, 3, 2)
|
||||||
|
assert view.strides == view2.strides == (0, 4, 1)
|
||||||
|
assert view.offset == view2.offset == 1
|
||||||
|
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
|
||||||
|
assert not view.contiguous and not view2.contiguous
|
||||||
|
|
||||||
class TestSymbolicExpand(unittest.TestCase):
|
class TestSymbolicExpand(unittest.TestCase):
|
||||||
def test_expand_into_symbols(self):
|
def test_expand_into_symbols(self):
|
||||||
vi = Variable("i", 1, 5).bind(3)
|
vi = Variable("i", 1, 5).bind(3)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import functools, operator, itertools, math
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||||
from tinygrad.helpers import prod, all_int, argsort
|
from tinygrad.helpers import prod, all_int, argsort
|
||||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
|
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||||
@@ -268,7 +268,7 @@ class View:
|
|||||||
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
||||||
return View.create(new_shape)
|
return View.create(new_shape)
|
||||||
# check for the same size
|
# check for the same size
|
||||||
if all_int(self.shape):
|
if (self_all_int := all_int(self.shape)):
|
||||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||||
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
||||||
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||||
@@ -278,6 +278,18 @@ class View:
|
|||||||
# after the asserts, it's okay to check contiguous
|
# after the asserts, it's okay to check contiguous
|
||||||
if self.contiguous: return View.create(new_shape)
|
if self.contiguous: return View.create(new_shape)
|
||||||
|
|
||||||
|
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
|
||||||
|
if self_all_int and not all_int(new_shape):
|
||||||
|
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||||
|
for si, so in zip(self.shape, new_shape):
|
||||||
|
if isinstance(so, int):
|
||||||
|
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||||
|
else:
|
||||||
|
var_vals = {v: v.unbind()[1] for v in so.vars()}
|
||||||
|
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||||
|
# all dimensions matched, return the new view directly
|
||||||
|
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
||||||
|
|
||||||
strides, r_new_shape = [], reversed(new_shape)
|
strides, r_new_shape = [], reversed(new_shape)
|
||||||
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
||||||
acc = 1
|
acc = 1
|
||||||
|
|||||||
Reference in New Issue
Block a user