diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index aa333da5fb..c6220bb42d 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -225,16 +225,35 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @recursive_property def _shape(self) -> tuple[sint, ...]|None: - # some ops init the shape match self.op: - case Ops.CONST: return () + # some ops init the shape + case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE: return None + case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None case Ops.BUFFER: return (self.arg,) case Ops.BUFFER_VIEW: return (self.arg[0],) + case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]]) + case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) + + # passthrough ops + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD: return self.src[0]._shape + + # ops with custom handling + case Ops.STORE: + if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,) + if self.dtype is not dtypes.void: return self.src[0].src[0].shape + return None + case Ops.BITCAST: + # TODO: disallow shape changing bitcast + ps = self.src[0]._shape + if ps is None: return None + if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) + return ps # movement ops change the shape # NOTE: ssimplify is required because the shape needs to be canonical - if self.op in GroupOp.Movement: - ps = self.src[0].shape + if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}): + ps = self.src[0]._shape + if ps is None: raise RuntimeError(f"movement op {self.op} requires shape") match self.op: case Ops.RESHAPE: if prod(ps) != prod(self.arg): raise RuntimeError(f"bad reshape: {ps} -> {self.arg}") @@ -257,16 +276,26 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.FLIP: if len(ps) != len(self.arg) or not all(isinstance(x, bool) for x in self.arg): raise RuntimeError(f"bad flip on {ps}, {self.arg}") return ps + case Ops.MULTI: return tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps)) + case Ops.REDUCE_AXIS | Ops.WMMA: + axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] + if not isinstance(axis_arg, tuple) or not all(isinstance(x, int) for x in axis_arg): + raise RuntimeError(f"invalid type for axis: {axis_arg}") + return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) - # elementwise ops keep the shape the same - if self.op in GroupOp.Elementwise-{Ops.BITCAST}: + # elementwise ops keep the shape the same. all with shape must match + if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN}): input_shapes = [x._shape for x in self.src if x._shape is not None] if len(input_shapes) == 0: return None if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}") return input_shapes[0] + #raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}") # keep old behavior and get from st - if (st:=self.st) is None: return None + if (st:=self.st) is None: + #print(f"none on {self.op}") + return None + #print(f"proc on {self.op} {self.dtype} -> {st.shape}") return st.shape @property