UOp st prereqs for valid [run_process_replay] (#6618)

This commit is contained in:
qazal
2024-09-20 15:55:35 +08:00
committed by GitHub
parent 74f8f86631
commit 2dfb1e022c
2 changed files with 8 additions and 8 deletions

View File

@@ -777,10 +777,8 @@ class Kernel:
# the living definition of intermediate UOps
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if uop in sts: return
if not uop.has_st or uop in sts: return
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
@@ -794,9 +792,9 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]]
for x in (src[1:] if op in BUFFER_UOPS else src):
if sts[x].shape != st.shape:
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]
for x in src:
if x.has_st and sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[uop] = st

View File

@@ -143,14 +143,16 @@ class UOp(MathTrait):
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
@property
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
from tinygrad.shape.shapetracker import ShapeTracker
if self.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return None
if not self.has_st: return None
if self.op in BUFFER_UOPS: return self.st_arg
if self.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: return self.arg
src_sts = [x.st for x in self.src if x.st is not None]
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.arg[1])) if self.op is UOps.REDUCE_AXIS else src_sts[0]
@functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]: