Optimize PTX gated loads index calculation (#4304)

* WIP but working

* Cleanup

* Remove float4 pred and alt

* Cleanup

* this is somehow slowin it down

* Simplify

* add define var to ignore when optimizing gates

* Update assembly.py

* Test for optimizing gated loads

* Cleanup

* Fix NEG needed before if

* Remove unused parameters

* Update assembly.py

* Fix for cachable gone

---------

Co-authored-by: oz <oz@oz-MS-7B86.NAT.gliwice.vectranet.pl>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Szymon Ożóg
2024-05-13 19:14:01 +02:00
committed by GitHub
parent c67b70ca67
commit d97d5a7689
2 changed files with 44 additions and 11 deletions

View File

@@ -257,5 +257,25 @@ class TestAssembly(unittest.TestCase):
self.assertEqual(u10.vin[1].uop, UOps.CONST)
self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize)
def test_gated_load(self):
from tinygrad.renderer.assembly import optimize_gated_loads
uops = UOpGraph()
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)
u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0)
u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1)
u7 = uops.add(UOps.CONST, dtypes.bool, tuple(), arg=1)
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD)
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u8, u7, u6))
optimize_gated_loads(uops)
if_op = next(filter(lambda x: x.uop is UOps.IF, uops.uops), None)
self.assertNotEqual(if_op, None)
self.assertNotEqual(next(filter(lambda x: x.uop is UOps.ENDIF, uops.uops), None), None)
for uu in [u2, u3, u4, u5, u6, u8, u9]:
self.assertLess(uops.uops.index(if_op), uops.uops.index(uu))
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -34,6 +34,17 @@ def ptr_ar(root, uops):
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
def optimize_gated_loads(uops: UOpGraph):
def successors(uop): return list(filter(lambda u: uop in u.vin, uops.uops))
for gl in list(filter(lambda u:u.uop is UOps.LOAD and len(u.vin)>3, uops.uops)):
uops.uops.insert(uops.uops.index(gl), gate:=UOp(UOps.IF, None, (gl.vin[2],)))
uops.uops.insert(uops.uops.index(gl)+1, end:=UOp(UOps.ENDIF, None, (gate,) + (gl, gl.vin[3])))
for u in reversed(uops.uops.copy()[:uops.uops.index(gate)]):
if (u.uop not in [UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR, UOps.DEFINE_LOCAL, UOps.PHI, UOps.STORE, UOps.ENDIF, UOps.ENDLOOP] and
all(uops.uops.index(s)>uops.uops.index(gate) and uops.uops.index(s)<=uops.uops.index(end) for s in successors(u))):
uops.uops.insert(uops.uops.index(gate), uops.uops.pop(uops.uops.index(u)))
gl.vin = gl.vin[:2]
class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
@@ -90,13 +101,12 @@ class PTXRenderer(Renderer):
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
def render_bra(self, b1, pred=None, neg=False) -> List[str]: return [f"@{'!' if neg else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"]
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
def render_load(self, loc, dest, dtype, ss="", offset=0) -> List[str]:
assert dtype is not dtypes.bool
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
@@ -155,6 +165,7 @@ class PTXRenderer(Renderer):
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
optimize_gated_loads(uops)
def kk(*s: str): kernel.append("\n".join(s))
@@ -192,14 +203,20 @@ class PTXRenderer(Renderer):
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), neg=True))
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDLOOP:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
kk(*self.render_bra(r_label[vin[0]], pred))
elif uop is UOps.ENDIF:
kk(f"@{_cast(r[vin[0].vin[0]], dtypes.bool, vin[0].vin[0].dtype, u=u, pred=True)} bra {r_label[vin[0]]}_true;")
kk(f"{r_label[vin[0]]}:")
if len(vin) > 1 and vin[1].dtype.count > 1:
kk(*[f"mov.b{self.types[vin[1].dtype.scalar()][1:]} {dd}, {r[vin[2]][i]};" for i, dd in enumerate(r[vin[1]])])
elif len(vin) > 1:
kk(*[f"mov.b{self.types[vin[1].dtype][1:]} {r[vin[1]]}, {r[vin[2]]};" ])
kk(f"{r_label[vin[0]]}_true:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
if vin[2].dtype.count > 1:
@@ -235,13 +252,9 @@ class PTXRenderer(Renderer):
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
kk(f"ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]