mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 23:54:58 -05:00
did vload do anything? [run_process_replay] (#6760)
This commit is contained in:
@@ -12,8 +12,6 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}"
|
||||
val = f"read_imagef({r[buf]}, smp, {sidx})"
|
||||
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and load.dtype.scalar() != dtypes.float16:
|
||||
val = f"vload_half{'' if load.dtype.count == 1 else str(load.dtype.count)}(0, {r[buf]}+{sidx})"
|
||||
elif load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
|
||||
else:
|
||||
@@ -28,8 +26,6 @@ def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str:
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}"
|
||||
val = f"write_imagef({r[buf]}, {sidx}, {r[var]});"
|
||||
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and var.dtype.scalar() != dtypes.float16:
|
||||
val = f"vstore_half{'' if var.dtype.count == 1 else str(var.dtype.count)}({r[var]}, 0, {r[buf]}+{sidx});"
|
||||
elif var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
|
||||
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
|
||||
@@ -104,7 +100,6 @@ class CStyleLanguage(Renderer):
|
||||
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
uses_ptr_arithmetic: bool = False
|
||||
type_map: Dict[DType, str] = {}
|
||||
infinity: str = "INFINITY"
|
||||
@@ -242,7 +237,6 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
||||
float4 = "(float4)"
|
||||
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
||||
uses_vload = True
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
Reference in New Issue
Block a user