diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0fad2fa2c9..9df05684e7 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -64,7 +64,7 @@ class CStyleLanguage(NamedTuple): if self.uses_vload and buf_dtype == dtypes.float16: return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" if output_dtype.sz > 1: - out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" + out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" else: out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 40a3431e8a..c4b3487d25 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -45,7 +45,7 @@ class WGSLLanguage(CStyleLanguage): def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select(f32({y}), {x}, bool({cond}))" - + def render_cast(self, x:List[str], var_dtype:DType) -> str: if type_map[var_dtype]: return f"{type_map[var_dtype]}({x[0]})" raise NotImplementedError(f"no cast for {var_dtype}")