diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index ffaaa6bdad..8e011eb120 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -146,7 +146,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})" else: if newvar.ltype == LocalTypes.float4: - val = f"{lang.float4}((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])" + val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])" else: val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]" # NOTE: if min and max are both 0, it should be a CONST in the Linearizer @@ -182,12 +182,15 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] + [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"]) + if lang.half_prekernel: + prg =''.join([f"{lang.half_prekernel}", "\n", prg]) return prg, global_size, local_size class CStyleCodegen(Linearizer): lang: ClassVar[CStyleLanguage] = CStyleLanguage() supports_constant_folding: bool = True supports_float4: bool = True + supports_float4_alu: bool = True # for renaming kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index a204b1a498..589949a504 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -88,6 +88,7 @@ class UOp(NamedTuple): class Linearizer: supports_float4: bool = False + supports_float4_alu: bool = False def __init__(self, ast:LazyOp, output_buffer:LazyBuffer): # NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf @@ -340,9 +341,9 @@ class Linearizer: values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] # TODO: fold float4 into a single uop when possible. if isinstance(x.op, (ReduceOps, FusedOps)): - ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values)] + ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)] else: - ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=x.op!=BinaryOps.CMPEQ)] + ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)] ordered_ret: List[Optional[Token]] = [None]*len(values[0]) # scatter for i,j in ret: diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index b60bf37be8..4784beaae7 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -50,5 +50,5 @@ class CUDACodegen(CStyleCodegen): half_prekernel = "#include ", gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]) - + supports_float4_alu = False CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 93cf469a06..c47d6ee05e 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -91,5 +91,6 @@ class CLCodegen(CStyleCodegen): half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True) - + supports_float4_alu = True + supports_float4 = True GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)