From 3ddcb5c36ff0d7ab6de8bfde734692ee6bf02367 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 25 May 2023 20:21:15 -0700 Subject: [PATCH] Half4 load (#804) * support half4 load * cast to float4 * dead assert --- tinygrad/codegen/cstyle.py | 34 ++++++++++++++++++---------------- tinygrad/codegen/linearizer.py | 3 ++- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index a26fee0b84..92189c05d8 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -127,31 +127,33 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};") else: kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};") - elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float: - assert not isinstance(bufs[args.i].dtype, ImageDType), "image load must be float4" + elif uop == UOps.LOAD and newvar is not None: # TODO: merge with CONST? if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst): + assert newvar.ltype == LocalTypes.float, "const can't be float4" # nan? inf? val = f"{bufs[args.i].realized._buf}f" - else: - if lang.uses_vload and bufs[args.i].dtype == dtypes.float16: - val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})" - else: - val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]" - # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args.valid.min == 1: kk(f"float {newvar.name} = {val};") - else: kk(f"float {newvar.name} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;") - elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float4: - assert newvar.offset is None, "load can't have an offset" - if isinstance(bufs[args.i].dtype, ImageDType): + elif isinstance(bufs[args.i].dtype, ImageDType): + assert newvar.ltype == LocalTypes.float4, "image must be float4" prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid) val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))" else: - val = f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}]" + if lang.uses_vload and bufs[args.i].dtype == dtypes.float16: + if newvar.ltype == LocalTypes.float4: + val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})" + else: + val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})" + else: + if newvar.ltype == LocalTypes.float4: + val = f"{lang.float4}((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])" + else: + val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]" # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args[2].min == 1: kk(f"{newvar.render(True)} = {val};") - else: kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);") + if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};") + else: + zero = f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);" if newvar.ltype == LocalTypes.float4 else "0.0f" + kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {zero};") elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)): assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4" assert args.valid.min == 1, "store must be valid" diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 85321a10cd..46fe11be88 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -16,6 +16,7 @@ class LocalBuffer(NamedTuple): dtype: DType = dtypes.float32 realized: None = None +# NOTE: half and half4 are not actually used yet class LocalTypes(Enum): float = auto(); float4 = auto(); half = auto(); half4 = auto(); simdgroup_float8x8 = auto() # noqa: E702 class Token(NamedTuple): @@ -169,7 +170,7 @@ class Linearizer: load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(LocalTypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)} # float4 grouping (optional) - should_upcast = self.supports_float4 and self.bufs[i].dtype != dtypes.float16 and len(self.float4_axis(i)) == 1 + should_upcast = self.supports_float4 and len(self.float4_axis(i)) == 1 if should_upcast: load_offset_new = {} for k,out_tokens in self._group_float4(i, load_offset).items():