mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
minor dimensions (#1730)
This commit is contained in:
@@ -100,11 +100,6 @@ class CStyleLanguage(NamedTuple):
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
|
||||
def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
|
||||
assert i < 3, "only 3 dims are supported"
|
||||
local_size.append(var.max+1)
|
||||
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
|
||||
global_size: List[int] = []
|
||||
local_size: List[int] = []
|
||||
@@ -118,9 +113,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
if uop == UOps.LOOP:
|
||||
for i,var in enumerate(args[0]):
|
||||
if args[1] == "global" and lang.gid:
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
||||
global_size.append(var.max+1)
|
||||
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
elif args[1] == "local" and lang.lid:
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||
local_size.append(var.max+1)
|
||||
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
else:
|
||||
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
|
||||
|
||||
Reference in New Issue
Block a user