test: delete the extra cast in cstyle load [run_process_replay] [no_assert] (#5310)

* test: delete the extra cast in cstyle load [run_process_replay] [no_assert]

* assert buf_uop

* ImageDType

* ptx is actually a 64bit address
This commit is contained in:
qazal
2024-07-07 09:12:49 +03:00
committed by GitHub
parent cededd8eb4
commit 2a7282c1e1

View File

@@ -57,10 +57,8 @@ class CStyleLanguage(Renderer):
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
if output_dtype.count > 1:
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
else:
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str: