From 6572ca68350101bcb7eee165e92a8d223fca3597 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 2 Aug 2023 17:03:46 -0700 Subject: [PATCH] support symbolic expand (#1407) --- test/test_symbolic_shapetracker.py | 11 ++++++++++- tinygrad/shape/shapetracker.py | 4 ++-- tinygrad/shape/symbolic.py | 6 ++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 3f5182615e..113b74e5b3 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -83,4 +83,13 @@ class TestSymbolicReshape(unittest.TestCase): with self.assertRaises(AssertionError): t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4) with self.assertRaises(AssertionError): - t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4) \ No newline at end of file + t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4) + +class TestSymbolicReshape(unittest.TestCase): + def test_expand_into_symbols(self): + vi = Variable("i", 1, 10) + a = Tensor([[1], [2], [3]]).expand((3, vi)) + assert a.shape == (3, vi) + vj = Variable("j", 1, 10) + a = a.reshape(3, vi, 1).expand((3, vi, vj)) + assert a.shape == (3, vi, vj) \ No newline at end of file diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index af697f90aa..d5fa8c4e2a 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -227,9 +227,9 @@ class ShapeTracker: self.__unsafe_resize(arg) return self - def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker: + def expand(self, new_shape: Tuple[Union[Node,int], ...]) -> ShapeTracker: assert len(new_shape) == len(self.views[-1].shape) - assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}" + assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}" # NOTE: can the mask ever be (0,0)? mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index bd918eede5..05d4f8be98 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -39,7 +39,7 @@ class Node: def __gt__(self, b:Union[Node,int]): return (-self) < (-b) def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1) def __lt__(self, b:Union[Node,int]): - if self == b: return False + if self == b: return NumNode(0) lhs = self if isinstance(lhs, SumNode): muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b) @@ -159,7 +159,9 @@ class OpNode(Node): class LtNode(OpNode): def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b) def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b) - def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b) + def get_bounds(self) -> Tuple[int, int]: + if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b) + return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1) class MulNode(OpNode): def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul