mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Handle missing bfloat16 natives on CPU architectures (#13553)
* CPU: fix compiler-rt libcall by adding intermediate casts for bfloat16 * fix lint * remove old manual bypass of bf16 for CPU tests, and add diversion converstion from bf16 to/from fp16 --------- Co-authored-by: Jakob Sachs <jakobs99@purelymail.com>
This commit is contained in:
@@ -17,8 +17,6 @@ pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
# dont cast internal dtypes
|
||||
@@ -435,6 +433,8 @@ class TestOpsBFloat16(unittest.TestCase):
|
||||
data = [60000.0, 70000.0, 80000.0]
|
||||
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
|
||||
|
||||
# some CPUs there is no native bfloat16 sqrt
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "no approximation")
|
||||
def test_no_approximation(self):
|
||||
data = [326.0, 339.0, 10603200512.0]
|
||||
expected = torch.tensor(data, dtype=torch.bfloat16).sqrt().float().numpy()
|
||||
@@ -442,3 +442,4 @@ class TestOpsBFloat16(unittest.TestCase):
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
||||
@@ -224,8 +224,12 @@ class ClangRenderer(CStyleLanguage):
|
||||
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
|
||||
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})",
|
||||
Ops.FDIV: lambda a,b,dtype: f"({a}/{b})"}
|
||||
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
||||
|
||||
# LLVM legalizes double => half/bf16 cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
||||
# there is also no native bfl16 <-> fp16 conversion on those CPUs
|
||||
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
|
||||
(UPat.var("x", dtypes.float64).cast(dtypes.bfloat16), lambda x: x.cast(dtypes.float32).cast(dtypes.bfloat16)),
|
||||
(UPat.var("x", dtypes.bfloat16).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
|
||||
(UPat((Ops.SQRT, Ops.TRUNC), name="alu"), no_vectorized_alu)]) + CStyleLanguage.extra_matcher
|
||||
|
||||
if sys.platform == 'win32':
|
||||
|
||||
Reference in New Issue
Block a user