mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move gpuctypes in tree (#3253)
* move gpuctypes in tree * fix mypy * regex exclude * autogen sh * mypy exclude * does that fix it * fix mypy * add hip confirm * verify all autogens * build clang2py * opencl headers * gpu on 22.04
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import gpuctypes.cuda as cuda
|
||||
import tinygrad.autogen.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, colored, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style # noqa: E501
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -22,7 +22,7 @@ CUDACPU = getenv("CUDACPU") == 1
|
||||
if CUDACPU:
|
||||
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
|
||||
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) # noqa: E501
|
||||
cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) # type: ignore # noqa: E501
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user