move is_dtype_supported logic to renderer (#14188)

* move is_dtype_supported logic to renderer

* fix CPU_COUNT

* mypy happy

* early import libclang too with llvm

* run with debug

* skip autogen tests if MTLCompiler or llvm is loaded

* run autogen tests separately in CI

* lint
This commit is contained in:
Christopher Milan
2026-01-18 19:37:04 -08:00
committed by GitHub
parent 7abe9b020f
commit 161fee9a48
13 changed files with 75 additions and 46 deletions

View File

@@ -781,9 +781,11 @@ jobs:
ocelot: 'true'
llvm: 'true'
- name: Run unit tests
run: METAL=1 python -m pytest -n=auto test/unit/ --durations=20
- name: Run autogen tests
env:
LIBCLANG_PATH: '/opt/homebrew/opt/llvm@20/lib/libclang.dylib'
run: METAL=1 python -m pytest -n=auto test/unit/ --durations=20
run: METAL=1 python -m pytest -n=auto test/unit/test_autogen.py --durations=20
- name: Run ONNX
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test tensor core ops (fake)

View File

@@ -243,6 +243,7 @@ class TestAutogen(unittest.TestCase):
self.assertEqual(out.contents.b, 10)
@unittest.skipIf(WIN, "doesn't compile on windows")
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
def test_packed_structs(self):
ns = self.run_gen("""
typedef unsigned NvU32;
@@ -333,6 +334,7 @@ typedef struct ip_discovery_header
assert ihdr.base_addr_64_bit == 1
@unittest.skipIf(WIN, "doesn't compile on windows")
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
def test_gen_from_header(self):
namespace = self.run_gen("""
typedef struct {
@@ -379,6 +381,7 @@ typedef struct ip_discovery_header
self.assertTrue(hasattr(rect, 'color'))
@unittest.skipIf(WIN, "doesn't compile on windows")
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
def test_struct_ordering(self):
namespace = self.run_gen("""
struct A;
@@ -479,6 +482,7 @@ typedef struct ip_discovery_header
self.assertEqual(result, 43) # 42 + 1
@unittest.skipIf(WIN, "doesn't compile on windows")
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
def test_anonymous_children(self):
namespace = self.run_gen("""
struct foo {
@@ -492,6 +496,7 @@ typedef struct ip_discovery_header
self.assertIn('struct_foo_bar', namespace)
@unittest.skipIf(WIN, "doesn't compile on windows")
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
def test_enums(self):
namespace = self.run_gen("""
enum Foo { A, B, C };

View File

@@ -2,10 +2,10 @@ from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Any, Generic, TypeVar, Iterator, Generator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
import importlib, inspect, functools, pathlib, os, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -342,35 +342,8 @@ class Compiled:
"""
# override this in your device implementation
# TODO: move this to each Device
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if dtype == dtypes.index: return False
if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16:
if device == "METAL": return not CI
if device == "CUDA": return not CI and not CUDA_PTX
if device == "NV": return not CI and not NV_PTX and not NV_NAK
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP
return device in {"AMD", "CL", "PYTHON", "NULL"}
if dtype in dtypes.fp8s:
if device == "CUDA": return not CI and not CUDA_PTX
if device == "NV": return not CI and not NV_PTX and not NV_NAK
if device == "AMD": return not CI and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)}
return device in {"PYTHON", "NULL"}
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
if device == "CL": return not CI and not OSX
if device == "QCOM": return False # QCOM compiler is flaky with half
if device in ["CUDA", "NV"]: return not CI
if device == "CPU" and CPU_LLVM: return OSX
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device not in {"METAL", "QCOM"} and not (OSX and device == "CL") and not getenv("NULL_IR3")
return True
return dtype != dtypes.index and Device[device or Device.DEFAULT].renderer.is_dtype_supported(dtype)
if PROFILE:
@atexit.register

View File

@@ -4,7 +4,7 @@ import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.dtype import AddrSpace, DType, PtrDType
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
if TYPE_CHECKING: from tinygrad.device import Compiler
@@ -150,3 +150,5 @@ class Renderer:
def __reduce__(self): return self.__class__, ()
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux")
def is_dtype_supported(self, dtype:DType) -> bool: return True

View File

@@ -1,9 +1,9 @@
from typing import Literal, Callable, cast
import os, math, sys, struct
import os, math, platform, sys, struct
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CI, CPU_COUNT, OSX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
@@ -224,7 +224,8 @@ class ClangRenderer(CStyleLanguage):
gep_arr_threshold = 0
has_local = False
has_threads = bool(getenv("THREADS", 1))
global_max = (CPU_COUNT.value, 0, 0)
@property
def global_max(self): return (CPU_COUNT.value, 0, 0) # type: ignore
infinity = "__builtin_inff()"
nan = '__builtin_nanf("")'
code_for_workitem = {"g": lambda _: "core_id"}
@@ -284,10 +285,16 @@ class ClangJITRenderer(ClangRenderer):
from tinygrad.runtime.support.compiler_cpu import ClangJITCompiler
self.compiler = ClangJITCompiler()
def is_dtype_supported(self, dtype:DType) -> bool:
return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} if dtype == dtypes.bfloat16 else dtype not in dtypes.fp8s
class OpenCLRenderer(CStyleLanguage):
device = "CL"
has_aux = True
# CI and OSX, cl_khr_fp16 is not supported
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in ((dtypes.half,) * (OSX or CI)) + ((dtypes.float64,) * OSX) + dtypes.fp8s
# language options
kernel_typedef = "__kernel void"
buffer_prefix = "__global "
@@ -342,6 +349,7 @@ class MetalRenderer(CStyleLanguage):
device = "METAL"
shared_max = 32768
def __init__(self): self.tensor_cores = tc.metal if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
def is_dtype_supported(self, dtype:DType) -> bool: return not CI if dtype == dtypes.bfloat16 else dtype not in (dtypes.float64,) + dtypes.fp8s
# language options
kernel_typedef = "kernel void"
@@ -393,6 +401,9 @@ class CUDARenderer(CStyleLanguage):
self.tensor_cores = tc.cuda_sm89 if arch_ver >= 89 else tc.cuda_sm80 if arch_ver >= 80 else tc.cuda_sm75 if arch_ver >= 75 else []
def __reduce__(self): return self.__class__, (self.arch,)
# CI CUDA is sm_35, so no fp16 ALUs
def is_dtype_supported(self, dtype:DType) -> bool: return not CI or dtype not in (dtypes.bfloat16, dtypes.half) + dtypes.fp8s
# language options
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
kernel_typedef = "extern \"C\" __global__ void __launch_bounds__({launch_bounds})"
@@ -462,6 +473,8 @@ class AMDHIPRenderer(CStyleLanguage):
# NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
global_max = (2147483647, 65535, 65535)
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in dtypes.fp8s or self.arch in {"gfx942", "gfx950"}
@staticmethod
def get_tensor_cores(arch):
return {"gfx942": tc.amd_cdna3, "gfx950": tc.amd_cdna4, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3)
@@ -556,4 +569,8 @@ class AMDHIPCCRenderer(AMDHIPRenderer):
super().__init__(arch)
self.compiler = HIPCCCompiler(arch)
class QCOMRenderer(OpenCLRenderer): device = "QCOM"
class QCOMRenderer(OpenCLRenderer):
device = "QCOM"
# QCOM compiler is flaky with half
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in (dtypes.bfloat16, dtypes.half, dtypes.float64) + dtypes.fp8s

View File

@@ -1,12 +1,12 @@
from typing import cast
import math, struct, sys
import math, platform, struct, sys
from tinygrad.codegen.opt import tc
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import AMDHIPRenderer, create_non_native_float_pats, pm_manual_bf16_cast
from tinygrad.uop.decompositions import xexp2, xlog2
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
from tinygrad.helpers import prod, AMX
from tinygrad.helpers import prod, AMX, CI, OSX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
@@ -142,6 +142,11 @@ class LLVMRenderer(Renderer):
code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None}
if AMX: tensor_cores = tc.amx
def is_dtype_supported(self, dtype:DType) -> bool:
if dtype == dtypes.bfloat16: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
# LLVM can't link to casting function?
return OSX if dtype == dtypes.half else dtype not in dtypes.fp8s
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
@@ -211,6 +216,9 @@ class AMDLLVMRenderer(LLVMRenderer):
has_local = True
shared_max = AMDHIPRenderer.shared_max
global_max = AMDHIPRenderer.global_max
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in dtypes.fp8s or self.arch in {"gfx942", "gfx950"}
abi = "amdgpu_kernel"
code_for_op = {**LLVMRenderer.code_for_op, **{op: lambda: None for op in llvm_intrinsics}}
string_rewrite = PatternMatcher([

View File

@@ -119,6 +119,9 @@ class NIRRenderer(Renderer):
suffix = "NIR"
nir_options: bytes
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in (dtypes.bfloat16,) + dtypes.fp8s
code_for_op = {**{k:lambda:None for k in u_aop.keys()}, **{k:lambda:None for k in s_aop.keys()}, **{k:lambda:None for k in f_aop.keys()}}
extra_matcher = PatternMatcher([
@@ -263,6 +266,9 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
class IR3Renderer(NIRRenderer):
device = "QCOM"
# TODO: can IR3 support half?
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in (dtypes.bfloat16, dtypes.half, dtypes.float64) + dtypes.fp8s
def nload_img(ctx,img,coord):
ctx.texs.add(img)
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)

View File

@@ -6,7 +6,7 @@ from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.helpers import flatten, get_single_element, prod, unwrap
from tinygrad.helpers import flatten, get_single_element, prod, unwrap, CI
def render_val(x, dtype):
if dtypes.is_float(dtype):
@@ -149,6 +149,9 @@ class PTXRenderer(Renderer):
self.tensor_cores = PTXRenderer.tc_sm80 if arch_ver >= 80 else tc.cuda_sm75 if arch_ver >= 75 else []
def __reduce__(self): return self.__class__, (self.arch, self.device)
# CI CUDA is sm_35, so no fp16 ALUs
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in ((dtypes.half,) if CI else ()) + (dtypes.bfloat16,) + dtypes.fp8s
# language options
kernel_prefix = """.version VERSION
.target TARGET

View File

@@ -53,6 +53,9 @@ class WGSLRenderer(CStyleLanguage):
code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[int(x)]})", "l": lambda x: f"i32(lindex.{'xyz'[int(x)]})"}
extra_matcher = wgsl_matcher
supports_float4 = False
def is_dtype_supported(self, dtype:DType) -> bool: return dtype in self.type_map
barrier = "workgroupBarrier();"
code_for_op = {**CStyleLanguage.code_for_op, Ops.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a})"}
nan = "nan()"

View File

@@ -36,6 +36,9 @@ class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
has_threads = False
def is_dtype_supported(self, dtype:DType) -> bool: return dtype not in (dtypes.bfloat16,) + dtypes.fp8s
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_typedef = "__attribute__((noinline)) void"
extra_args = []

View File

@@ -4,12 +4,13 @@ import tinygrad.runtime.support.objc as objc
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent, CompilerSet, CompilerPair
from tinygrad.renderer.cstyle import MetalRenderer
from tinygrad.runtime.autogen import metal
from tinygrad.runtime.support.c import DLL
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
DLL("CoreGraphics", "CoreGraphics")
# FIXME: these need autogen to support objc categories
# https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ObjectiveC/Chapters/ocCategories.html
@@ -67,7 +68,7 @@ class MetalCompiler(Compiler):
# doesn't seem to be anything we can do.
with contextlib.suppress(FileNotFoundError, ModuleNotFoundError):
import tinygrad.runtime.autogen.llvm # noqa: F401
support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
support = DLL("MTLCompiler", "MTLCompiler")
support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
def __init__(self):

View File

@@ -231,6 +231,9 @@ class PythonRenderer(Renderer):
case "": pass
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
# supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
def is_dtype_supported(self, dtype:DType) -> bool: return dtype != dtypes.half or sys.version_info >= (3, 12)
def render(self, uops:list[UOp]) -> str:
# the value of SPECIAL comes from local/global_size, not form its source
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops]

View File

@@ -131,13 +131,15 @@ def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]
class DLL(ctypes.CDLL):
_loaded_: set[str] = set()
@staticmethod
def findlib(nm:str, paths:list[str], extra_paths=[]):
if nm == 'libc' and OSX: return '/usr/lib/libc.dylib'
if pathlib.Path(path:=getenv(nm.replace('-', '_').upper()+"_PATH", '')).is_file(): return path
for p in paths:
libpaths = {"posix": ["/usr/lib64", "/usr/lib", "/usr/local/lib"], "nt": os.environ['PATH'].split(os.pathsep),
"darwin": ["/opt/homebrew/lib", f"/System/Library/Frameworks/{p}.framework"],
"darwin": ["/opt/homebrew/lib", f"/System/Library/Frameworks/{p}.framework", f"/System/Library/PrivateFrameworks/{p}.framework"],
'linux': ['/lib', '/lib64', f"/lib/{sysconfig.get_config_var('MULTIARCH')}", "/usr/lib/wsl/lib/"]}
if (pth:=pathlib.Path(p)).is_absolute():
if pth.is_file(): return p
@@ -154,12 +156,12 @@ class DLL(ctypes.CDLL):
if f.read(4) == b'\x7FELF': return str(l)
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
self.nm, self.emsg, self.loaded = nm, emsg, False
self.nm, self.emsg = nm, emsg
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
if DEBUG >= 3: print(f"loading {nm} from {path}")
try:
super().__init__(path, **kwargs)
self.loaded = True
self._loaded_.add(self.nm)
except OSError as e:
self.emsg = str(e)
if DEBUG >= 3: print(f"loading {nm} failed: {e}")
@@ -175,5 +177,6 @@ class DLL(ctypes.CDLL):
return wrapper
def __getattr__(self, nm):
if not self.loaded: raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
if self.nm not in self._loaded_:
raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
return super().__getattr__(nm)