Handle missing bfloat16 natives on CPU architectures (#13553)

* CPU: fix compiler-rt libcall by adding intermediate casts for bfloat16

* fix lint

* remove old manual bypass of bf16 for CPU tests, and add diversion converstion from bf16 to/from fp16

---------

Co-authored-by: Jakob Sachs <jakobs99@purelymail.com>
This commit is contained in:
Jakob Sachs
2025-12-11 21:38:43 +01:00
committed by GitHub
parent cbae33003d
commit ab2220b834
2 changed files with 8 additions and 3 deletions

View File

@@ -17,8 +17,6 @@ pytestmark = pytest.mark.filterwarnings("ignore")
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
# dont cast internal dtypes
@@ -435,6 +433,8 @@ class TestOpsBFloat16(unittest.TestCase):
data = [60000.0, 70000.0, 80000.0]
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
# some CPUs there is no native bfloat16 sqrt
@unittest.skipIf(Device.DEFAULT == "CPU", "no approximation")
def test_no_approximation(self):
data = [326.0, 339.0, 10603200512.0]
expected = torch.tensor(data, dtype=torch.bfloat16).sqrt().float().numpy()
@@ -442,3 +442,4 @@ class TestOpsBFloat16(unittest.TestCase):
if __name__ == '__main__':
unittest.main()