support symbolic expand (#1407)

This commit is contained in:
chenyu
2023-08-02 17:03:46 -07:00
committed by GitHub
parent a367f71fea
commit 6572ca6835
3 changed files with 16 additions and 5 deletions

View File

@@ -83,4 +83,13 @@ class TestSymbolicReshape(unittest.TestCase):
with self.assertRaises(AssertionError):
t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4)
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
class TestSymbolicReshape(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 10)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
vj = Variable("j", 1, 10)
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)

View File

@@ -227,9 +227,9 @@ class ShapeTracker:
self.__unsafe_resize(arg)
return self
def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker:
def expand(self, new_shape: Tuple[Union[Node,int], ...]) -> ShapeTracker:
assert len(new_shape) == len(self.views[-1].shape)
assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
# NOTE: can the mask ever be (0,0)?
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None
self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask)

View File

@@ -39,7 +39,7 @@ class Node:
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
def __lt__(self, b:Union[Node,int]):
if self == b: return False
if self == b: return NumNode(0)
lhs = self
if isinstance(lhs, SumNode):
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
@@ -159,7 +159,9 @@ class OpNode(Node):
class LtNode(OpNode):
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
def get_bounds(self) -> Tuple[int, int]:
if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b)
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1)
class MulNode(OpNode):
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul