mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
Refactor amd kernel prefix (#6205)
* refactor amd kernel_prefix * restore removed comment * nit
This commit is contained in:
@@ -367,16 +367,13 @@ class AMDRenderer(CStyleLanguage):
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
||||
|
||||
# language options
|
||||
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
extern "C" {\n""" + "".join([
|
||||
f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
|
||||
__attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
|
||||
__attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
|
||||
__attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
|
||||
__attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
|
||||
'}\nextern "C" __attribute__((global))'
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
||||
for dt, n in [("float", 32), ("double", 64), ("_Float16", 16)]
|
||||
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
|
||||
|
||||
kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
|
||||
kernel_prefix += '\nextern "C" __attribute__((global))'
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = _make_hip_code_for_op()
|
||||
|
||||
Reference in New Issue
Block a user