mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
support symbolic reshape with non-contiguous (#4844)
* support symbolic reshape with non-contiguous pre-requisite for symbolic arange (make symbolic ones that can be folded). * test cases * typo * shorter
This commit is contained in:
@@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# TODO: without contiguous, the CONST shape are different in jit
|
||||
t = Tensor.ones(i).contiguous()
|
||||
t = Tensor.ones(i)
|
||||
symbolic = jf(t.reshape(vi)).item()
|
||||
expected = f(t).item()
|
||||
np.testing.assert_equal(symbolic, expected)
|
||||
|
||||
@@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
||||
assert var_val == {v: 2}
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 5).bind(i)
|
||||
@@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase):
|
||||
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||
def test_reshape_from_const(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).reshape(3, vi)
|
||||
assert t.shape == (3, vi)
|
||||
assert not t.lazydata.st.contiguous
|
||||
assert len(t.lazydata.st.views) == 1
|
||||
|
||||
def test_reshape_not_allowed(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
with self.assertRaises(ValueError):
|
||||
# different shape length # TODO: cases where contractions matched might be fine
|
||||
Tensor.ones(3, 4, 1).reshape(3, vi)
|
||||
with self.assertRaises(ValueError):
|
||||
# size matched, but dimensions do not match
|
||||
Tensor.ones(4, 3).reshape(3, vi)
|
||||
|
||||
def test_reshape_from_padded(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
|
||||
st = t.lazydata.st
|
||||
assert len(st.views) == 1
|
||||
view = st.views[0]
|
||||
assert view.shape == (4, 3, 2)
|
||||
t = t.reshape(vi, 3, 2)
|
||||
st2 = t.lazydata.st
|
||||
assert len(st2.views) == 1
|
||||
view2 = st2.views[0]
|
||||
# check only shape changed. strides, offset, mask, contiguous remained the same
|
||||
assert view2.shape == (vi, 3, 2)
|
||||
assert view.strides == view2.strides == (0, 4, 1)
|
||||
assert view.offset == view2.offset == 1
|
||||
assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2))
|
||||
assert not view.contiguous and not view2.contiguous
|
||||
|
||||
class TestSymbolicExpand(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 5).bind(3)
|
||||
|
||||
Reference in New Issue
Block a user