mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Fix cuda (#836)
* disabled float4 ALU ops for CUDA, small fix to add half_prekernel before kernel_prefix * added supports_float4_alu option, and disabled for ops_cuda
This commit is contained in:
@@ -146,7 +146,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
if newvar.ltype == LocalTypes.float4:
|
||||
val = f"{lang.float4}((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
@@ -182,12 +182,15 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
|
||||
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
if lang.half_prekernel:
|
||||
prg =''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
return prg, global_size, local_size
|
||||
|
||||
class CStyleCodegen(Linearizer):
|
||||
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
|
||||
supports_constant_folding: bool = True
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
|
||||
# for renaming
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
|
||||
|
||||
@@ -88,6 +88,7 @@ class UOp(NamedTuple):
|
||||
|
||||
class Linearizer:
|
||||
supports_float4: bool = False
|
||||
supports_float4_alu: bool = False
|
||||
|
||||
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer):
|
||||
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
|
||||
@@ -340,9 +341,9 @@ class Linearizer:
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
# TODO: fold float4 into a single uop when possible.
|
||||
if isinstance(x.op, (ReduceOps, FusedOps)):
|
||||
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values)]
|
||||
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=x.op!=BinaryOps.CMPEQ)]
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
|
||||
@@ -50,5 +50,5 @@ class CUDACodegen(CStyleCodegen):
|
||||
half_prekernel = "#include <cuda_fp16.h>",
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
|
||||
|
||||
supports_float4_alu = False
|
||||
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize)
|
||||
|
||||
@@ -91,5 +91,6 @@ class CLCodegen(CStyleCodegen):
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
|
||||
|
||||
supports_float4_alu = True
|
||||
supports_float4 = True
|
||||
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)
|
||||
|
||||
Reference in New Issue
Block a user