Refactor amd kernel prefix (#6205)

* refactor amd kernel_prefix

* restore removed comment

* nit
This commit is contained in:
ignaciosica
2024-08-20 14:37:36 -03:00
committed by GitHub
parent 074cf780dd
commit e4bb63c1be

View File

@@ -367,16 +367,13 @@ class AMDRenderer(CStyleLanguage):
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
# language options
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
extern "C" {\n""" + "".join([
f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
__attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
__attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
__attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
__attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
'}\nextern "C" __attribute__((global))'
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
for dt, n in [("float", 32), ("double", 64), ("_Float16", 16)]
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
kernel_prefix += '\nextern "C" __attribute__((global))'
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = _make_hip_code_for_op()