mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
mockcuda (#8503)
* init mockcuda * run gpu ocelot * fix * sfixes * disable broken tests * linter * these fails as well * pylint * myypy * this fails on real platforms as well * mypy please
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -483,7 +483,7 @@ jobs:
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
|
||||
- name: Install OpenCL
|
||||
if: matrix.backend == 'gpu'
|
||||
run: |
|
||||
|
||||
170
test/mockgpu/cuda/cuda.py
Normal file
170
test/mockgpu/cuda/cuda.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
import ctypes, time
|
||||
from tinygrad.runtime.autogen import cuda as orig_cuda
|
||||
from tinygrad.helpers import mv_address
|
||||
|
||||
for attr in dir(orig_cuda):
|
||||
if not attr.startswith('__'):
|
||||
globals()[attr] = getattr(orig_cuda, attr)
|
||||
|
||||
try:
|
||||
gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
|
||||
except Exception: pass
|
||||
|
||||
# Global state
|
||||
class CUDAState:
|
||||
def __init__(self):
|
||||
self.memory: dict[int, bytearray] = {}
|
||||
self.events: dict[int, float] = {} # Event ID -> timestamp
|
||||
self.modules: dict[int, memoryview] = {} # Module ID -> code
|
||||
self.current_context: int|None = None
|
||||
self.contexts: dict[int, dict] = {} # Context ID -> context data
|
||||
self.devices: dict[int, dict] = {} # Device ID -> device data
|
||||
self.next_ptr = 1000 # For memory allocation
|
||||
self.next_event_id = 1
|
||||
self.next_module_id = 1
|
||||
self.next_context_id = 1
|
||||
|
||||
cuda_state = CUDAState()
|
||||
|
||||
# Helper functions
|
||||
def check_context():
|
||||
if cuda_state.current_context is None:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
# CUDA API simulation
|
||||
def cuInit(flags: int) -> int:
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuDeviceGet(device, ordinal: int) -> int:
|
||||
if ordinal < 0:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
device._obj.value = ordinal
|
||||
cuda_state.devices[ordinal] = {"compute_capability": (3, 5)}
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuCtxCreate_v2(pctx, flags: int, dev: int) -> int:
|
||||
ctx_id = cuda_state.next_context_id
|
||||
cuda_state.next_context_id += 1
|
||||
cuda_state.contexts[ctx_id] = {"device": dev, "flags": flags}
|
||||
pctx._obj.value = ctx_id
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuCtxSetCurrent(context: orig_cuda.CUcontext) -> int:
|
||||
if context.value not in cuda_state.contexts:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
cuda_state.current_context = context.value
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemAlloc_v2(dptr, bytesize: int) -> int:
|
||||
x = memoryview(bytearray(bytesize))
|
||||
dptr._obj.value = mv_address(x)
|
||||
cuda_state.memory[dptr._obj.value] = x
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemFree_v2(dptr: orig_cuda.CUdeviceptr_v2) -> int:
|
||||
if dptr.value in cuda_state.memory:
|
||||
del cuda_state.memory[dptr.value]
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
|
||||
def cuMemcpyHtoDAsync_v2(dst: orig_cuda.CUdeviceptr_v2, src: ctypes.c_void_p, bytesize: int, stream: Any) -> int:
|
||||
ctypes.memmove(dst.value, src, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemcpyDtoH_v2(dst: ctypes.c_void_p, src: orig_cuda.CUdeviceptr_v2, bytesize: int) -> int:
|
||||
ctypes.memmove(dst, src.value, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventCreate(phEvent, flags: int) -> int:
|
||||
event_id = cuda_state.next_event_id
|
||||
cuda_state.next_event_id += 1
|
||||
cuda_state.events[event_id] = 0.0
|
||||
phEvent._obj.value = event_id
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventRecord(hEvent: orig_cuda.CUevent, hStream: Any) -> int:
|
||||
if hEvent.value not in cuda_state.events:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
cuda_state.events[hEvent.value] = time.perf_counter_ns()
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventSynchronize(hEvent: orig_cuda.CUevent) -> int:
|
||||
if hEvent.value not in cuda_state.events:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventElapsedTime(pMilliseconds, hStart: orig_cuda.CUevent, hEnd: orig_cuda.CUevent) -> int:
|
||||
if hStart.value not in cuda_state.events or hEnd.value not in cuda_state.events:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
elapsed = (cuda_state.events[hEnd.value] - cuda_state.events[hStart.value]) * 1000
|
||||
pMilliseconds._obj.value = elapsed
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventDestroy_v2(hEvent: orig_cuda.CUevent) -> int:
|
||||
if hEvent.value in cuda_state.events:
|
||||
del cuda_state.events[hEvent.value]
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuModuleLoadData(module, image: bytes) -> int:
|
||||
module_id = cuda_state.next_module_id
|
||||
cuda_state.next_module_id += 1
|
||||
cuda_state.modules[module_id] = memoryview(bytearray(image))
|
||||
module._obj.value = module_id
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuModuleGetFunction(hfunc, hmod: orig_cuda.CUmodule, name: bytes) -> int:
|
||||
if hmod.value not in cuda_state.modules:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
hfunc._obj.value = mv_address(cuda_state.modules[hmod.value])
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuModuleUnload(hmod: orig_cuda.CUmodule) -> int:
|
||||
if hmod.value in cuda_state.modules:
|
||||
del cuda_state.modules[hmod.value]
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuLaunchKernel(f: orig_cuda.CUfunction, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, sharedMemBytes: int,
|
||||
hStream: Any, kernelParams: Any, extra: Any) -> int:
|
||||
cargs = [ctypes.cast(getattr(extra, field[0]), ctypes.c_void_p) for field in extra._fields_]
|
||||
gpuocelot_lib.ptx_run(ctypes.cast(f.value, ctypes.c_char_p), len(cargs), (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuDeviceComputeCapability(major, minor, dev: int) -> int:
|
||||
if dev not in cuda_state.devices:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
major._obj.value = 3
|
||||
minor._obj.value = 5
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuDeviceCanAccessPeer(canAccessPeer, dev: int, peerDev: int) -> int:
|
||||
canAccessPeer._obj.value = 1 # Always allow peer access in simulation
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuCtxEnablePeerAccess(peerContext: orig_cuda.CUcontext, flags: int) -> int:
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemHostAlloc(pp, bytesize: int, flags: int) -> int:
|
||||
return cuMemAlloc_v2(pp, bytesize)
|
||||
|
||||
def cuMemFreeHost(p: ctypes.c_void_p) -> int: return cuMemFree_v2(p)
|
||||
|
||||
def cuMemcpyDtoDAsync_v2(dst: orig_cuda.CUdeviceptr_v2, src: orig_cuda.CUdeviceptr_v2, bytesize: int, stream: Any) -> int:
|
||||
ctypes.memmove(dst.value, src.value, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuFuncSetAttribute(hfunc: orig_cuda.CUfunction, attrib: int, value: int) -> int:
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuStreamWaitEvent(stream: Any, event: orig_cuda.CUevent, flags: int) -> int: return orig_cuda.CUDA_SUCCESS
|
||||
def cuCtxSynchronize() -> int: return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuGetErrorString(error: int, pStr) -> int:
|
||||
error_str = orig_cuda.cudaError_enum__enumvalues.get(error, b"Unknown CUDA error")
|
||||
buf = ctypes.create_string_buffer(error_str)
|
||||
# Set the pointer to point to our error string buffer
|
||||
pStr._obj.value = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char)).value
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
@@ -41,7 +41,7 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.
|
||||
#binary_operations += [(Tensor.maximum, np.maximum)]
|
||||
|
||||
# TODO: CI CUDA segfaults on sin, WEBGPU sin is not precise enough for large numbers
|
||||
if (getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU": unary_operations.remove((Tensor.sin, np.sin))
|
||||
if (getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}) or Device.DEFAULT == "WEBGPU": unary_operations.remove((Tensor.sin, np.sin))
|
||||
|
||||
class ht:
|
||||
float64 = strat.floats(width=64, allow_subnormal=False)
|
||||
@@ -162,7 +162,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
|
||||
|
||||
# Metal and CUDA and HIP behave differently than numpy in CI for overflows
|
||||
skip_overflow = CI and Device.DEFAULT in {"AMD", "NV"}
|
||||
skip_overflow = CI and Device.DEFAULT in {"AMD", "NV", "CUDA"}
|
||||
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestRandomness(unittest.TestCase):
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"NV", "CUDA"}, "gpuocelot doesn't support certain ops needed for threefry")
|
||||
def test_threefry_against_reference(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
@@ -97,6 +97,7 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(jr, r)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "fails with PTX")
|
||||
def test_threefry_doesnt_use_long(self):
|
||||
for ei in lower_schedule(Tensor.rand(20).schedule()):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
|
||||
@@ -13,7 +13,7 @@ settings.load_profile("my_profile")
|
||||
|
||||
class TestTranscendentalMath(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float64, Device.DEFAULT), f"no float64 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashed")
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
|
||||
@given(ht.float64, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
|
||||
def test_float64(self, x, op):
|
||||
if op[0] == Tensor.sin:
|
||||
@@ -24,7 +24,7 @@ class TestTranscendentalMath(unittest.TestCase):
|
||||
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float64))),
|
||||
atol=3e-2, rtol=1e-5) # sin can have bigger atol for very big x
|
||||
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashed")
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
|
||||
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
|
||||
def test_float32(self, x, op):
|
||||
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
|
||||
@@ -56,7 +56,7 @@ class TestFromFuzzer(unittest.TestCase):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype == dtypes.float64:
|
||||
# crashes in CI CUDA
|
||||
if getenv("MOCKGPU") and Device.DEFAULT == "NV": return
|
||||
if getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}: return
|
||||
def _test_value(n: float, unit: float=1.0):
|
||||
next_float = np.nextafter(1.0, 2.0, dtype=_to_np_dtype(dtype))
|
||||
ulp = next_float - 1.0
|
||||
@@ -78,7 +78,7 @@ class TestFromFuzzer(unittest.TestCase):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype == dtypes.float64:
|
||||
# crashes in CI CUDA
|
||||
if getenv("MOCKGPU") and Device.DEFAULT == "NV": return
|
||||
if getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}: return
|
||||
def _test_value(n: float, unit: float=1.0):
|
||||
next_float = np.nextafter(1.0, 2.0, dtype=_to_np_dtype(dtype))
|
||||
ulp = next_float - 1.0
|
||||
|
||||
@@ -7,6 +7,7 @@ from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.runtime.autogen import cuda
|
||||
from tinygrad.runtime.support.compiler_cuda import cuda_disassemble, pretty_ptx, CUDACompiler, PTXCompiler, PTX
|
||||
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
|
||||
@@ -53,6 +54,9 @@ class CUDAProgram:
|
||||
check(cuda.cuCtxSetCurrent(self.dev.context))
|
||||
if not hasattr(self, "vargs"):
|
||||
self.c_args, self.vargs = encode_args(args, vals)
|
||||
|
||||
# HACK: For MOCKGPU send the args struct itself.
|
||||
if MOCKGPU: self.vargs = self.c_args # type: ignore[assignment]
|
||||
else:
|
||||
for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
|
||||
for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
|
||||
@@ -114,7 +118,7 @@ class CUDADevice(Compiled):
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch),
|
||||
PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), graph=CUDAGraph)
|
||||
PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph)
|
||||
|
||||
def synchronize(self):
|
||||
check(cuda.cuCtxSetCurrent(self.context))
|
||||
|
||||
Reference in New Issue
Block a user