bugfix for cuda warning (#2078)

This commit is contained in:
George Hotz
2023-10-15 18:35:35 -07:00
committed by GitHub
parent 0d3410d93f
commit 566660675c
2 changed files with 4 additions and 3 deletions

View File

@@ -13,6 +13,7 @@ class CStyleLanguage(NamedTuple):
buffer_suffix: str = ""
smem_align: str = ""
smem_prefix: str = ""
smem_prefix_for_cast: bool = True
arg_int_prefix: str = ""
barrier: str = ""
gid: List[str] = []
@@ -64,7 +65,7 @@ class CStyleLanguage(NamedTuple):
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
if output_dtype.sz > 1:
out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
else:
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
@@ -101,7 +102,7 @@ class CStyleLanguage(NamedTuple):
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:

View File

@@ -86,7 +86,7 @@ class CUDAProgram:
return start.time_till(end)*1e-3
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
half_prekernel = """