did vload do anything? [run_process_replay] (#6760)

This commit is contained in:
George Hotz
2024-09-26 14:46:16 +08:00
committed by GitHub
parent ee4feedb77
commit 0c7d34ceb7

View File

@@ -12,8 +12,6 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
if isinstance(buf.dtype, ImageDType):
assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}"
val = f"read_imagef({r[buf]}, smp, {sidx})"
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and load.dtype.scalar() != dtypes.float16:
val = f"vload_half{'' if load.dtype.count == 1 else str(load.dtype.count)}(0, {r[buf]}+{sidx})"
elif load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
else:
@@ -28,8 +26,6 @@ def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str:
if isinstance(buf.dtype, ImageDType):
assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}"
val = f"write_imagef({r[buf]}, {sidx}, {r[var]});"
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and var.dtype.scalar() != dtypes.float16:
val = f"vstore_half{'' if var.dtype.count == 1 else str(var.dtype.count)}({r[var]}, 0, {r[buf]}+{sidx});"
elif var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
@@ -104,7 +100,6 @@ class CStyleLanguage(Renderer):
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
extra_args: List[str] = []
float4: Optional[str] = None
uses_vload: bool = False
uses_ptr_arithmetic: bool = False
type_map: Dict[DType, str] = {}
infinity: str = "INFINITY"
@@ -242,7 +237,6 @@ class OpenCLRenderer(CStyleLanguage):
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
float4 = "(float4)"
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
uses_vload = True
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
def render_cast(self, x, var_dtype, bitcast=False) -> str:
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)