From deb3722aacf6b63bf3803477f8ade8cfa1f94023 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 2 Jan 2024 19:16:52 +0200 Subject: [PATCH] refactor workitems (#2973) --- tinygrad/renderer/cstyle.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 286a32e601..a9372550b8 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast +from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable import math, functools from collections import defaultdict, Counter from tinygrad.codegen.linearizer import UOps, UOp @@ -17,9 +17,7 @@ class CStyleLanguage(NamedTuple): smem_prefix_for_cast: bool = True arg_int_prefix: str = "const int" barrier: str = "" - xid: List[str] = [] - gid: List[str] = [] - lid: List[str] = [] + code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {} global_max: List[int] = [] local_max: List[int] = [] extra_args: List[str] = [] @@ -167,8 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st elif uop == UOps.DEFINE_ACC: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop == UOps.SPECIAL: - xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid) - kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") + kk(f"{lang.size_prefix} {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */") if args[1].startswith("l"): local_size.append(args[2]) r[u] = args[1] elif uop == UOps.CONST: @@ -212,9 +209,7 @@ class OpenCLLanguage(CStyleLanguage): half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable" barrier = "barrier(CLK_LOCAL_MEM_FENCE);" float4 = "(float4)" - gid = [f'get_group_id({i})' for i in range(3)] - lid = [f'get_local_id({i})' for i in range(3)] - xid = [f'get_global_id({i})' for i in range(3)] + code_for_workitem ={ "g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})" } uses_vload = True # NOTE: mad is used so the loads aren't reordered into the math on 845 code_for_op = {**CStyleLanguage().code_for_op, @@ -232,8 +227,7 @@ class MetalLanguage(CStyleLanguage): barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);" float4 = "float4" uses_ptr_arithmetic=True - gid = [f"gid.{chr(120+i)}" for i in range(3)] - lid = [f"lid.{chr(120+i)}" for i in range(3)] + code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"} extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'] def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str: return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) @@ -253,9 +247,10 @@ class CUDALanguage(CStyleLanguage): smem_prefix_for_cast = False barrier = "__syncthreads();" float4 = "make_float4" - gid = [f'blockIdx.{chr(120+i)}' for i in range(3)] - lid = [f'threadIdx.{chr(120+i)}' for i in range(3)] - xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)] + code_for_workitem = { + "g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}", + "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})" + } code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half} half_prekernel = """ #include @@ -286,8 +281,7 @@ HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) # TODO: how much of this can be merged with above? class WGSLLanguage(CStyleLanguage): - gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] - lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] + code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[x]})", "l": lambda x: f"i32(lindex.{'xyz'[x]})"} size_prefix = "let" barrier="workgroupBarrier();" generic_var_prefix = "var "