auto-select available compilers (#12094)

* device: auto select compilers

* fix

* metal+opencl

* nv/cuda

* test without ptx

* ptx

* fix tests

* fix

* fix test

* rename

* test + cleaner

* xx

* ops

* better test

* win?

* um?

* types

* debug

* win??

* sep rung

* wtf?

* debug

* skip win

* revert this

* types
This commit is contained in:
nimlgen
2025-09-10 19:52:01 +03:00
committed by GitHub
parent bb67829e99
commit fb96394ff5
25 changed files with 123 additions and 58 deletions

View File

@@ -197,14 +197,14 @@ jobs:
- name: Test tensor cores
run: |
NV=1 ALLOW_TF32=1 python3 test/opt/test_tensor_cores.py
PTX=1 ALLOW_TF32=1 NV=1 python3 test/opt/test_tensor_cores.py
NV=1 NV_PTX=1 ALLOW_TF32=1 python3 test/opt/test_tensor_cores.py
- name: Run Tensor Core GEMM (CUDA)
run: |
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
- name: Run Tensor Core GEMM (PTX)
run: NV=1 PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
- name: Run Tensor Core GEMM (NV)
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
- name: Test NV=1
@@ -302,7 +302,7 @@ jobs:
- name: Fuzz Padded Tensor Core GEMM (NV)
run: NV=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Fuzz Padded Tensor Core GEMM (PTX)
run: NV=1 PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
run: NV=1 NV_PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Train MNIST
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps

View File

@@ -694,7 +694,7 @@ jobs:
cuda: 'true'
ocelot: 'true'
- name: Set env
run: printf "${{ matrix.backend == 'PTX' && 'CUDA=1\nPTX=1' || matrix.backend == 'nv' && 'NV=1\nSKIP_SLOW_TEST=1' }}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'PTX' && 'CUDA=1\nCUDA_PTX=1' || matrix.backend == 'nv' && 'NV=1\nSKIP_SLOW_TEST=1' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CUDA','NV'], Device.DEFAULT"

View File

@@ -75,7 +75,7 @@ if __name__ == "__main__":
if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA and triton-generated kernel")
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py`
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 CUDA_PTX=1 python3 extra/gemm/triton_nv_matmul.py`
# this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS.
# WMMA element size is (M, N, K) = (16, 8, 16)

View File

@@ -43,7 +43,7 @@ def matmul_kernel(c_ptr, a_ptr, b_ptr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N:
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_ptrs, c)
# CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py
# CUDA=1 CUDA_PTX=1 python3 extra/gemm/triton_nv_matmul.py
if __name__ == "__main__":
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 64
M, N, K = 4096, 4096, 4096

View File

@@ -272,4 +272,4 @@ def compare_launch_state(states, good_states):
return True, "PASS"
# IOCTL=1 PTX=1 CUDA=1 python3 test/test_ops.py TestOps.test_tiny_add
# IOCTL=1 CUDA=1 CUDA_PTX=1 python3 test/test_ops.py TestOps.test_tiny_add

View File

@@ -2,7 +2,7 @@
import unittest, os, subprocess, sys
from tinygrad import Tensor
from tinygrad.device import Device, Compiler
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context, WIN, CI
class TestDevice(unittest.TestCase):
def test_canonicalize(self):
@@ -28,6 +28,42 @@ class TestDevice(unittest.TestCase):
self.assertEqual(Device.canonicalize(None), device)
Device.DEFAULT = device
@unittest.skipIf(WIN and CI, "skipping windows test") # TODO: subproccess causes memory violation?
def test_env_overwrite_default_compiler(self):
expect_failure = "\ntry: assert Device[Device.DEFAULT].compiler is None;\nexcept RuntimeError: pass"
if Device.DEFAULT == "CPU":
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler"
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "1"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "0"})
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0", "CPU_LLVM": "0"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0"})
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "1", "CPU_LLVM": "1"})
elif Device.DEFAULT == "AMD":
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
try: _, _ = HIPCompiler(Device[Device.DEFAULT].arch), AMDLLVMCompiler(Device[Device.DEFAULT].arch)
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler"
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, AMDLLVMCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "1"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "0"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1"})
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
else: self.skipTest("only run on CPU/AMD")
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Any, Generic, TypeVar, Iterator
from typing import Any, Generic, TypeVar, Iterator, Sequence, cast
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM, \
Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import unwrap_class_type
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -272,12 +273,32 @@ class Compiler:
return lib
def disassemble(self, lib:bytes): pass
CompilerPairT = tuple[functools.partial|type[Renderer], functools.partial|type[Compiler]]
class Compiled:
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None):
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
self.renderer, self.group_id = renderer or Renderer(), group_id
def __init__(self, device:str, allocator:Allocator, compilers:Sequence[CompilerPairT]|None, runtime, graph=None, group_id=None):
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
compilers = cast(list[CompilerPairT], compilers or [(Renderer, Compiler)])
devname = device.split(':')[0].upper()
envnames = [f"{devname}_{unwrap_class_type(c).__name__.removesuffix('Compiler').removeprefix(devname).upper()}" for r,c in compilers]
enable_comps = set((en, comp_pair) for en, comp_pair in zip(envnames, compilers) if en is not None and getenv(en, -1) == 1)
disable_comps = set((en, comp_pair) for en, comp_pair in zip(envnames, compilers) if en is not None and getenv(en, -1) == 0)
if len(enable_comps) > 1: raise RuntimeError(f"{self.device}: multiple compilers set in env {enable_comps}")
for _, comp_pair in disable_comps: compilers.remove(comp_pair)
try: self.renderer, self.compiler = next(self._get_available_compilers([list(enable_comps)[0][1]] if len(enable_comps) == 1 else compilers))
except StopIteration as exc: raise RuntimeError(f"no usable compilers for {self.device}") from exc
if DEBUG >= 1: print(f"{self.device}: using {self.compiler.__class__.__name__}")
def _get_available_compilers(self, compilers) -> Iterator[tuple[Renderer, Compiler]]:
for renderer, compiler in compilers:
with contextlib.suppress(Exception): yield renderer(), compiler()
def synchronize(self):
"""
Synchronize all pending operations on the device.
@@ -302,7 +323,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16:
if device == "METAL": return not CI
if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
return device in {"AMD", "PYTHON"}
if dtype in dtypes.fp8s:

View File

@@ -89,6 +89,8 @@ def suppress_finalizing(func):
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
return wrapper
def unwrap_class_type(cls_t:T): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__

View File

@@ -6,8 +6,8 @@ from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, FileIOInterface
from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator
from tinygrad.uop.ops import sint
from tinygrad.device import Compiled, DMAFdRef, BufferSpec
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, DEBUG, AMD_LLVM, PROFILE, ProfileEvent, suppress_finalizing, lo32, hi32
from tinygrad.device import Compiled, DMAFdRef, BufferSpec, CompilerPairT
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, suppress_finalizing, lo32, hi32
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt
@@ -785,9 +785,11 @@ class AMDDevice(HCQCompiled):
max_copy_size = 0x40000000 if self.iface.ip_versions[am.SDMA0_HWIP][0] >= 5 else 0x400000
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x200 if self.is_usb() else (16 << 20))
super().__init__(device, AMDAllocator(self), AMDLLVMRenderer(self.arch) if AMD_LLVM else AMDRenderer(self.arch),
AMDLLVMCompiler(self.arch) if AMD_LLVM else HIPCompiler(self.arch), functools.partial(AMDProgram, self),
AMDSignal, functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
compilers:list[CompilerPairT] = [(functools.partial(AMDLLVMRenderer, self.arch), functools.partial(AMDLLVMCompiler, self.arch)),
(functools.partial(AMDRenderer, self.arch), functools.partial(HIPCompiler, self.arch))]
super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal,
functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
functools.partial(AMDCopyQueue, self, max_copy_size=max_copy_size),
kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import platform, sys, ctypes, functools, time, mmap, threading, queue
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, CPU_LLVM, suppress_finalizing
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, suppress_finalizing
from tinygrad.device import BufferSpec, DMACPURef
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
from tinygrad.renderer.cstyle import ClangRenderer
@@ -116,5 +116,5 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
super().__init__(device, CPUAllocator(self), LLVMRenderer() if CPU_LLVM else ClangRenderer(),
CPULLVMCompiler() if CPU_LLVM else ClangJITCompiler(), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
compilers = [(ClangRenderer, ClangJITCompiler), (LLVMRenderer, CPULLVMCompiler)]
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
import ctypes, ctypes.util, functools
from tinygrad.helpers import DEBUG, getenv, mv_address, init_c_var, init_c_struct_t, suppress_finalizing
from tinygrad.device import Compiled, BufferSpec, LRUAllocator
from tinygrad.device import Compiled, BufferSpec, LRUAllocator, CompilerPairT
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.runtime.autogen import cuda
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler, PTX
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported
@@ -115,8 +115,9 @@ class CUDADevice(Compiled):
CUDADevice.devices.append(self)
from tinygrad.runtime.graph.cuda import CUDAGraph
super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch),
PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph)
compilers:list[CompilerPairT] = [(functools.partial(CUDARenderer, self.arch), functools.partial(CUDACompiler, self.arch)),
(functools.partial(PTXRenderer, self.arch), functools.partial(PTXCompiler, self.arch))]
super().__init__(device, CUDAAllocator(self), compilers, functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph)
def synchronize(self):
check(cuda.cuCtxSetCurrent(self.context))

View File

@@ -15,7 +15,7 @@ class DiskDevice(Compiled):
self.size: int|None = None
self.fd: int|None = None
self.count = 0
super().__init__(device, DiskAllocator(self), None, None, None)
super().__init__(device, DiskAllocator(self), None, None)
def _might_open(self, size:int):
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
if self.size is not None and hasattr(self.device, "mem"):

View File

@@ -134,8 +134,8 @@ class DSPDevice(Compiled):
def __init__(self, device:str=""):
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b"]
if getenv("MOCKDSP"):
super().__init__(device, CPUAllocator(self), MockDSPRenderer(),
ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram)
mock_compilers = [(MockDSPRenderer, functools.partial(ClangCompiler, None, ["-static"] + compiler_args, 'llvm-objdump'))]
super().__init__(device, CPUAllocator(self), mock_compilers, MockDSPProgram)
else:
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
@@ -146,8 +146,9 @@ class DSPDevice(Compiled):
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
self.link_ld.flush()
super().__init__(device, DSPAllocator(self), DSPRenderer(),
ClangCompiler("compile_dsp", ["-shared"] + compiler_args + [f"-T{self.link_ld.name}"], 'llvm-objdump'), functools.partial(DSPProgram, self))
compilers = [(DSPRenderer, functools.partial(ClangCompiler, "compile_dsp", ["-shared"] + compiler_args + [f"-T{self.link_ld.name}"],
'llvm-objdump'))]
super().__init__(device, DSPAllocator(self), compilers, functools.partial(DSPProgram, self))
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import cast
import ctypes, functools, hashlib
from tinygrad.runtime.autogen import opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, getenv, mv_address, suppress_finalizing
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError
@@ -108,9 +108,9 @@ class CLDevice(Compiled):
self.pending_copyin: list[memoryview] = []
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
renderer = IntelRenderer() if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts and getenv("INTEL") else OpenCLRenderer()
super().__init__(device, CLAllocator(self), renderer, CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
compilers = [(IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer,
functools.partial(CLCompiler, self, f"compile_cl_{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}"))]
super().__init__(device, CLAllocator(self), compilers, functools.partial(CLProgram, self))
def synchronize(self):
check(cl.clFinish(self.queue))
self.pending_copyin.clear()

View File

@@ -14,7 +14,9 @@ class HIPDevice(Compiled):
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode()
self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
super().__init__(device, HIPAllocator(self), HIPRenderer(self.arch), HIPCompiler(self.arch), functools.partial(HIPProgram, self))
compilers = [(functools.partial(HIPRenderer, self.arch), functools.partial(HIPCompiler, self.arch))]
super().__init__(device, HIPAllocator(self), compilers, functools.partial(HIPProgram, self))
def synchronize(self):
check(hip.hipSetDevice(self.device_id))
check(hip.hipDeviceSynchronize())

View File

@@ -76,7 +76,7 @@ class MetalDevice(Compiled):
from tinygrad.runtime.graph.metal import MetalGraph
# NOTE: GitHub CI macOS runners use paravirtualized metal which is broken with graph.
# This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework.
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)],
functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(msg('name')(self.sysdevice)).lower() else None)
def synchronize(self):

View File

@@ -8,4 +8,4 @@ class NpyAllocator(Allocator['NpyDevice']):
def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = self._as_buffer(src)
class NpyDevice(Compiled):
def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None, None)
def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None)

View File

@@ -29,5 +29,5 @@ class NullGraph(MultiGraphRunner):
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3
class NullDevice(Compiled):
def __init__(self, device:str): super().__init__(device, NullAllocator(self), NullRenderer(), Compiler(), functools.partial(NullProgram, device),
def __init__(self, device:str): super().__init__(device, NullAllocator(self), [(NullRenderer, Compiler)], functools.partial(NullProgram, device),
NullGraph)

View File

@@ -6,11 +6,11 @@ from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator
from tinygrad.runtime.support.hcq import MMIOInterface, FileIOInterface, MOCKGPU
from tinygrad.uop.ops import sint
from tinygrad.device import BufferSpec
from tinygrad.device import BufferSpec, CompilerPairT
from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, prod, OSX, to_mv, hi32, lo32, suppress_finalizing
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import NVRenderer
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, PTX, NVPTXCompiler, NVCompiler
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler
from tinygrad.runtime.autogen import nv_gpu, pci
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
@@ -525,9 +525,9 @@ class NVDevice(HCQCompiled[HCQSignal]):
self.arch: str = "sm_120" if self.sm_version==0xa04 else f"sm_{(self.sm_version>>8)&0xff}{(val>>4) if (val:=self.sm_version&0xff) > 0xf else val}"
self.sass_version = ((self.sm_version & 0xf00) >> 4) | (self.sm_version & 0xf)
compiler_t = (PTXCompiler if PTX else CUDACompiler) if MOCKGPU else (NVPTXCompiler if PTX else NVCompiler)
super().__init__(device, NVAllocator(self), PTXRenderer(self.arch, device="NV") if PTX else NVRenderer(self.arch), compiler_t(self.arch),
functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
compilers:list[CompilerPairT] = [(functools.partial(NVRenderer, self.arch),functools.partial(CUDACompiler if MOCKGPU else NVCompiler, self.arch)),
(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(PTXCompiler if MOCKGPU else NVPTXCompiler, self.arch))]
super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
self._setup_gpfifos()

View File

@@ -236,4 +236,4 @@ class PythonAllocator(Allocator['PythonDevice']):
def _copyout(self, dest:memoryview, src): dest[:] = src
class PythonDevice(Compiled):
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram)
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), [(PythonRenderer, PythonCompiler)], PythonProgram)

View File

@@ -341,8 +341,8 @@ class QCOMDevice(HCQCompiled):
QCOMDevice.gpu_id = ((info.chip_id >> 24) & 0xFF) * 100 + ((info.chip_id >> 16) & 0xFF) * 10 + ((info.chip_id >> 8) & 0xFF)
if QCOMDevice.gpu_id >= 700: raise RuntimeError(f"Unsupported GPU: {QCOMDevice.gpu_id}")
super().__init__(device, QCOMAllocator(self), QCOMRenderer(), QCOMCompiler(device), functools.partial(QCOMProgram, self),
QCOMSignal, QCOMComputeQueue, None)
compilers = [(QCOMRenderer, functools.partial(QCOMCompiler, device))]
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal, QCOMComputeQueue, None)
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
flags |= kgsl.KGSL_MEMALIGN(alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP

View File

@@ -471,10 +471,11 @@ class RemoteDevice(Compiled):
if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
renderer_instance = renderer_class(*renderer[2])
renderer_instance.device = device
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn))
compilers = [(functools.partial(renderer_class, *renderer[2]), Compiler)]
super().__init__(device, RemoteAllocator(self), compilers, functools.partial(RemoteProgram, self), graph, id(self.conn))
self.renderer.device = device
def finalize(self):
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)

View File

@@ -217,7 +217,7 @@ class WebGpuDevice(Compiled):
device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc)
super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(),
super().__init__(device, WebGpuAllocator(device_res), [(WGSLRenderer, Compiler)],
functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
def synchronize(self):

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
import tinygrad.runtime.autogen.nvrtc as nvrtc
from tinygrad.device import Compiler, CompileError
PTX, CUDA_PATH = getenv("PTX"), getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist
CUDA_PATH = getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist
def _get_bytes(arg, get_str, get_sz, check) -> bytes:
sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
from typing import cast, Callable, Type, TypeVar, Generic, Any
from typing import cast, Callable, Type, TypeVar, Generic, Any, Sequence
import contextlib, decimal, statistics, time, ctypes, array, os, struct, traceback, collections
try: import fcntl # windows misses that
except ImportError: fcntl = None #type:ignore[assignment]
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent, CompilerPairT
from tinygrad.uop.ops import sym_infer, sint, UOp
from tinygrad.runtime.autogen import libc
@@ -359,12 +358,12 @@ class HCQCompiled(Compiled, Generic[SignalType]):
signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
cpu_devices: list[HCQCompiled] = []
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:Sequence[CompilerPairT], runtime, signal_t:Type[SignalType],
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
super().__init__(device, allocator, compilers, runtime, HCQGraph)
# TODO: peer logic is determined based on device name.
self.peer_group = device.split(":")[0]