mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
test: delete the extra cast in cstyle load [run_process_replay] [no_assert] (#5310)
* test: delete the extra cast in cstyle load [run_process_replay] [no_assert] * assert buf_uop * ImageDType * ptx is actually a 64bit address
This commit is contained in:
@@ -57,10 +57,8 @@ class CStyleLanguage(Renderer):
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.count > 1:
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
|
||||
else:
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
|
||||
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
|
||||
|
||||
Reference in New Issue
Block a user