mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
assert bool dtype for valid [run_process_replay] (#5214)
* valid is always bool * prevent NumNode to begin with * part 2 * test: disable pattern matchers, asserts should pass * test: store without cast * test: if (0) * cleanup time * only pattern match bool literal * better for upstream debug
This commit is contained in:
@@ -186,5 +186,12 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(len(uops.uops), 4)
|
||||
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||
|
||||
def test_asserts_bad_gate(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
bad_gate = UOp.const(dtypes.int, 1)
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
with self.assertRaises(AssertionError): uops.linearize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -141,7 +141,8 @@ class Linearizer(Kernel):
|
||||
invalid_value = 0
|
||||
acc_count = 0
|
||||
for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)):
|
||||
this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid)
|
||||
this_const, idx = (invalid_value, NumNode(0)) if valid.max == 0 else (const, idx)
|
||||
valid_uop = UOp.const(dtypes.bool, valid.b) if valid.min == valid.max else valid.render(render_ops, self.loop_uops)
|
||||
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
@@ -150,14 +151,13 @@ class Linearizer(Kernel):
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = UOp.const(localtype, this_const)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(render_ops, self.loop_uops)
|
||||
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
|
||||
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_uop, self.load_cache[key], UOp.const(localtype, invalid_value))
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
if localtype == localtype.scalar():
|
||||
@@ -173,7 +173,7 @@ class Linearizer(Kernel):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
rendered_idx = idx.render(render_ops, self.loop_uops)
|
||||
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
||||
valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
@@ -213,7 +213,7 @@ class Linearizer(Kernel):
|
||||
if self.late_gate is not None: valid *= self.late_gate
|
||||
# TODO: let UPat check this once it's fast
|
||||
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
|
||||
elif valid.max == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
|
||||
return stores
|
||||
|
||||
# render loop
|
||||
|
||||
@@ -265,13 +265,13 @@ constant_folder = PatternMatcher([
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
# fold gated LOAD/STORE
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
|
||||
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
||||
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
|
||||
(UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
|
||||
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
|
||||
(UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var")), lambda var: var),
|
||||
(UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var"), UOp.var()), lambda var: var),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store),
|
||||
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
||||
# remove NOOPs from SINK
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
|
||||
|
||||
Reference in New Issue
Block a user