diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 95cbfd8ac7..ac9c1eb4d2 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -186,5 +186,12 @@ class TestUOpGraph(TestUOps): self.assertEqual(len(uops.uops), 4) self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) + def test_asserts_bad_gate(self): + glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True)) + idx = UOp.const(dtypes.int, 0) + bad_gate = UOp.const(dtypes.int, 1) + uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) + with self.assertRaises(AssertionError): uops.linearize() + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 22bbdb68e1..c92c302ea6 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -141,7 +141,8 @@ class Linearizer(Kernel): invalid_value = 0 acc_count = 0 for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)): - this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid) + this_const, idx = (invalid_value, NumNode(0)) if valid.max == 0 else (const, idx) + valid_uop = UOp.const(dtypes.bool, valid.b) if valid.min == valid.max else valid.render(render_ops, self.loop_uops) key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 if key not in self.load_cache: if acc is not None: @@ -150,14 +151,13 @@ class Linearizer(Kernel): elif this_const is not None: self.load_cache[key] = UOp.const(localtype, this_const) if valid.min == 0 and valid.max == 1: - valid_rendered = valid.render(render_ops, self.loop_uops) - self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value)) + self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_uop, self.load_cache[key], UOp.const(localtype, invalid_value)) elif isinstance(buf.dtype, ImageDType): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx)) - valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple() + valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple() self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) if localtype == localtype.scalar(): @@ -173,7 +173,7 @@ class Linearizer(Kernel): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" rendered_idx = idx.render(render_ops, self.loop_uops) - valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple() + valid_tuple = (valid_uop, UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple() self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) return ret @@ -213,7 +213,7 @@ class Linearizer(Kernel): if self.late_gate is not None: valid *= self.late_gate # TODO: let UPat check this once it's fast if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var))) - else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops)))) + elif valid.max == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops)))) return stores # render loop diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 07278f47ef..6155f24e30 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -265,13 +265,13 @@ constant_folder = PatternMatcher([ # cast NOOP (NOTE: it's str to deal with PtrDType) (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), # fold gated LOAD/STORE - (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), - (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")), + (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), + (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")), lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)), - (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var), - (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var), - (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store), - (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)), + (UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var")), lambda var: var), + (UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var"), UOp.var()), lambda var: var), + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store), + (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)), # remove NOOPs from SINK (UPat(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)