mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
Optimize PTX gated loads index calculation (#4304)
* WIP but working * Cleanup * Remove float4 pred and alt * Cleanup * this is somehow slowin it down * Simplify * add define var to ignore when optimizing gates * Update assembly.py * Test for optimizing gated loads * Cleanup * Fix NEG needed before if * Remove unused parameters * Update assembly.py * Fix for cachable gone --------- Co-authored-by: oz <oz@oz-MS-7B86.NAT.gliwice.vectranet.pl> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -257,5 +257,25 @@ class TestAssembly(unittest.TestCase):
|
||||
self.assertEqual(u10.vin[1].uop, UOps.CONST)
|
||||
self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize)
|
||||
|
||||
def test_gated_load(self):
|
||||
from tinygrad.renderer.assembly import optimize_gated_loads
|
||||
uops = UOpGraph()
|
||||
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
|
||||
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
|
||||
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
|
||||
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)
|
||||
u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0)
|
||||
u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1)
|
||||
u7 = uops.add(UOps.CONST, dtypes.bool, tuple(), arg=1)
|
||||
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD)
|
||||
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u8, u7, u6))
|
||||
optimize_gated_loads(uops)
|
||||
if_op = next(filter(lambda x: x.uop is UOps.IF, uops.uops), None)
|
||||
self.assertNotEqual(if_op, None)
|
||||
self.assertNotEqual(next(filter(lambda x: x.uop is UOps.ENDIF, uops.uops), None), None)
|
||||
for uu in [u2, u3, u4, u5, u6, u8, u9]:
|
||||
self.assertLess(uops.uops.index(if_op), uops.uops.index(uu))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -34,6 +34,17 @@ def ptr_ar(root, uops):
|
||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
def optimize_gated_loads(uops: UOpGraph):
|
||||
def successors(uop): return list(filter(lambda u: uop in u.vin, uops.uops))
|
||||
for gl in list(filter(lambda u:u.uop is UOps.LOAD and len(u.vin)>3, uops.uops)):
|
||||
uops.uops.insert(uops.uops.index(gl), gate:=UOp(UOps.IF, None, (gl.vin[2],)))
|
||||
uops.uops.insert(uops.uops.index(gl)+1, end:=UOp(UOps.ENDIF, None, (gate,) + (gl, gl.vin[3])))
|
||||
for u in reversed(uops.uops.copy()[:uops.uops.index(gate)]):
|
||||
if (u.uop not in [UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR, UOps.DEFINE_LOCAL, UOps.PHI, UOps.STORE, UOps.ENDIF, UOps.ENDLOOP] and
|
||||
all(uops.uops.index(s)>uops.uops.index(gate) and uops.uops.index(s)<=uops.uops.index(end) for s in successors(u))):
|
||||
uops.uops.insert(uops.uops.index(gate), uops.uops.pop(uops.uops.index(u)))
|
||||
gl.vin = gl.vin[:2]
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
device = "CUDA"
|
||||
suffix = "PTX"
|
||||
@@ -90,13 +101,12 @@ class PTXRenderer(Renderer):
|
||||
|
||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
||||
|
||||
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
|
||||
def render_bra(self, b1, pred=None, neg=False) -> List[str]: return [f"@{'!' if neg else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"]
|
||||
|
||||
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
||||
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
||||
def render_load(self, loc, dest, dtype, ss="", offset=0) -> List[str]:
|
||||
assert dtype is not dtypes.bool
|
||||
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
||||
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
||||
@@ -155,6 +165,7 @@ class PTXRenderer(Renderer):
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.optimize_loops()
|
||||
optimize_gated_loads(uops)
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
@@ -192,14 +203,20 @@ class PTXRenderer(Renderer):
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), neg=True))
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
kk(*self.render_bra(r_label[vin[0]], pred))
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"@{_cast(r[vin[0].vin[0]], dtypes.bool, vin[0].vin[0].dtype, u=u, pred=True)} bra {r_label[vin[0]]}_true;")
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
if len(vin) > 1 and vin[1].dtype.count > 1:
|
||||
kk(*[f"mov.b{self.types[vin[1].dtype.scalar()][1:]} {dd}, {r[vin[2]][i]};" for i, dd in enumerate(r[vin[1]])])
|
||||
elif len(vin) > 1:
|
||||
kk(*[f"mov.b{self.types[vin[1].dtype][1:]} {r[vin[1]]}, {r[vin[2]]};" ])
|
||||
kk(f"{r_label[vin[0]]}_true:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
@@ -235,13 +252,9 @@ class PTXRenderer(Renderer):
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
kk(f"ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
|
||||
Reference in New Issue
Block a user