cleaner code_for_op order [pr] (#7653)

* cleaner code_for_op order

* mantain unary-bin-tern order

* might as well reorder for cuda and amd
This commit is contained in:
ignaciosica
2024-11-12 17:13:56 -03:00
committed by GitHub
parent 962dafb467
commit 54c0abcb2b

View File

@@ -82,16 +82,13 @@ class CStyleLanguage(Renderer):
infinity: str = "INFINITY"
nan: str = "NAN"
code_for_op: Dict = {
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}",
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", UnaryOps.NEG: lambda x,dtype: f"-{x}",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})",
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})",
BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
string_rewrite = base_rewrite
extra_matcher = extra_pm
@@ -311,12 +308,12 @@ class CUDARenderer(CStyleLanguage):
float4 = "make_float4"
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
code_for_op = {**CStyleLanguage.code_for_op,
UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
code_for_op = { **CStyleLanguage.code_for_op,
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
type_map = {dtypes.bfloat16: "nv_bfloat16"}
def render_vector_prefix(self, dt:DType) -> str:
@@ -375,10 +372,10 @@ class AMDRenderer(CStyleLanguage):
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = { **CStyleLanguage.code_for_op,
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
smem_prefix = "__attribute__((shared))"
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'