minor dimensions (#1730)

This commit is contained in:
George Hotz
2023-09-01 06:42:00 -07:00
committed by GitHub
parent 5c403d43b9
commit 7780eb3c5a

View File

@@ -100,11 +100,6 @@ class CStyleLanguage(NamedTuple):
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
assert i < 3, "only 3 dims are supported"
local_size.append(var.max+1)
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
global_size: List[int] = []
local_size: List[int] = []
@@ -118,9 +113,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
if uop == UOps.LOOP:
for i,var in enumerate(args[0]):
if args[1] == "global" and lang.gid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
global_size.append(var.max+1)
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
local_size.append(var.max+1)
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
else:
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))