mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
mostly works
This commit is contained in:
@@ -225,16 +225,35 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@recursive_property
|
||||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
# some ops init the shape
|
||||
match self.op:
|
||||
case Ops.CONST: return ()
|
||||
# some ops init the shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE: return None
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD: return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
case Ops.STORE:
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
if self.dtype is not dtypes.void: return self.src[0].src[0].shape
|
||||
return None
|
||||
case Ops.BITCAST:
|
||||
# TODO: disallow shape changing bitcast
|
||||
ps = self.src[0]._shape
|
||||
if ps is None: return None
|
||||
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
|
||||
return ps
|
||||
|
||||
# movement ops change the shape
|
||||
# NOTE: ssimplify is required because the shape needs to be canonical
|
||||
if self.op in GroupOp.Movement:
|
||||
ps = self.src[0].shape
|
||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
||||
ps = self.src[0]._shape
|
||||
if ps is None: raise RuntimeError(f"movement op {self.op} requires shape")
|
||||
match self.op:
|
||||
case Ops.RESHAPE:
|
||||
if prod(ps) != prod(self.arg): raise RuntimeError(f"bad reshape: {ps} -> {self.arg}")
|
||||
@@ -257,16 +276,26 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
case Ops.FLIP:
|
||||
if len(ps) != len(self.arg) or not all(isinstance(x, bool) for x in self.arg): raise RuntimeError(f"bad flip on {ps}, {self.arg}")
|
||||
return ps
|
||||
case Ops.MULTI: return tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps))
|
||||
case Ops.REDUCE_AXIS | Ops.WMMA:
|
||||
axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
if not isinstance(axis_arg, tuple) or not all(isinstance(x, int) for x in axis_arg):
|
||||
raise RuntimeError(f"invalid type for axis: {axis_arg}")
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
# elementwise ops keep the shape the same
|
||||
if self.op in GroupOp.Elementwise-{Ops.BITCAST}:
|
||||
# elementwise ops keep the shape the same. all with shape must match
|
||||
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
return input_shapes[0]
|
||||
|
||||
#raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")
|
||||
# keep old behavior and get from st
|
||||
if (st:=self.st) is None: return None
|
||||
if (st:=self.st) is None:
|
||||
#print(f"none on {self.op}")
|
||||
return None
|
||||
#print(f"proc on {self.op} {self.dtype} -> {st.shape}")
|
||||
return st.shape
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user