mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix tests
This commit is contained in:
@@ -228,7 +228,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
|
||||
return None
|
||||
|
||||
# some ops init the shape
|
||||
@@ -247,13 +247,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
if self.dtype is not dtypes.void: return self.src[0].src[0].shape
|
||||
return None
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
# TODO: disallow shape changing bitcast
|
||||
ps = self.src[0]._shape
|
||||
if ps is None: return None
|
||||
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
|
||||
return ps
|
||||
|
||||
# TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score
|
||||
case Ops.RESHAPE:
|
||||
if self.src[0]._shape is None: return tuple(ssimplify(s) for s in self.arg)
|
||||
|
||||
# movement ops change the shape. this is the logic from the old ShapeTracker
|
||||
# NOTE: ssimplify is required because the shape needs to be canonical
|
||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
||||
|
||||
Reference in New Issue
Block a user