mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Optimize PTX gated loads index calculation (#4304)
* WIP but working * Cleanup * Remove float4 pred and alt * Cleanup * this is somehow slowin it down * Simplify * add define var to ignore when optimizing gates * Update assembly.py * Test for optimizing gated loads * Cleanup * Fix NEG needed before if * Remove unused parameters * Update assembly.py * Fix for cachable gone --------- Co-authored-by: oz <oz@oz-MS-7B86.NAT.gliwice.vectranet.pl> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -257,5 +257,25 @@ class TestAssembly(unittest.TestCase):
|
|||||||
self.assertEqual(u10.vin[1].uop, UOps.CONST)
|
self.assertEqual(u10.vin[1].uop, UOps.CONST)
|
||||||
self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize)
|
self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize)
|
||||||
|
|
||||||
|
def test_gated_load(self):
|
||||||
|
from tinygrad.renderer.assembly import optimize_gated_loads
|
||||||
|
uops = UOpGraph()
|
||||||
|
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
|
||||||
|
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
|
||||||
|
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
|
||||||
|
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)
|
||||||
|
u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0)
|
||||||
|
u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1)
|
||||||
|
u7 = uops.add(UOps.CONST, dtypes.bool, tuple(), arg=1)
|
||||||
|
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD)
|
||||||
|
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u8, u7, u6))
|
||||||
|
optimize_gated_loads(uops)
|
||||||
|
if_op = next(filter(lambda x: x.uop is UOps.IF, uops.uops), None)
|
||||||
|
self.assertNotEqual(if_op, None)
|
||||||
|
self.assertNotEqual(next(filter(lambda x: x.uop is UOps.ENDIF, uops.uops), None), None)
|
||||||
|
for uu in [u2, u3, u4, u5, u6, u8, u9]:
|
||||||
|
self.assertLess(uops.uops.index(if_op), uops.uops.index(uu))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
@@ -34,6 +34,17 @@ def ptr_ar(root, uops):
|
|||||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||||
root.vin = (fptr, zero) + root.vin[2:]
|
root.vin = (fptr, zero) + root.vin[2:]
|
||||||
|
|
||||||
|
def optimize_gated_loads(uops: UOpGraph):
|
||||||
|
def successors(uop): return list(filter(lambda u: uop in u.vin, uops.uops))
|
||||||
|
for gl in list(filter(lambda u:u.uop is UOps.LOAD and len(u.vin)>3, uops.uops)):
|
||||||
|
uops.uops.insert(uops.uops.index(gl), gate:=UOp(UOps.IF, None, (gl.vin[2],)))
|
||||||
|
uops.uops.insert(uops.uops.index(gl)+1, end:=UOp(UOps.ENDIF, None, (gate,) + (gl, gl.vin[3])))
|
||||||
|
for u in reversed(uops.uops.copy()[:uops.uops.index(gate)]):
|
||||||
|
if (u.uop not in [UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR, UOps.DEFINE_LOCAL, UOps.PHI, UOps.STORE, UOps.ENDIF, UOps.ENDLOOP] and
|
||||||
|
all(uops.uops.index(s)>uops.uops.index(gate) and uops.uops.index(s)<=uops.uops.index(end) for s in successors(u))):
|
||||||
|
uops.uops.insert(uops.uops.index(gate), uops.uops.pop(uops.uops.index(u)))
|
||||||
|
gl.vin = gl.vin[:2]
|
||||||
|
|
||||||
class PTXRenderer(Renderer):
|
class PTXRenderer(Renderer):
|
||||||
device = "CUDA"
|
device = "CUDA"
|
||||||
suffix = "PTX"
|
suffix = "PTX"
|
||||||
@@ -90,13 +101,12 @@ class PTXRenderer(Renderer):
|
|||||||
|
|
||||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
||||||
|
|
||||||
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
|
def render_bra(self, b1, pred=None, neg=False) -> List[str]: return [f"@{'!' if neg else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"]
|
||||||
|
|
||||||
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
||||||
|
|
||||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
def render_load(self, loc, dest, dtype, ss="", offset=0) -> List[str]:
|
||||||
assert dtype is not dtypes.bool
|
assert dtype is not dtypes.bool
|
||||||
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
|
||||||
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||||
|
|
||||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
||||||
@@ -155,6 +165,7 @@ class PTXRenderer(Renderer):
|
|||||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||||
uops.optimize_loops()
|
uops.optimize_loops()
|
||||||
|
optimize_gated_loads(uops)
|
||||||
|
|
||||||
def kk(*s: str): kernel.append("\n".join(s))
|
def kk(*s: str): kernel.append("\n".join(s))
|
||||||
|
|
||||||
@@ -192,14 +203,20 @@ class PTXRenderer(Renderer):
|
|||||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||||
if uop is UOps.IF:
|
if uop is UOps.IF:
|
||||||
assert vin[0].dtype is not None
|
assert vin[0].dtype is not None
|
||||||
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), neg=True))
|
||||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||||
elif uop is UOps.ENDLOOP:
|
elif uop is UOps.ENDLOOP:
|
||||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||||
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
kk(*self.render_bra(r_label[vin[0]], pred))
|
||||||
elif uop is UOps.ENDIF:
|
elif uop is UOps.ENDIF:
|
||||||
|
kk(f"@{_cast(r[vin[0].vin[0]], dtypes.bool, vin[0].vin[0].dtype, u=u, pred=True)} bra {r_label[vin[0]]}_true;")
|
||||||
kk(f"{r_label[vin[0]]}:")
|
kk(f"{r_label[vin[0]]}:")
|
||||||
|
if len(vin) > 1 and vin[1].dtype.count > 1:
|
||||||
|
kk(*[f"mov.b{self.types[vin[1].dtype.scalar()][1:]} {dd}, {r[vin[2]][i]};" for i, dd in enumerate(r[vin[1]])])
|
||||||
|
elif len(vin) > 1:
|
||||||
|
kk(*[f"mov.b{self.types[vin[1].dtype][1:]} {r[vin[1]]}, {r[vin[2]]};" ])
|
||||||
|
kk(f"{r_label[vin[0]]}_true:")
|
||||||
elif uop is UOps.STORE:
|
elif uop is UOps.STORE:
|
||||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||||
if vin[2].dtype.count > 1:
|
if vin[2].dtype.count > 1:
|
||||||
@@ -235,13 +252,9 @@ class PTXRenderer(Renderer):
|
|||||||
assert vin[1].dtype is not None
|
assert vin[1].dtype is not None
|
||||||
if dtype.count > 1:
|
if dtype.count > 1:
|
||||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||||
if(len(vin)>3):
|
kk(f"ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||||
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
|
||||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
|
||||||
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
|
||||||
else:
|
else:
|
||||||
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, ss=u.arg, offset=vin[1].arg))
|
||||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
|
||||||
elif uop is UOps.PHI:
|
elif uop is UOps.PHI:
|
||||||
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||||
r[u] = r[vin[0]]
|
r[u] = r[vin[0]]
|
||||||
|
|||||||
Reference in New Issue
Block a user