refactor workitems (#2973)

This commit is contained in:
qazal
2024-01-02 19:16:52 +02:00
committed by GitHub
parent 01cdd6596f
commit deb3722aac

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
import math, functools
from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
@@ -17,9 +17,7 @@ class CStyleLanguage(NamedTuple):
smem_prefix_for_cast: bool = True
arg_int_prefix: str = "const int"
barrier: str = ""
xid: List[str] = []
gid: List[str] = []
lid: List[str] = []
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = []
@@ -167,8 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
elif uop == UOps.DEFINE_ACC:
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop == UOps.SPECIAL:
xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid)
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
kk(f"{lang.size_prefix} {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
if args[1].startswith("l"): local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CONST:
@@ -212,9 +209,7 @@ class OpenCLLanguage(CStyleLanguage):
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
float4 = "(float4)"
gid = [f'get_group_id({i})' for i in range(3)]
lid = [f'get_local_id({i})' for i in range(3)]
xid = [f'get_global_id({i})' for i in range(3)]
code_for_workitem ={ "g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})" }
uses_vload = True
# NOTE: mad is used so the loads aren't reordered into the math on 845
code_for_op = {**CStyleLanguage().code_for_op,
@@ -232,8 +227,7 @@ class MetalLanguage(CStyleLanguage):
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
float4 = "float4"
uses_ptr_arithmetic=True
gid = [f"gid.{chr(120+i)}" for i in range(3)]
lid = [f"lid.{chr(120+i)}" for i in range(3)]
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
@@ -253,9 +247,10 @@ class CUDALanguage(CStyleLanguage):
smem_prefix_for_cast = False
barrier = "__syncthreads();"
float4 = "make_float4"
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
code_for_workitem = {
"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"
}
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
half_prekernel = """
#include <cuda_fp16.h>
@@ -286,8 +281,7 @@ HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
# TODO: how much of this can be merged with above?
class WGSLLanguage(CStyleLanguage):
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[x]})", "l": lambda x: f"i32(lindex.{'xyz'[x]})"}
size_prefix = "let"
barrier="workgroupBarrier();"
generic_var_prefix = "var "