Handle missing bfloat16 natives on CPU architectures (#13553)

* CPU: fix compiler-rt libcall by adding intermediate casts for bfloat16

* fix lint

* remove old manual bypass of bf16 for CPU tests, and add diversion converstion from bf16 to/from fp16

---------

Co-authored-by: Jakob Sachs <jakobs99@purelymail.com>
This commit is contained in:
Jakob Sachs
2025-12-11 21:38:43 +01:00
committed by GitHub
parent cbae33003d
commit ab2220b834
2 changed files with 8 additions and 3 deletions

View File

@@ -17,8 +17,6 @@ pytestmark = pytest.mark.filterwarnings("ignore")
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
# dont cast internal dtypes
@@ -435,6 +433,8 @@ class TestOpsBFloat16(unittest.TestCase):
data = [60000.0, 70000.0, 80000.0]
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
# some CPUs there is no native bfloat16 sqrt
@unittest.skipIf(Device.DEFAULT == "CPU", "no approximation")
def test_no_approximation(self):
data = [326.0, 339.0, 10603200512.0]
expected = torch.tensor(data, dtype=torch.bfloat16).sqrt().float().numpy()
@@ -442,3 +442,4 @@ class TestOpsBFloat16(unittest.TestCase):
if __name__ == '__main__':
unittest.main()

View File

@@ -224,8 +224,12 @@ class ClangRenderer(CStyleLanguage):
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})",
Ops.FDIV: lambda a,b,dtype: f"({a}/{b})"}
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
# LLVM legalizes double => half/bf16 cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
# there is also no native bfl16 <-> fp16 conversion on those CPUs
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
(UPat.var("x", dtypes.float64).cast(dtypes.bfloat16), lambda x: x.cast(dtypes.float32).cast(dtypes.bfloat16)),
(UPat.var("x", dtypes.bfloat16).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
(UPat((Ops.SQRT, Ops.TRUNC), name="alu"), no_vectorized_alu)]) + CStyleLanguage.extra_matcher
if sys.platform == 'win32':