mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 09:28:04 -05:00
support symbolic expand (#1407)
This commit is contained in:
@@ -83,4 +83,13 @@ class TestSymbolicReshape(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4)
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
|
||||
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor([[1], [2], [3]]).expand((3, vi))
|
||||
assert a.shape == (3, vi)
|
||||
vj = Variable("j", 1, 10)
|
||||
a = a.reshape(3, vi, 1).expand((3, vi, vj))
|
||||
assert a.shape == (3, vi, vj)
|
||||
@@ -227,9 +227,9 @@ class ShapeTracker:
|
||||
self.__unsafe_resize(arg)
|
||||
return self
|
||||
|
||||
def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker:
|
||||
def expand(self, new_shape: Tuple[Union[Node,int], ...]) -> ShapeTracker:
|
||||
assert len(new_shape) == len(self.views[-1].shape)
|
||||
assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None
|
||||
self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask)
|
||||
|
||||
@@ -39,7 +39,7 @@ class Node:
|
||||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||
def __lt__(self, b:Union[Node,int]):
|
||||
if self == b: return False
|
||||
if self == b: return NumNode(0)
|
||||
lhs = self
|
||||
if isinstance(lhs, SumNode):
|
||||
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
@@ -159,7 +159,9 @@ class OpNode(Node):
|
||||
class LtNode(OpNode):
|
||||
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
|
||||
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1)
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
||||
|
||||
Reference in New Issue
Block a user