diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 76c279a744..778368e303 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -76,7 +76,7 @@ jobs: - name: Lint tinygrad with pylint run: python -m pylint tinygrad/ - name: Run mypy - run: python -m mypy + run: python -m mypy --strict-equality - name: Test Docs run: | python docs/abstractions.py diff --git a/test/test_dtype.py b/test/test_dtype.py index 31c6fc148a..980f8eed4e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -122,6 +122,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) + _assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2]) _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16]) _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 9fc56fbff8..ec90358f41 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -20,7 +20,6 @@ class CStyleLanguage(NamedTuple): local_max: List[int] = [] extra_args: List[str] = [] float4: Optional[str] = None - half_prekernel: Optional[str] = None uses_vload: bool = False uses_ptr_arithmetic: bool = False launch_bounds: bool = False @@ -62,7 +61,7 @@ class CStyleLanguage(NamedTuple): out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int]) -> str: + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], uops:List[UOp], prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501 buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else ("const " if i > 0 else "")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else @@ -70,8 +69,7 @@ class CStyleLanguage(NamedTuple): prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) - if self.half_prekernel and any(dtype in [dtypes.float16, dtypes.bfloat16] for _,dtype in bufs): prg = ''.join((self.half_prekernel, "\n", prg)) - return prg + return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" # returns a str statement that does the store def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str: @@ -169,14 +167,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).sz > 4 else f"({r[vin[0]]}).{'xyzw'[args]}" else: raise RuntimeError(f"failed to render {uop}") - return lang.render_kernel(function_name, kernel, bufs, local_size) + return lang.render_kernel(function_name, kernel, bufs, local_size, uops) class OpenCLLanguage(CStyleLanguage): kernel_prefix = "__kernel " buffer_prefix = "__global " smem_align = "__attribute__ ((aligned (16))) " smem_prefix = "__local " - half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable" barrier = "barrier(CLK_LOCAL_MEM_FENCE);" float4 = "(float4)" code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"} @@ -187,6 +184,10 @@ class OpenCLLanguage(CStyleLanguage): type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" } def render_cast(self, x, var_dtype, bitcast=False) -> str: return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype) + + def render_kernel(self, function_name, kernel, bufs, local_size, uops, prefix=None) -> str: + if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + return super().render_kernel(function_name, kernel, bufs, local_size, uops, prefix) OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage()) class MetalLanguage(CStyleLanguage): @@ -215,7 +216,7 @@ code_for_op_half = { } class CUDALanguage(CStyleLanguage): - kernel_prefix = "#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern \"C\" __global__ " + kernel_prefix = "extern \"C\" __global__ " smem_prefix = "__shared__ " smem_prefix_for_cast = False barrier = "__syncthreads();" @@ -223,11 +224,15 @@ class CUDALanguage(CStyleLanguage): code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}", "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"} code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half} - half_prekernel = "#include \n"+"#include \n"+""" - struct half4 { half x, y, z, w; }; - __device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; } - """ type_map = {dtypes.bfloat16: "nv_bfloat16"} + + def render_kernel(self, function_name, kernel, bufs, local_size, uops, prefix=None): + prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] + if any(uop.dtype == dtypes.half for uop in uops): + prefix += ["#include ", "struct half4 { half x, y, z, w; };", + "__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }"] + if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include ") + return super().render_kernel(function_name, kernel, bufs, local_size, uops, prefix=prefix) CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) code_for_op_hip = { @@ -289,4 +294,4 @@ class HIPLanguage(CStyleLanguage): launch_bounds = True uses_ptr_arithmetic = True type_map = {dtypes.bfloat16: "hip_bfloat16"} -HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) \ No newline at end of file +HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 9f03265f54..819444b177 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -23,8 +23,8 @@ class CLCompiler(Compiler): def compile(self, src:str) -> bytes: program = checked(cl.clCreateProgramWithSource(self.device.context, 1, to_char_p_p([prg_bytes := src.encode()]), ctypes.byref(ctypes.c_size_t(len(prg_bytes))), ctypes.byref(status := ctypes.c_int32())), status) - status = cl.clBuildProgram(program, 1, ctypes.byref(self.device.device_id), None, cl.clBuildProgram.argtypes[4](), None) - if status != 0: + build_status: int = cl.clBuildProgram(program, 1, ctypes.byref(self.device.device_id), None, cl.clBuildProgram.argtypes[4](), None) + if build_status != 0: cl.clGetProgramBuildInfo(program, self.device.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, ctypes.byref(log_size := ctypes.c_size_t())) cl.clGetProgramBuildInfo(program, self.device.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501 raise RuntimeError(f"OpenCL Compile Error\n\n{ctypes.string_at(mstr, size=log_size.value).decode()}")