From d2e9245de8cb5ec3376a3f2bccef786e91ad52b0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 21 Dec 2023 21:15:28 +0200 Subject: [PATCH] render_locals takes a dtype (#2873) Co-authored-by: chenyu --- tinygrad/renderer/cstyle.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index d0cced4f7b..3c32038f16 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -68,8 +68,8 @@ class CStyleLanguage(NamedTuple): return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val - def render_local(self, name:str, size:int): - return self.smem_align + self.smem_prefix + f"float {name}[{size}];" + def render_local(self, name:str, dtype:DType, size:int): + return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];" def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{" @@ -190,10 +190,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu if child_count[u] <= 1: r[u] = val else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};") elif uop == UOps.DEFINE_LOCAL: + assert dtype is not None if lang.external_local_bufs: - prekernel.append(lang.render_local(args[0], args[1])) + prekernel.append(lang.render_local(args[0], dtype, args[1])) else: - kk(lang.render_local(args[0], args[1])) + kk(lang.render_local(args[0], dtype, args[1])) r[u] = args[0] elif uop == UOps.DEFINE_GLOBAL: assert dtype is not None @@ -309,8 +310,8 @@ class WGSLLanguage(CStyleLanguage): # HACK: write bool as f32. remove after elementwise op cast inputs properly type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "f32"} - def render_local(self, name: str, size: int): - return f"var {name}: array;" + def render_local(self, name: str, dtype:DType, size: int): + return f"var {name}: array<{self.type_map[dtype]},{size}>;" def render_const(self, x:Union[float,int], var_dtype) -> str: if math.isnan(x): return "nan()"