From e7248b677cf780140bbb2eadd8d86d0b2806a1a2 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Wed, 13 Dec 2023 18:04:17 +0100 Subject: [PATCH] Remove wgsl custom render_for (#2729) * Generic for * remove custom render_if * Simplify for loop * 150 line-length constraint * Put custom render_if back --- tinygrad/renderer/cstyle.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 3e7939ef67..c8ef1e723a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -72,7 +72,7 @@ class CStyleLanguage(NamedTuple): return self.smem_align + self.smem_prefix + f"float {name}[{size}];" def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: - return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{" + return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{" def render_if(self, cond: str): return f"if ({cond}) {{" @@ -358,9 +358,6 @@ class WGSLLanguage(CStyleLanguage): prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501 return prg - def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str: - return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{" - def render_if(self, cond: str): return f"if (bool({cond})) {{"