work on shape property

This commit is contained in:
George Hotz
2025-10-14 16:50:43 +08:00
parent fb61f3519f
commit a73b59caa2

View File

@@ -225,6 +225,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def shape(self) -> tuple[sint, ...]:
match self.op:
case Ops.CONST: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.RESHAPE:
if prod(self.src[0].shape) != prod(self.arg): raise RuntimeError(f"bad reshape: {self.src[0].shape} -> {self.arg}")
return self.arg
case Ops.PERMUTE:
if sorted(self.arg) != list(range(len(bs:=self.src[0].shape))): raise RuntimeError(f"invalid permutation {self.arg} of len {len(bs)}")
return tuple(bs[i] for i in self.arg)
# TODO: finish this and remove self.st.shape
assert self.st is not None, f"{self.op} doesn't have a shape"
return unwrap(self.st).shape
@property