faster RDNA assembly backend (#990)

* fast asm

* torch gemm
This commit is contained in:
George Hotz
2023-06-16 12:06:38 -07:00
committed by GitHub
parent ba56ee6020
commit fe71282ba1
4 changed files with 51 additions and 13 deletions

View File

@@ -62,7 +62,7 @@ class AssemblyCodegen(Linearizer):
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in tor:
if not isinstance(b, Register): b = render_numnode(b)
#if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return tor[key]

View File

@@ -29,7 +29,7 @@ code:
# RDNA3 is actually a SIMD machine!
class RDNACodegen(AssemblyCodegen):
supports_float4: bool = True
supports_float4_alu: bool = False
supports_float4_alu: bool = True
supports_load3: bool = True
sin_is_sin2pi: bool = True
no_div: bool = True
@@ -127,11 +127,8 @@ class RDNACodegen(AssemblyCodegen):
alu_arg = "fmac"
vin = vin[0:2]
if out.dtype == dtypes._float4:
tins = []
for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]):
tins.append(f"{'s_' if rr[0].scalar else 'v_'}dual_{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
ins.append(tins[0] + " :: " + tins[1])
ins.append(tins[2] + " :: " + tins[3])
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
elif uop == UOps.LOAD:
@@ -152,6 +149,29 @@ class RDNACodegen(AssemblyCodegen):
raise NotImplementedError(uop)
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
# dual alu group
seen = set()
new_ins = []
for i,tins in enumerate(ins):
if tins in seen: continue
if tins.startswith("v_fmac_f32"):
for gins in reversed(ins[i+1:]):
if gins in seen: continue
if gins.startswith("v_fmac_f32"):
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
if r0[0]%2 == r1[0]%2: continue
if r0[1]%2 == r1[1]%2: continue
if r0[2]%2 == r1[2]%2: continue
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
seen.add(tins)
seen.add(gins)
break
if tins not in seen:
new_ins.append(tins)
ins = new_ins
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
def assemble(self, args, ins, v_cnt, s_cnt):