mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
cuda: fix fp16, uint8, int64, half4 codegen (#968)
* cuda: add uchar, int64 typedefs * cuda: fix float16 codegen * fuck it, half4 stub. llama time! * inline fp16 half4, revert changes to CStyleLanguage * add inline just in case * remove half4 operators * use dict
This commit is contained in:
@@ -141,8 +141,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
|
||||
else:
|
||||
zero = f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);" if newvar.ltype == LocalTypes.float4 else "0.0f"
|
||||
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {zero};")
|
||||
casts = {LocalTypes.float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), LocalTypes.half: ("(half)", "(half)(0.0f)"), LocalTypes.float: ("(float)", "0.0f")}[newvar.ltype]
|
||||
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
|
||||
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
|
||||
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
|
||||
@@ -47,8 +47,16 @@ class CUDAProgram:
|
||||
class CUDACodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
half_prekernel = "#include <cuda_fp16.h>",
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
half_prekernel = """
|
||||
#include <cuda_fp16.h>
|
||||
struct __align__(8) half4 {
|
||||
half2 x, y;
|
||||
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
|
||||
};
|
||||
typedef unsigned char uchar;
|
||||
typedef long long int64;
|
||||
""")
|
||||
supports_float4_alu = False
|
||||
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize)
|
||||
|
||||
Reference in New Issue
Block a user