assert bool dtype for valid [run_process_replay] (#5214)

* valid is always bool

* prevent NumNode to begin with

* part 2

* test: disable pattern matchers, asserts should pass

* test: store without cast

* test: if (0)

* cleanup time

* only pattern match bool literal

* better for upstream debug
This commit is contained in:
qazal
2024-06-29 21:20:32 +03:00
committed by GitHub
parent 3f4eeb8b54
commit f374fb77af
3 changed files with 19 additions and 12 deletions

View File

@@ -186,5 +186,12 @@ class TestUOpGraph(TestUOps):
self.assertEqual(len(uops.uops), 4)
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
def test_asserts_bad_gate(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1)
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
with self.assertRaises(AssertionError): uops.linearize()
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -141,7 +141,8 @@ class Linearizer(Kernel):
invalid_value = 0
acc_count = 0
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
this_const, idx = (invalid_value, NumNode(0)) if valid.max == 0 else (const, idx)
valid_uop = UOp.const(dtypes.bool, valid.b) if valid.min == valid.max else valid.render(render_ops, self.loop_uops)
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
if key not in self.load_cache:
if acc is not None:
@@ -150,14 +151,13 @@ class Linearizer(Kernel):
elif this_const is not None:
self.load_cache[key] = UOp.const(localtype, this_const)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(render_ops, self.loop_uops)
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_uop, self.load_cache[key], UOp.const(localtype, invalid_value))
elif isinstance(buf.dtype, ImageDType):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
if localtype == localtype.scalar():
@@ -173,7 +173,7 @@ class Linearizer(Kernel):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(render_ops, self.loop_uops)
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
@@ -213,7 +213,7 @@ class Linearizer(Kernel):
if self.late_gate is not None: valid *= self.late_gate
# TODO: let UPat check this once it's fast
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
elif valid.max == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
return stores
# render loop

View File

@@ -265,13 +265,13 @@ constant_folder = PatternMatcher([
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
# fold gated LOAD/STORE
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
(UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var")), lambda var: var),
(UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var"), UOp.var()), lambda var: var),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store),
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
# remove NOOPs from SINK
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)