mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
clean up hip renderer (#13063)
* clean up hip renderer * ocml --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -450,16 +450,9 @@ class AMDRenderer(CStyleLanguage):
|
||||
]) + base_rewrite
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
# language options
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
||||
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
|
||||
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", ""), ("trunc", "")]]
|
||||
|
||||
kernel_typedef = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
kernel_typedef += '\nextern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {launch_bounds})))'
|
||||
kernel_typedef = 'extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {launch_bounds})))'
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = { **CStyleLanguage.code_for_op,
|
||||
@@ -490,15 +483,25 @@ class AMDRenderer(CStyleLanguage):
|
||||
f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
||||
prefix, ockl = [], []
|
||||
type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16", dtypes.fp8e4m3: "_fp8_fp8", dtypes.fp8e5m2: "_bf8_bf8" }
|
||||
used_dtypes = uops_to_dtypes(uops)
|
||||
if any(u.op is Ops.CONST and not math.isfinite(u.arg) for u in uops):
|
||||
prefix += ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))"]
|
||||
if any(u.op is Ops.SPECIAL for u in uops):
|
||||
prefix.append("typedef long unsigned int size_t;")
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")}
|
||||
ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.itemsize * 8}", dt.name, dt.name, ocml_ops[op][1])
|
||||
for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)]
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
||||
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16")
|
||||
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes):
|
||||
prefix += ["typedef unsigned char hip_bf8;", "typedef unsigned char hip_fp8;"]
|
||||
prefix.append("""static inline __attribute__((device)) unsigned char f32_to_fp8(float v, int is_bf8) {
|
||||
v = (((*(unsigned*)&v)&0x7F800000)!=0x7F800000)?__builtin_amdgcn_fmed3f(v,is_bf8?57344.0f:448.0f,is_bf8?-57344.0f:-448.0f) : v;
|
||||
return (unsigned char)(is_bf8?__builtin_amdgcn_cvt_pk_bf8_f32(v,v,0,false):__builtin_amdgcn_cvt_pk_fp8_f32(v,v,0,false));\n}""")
|
||||
prefix += [f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml]
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
||||
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
|
||||
Reference in New Issue
Block a user