diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 06f3624a09..fa5d22023b 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -13,6 +13,7 @@ class CStyleLanguage(NamedTuple): buffer_suffix: str = "" smem_align: str = "" smem_prefix: str = "" + smem_prefix_for_cast: bool = True arg_int_prefix: str = "" barrier: str = "" gid: List[str] = [] @@ -64,7 +65,7 @@ class CStyleLanguage(NamedTuple): if self.uses_vload and buf_dtype == dtypes.float16: return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" if output_dtype.sz > 1: - out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" + out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" else: out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" @@ -101,7 +102,7 @@ class CStyleLanguage(NamedTuple): if self.uses_vload and buf_dtype == dtypes.float16: return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: - return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" + return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 71d0d29ee9..53456e4134 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -86,7 +86,7 @@ class CUDAProgram: return start.time_till(end)*1e-3 renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4", + kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4", gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)], half_prekernel = """