mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
bugfix for cuda warning (#2078)
This commit is contained in:
@@ -13,6 +13,7 @@ class CStyleLanguage(NamedTuple):
|
||||
buffer_suffix: str = ""
|
||||
smem_align: str = ""
|
||||
smem_prefix: str = ""
|
||||
smem_prefix_for_cast: bool = True
|
||||
arg_int_prefix: str = ""
|
||||
barrier: str = ""
|
||||
gid: List[str] = []
|
||||
@@ -64,7 +65,7 @@ class CStyleLanguage(NamedTuple):
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.sz > 1:
|
||||
out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
|
||||
else:
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
@@ -101,7 +102,7 @@ class CStyleLanguage(NamedTuple):
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||
|
||||
@@ -86,7 +86,7 @@ class CUDAProgram:
|
||||
return start.time_till(end)*1e-3
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
half_prekernel = """
|
||||
|
||||
Reference in New Issue
Block a user