need to cache it

This commit is contained in:
George Hotz
2025-10-14 17:35:29 +08:00
parent a659cb18a4
commit faddebef07
2 changed files with 18 additions and 6 deletions

View File

@@ -177,8 +177,8 @@ class TestSymbolicExpand(unittest.TestCase):
vi = Variable("i", 1, 10).bind(3)
a = Tensor(1).unsqueeze(0).pad((0, 24)).unsqueeze(0).expand((vi, 25))
self.assertEqual(a.shape, (vi, 25))
self.assertIn(a.reshape(25*vi).shape, {(vi*25,), (25*vi,)})
self.assertIn(a.reshape(vi*25).shape, {(vi*25,), (25*vi,)})
self.assertEqual(a.reshape(25*vi).shape, (vi*25,))
self.assertEqual(a.reshape(vi*25).shape, (vi*25,))
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols_simple(self):

View File

@@ -223,13 +223,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
shape = tuple(1 if i in axis_arg else s for i,s in enumerate(shape))
return ShapeTracker.from_shape(shape)
@property
@functools.cached_property
def shape(self) -> tuple[sint, ...]:
# some ops init the shape
match self.op:
case Ops.CONST: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
# movement ops change the shape
# NOTE: ssimplify is required because the shape needs to be canonical
if self.op in GroupOp.Movement:
ps = self.src[0].shape
match self.op:
@@ -243,14 +246,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if sorted(self.arg) != list(range(len(ps))): raise RuntimeError(f"invalid permutation {self.arg} of len {len(ps)}")
return tuple(ps[i] for i in self.arg)
case Ops.PAD:
if not all(b>=0 and e>=0 for b,e in self.arg): raise RuntimeError(f"invalid pad {self.arg}")
if len(ps) != len(self.arg) or not all(b>=0 and e>=0 for b,e in self.arg): raise RuntimeError(f"invalid pad {self.arg}")
return tuple(ssimplify(s+b+e) for s,(b,e) in zip(ps, self.arg))
case Ops.SHRINK:
# TODO: why do i need resolve here?
if not all(resolve(0<=b) and resolve(b<=e) and resolve(e<=s) for s,(b,e) in zip(ps, self.arg)):
if len(ps) != len(self.arg) or not all(resolve(0<=b) and resolve(b<=e) and resolve(e<=s) for s,(b,e) in zip(ps, self.arg)):
raise RuntimeError(f"invalid shrink {self.arg} for {ps}")
return tuple(ssimplify(e-s) for s,e in self.arg)
case Ops.FLIP: return ps
case Ops.FLIP:
if len(ps) != len(self.arg) or not all(isinstance(x, bool) for x in self.arg): raise RuntimeError(f"bad flip on {ps}, {self.arg}")
return ps
# elementwise ops keep the shape the same
if self.op in GroupOp.Elementwise-{Ops.BITCAST}:
input_shapes = [x.shape for x in self.src]
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
return input_shapes[0]
# TODO: finish this and remove self.st.shape
assert self.st is not None, f"{self.op} doesn't have a shape"
return unwrap(self.st).shape