mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
@@ -62,7 +62,7 @@ class AssemblyCodegen(Linearizer):
|
||||
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in tor:
|
||||
if not isinstance(b, Register): b = render_numnode(b)
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return tor[key]
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ code:
|
||||
# RDNA3 is actually a SIMD machine!
|
||||
class RDNACodegen(AssemblyCodegen):
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = False
|
||||
supports_float4_alu: bool = True
|
||||
supports_load3: bool = True
|
||||
sin_is_sin2pi: bool = True
|
||||
no_div: bool = True
|
||||
@@ -127,11 +127,8 @@ class RDNACodegen(AssemblyCodegen):
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes._float4:
|
||||
tins = []
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]):
|
||||
tins.append(f"{'s_' if rr[0].scalar else 'v_'}dual_{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
ins.append(tins[0] + " :: " + tins[1])
|
||||
ins.append(tins[2] + " :: " + tins[3])
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
@@ -152,6 +149,29 @@ class RDNACodegen(AssemblyCodegen):
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
|
||||
# dual alu group
|
||||
seen = set()
|
||||
new_ins = []
|
||||
for i,tins in enumerate(ins):
|
||||
if tins in seen: continue
|
||||
if tins.startswith("v_fmac_f32"):
|
||||
for gins in reversed(ins[i+1:]):
|
||||
if gins in seen: continue
|
||||
if gins.startswith("v_fmac_f32"):
|
||||
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
||||
if r0[0]%2 == r1[0]%2: continue
|
||||
if r0[1]%2 == r1[1]%2: continue
|
||||
if r0[2]%2 == r1[2]%2: continue
|
||||
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
||||
seen.add(tins)
|
||||
seen.add(gins)
|
||||
break
|
||||
if tins not in seen:
|
||||
new_ins.append(tins)
|
||||
ins = new_ins
|
||||
|
||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
||||
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
|
||||
Reference in New Issue
Block a user