improve bf16 case for is_dtype_supported [pr] (#10034)

* fix is_dtype_supported for bf16

* hotfix

* add llvm and amd_llvm

* gate on machine

* separate gpu vs cpu cases

* add arm case
This commit is contained in:
Ignacio Sica
2025-04-24 14:03:57 -03:00
committed by GitHub
parent 754d789f51
commit 93a1e9eeb9

View File

@@ -330,8 +330,10 @@ class Compiled:
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
if device in {"METAL", "AMD", "AMD_LLVM"}: return not CI
if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
return False
if dtype in dtypes.fp8s:
# not supported yet - in progress
return False