mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
auto-select available compilers (#12094)
* device: auto select compilers * fix * metal+opencl * nv/cuda * test without ptx * ptx * fix tests * fix * fix test * rename * test + cleaner * xx * ops * better test * win? * um? * types * debug * win?? * sep rung * wtf? * debug * skip win * revert this * types
This commit is contained in:
6
.github/workflows/benchmark.yml
vendored
6
.github/workflows/benchmark.yml
vendored
@@ -197,14 +197,14 @@ jobs:
|
||||
- name: Test tensor cores
|
||||
run: |
|
||||
NV=1 ALLOW_TF32=1 python3 test/opt/test_tensor_cores.py
|
||||
PTX=1 ALLOW_TF32=1 NV=1 python3 test/opt/test_tensor_cores.py
|
||||
NV=1 NV_PTX=1 ALLOW_TF32=1 python3 test/opt/test_tensor_cores.py
|
||||
- name: Run Tensor Core GEMM (CUDA)
|
||||
run: |
|
||||
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
|
||||
- name: Run Tensor Core GEMM (PTX)
|
||||
run: NV=1 PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
|
||||
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
|
||||
- name: Run Tensor Core GEMM (NV)
|
||||
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
|
||||
- name: Test NV=1
|
||||
@@ -302,7 +302,7 @@ jobs:
|
||||
- name: Fuzz Padded Tensor Core GEMM (NV)
|
||||
run: NV=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM (PTX)
|
||||
run: NV=1 PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
run: NV=1 NV_PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -694,7 +694,7 @@ jobs:
|
||||
cuda: 'true'
|
||||
ocelot: 'true'
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'PTX' && 'CUDA=1\nPTX=1' || matrix.backend == 'nv' && 'NV=1\nSKIP_SLOW_TEST=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'PTX' && 'CUDA=1\nCUDA_PTX=1' || matrix.backend == 'nv' && 'NV=1\nSKIP_SLOW_TEST=1' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CUDA','NV'], Device.DEFAULT"
|
||||
|
||||
@@ -75,7 +75,7 @@ if __name__ == "__main__":
|
||||
|
||||
if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
|
||||
print("Using CUDA and triton-generated kernel")
|
||||
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py`
|
||||
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 CUDA_PTX=1 python3 extra/gemm/triton_nv_matmul.py`
|
||||
# this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS.
|
||||
|
||||
# WMMA element size is (M, N, K) = (16, 8, 16)
|
||||
|
||||
@@ -43,7 +43,7 @@ def matmul_kernel(c_ptr, a_ptr, b_ptr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N:
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
# CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py
|
||||
# CUDA=1 CUDA_PTX=1 python3 extra/gemm/triton_nv_matmul.py
|
||||
if __name__ == "__main__":
|
||||
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 64
|
||||
M, N, K = 4096, 4096, 4096
|
||||
|
||||
@@ -272,4 +272,4 @@ def compare_launch_state(states, good_states):
|
||||
|
||||
return True, "PASS"
|
||||
|
||||
# IOCTL=1 PTX=1 CUDA=1 python3 test/test_ops.py TestOps.test_tiny_add
|
||||
# IOCTL=1 CUDA=1 CUDA_PTX=1 python3 test/test_ops.py TestOps.test_tiny_add
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest, os, subprocess, sys
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.device import Device, Compiler
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context, WIN, CI
|
||||
|
||||
class TestDevice(unittest.TestCase):
|
||||
def test_canonicalize(self):
|
||||
@@ -28,6 +28,42 @@ class TestDevice(unittest.TestCase):
|
||||
self.assertEqual(Device.canonicalize(None), device)
|
||||
Device.DEFAULT = device
|
||||
|
||||
@unittest.skipIf(WIN and CI, "skipping windows test") # TODO: subproccess causes memory violation?
|
||||
def test_env_overwrite_default_compiler(self):
|
||||
expect_failure = "\ntry: assert Device[Device.DEFAULT].compiler is None;\nexcept RuntimeError: pass"
|
||||
|
||||
if Device.DEFAULT == "CPU":
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
|
||||
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
|
||||
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
|
||||
|
||||
imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler"
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "1"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0", "CPU_LLVM": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "1", "CPU_LLVM": "1"})
|
||||
elif Device.DEFAULT == "AMD":
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
|
||||
try: _, _ = HIPCompiler(Device[Device.DEFAULT].arch), AMDLLVMCompiler(Device[Device.DEFAULT].arch)
|
||||
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
|
||||
|
||||
imports = "from tinygrad import Device; from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler"
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, AMDLLVMCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "1"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1"})
|
||||
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
|
||||
else: self.skipTest("only run on CPU/AMD")
|
||||
|
||||
class MockCompiler(Compiler):
|
||||
def __init__(self, key): super().__init__(key)
|
||||
def compile(self, src) -> bytes: return src.encode()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator
|
||||
from typing import Any, Generic, TypeVar, Iterator, Sequence, cast
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM, \
|
||||
Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
|
||||
from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.helpers import unwrap_class_type
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -272,12 +273,32 @@ class Compiler:
|
||||
return lib
|
||||
def disassemble(self, lib:bytes): pass
|
||||
|
||||
CompilerPairT = tuple[functools.partial|type[Renderer], functools.partial|type[Compiler]]
|
||||
class Compiled:
|
||||
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
|
||||
|
||||
def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None):
|
||||
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
||||
self.renderer, self.group_id = renderer or Renderer(), group_id
|
||||
def __init__(self, device:str, allocator:Allocator, compilers:Sequence[CompilerPairT]|None, runtime, graph=None, group_id=None):
|
||||
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
|
||||
compilers = cast(list[CompilerPairT], compilers or [(Renderer, Compiler)])
|
||||
|
||||
devname = device.split(':')[0].upper()
|
||||
envnames = [f"{devname}_{unwrap_class_type(c).__name__.removesuffix('Compiler').removeprefix(devname).upper()}" for r,c in compilers]
|
||||
|
||||
enable_comps = set((en, comp_pair) for en, comp_pair in zip(envnames, compilers) if en is not None and getenv(en, -1) == 1)
|
||||
disable_comps = set((en, comp_pair) for en, comp_pair in zip(envnames, compilers) if en is not None and getenv(en, -1) == 0)
|
||||
|
||||
if len(enable_comps) > 1: raise RuntimeError(f"{self.device}: multiple compilers set in env {enable_comps}")
|
||||
for _, comp_pair in disable_comps: compilers.remove(comp_pair)
|
||||
|
||||
try: self.renderer, self.compiler = next(self._get_available_compilers([list(enable_comps)[0][1]] if len(enable_comps) == 1 else compilers))
|
||||
except StopIteration as exc: raise RuntimeError(f"no usable compilers for {self.device}") from exc
|
||||
|
||||
if DEBUG >= 1: print(f"{self.device}: using {self.compiler.__class__.__name__}")
|
||||
|
||||
def _get_available_compilers(self, compilers) -> Iterator[tuple[Renderer, Compiler]]:
|
||||
for renderer, compiler in compilers:
|
||||
with contextlib.suppress(Exception): yield renderer(), compiler()
|
||||
|
||||
def synchronize(self):
|
||||
"""
|
||||
Synchronize all pending operations on the device.
|
||||
@@ -302,7 +323,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if device is None: device = Device.DEFAULT
|
||||
if dtype == dtypes.bfloat16:
|
||||
if device == "METAL": return not CI
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX")
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
||||
return device in {"AMD", "PYTHON"}
|
||||
if dtype in dtypes.fp8s:
|
||||
|
||||
@@ -89,6 +89,8 @@ def suppress_finalizing(func):
|
||||
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
|
||||
return wrapper
|
||||
|
||||
def unwrap_class_type(cls_t:T): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t
|
||||
|
||||
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
||||
|
||||
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
|
||||
|
||||
@@ -6,8 +6,8 @@ from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, FileIOInterface
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator
|
||||
from tinygrad.uop.ops import sint
|
||||
from tinygrad.device import Compiled, DMAFdRef, BufferSpec
|
||||
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, DEBUG, AMD_LLVM, PROFILE, ProfileEvent, suppress_finalizing, lo32, hi32
|
||||
from tinygrad.device import Compiled, DMAFdRef, BufferSpec, CompilerPairT
|
||||
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, suppress_finalizing, lo32, hi32
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt
|
||||
@@ -785,9 +785,11 @@ class AMDDevice(HCQCompiled):
|
||||
max_copy_size = 0x40000000 if self.iface.ip_versions[am.SDMA0_HWIP][0] >= 5 else 0x400000
|
||||
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x200 if self.is_usb() else (16 << 20))
|
||||
|
||||
super().__init__(device, AMDAllocator(self), AMDLLVMRenderer(self.arch) if AMD_LLVM else AMDRenderer(self.arch),
|
||||
AMDLLVMCompiler(self.arch) if AMD_LLVM else HIPCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
AMDSignal, functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
|
||||
compilers:list[CompilerPairT] = [(functools.partial(AMDLLVMRenderer, self.arch), functools.partial(AMDLLVMCompiler, self.arch)),
|
||||
(functools.partial(AMDRenderer, self.arch), functools.partial(HIPCompiler, self.arch))]
|
||||
|
||||
super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal,
|
||||
functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
|
||||
functools.partial(AMDCopyQueue, self, max_copy_size=max_copy_size),
|
||||
kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import platform, sys, ctypes, functools, time, mmap, threading, queue
|
||||
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, CPU_LLVM, suppress_finalizing
|
||||
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, suppress_finalizing
|
||||
from tinygrad.device import BufferSpec, DMACPURef
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
@@ -116,5 +116,5 @@ class CPUDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
super().__init__(device, CPUAllocator(self), LLVMRenderer() if CPU_LLVM else ClangRenderer(),
|
||||
CPULLVMCompiler() if CPU_LLVM else ClangJITCompiler(), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
compilers = [(ClangRenderer, ClangJITCompiler), (LLVMRenderer, CPULLVMCompiler)]
|
||||
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, ctypes.util, functools
|
||||
from tinygrad.helpers import DEBUG, getenv, mv_address, init_c_var, init_c_struct_t, suppress_finalizing
|
||||
from tinygrad.device import Compiled, BufferSpec, LRUAllocator
|
||||
from tinygrad.device import Compiled, BufferSpec, LRUAllocator, CompilerPairT
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.runtime.autogen import cuda
|
||||
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler, PTX
|
||||
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler
|
||||
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported
|
||||
|
||||
@@ -115,8 +115,9 @@ class CUDADevice(Compiled):
|
||||
CUDADevice.devices.append(self)
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch),
|
||||
PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph)
|
||||
compilers:list[CompilerPairT] = [(functools.partial(CUDARenderer, self.arch), functools.partial(CUDACompiler, self.arch)),
|
||||
(functools.partial(PTXRenderer, self.arch), functools.partial(PTXCompiler, self.arch))]
|
||||
super().__init__(device, CUDAAllocator(self), compilers, functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph)
|
||||
|
||||
def synchronize(self):
|
||||
check(cuda.cuCtxSetCurrent(self.context))
|
||||
|
||||
@@ -15,7 +15,7 @@ class DiskDevice(Compiled):
|
||||
self.size: int|None = None
|
||||
self.fd: int|None = None
|
||||
self.count = 0
|
||||
super().__init__(device, DiskAllocator(self), None, None, None)
|
||||
super().__init__(device, DiskAllocator(self), None, None)
|
||||
def _might_open(self, size:int):
|
||||
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
|
||||
if self.size is not None and hasattr(self.device, "mem"):
|
||||
|
||||
@@ -134,8 +134,8 @@ class DSPDevice(Compiled):
|
||||
def __init__(self, device:str=""):
|
||||
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b"]
|
||||
if getenv("MOCKDSP"):
|
||||
super().__init__(device, CPUAllocator(self), MockDSPRenderer(),
|
||||
ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram)
|
||||
mock_compilers = [(MockDSPRenderer, functools.partial(ClangCompiler, None, ["-static"] + compiler_args, 'llvm-objdump'))]
|
||||
super().__init__(device, CPUAllocator(self), mock_compilers, MockDSPProgram)
|
||||
else:
|
||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
||||
@@ -146,8 +146,9 @@ class DSPDevice(Compiled):
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.flush()
|
||||
|
||||
super().__init__(device, DSPAllocator(self), DSPRenderer(),
|
||||
ClangCompiler("compile_dsp", ["-shared"] + compiler_args + [f"-T{self.link_ld.name}"], 'llvm-objdump'), functools.partial(DSPProgram, self))
|
||||
compilers = [(DSPRenderer, functools.partial(ClangCompiler, "compile_dsp", ["-shared"] + compiler_args + [f"-T{self.link_ld.name}"],
|
||||
'llvm-objdump'))]
|
||||
super().__init__(device, DSPAllocator(self), compilers, functools.partial(DSPProgram, self))
|
||||
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
|
||||
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
|
||||
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import cast
|
||||
import ctypes, functools, hashlib
|
||||
from tinygrad.runtime.autogen import opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, getenv, mv_address, suppress_finalizing
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
|
||||
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError
|
||||
|
||||
@@ -108,9 +108,9 @@ class CLDevice(Compiled):
|
||||
self.pending_copyin: list[memoryview] = []
|
||||
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
|
||||
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
|
||||
renderer = IntelRenderer() if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts and getenv("INTEL") else OpenCLRenderer()
|
||||
super().__init__(device, CLAllocator(self), renderer, CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
|
||||
compilers = [(IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer,
|
||||
functools.partial(CLCompiler, self, f"compile_cl_{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}"))]
|
||||
super().__init__(device, CLAllocator(self), compilers, functools.partial(CLProgram, self))
|
||||
def synchronize(self):
|
||||
check(cl.clFinish(self.queue))
|
||||
self.pending_copyin.clear()
|
||||
|
||||
@@ -14,7 +14,9 @@ class HIPDevice(Compiled):
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode()
|
||||
self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
|
||||
super().__init__(device, HIPAllocator(self), HIPRenderer(self.arch), HIPCompiler(self.arch), functools.partial(HIPProgram, self))
|
||||
|
||||
compilers = [(functools.partial(HIPRenderer, self.arch), functools.partial(HIPCompiler, self.arch))]
|
||||
super().__init__(device, HIPAllocator(self), compilers, functools.partial(HIPProgram, self))
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device_id))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
|
||||
@@ -76,7 +76,7 @@ class MetalDevice(Compiled):
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
# NOTE: GitHub CI macOS runners use paravirtualized metal which is broken with graph.
|
||||
# This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework.
|
||||
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
|
||||
super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)],
|
||||
functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(msg('name')(self.sysdevice)).lower() else None)
|
||||
|
||||
def synchronize(self):
|
||||
|
||||
@@ -8,4 +8,4 @@ class NpyAllocator(Allocator['NpyDevice']):
|
||||
def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = self._as_buffer(src)
|
||||
|
||||
class NpyDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None, None)
|
||||
def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None)
|
||||
|
||||
@@ -29,5 +29,5 @@ class NullGraph(MultiGraphRunner):
|
||||
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3
|
||||
|
||||
class NullDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, NullAllocator(self), NullRenderer(), Compiler(), functools.partial(NullProgram, device),
|
||||
def __init__(self, device:str): super().__init__(device, NullAllocator(self), [(NullRenderer, Compiler)], functools.partial(NullProgram, device),
|
||||
NullGraph)
|
||||
|
||||
@@ -6,11 +6,11 @@ from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface, FileIOInterface, MOCKGPU
|
||||
from tinygrad.uop.ops import sint
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.device import BufferSpec, CompilerPairT
|
||||
from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, prod, OSX, to_mv, hi32, lo32, suppress_finalizing
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.cstyle import NVRenderer
|
||||
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, PTX, NVPTXCompiler, NVCompiler
|
||||
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler
|
||||
from tinygrad.runtime.autogen import nv_gpu, pci
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
|
||||
@@ -525,9 +525,9 @@ class NVDevice(HCQCompiled[HCQSignal]):
|
||||
self.arch: str = "sm_120" if self.sm_version==0xa04 else f"sm_{(self.sm_version>>8)&0xff}{(val>>4) if (val:=self.sm_version&0xff) > 0xf else val}"
|
||||
self.sass_version = ((self.sm_version & 0xf00) >> 4) | (self.sm_version & 0xf)
|
||||
|
||||
compiler_t = (PTXCompiler if PTX else CUDACompiler) if MOCKGPU else (NVPTXCompiler if PTX else NVCompiler)
|
||||
super().__init__(device, NVAllocator(self), PTXRenderer(self.arch, device="NV") if PTX else NVRenderer(self.arch), compiler_t(self.arch),
|
||||
functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
|
||||
compilers:list[CompilerPairT] = [(functools.partial(NVRenderer, self.arch),functools.partial(CUDACompiler if MOCKGPU else NVCompiler, self.arch)),
|
||||
(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(PTXCompiler if MOCKGPU else NVPTXCompiler, self.arch))]
|
||||
super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
|
||||
|
||||
self._setup_gpfifos()
|
||||
|
||||
|
||||
@@ -236,4 +236,4 @@ class PythonAllocator(Allocator['PythonDevice']):
|
||||
def _copyout(self, dest:memoryview, src): dest[:] = src
|
||||
|
||||
class PythonDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram)
|
||||
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), [(PythonRenderer, PythonCompiler)], PythonProgram)
|
||||
|
||||
@@ -341,8 +341,8 @@ class QCOMDevice(HCQCompiled):
|
||||
QCOMDevice.gpu_id = ((info.chip_id >> 24) & 0xFF) * 100 + ((info.chip_id >> 16) & 0xFF) * 10 + ((info.chip_id >> 8) & 0xFF)
|
||||
if QCOMDevice.gpu_id >= 700: raise RuntimeError(f"Unsupported GPU: {QCOMDevice.gpu_id}")
|
||||
|
||||
super().__init__(device, QCOMAllocator(self), QCOMRenderer(), QCOMCompiler(device), functools.partial(QCOMProgram, self),
|
||||
QCOMSignal, QCOMComputeQueue, None)
|
||||
compilers = [(QCOMRenderer, functools.partial(QCOMCompiler, device))]
|
||||
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal, QCOMComputeQueue, None)
|
||||
|
||||
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
|
||||
flags |= kgsl.KGSL_MEMALIGN(alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
|
||||
|
||||
@@ -471,10 +471,11 @@ class RemoteDevice(Compiled):
|
||||
if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
||||
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
||||
renderer_instance = renderer_class(*renderer[2])
|
||||
renderer_instance.device = device
|
||||
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
|
||||
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn))
|
||||
compilers = [(functools.partial(renderer_class, *renderer[2]), Compiler)]
|
||||
super().__init__(device, RemoteAllocator(self), compilers, functools.partial(RemoteProgram, self), graph, id(self.conn))
|
||||
self.renderer.device = device
|
||||
|
||||
def finalize(self):
|
||||
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
|
||||
|
||||
@@ -217,7 +217,7 @@ class WebGpuDevice(Compiled):
|
||||
device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
|
||||
webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc)
|
||||
|
||||
super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(),
|
||||
super().__init__(device, WebGpuAllocator(device_res), [(WGSLRenderer, Compiler)],
|
||||
functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
|
||||
|
||||
def synchronize(self):
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
|
||||
import tinygrad.runtime.autogen.nvrtc as nvrtc
|
||||
from tinygrad.device import Compiler, CompileError
|
||||
|
||||
PTX, CUDA_PATH = getenv("PTX"), getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist
|
||||
CUDA_PATH = getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist
|
||||
|
||||
def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
||||
sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from typing import cast, Callable, Type, TypeVar, Generic, Any
|
||||
from typing import cast, Callable, Type, TypeVar, Generic, Any, Sequence
|
||||
import contextlib, decimal, statistics, time, ctypes, array, os, struct, traceback, collections
|
||||
try: import fcntl # windows misses that
|
||||
except ImportError: fcntl = None #type:ignore[assignment]
|
||||
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent, CompilerPairT
|
||||
from tinygrad.uop.ops import sym_infer, sint, UOp
|
||||
from tinygrad.runtime.autogen import libc
|
||||
|
||||
@@ -359,12 +358,12 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
|
||||
cpu_devices: list[HCQCompiled] = []
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:Sequence[CompilerPairT], runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
|
||||
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
super().__init__(device, allocator, compilers, runtime, HCQGraph)
|
||||
|
||||
# TODO: peer logic is determined based on device name.
|
||||
self.peer_group = device.split(":")[0]
|
||||
|
||||
Reference in New Issue
Block a user