diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 55ef0c2fd0..9c61746564 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -685,8 +685,9 @@ class Kernel: # the living definition of intermediate UOps def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None: - if not uop.has_st or uop in sts: return + if uop in sts: return # restore globals from the two stage reduce + # this is because this LOAD has an implicit movement op if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL: _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts) sts[uop] = sts[local_reduce] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 0ef45d9ce9..36d964e886 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -187,7 +187,7 @@ def elementwise_view_right(root:UOp) -> UOp|None: # push the swizzle from src to root output_swizzle = swizzles[0] new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) - ret = root.replace(src=tuple(x if not x.has_st else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) + ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) # NOTE: swizzle resolves once we hit STORE return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cfd3344073..7e87f72611 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -273,8 +273,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop shape stuff *** - @property - def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR} @functools.cached_property def st(self) -> ShapeTracker|None: # these ops define a ShapeTracker from the arg @@ -295,7 +293,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ShapeTracker.from_shape(shape) @functools.cached_property def full_shape(self) -> tuple[sint, ...]: - return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) + if self.op is Ops.VIEW: return self.shape + # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops + return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL,Ops.DEFINE_VAR,Ops.CONST}])) @property def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape @property @@ -1310,7 +1310,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda view_left = merge_views+PatternMatcher([ # VIEW before elementwise ops (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), - lambda e,v: e.replace(src=tuple(s if not s.has_st else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))), + lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))), # early merge VIEW buffer ops (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ])