mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
refactor workitems (#2973)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import math, functools
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
@@ -17,9 +17,7 @@ class CStyleLanguage(NamedTuple):
|
||||
smem_prefix_for_cast: bool = True
|
||||
arg_int_prefix: str = "const int"
|
||||
barrier: str = ""
|
||||
xid: List[str] = []
|
||||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
||||
global_max: List[int] = []
|
||||
local_max: List[int] = []
|
||||
extra_args: List[str] = []
|
||||
@@ -167,8 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid)
|
||||
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
|
||||
kk(f"{lang.size_prefix} {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
if args[1].startswith("l"): local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CONST:
|
||||
@@ -212,9 +209,7 @@ class OpenCLLanguage(CStyleLanguage):
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
||||
float4 = "(float4)"
|
||||
gid = [f'get_group_id({i})' for i in range(3)]
|
||||
lid = [f'get_local_id({i})' for i in range(3)]
|
||||
xid = [f'get_global_id({i})' for i in range(3)]
|
||||
code_for_workitem ={ "g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})" }
|
||||
uses_vload = True
|
||||
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
@@ -232,8 +227,7 @@ class MetalLanguage(CStyleLanguage):
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
||||
float4 = "float4"
|
||||
uses_ptr_arithmetic=True
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)]
|
||||
lid = [f"lid.{chr(120+i)}" for i in range(3)]
|
||||
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
|
||||
return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
@@ -253,9 +247,10 @@ class CUDALanguage(CStyleLanguage):
|
||||
smem_prefix_for_cast = False
|
||||
barrier = "__syncthreads();"
|
||||
float4 = "make_float4"
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
|
||||
code_for_workitem = {
|
||||
"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
|
||||
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"
|
||||
}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
|
||||
half_prekernel = """
|
||||
#include <cuda_fp16.h>
|
||||
@@ -286,8 +281,7 @@ HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
|
||||
# TODO: how much of this can be merged with above?
|
||||
class WGSLLanguage(CStyleLanguage):
|
||||
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
|
||||
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
|
||||
code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[x]})", "l": lambda x: f"i32(lindex.{'xyz'[x]})"}
|
||||
size_prefix = "let"
|
||||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
|
||||
Reference in New Issue
Block a user