disable bfloat16 from ptx tests (#4305)

This commit is contained in:
Szymon Ożóg
2024-04-26 07:20:10 +02:00
committed by GitHub
parent ec65aea32f
commit de832d26c6

View File

@@ -25,7 +25,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI)
return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
if device == "CUDA" and getenv("PTX") and dtype in (dtypes.int8, dtypes.uint8): return False
# for CI GPU and OSX, cl_khr_fp16 isn't supported