* disabled float4 ALU ops for CUDA, small fix to add half_prekernel before kernel_prefix

* added supports_float4_alu option, and disabled for ops_cuda
This commit is contained in:
crthilakraj
2023-05-29 16:59:36 +02:00
committed by GitHub
parent 6ea5df19b2
commit 7925fa58d9
4 changed files with 10 additions and 5 deletions

View File

@@ -146,7 +146,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
else:
if newvar.ltype == LocalTypes.float4:
val = f"{lang.float4}((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
@@ -182,12 +182,15 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
if lang.half_prekernel:
prg =''.join([f"{lang.half_prekernel}", "\n", prg])
return prg, global_size, local_size
class CStyleCodegen(Linearizer):
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
supports_constant_folding: bool = True
supports_float4: bool = True
supports_float4_alu: bool = True
# for renaming
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)

View File

@@ -88,6 +88,7 @@ class UOp(NamedTuple):
class Linearizer:
supports_float4: bool = False
supports_float4_alu: bool = False
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
@@ -340,9 +341,9 @@ class Linearizer:
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
# TODO: fold float4 into a single uop when possible.
if isinstance(x.op, (ReduceOps, FusedOps)):
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values)]
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=x.op!=BinaryOps.CMPEQ)]
ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:

View File

@@ -50,5 +50,5 @@ class CUDACodegen(CStyleCodegen):
half_prekernel = "#include <cuda_fp16.h>",
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
supports_float4_alu = False
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize)

View File

@@ -91,5 +91,6 @@ class CLCodegen(CStyleCodegen):
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
supports_float4_alu = True
supports_float4 = True
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)