diff --git a/test/test_uops.py b/test/test_uops.py index cad3f00660..15e03e26d7 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -257,5 +257,25 @@ class TestAssembly(unittest.TestCase): self.assertEqual(u10.vin[1].uop, UOps.CONST) self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize) + def test_gated_load(self): + from tinygrad.renderer.assembly import optimize_gated_loads + uops = UOpGraph() + u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True)) + u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9)) + u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42) + u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL) + u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0) + u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1) + u7 = uops.add(UOps.CONST, dtypes.bool, tuple(), arg=1) + u8 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD) + u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u8, u7, u6)) + optimize_gated_loads(uops) + if_op = next(filter(lambda x: x.uop is UOps.IF, uops.uops), None) + self.assertNotEqual(if_op, None) + self.assertNotEqual(next(filter(lambda x: x.uop is UOps.ENDIF, uops.uops), None), None) + for uu in [u2, u3, u4, u5, u6, u8, u9]: + self.assertLess(uops.uops.index(if_op), uops.uops.index(uu)) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 4617946718..78da0abaf2 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -34,6 +34,17 @@ def ptr_ar(root, uops): fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root)) root.vin = (fptr, zero) + root.vin[2:] +def optimize_gated_loads(uops: UOpGraph): + def successors(uop): return list(filter(lambda u: uop in u.vin, uops.uops)) + for gl in list(filter(lambda u:u.uop is UOps.LOAD and len(u.vin)>3, uops.uops)): + uops.uops.insert(uops.uops.index(gl), gate:=UOp(UOps.IF, None, (gl.vin[2],))) + uops.uops.insert(uops.uops.index(gl)+1, end:=UOp(UOps.ENDIF, None, (gate,) + (gl, gl.vin[3]))) + for u in reversed(uops.uops.copy()[:uops.uops.index(gate)]): + if (u.uop not in [UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR, UOps.DEFINE_LOCAL, UOps.PHI, UOps.STORE, UOps.ENDIF, UOps.ENDLOOP] and + all(uops.uops.index(s)>uops.uops.index(gate) and uops.uops.index(s)<=uops.uops.index(end) for s in successors(u))): + uops.uops.insert(uops.uops.index(gate), uops.uops.pop(uops.uops.index(u))) + gl.vin = gl.vin[:2] + class PTXRenderer(Renderer): device = "CUDA" suffix = "PTX" @@ -90,13 +101,12 @@ class PTXRenderer(Renderer): def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"] - def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"] + def render_bra(self, b1, pred=None, neg=False) -> List[str]: return [f"@{'!' if neg else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"] def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype] - def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: + def render_load(self, loc, dest, dtype, ss="", offset=0) -> List[str]: assert dtype is not dtypes.bool - if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"] return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"] def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: @@ -155,6 +165,7 @@ class PTXRenderer(Renderer): for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops) uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE})) uops.optimize_loops() + optimize_gated_loads(uops) def kk(*s: str): kernel.append("\n".join(s)) @@ -192,14 +203,20 @@ class PTXRenderer(Renderer): uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg if uop is UOps.IF: assert vin[0].dtype is not None - kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:") + kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), neg=True)) elif uop is UOps.BARRIER and self.barrier: kk(self.barrier) elif uop is UOps.ENDLOOP: kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]), self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int])) - kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:") + kk(*self.render_bra(r_label[vin[0]], pred)) elif uop is UOps.ENDIF: + kk(f"@{_cast(r[vin[0].vin[0]], dtypes.bool, vin[0].vin[0].dtype, u=u, pred=True)} bra {r_label[vin[0]]}_true;") kk(f"{r_label[vin[0]]}:") + if len(vin) > 1 and vin[1].dtype.count > 1: + kk(*[f"mov.b{self.types[vin[1].dtype.scalar()][1:]} {dd}, {r[vin[2]][i]};" for i, dd in enumerate(r[vin[1]])]) + elif len(vin) > 1: + kk(*[f"mov.b{self.types[vin[1].dtype][1:]} {r[vin[1]]}, {r[vin[2]]};" ]) + kk(f"{r_label[vin[0]]}_true:") elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None if vin[2].dtype.count > 1: @@ -235,13 +252,9 @@ class PTXRenderer(Renderer): assert vin[1].dtype is not None if dtype.count > 1: r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - if(len(vin)>3): - for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};") - kk((f"@{r[vin[2]]}"if len(vin) > 3 else "") - + f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];") + kk(f"ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];") else: - kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None, - alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg)) + kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, ss=u.arg, offset=vin[1].arg)) elif uop is UOps.PHI: kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};") r[u] = r[vin[0]]