From 93a1e9eeb9a7385cda6029f9fb250f54def573cd Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Thu, 24 Apr 2025 14:03:57 -0300 Subject: [PATCH] improve `bf16` case for `is_dtype_supported` [pr] (#10034) * fix is_dtype_supported for bf16 * hotfix * add llvm and amd_llvm * gate on machine * separate gpu vs cpu cases * add arm case --- tinygrad/device.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 3764ec487a..ac62a90a20 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -330,8 +330,10 @@ class Compiled: def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool: if device is None: device = Device.DEFAULT if dtype == dtypes.bfloat16: - # NOTE: this requires bf16 buffer support - return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) + if device in {"METAL", "AMD", "AMD_LLVM"}: return not CI + if device in {"CUDA", "NV"}: return not CI and not getenv("PTX") + if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} + return False if dtype in dtypes.fp8s: # not supported yet - in progress return False