From 99e7a1d5e9108e71b2d545e1f16eb50c93c553df Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 5 Jun 2024 16:01:19 -0400 Subject: [PATCH] support symbolic reshape with non-contiguous (#4844) * support symbolic reshape with non-contiguous pre-requisite for symbolic arange (make symbolic ones that can be folded). * test cases * typo * shorter --- test/test_symbolic_jit.py | 3 +-- test/test_symbolic_shapetracker.py | 37 +++++++++++++++++++++++++++++- tinygrad/shape/view.py | 16 +++++++++++-- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index e0a55fe5b0..432c3dfa19 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -182,8 +182,7 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - # TODO: without contiguous, the CONST shape are different in jit - t = Tensor.ones(i).contiguous() + t = Tensor.ones(i) symbolic = jf(t.reshape(vi)).item() expected = f(t).item() np.testing.assert_equal(symbolic, expected) diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index ee1b5427fc..d1ffdfb4db 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -102,7 +102,7 @@ class TestShapeTrackerUnbind(unittest.TestCase): assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) assert var_val == {v: 2} -class TestSymbolicReshape(unittest.TestCase): +class TestSymbolicReshapeFromContiguous(unittest.TestCase): def test_reshape_into_symbols_simple(self): for i in range(1, 6): vi = Variable("i", 1, 5).bind(i) @@ -151,6 +151,41 @@ class TestSymbolicReshape(unittest.TestCase): new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64) assert view.reshape(new_shape) is None +class TestSymbolicReshapeFromNonContiguous(unittest.TestCase): + def test_reshape_from_const(self): + vi = Variable("i", 1, 5).bind(4) + t = Tensor.ones(3, 4).reshape(3, vi) + assert t.shape == (3, vi) + assert not t.lazydata.st.contiguous + assert len(t.lazydata.st.views) == 1 + + def test_reshape_not_allowed(self): + vi = Variable("i", 1, 5).bind(4) + with self.assertRaises(ValueError): + # different shape length # TODO: cases where contractions matched might be fine + Tensor.ones(3, 4, 1).reshape(3, vi) + with self.assertRaises(ValueError): + # size matched, but dimensions do not match + Tensor.ones(4, 3).reshape(3, vi) + + def test_reshape_from_padded(self): + vi = Variable("i", 1, 5).bind(4) + t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3))) + st = t.lazydata.st + assert len(st.views) == 1 + view = st.views[0] + assert view.shape == (4, 3, 2) + t = t.reshape(vi, 3, 2) + st2 = t.lazydata.st + assert len(st2.views) == 1 + view2 = st2.views[0] + # check only shape changed. strides, offset, mask, contiguous remained the same + assert view2.shape == (vi, 3, 2) + assert view.strides == view2.strides == (0, 4, 1) + assert view.offset == view2.offset == 1 + assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2)) + assert not view.contiguous and not view2.contiguous + class TestSymbolicExpand(unittest.TestCase): def test_expand_into_symbols(self): vi = Variable("i", 1, 5).bind(3) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 077eb1a65b..7995423a4d 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -3,7 +3,7 @@ import functools, operator, itertools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast from tinygrad.helpers import prod, all_int, argsort -from tinygrad.shape.symbolic import Node, NumNode, Variable, sint +from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer @functools.lru_cache(maxsize=None) def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]: @@ -268,7 +268,7 @@ class View: assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}" return View.create(new_shape) # check for the same size - if all_int(self.shape): + if (self_all_int := all_int(self.shape)): assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") @@ -278,6 +278,18 @@ class View: # after the asserts, it's okay to check contiguous if self.contiguous: return View.create(new_shape) + # if it's not contiguous and new shape is symbolic, check if it's directly replaceable + if self_all_int and not all_int(new_shape): + if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") + for si, so in zip(self.shape, new_shape): + if isinstance(so, int): + if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") + else: + var_vals = {v: v.unbind()[1] for v in so.vars()} + if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") + # all dimensions matched, return the new view directly + return View(new_shape, self.strides, self.offset, self.mask, self.contiguous) + strides, r_new_shape = [], reversed(new_shape) for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)): acc = 1