diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 15d8af7374..72aeadbdd0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -319,9 +319,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def full_shape(self) -> tuple[sint, ...]: if self.op is Ops.VIEW: return self.shape - # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this - parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)] - # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops + # NOTE: if a parent doesn't have st its full_shape is empty + parent_shapes = [x.full_shape for x in self.src] return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()])) @property def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape @@ -1031,4 +1030,4 @@ merge_views = PatternMatcher([ lambda v: v.const_like(0) if (mask:=v.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None), # movement ops apply a new view on the base (UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)), -]) \ No newline at end of file +])