more mops

This commit is contained in:
George Hotz
2025-10-14 17:20:04 +08:00
parent 59512a49fa
commit 8721b6884c

View File

@@ -229,13 +229,28 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case Ops.CONST: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
#case Ops.RESHAPE:
#if prod(self.src[0].shape) != prod(self.arg): raise RuntimeError(f"bad reshape: {self.src[0].shape} -> {self.arg}")
#return self.arg
case Ops.PERMUTE:
if sorted(self.arg) != list(range(len(bs:=self.src[0].shape))): raise RuntimeError(f"invalid permutation {self.arg} of len {len(bs)}")
return tuple(bs[i] for i in self.arg)
# TODO: finish this and remove self.st.shape
if self.op in GroupOp.Movement:
ps = self.src[0].shape
match self.op:
#case Ops.RESHAPE:
#if prod(ps) != prod(self.arg): raise RuntimeError(f"bad reshape: {ps} -> {self.arg}")
#return self.arg
case Ops.EXPAND:
if len(ps) != len(self.arg) or not all(s==ns or s==1 for s,ns in zip(ps, self.arg)): raise RuntimeError(f"bad expand: {ps} -> {self.arg}")
return self.arg
case Ops.PERMUTE:
if sorted(self.arg) != list(range(len(ps))): raise RuntimeError(f"invalid permutation {self.arg} of len {len(ps)}")
return tuple(ps[i] for i in self.arg)
case Ops.PAD:
if not all(b>=0 and e>=0 for b,e in self.arg): raise RuntimeError(f"invalid pad {self.arg}")
return tuple(ssimplify(s+b+e) for s,(b,e) in zip(ps, self.arg))
case Ops.SHRINK:
# TODO: why do i need resolve here?
if not all(resolve(0<=b) and resolve(b<=e) and resolve(e<=s) for s,(b,e) in zip(ps, self.arg)):
raise RuntimeError(f"invalid shrink {self.arg} for {ps}")
return tuple(ssimplify(e-s) for s,e in self.arg)
# TODO: finish this and remove self.st.shape
assert self.st is not None, f"{self.op} doesn't have a shape"
return unwrap(self.st).shape
@property