RDNA3 assembly support (#3637)

* amazing that i can use comgr for this

* compile empty kernel

* cleanups

* tiny_add compiles

* ugh

* more work

* put that in extra
This commit is contained in:
George Hotz
2024-06-13 09:09:24 +02:00
committed by GitHub
parent fd071ba27e
commit e63701fbd4
4 changed files with 162 additions and 13 deletions

1
.gitignore vendored
View File

@@ -51,5 +51,6 @@ quickstart.py
.hypothesis
weights
*.lprof
comgr_*
*.pkl
site/

View File

@@ -7,6 +7,18 @@ from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip
from tinygrad.renderer.rdna import uops_to_rdna
class RDNACompiler(Compiler):
linearizer_opts = LinearizerOptions("HIP", has_tensor_cores=True)
def __init__(self, arch:str):
self.arch = arch
super().__init__(f"compile_rdna_{self.arch}")
def render(self, name:str, uops) -> str: return uops_to_rdna(name, uops)
def compile(self, src:str) -> bytes:
ret = compile_hip(src, self.arch, True)
#with open("/tmp/out.so", "wb") as f: f.write(ret)
return ret
class HIPCompiler(Compiler):
compiler_opts = CompilerOptions("HIP", has_tensor_cores=True, shared_max=65536)
@@ -169,8 +181,8 @@ class HIPDevice(Compiled):
else:
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode()
from tinygrad.runtime.graph.hip import HIPGraph
super().__init__(device, HIPAllocator(self), HIPCompiler(self.arch),
functools.partial(HIPProgram, self.device), HIPGraph)
super().__init__(device, HIPAllocator(self), RDNACompiler(self.arch) if getenv("RDNA") else HIPCompiler(self.arch),
functools.partial(HIPProgram, self.device), HIPGraph)
def synchronize(self):
if getenv("HIPCPU"): return
hip_set_device(self.device)

127
extra/backends/rdna.py Normal file
View File

@@ -0,0 +1,127 @@
from typing import Dict, Set
import yaml
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
from tinygrad.ops import BinaryOps
from tinygrad.dtype import dtypes
def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
# pointer indexing
if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
ptr = uops.add(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
u.vin = (u.vin[0], ptr) + u.vin[2:]
#uops.print()
args = []
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
r: Dict[UOp, str] = {}
for u in uops:
if u.uop == UOps.SPECIAL:
if u.arg[1].startswith("lidx"):
r[u] = f'v{u.arg[0]}'
elif u.arg[1].startswith("gidx"):
r[u] = f's{2+u.arg[0]}'
else:
raise NotImplementedError
elif u.uop == UOps.CONST:
#r[u] = u.arg
# TODO: sometimes we can use s
#r[u] = f"s{s_cnt}"
#s_cnt += 1
#ins.append(f"s_mov_b32 {r[u]}, {u.arg}")
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"v_mov_b32 {r[u]}, {u.arg}")
elif u.uop == UOps.ALU:
if u.arg == BinaryOps.ADD:
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"v_add_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
elif u.arg == BinaryOps.MUL:
r[u] = f"v{v_cnt}"
v_cnt += 1
if dtypes.is_float(u.dtype):
ins.append(f"v_mul_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
else:
ins.append(f"v_mul_u32_u24 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
else:
raise NotImplementedError
elif u.uop == UOps.LOAD:
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"global_load_b32 {r[u]}, {r[u.vin[1]]}, {r[u.vin[0]]}")
ins.append("s_waitcnt vmcnt(0)")
elif u.uop == UOps.STORE:
ins.append(f"global_store_b32 {r[u.vin[1]]}, {r[u.vin[2]]}, {r[u.vin[0]]}")
elif u.uop == UOps.DEFINE_GLOBAL:
i = u.arg[0]
args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8,
'.type_name': u.dtype.name+"*", '.value_kind': 'global_buffer'})
s_cnt += s_cnt%2 # skip
r[u] = f"s[{s_cnt}:{s_cnt+1}]"
s_cnt += 2
ins.append(f"s_load_b64 {r[u]}, s[0:1], {i*8}")
ins.append("s_waitcnt lgkmcnt(0)")
else:
raise NotImplementedError(f"can't render {u.uop}")
# *** boilerplate rendering ***
metadata = {
'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': function_name, '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': f'{function_name}.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
boilerplate_start = f"""
.rodata
.global {function_name}.kd
.type {function_name}.kd,STT_OBJECT
.align 0x10
.amdhsa_kernel {function_name}"""
kernel_desc = {
'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3,
'.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, '.amdhsa_fp16_overflow': 0,
'.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0,
'.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0,
'.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
code_start = f""".end_amdhsa_kernel
.text
.global {function_name}
.type {function_name},@function
.p2align 8
{function_name}:
"""
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
return ".amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + \
boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + \
'\n'.join(ins) + f"\n.size {function_name}, .-{function_name}"

View File

@@ -14,7 +14,7 @@ def _get_comgr_data(data_set, data_type):
return bytes(dat)
# AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
def compile_hip(prg:str, arch="gfx1100") -> bytes:
def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
@@ -27,17 +27,26 @@ def compile_hip(prg:str, arch="gfx1100") -> bytes:
check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
# -include hiprtc_runtime.h was removed
check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
if status != 0:
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
raise RuntimeError("compile failed")
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
if asm:
check(comgr.amd_comgr_set_data_name(data_src, b"<null>.s"))
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_ASSEMBLE_SOURCE_TO_RELOCATABLE, action_info, data_set_src, data_set_reloc)
if status != 0:
print(_get_comgr_data(data_set_reloc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
raise RuntimeError("assemble failed")
else:
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
# -include hiprtc_runtime.h was removed
check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
if status != 0:
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
raise RuntimeError("compile failed")
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
check(comgr.amd_comgr_action_info_set_options(action_info, b""))
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)