add UOps.VALID (#6387)

* uops valid

* broke full_shape

* fixup that st (hardcoded asts still red)

* fixup DEFINE_VAR

debug

more debug

* start moving stuff to ast_const

* move test_linearizer

* move test_linearizer_failures to ast_const

* fixup test_schedule

* small diff change

* regenerate dataset

* fixup test_multitensor

* regen dataset try 2

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2024-09-09 16:58:43 +08:00
committed by GitHub
parent e1d61b048b
commit 8186e4e7d6
9 changed files with 23 additions and 20 deletions

Binary file not shown.

View File

@@ -7,7 +7,7 @@ from test.external.process_replay.process_replay import _run_differ
PAGE_SIZE = 100
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
LOGOPS = os.getenv("LOGOPS", "/tmp/ops")
def extract_ast(offset:int):
logops = open(LOGOPS, "a")

View File

@@ -64,8 +64,8 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype))
st_src = st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
T = TypeVar("T")
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:

View File

@@ -572,13 +572,13 @@ class TestMultiTensor(unittest.TestCase):
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0]
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 1
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.MUL
assert ast.src[2].src[0].op is UOps.CONST and ast.src[2].src[0].arg == 2
assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is UOps.LOAD
t = t + t.full_like(3)
for si in t.schedule():
@@ -586,7 +586,7 @@ class TestMultiTensor(unittest.TestCase):
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 3
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3
def test_shard_memory(self):
devices = (d0, d1, d2, d3)

View File

@@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase):
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
if __name__ == '__main__':

View File

@@ -639,7 +639,7 @@ class Kernel:
# for locals, we use the ShapeTracker that's in the srcs
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
if op.op is UOps.CONST: return replace(op, src=(st_uop,))
if op.op is UOps.VALID: return replace(op, src=(st_uop,))
if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is UOps.REDUCE_AXIS:
@@ -777,8 +777,8 @@ class Kernel:
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if uop in sts: return
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# NOTE: UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL, UOps.CONST and UOps.DEFINE_VAR don't have shape
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}: return
# restore globals from the two stage reduce
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
@@ -791,8 +791,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]]
for x in (src[1:] if op in BUFFER_UOPS else src):
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]
for x in (src[0:1] if len(src) and src[0].op is UOps.VALID else src[1:] if op in BUFFER_UOPS else src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")

View File

@@ -91,9 +91,9 @@ class IndependentLowerer:
def _to_uop(self, x:UOp) -> UOp:
if x.op in BUFFER_UOPS:
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
if x.op is UOps.VALID: return valid
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0))
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()

View File

@@ -66,7 +66,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
val, var_val = val.unbind()
var_vals[val] = var_val
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val)
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
# otherwise, it's a load and we add it to the inputs
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
@@ -136,8 +136,9 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]:
if (st:=uop_sts.get(uop)): return st
if uop.op in BUFFER_UOPS: return uop.st_arg
src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None]
if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None
src = [x for x in uop.src if x.op not in {UOps.CONST, UOps.DEFINE_VAR}]
src_sts = [xst for x in src if (xst:=get_output_st(x, uop_sts)) is not None]
if len(src_sts) != len(src) or not all_same([x.shape for x in src_sts]): return None
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
return st

View File

@@ -319,8 +319,9 @@ class UOps(Enum):
# ops that are not graph nodes
ENDRANGE = auto()
ENDIF = auto()
VALID = auto()
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
@@ -353,7 +354,7 @@ class UOp(MathTrait):
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
# *** uop syntactic sugar
@property
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
def st_loc(self) -> int: return 0 if self.op is UOps.VALID else 1
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
@@ -390,11 +391,12 @@ class UOp(MathTrait):
def full_shape(self) -> Tuple[sint, ...]:
if self.op is UOps.SHAPETRACKER: return self.arg.shape
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL, \
UOps.CONST, UOps.DEFINE_VAR}]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr)
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1].arg, x.arg[2].arg) for x in self.vars()]), key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg