diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 979ea691b8..0abb767ef5 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -29,8 +29,9 @@ class TestLinearizerDumb(unittest.TestCase): k.required_optimizations() for opt in opts: k.apply_opt(opt) prg = k.to_program() - prg.uops.print() + #prg.uops.print() print(prg.src) + #print("\n".join(prg.src.splitlines()[-4:])) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 9c2302dd88..f2cc040d86 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -1,4 +1,5 @@ from __future__ import annotations +from dataclasses import replace from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING import functools, itertools, heapq, math from tinygrad.dtype import dtypes, PtrDType, ImageDType @@ -473,12 +474,20 @@ class UOpGraph: if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: if_uop = UOp(UOps.IF, None, (gate, u.src[-1])) return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg) - if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg) + if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: + if u.op is UOps.STORE and replace_source[3] is gate: replace_source = replace_source[:3] + return UOp(u.op, u.dtype, replace_source, u.arg) + if u.op is UOps.STORE: + assert len(u.src) == 4 + assert gate is u.src[3] + if_uop = UOp(UOps.IF, None, (gate,)) + replace_source = u.src[:3]+(if_uop,) + return UOp(u.op, u.dtype, replace_source, u.arg) return u sink_srcs = list(self.sink.src) for i, s in enumerate(sink_srcs): if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_replace_gates(s, s.src[3])) != s: - sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg) + sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src, rw.arg) sink = UOp(UOps.SINK, None, tuple(sink_srcs)) # do graph rewrite diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index a670776f1c..3c2dc74ec4 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -43,7 +43,9 @@ class UOp: return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ self.arg.value, self.dtype, self.src) def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple - def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") + #def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") + def __repr__(self): + return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}" # *** uop syntactic sugar def ufix(self, x): return UOp.const(self.dtype, x) if not isinstance(x, UOp) else x def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,)) @@ -177,7 +179,7 @@ def type_verify(uops): if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype if uop is UOps.STORE: assert dtype is None, f"{uop} dtype must be None, got {dtype}" - if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" + #if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" if uop is UOps.ALU: if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index d87c9f48d4..3e90caf6f3 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -148,10 +148,9 @@ class PTXRenderer(Renderer): assert src[1].op is UOps.CONST, f"store isn't const {u}" mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global' if src[2].dtype.count > 1: - kk((f"@{r[src[3]]} " if len(src)>3 else "") + \ - f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};") + kk(f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};") else: - kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg)) + kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=None, ss=mem_type, offset=src[1].arg)) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f8ca7f501d..04a1e05a3c 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -123,7 +123,7 @@ class CStyleLanguage(Renderer): elif uop is UOps.STORE: assert src[0].dtype is not None and src[2].dtype is not None rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) - kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store) + kk(rendered_store) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 15cb65203c..05f290a3c5 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -100,11 +100,7 @@ class LLVMRenderer(Renderer): uop,dtype,src,args = u.op,u.dtype,u.src,u.arg if uop is UOps.STORE: element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype) - if len(src) > 3: - with bb[-1].if_then(lvars[src[3]]): - bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) - else: - bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) + bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) elif uop is UOps.ENDRANGE: loop_entry_bb, phis = loop_blocks.pop() idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))