fix tests

This commit is contained in:
George Hotz
2025-10-14 19:22:16 +08:00
parent 61855c24a8
commit 7a2e206a0d

View File

@@ -228,7 +228,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | \
Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL:
Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
return None
# some ops init the shape
@@ -247,13 +247,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
if self.dtype is not dtypes.void: return self.src[0].src[0].shape
return None
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
# TODO: disallow shape changing bitcast
ps = self.src[0]._shape
if ps is None: return None
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
return ps
# TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score
case Ops.RESHAPE:
if self.src[0]._shape is None: return tuple(ssimplify(s) for s in self.arg)
# movement ops change the shape. this is the logic from the old ShapeTracker
# NOTE: ssimplify is required because the shape needs to be canonical
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):