mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
improve bf16 case for is_dtype_supported [pr] (#10034)
* fix is_dtype_supported for bf16 * hotfix * add llvm and amd_llvm * gate on machine * separate gpu vs cpu cases * add arm case
This commit is contained in:
@@ -330,8 +330,10 @@ class Compiled:
|
||||
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
||||
if device is None: device = Device.DEFAULT
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
if device in {"METAL", "AMD", "AMD_LLVM"}: return not CI
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
|
||||
if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
||||
return False
|
||||
if dtype in dtypes.fp8s:
|
||||
# not supported yet - in progress
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user