diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index cc3b85dfd4..b4357b9cf3 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -777,10 +777,8 @@ class Kernel: # the living definition of intermediate UOps def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None: - if uop in sts: return + if not uop.has_st or uop in sts: return op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg - # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape - if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return # restore globals from the two stage reduce if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL: _assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts) @@ -794,9 +792,9 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}" # movementops are pushed to the edges with SHAPETRACKER # elementwise inherits shape - st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]] - for x in (src[1:] if op in BUFFER_UOPS else src): - if sts[x].shape != st.shape: + st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]] + for x in src: + if x.has_st and sts[x].shape != st.shape: if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}") raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}") sts[uop] = st diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ddc4442125..f013fa7f27 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -143,14 +143,16 @@ class UOp(MathTrait): self.op, self.dtype, self.src, self.arg = op, dtype, src, arg def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None): return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg) + @property + def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL} @functools.cached_property def st(self) -> Optional[ShapeTracker]: - from tinygrad.shape.shapetracker import ShapeTracker - if self.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return None + if not self.has_st: return None if self.op in BUFFER_UOPS: return self.st_arg if self.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: return self.arg src_sts = [x.st for x in self.src if x.st is not None] assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" + from tinygrad.shape.shapetracker import ShapeTracker return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0] @functools.cached_property def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]: