mostly works

This commit is contained in:
George Hotz
2025-10-14 18:19:02 +08:00
parent 04ead92ebd
commit 0b69698ad4

View File

@@ -225,16 +225,35 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@recursive_property
def _shape(self) -> tuple[sint, ...]|None:
# some ops init the shape
match self.op:
case Ops.CONST: return ()
# some ops init the shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE: return None
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD: return self.src[0]._shape
# ops with custom handling
case Ops.STORE:
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
if self.dtype is not dtypes.void: return self.src[0].src[0].shape
return None
case Ops.BITCAST:
# TODO: disallow shape changing bitcast
ps = self.src[0]._shape
if ps is None: return None
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
return ps
# movement ops change the shape
# NOTE: ssimplify is required because the shape needs to be canonical
if self.op in GroupOp.Movement:
ps = self.src[0].shape
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
ps = self.src[0]._shape
if ps is None: raise RuntimeError(f"movement op {self.op} requires shape")
match self.op:
case Ops.RESHAPE:
if prod(ps) != prod(self.arg): raise RuntimeError(f"bad reshape: {ps} -> {self.arg}")
@@ -257,16 +276,26 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case Ops.FLIP:
if len(ps) != len(self.arg) or not all(isinstance(x, bool) for x in self.arg): raise RuntimeError(f"bad flip on {ps}, {self.arg}")
return ps
case Ops.MULTI: return tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps))
case Ops.REDUCE_AXIS | Ops.WMMA:
axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
if not isinstance(axis_arg, tuple) or not all(isinstance(x, int) for x in axis_arg):
raise RuntimeError(f"invalid type for axis: {axis_arg}")
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
# elementwise ops keep the shape the same
if self.op in GroupOp.Elementwise-{Ops.BITCAST}:
# elementwise ops keep the shape the same. all with shape must match
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN}):
input_shapes = [x._shape for x in self.src if x._shape is not None]
if len(input_shapes) == 0: return None
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
return input_shapes[0]
#raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")
# keep old behavior and get from st
if (st:=self.st) is None: return None
if (st:=self.st) is None:
#print(f"none on {self.op}")
return None
#print(f"proc on {self.op} {self.dtype} -> {st.shape}")
return st.shape
@property