move gpuctypes in tree (#3253)

* move gpuctypes in tree

* fix mypy

* regex exclude

* autogen sh

* mypy exclude

* does that fix it

* fix mypy

* add hip confirm

* verify all autogens

* build clang2py

* opencl headers

* gpu on 22.04
This commit is contained in:
George Hotz
2024-01-26 12:25:03 -08:00
committed by GitHub
parent bc92c4cc32
commit a3869ffd46
17 changed files with 14597 additions and 21 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
from pathlib import Path
from typing import Tuple, Optional
import gpuctypes.cuda as cuda
import tinygrad.autogen.cuda as cuda
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, colored, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style # noqa: E501
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
from tinygrad.codegen.kernel import LinearizerOptions
@@ -22,7 +22,7 @@ CUDACPU = getenv("CUDACPU") == 1
if CUDACPU:
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) # noqa: E501
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) # type: ignore # noqa: E501
def check(status):
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501