From dc9a6b4bb7285e8190834ac1854ba8feefe511c5 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 11 Mar 2023 21:51:22 -0800 Subject: [PATCH] fix float16 in CLANG on linux --- test/test_dtype.py | 4 ++++ tinygrad/runtime/ops_clang.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index e1119b32f6..2eb5b618f4 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -14,6 +14,7 @@ class TestDtype(unittest.TestCase): na = a.numpy() print(na, na.dtype, a.lazydata.realized) assert na.dtype == np.float16 + np.testing.assert_allclose(na, [1,2,3,4]) def test_half_add(self): a = Tensor([1,2,3,4], dtype=dtypes.float16) @@ -21,6 +22,7 @@ class TestDtype(unittest.TestCase): c = a+b print(c.numpy()) assert c.dtype == dtypes.float16 + np.testing.assert_allclose(c.numpy(), [2,4,6,8]) def test_upcast_float(self): # NOTE: there's no downcasting support @@ -29,6 +31,7 @@ class TestDtype(unittest.TestCase): na = a.numpy() print(na, na.dtype) assert na.dtype == np.float32 + np.testing.assert_allclose(na, [1,2,3,4]) def test_half_add_upcast(self): a = Tensor([1,2,3,4], dtype=dtypes.float16) @@ -36,6 +39,7 @@ class TestDtype(unittest.TestCase): c = a+b print(c.numpy()) assert c.dtype == dtypes.float32 + np.testing.assert_allclose(c.numpy(), [2,4,6,8]) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 8f2276622b..03c6ce6352 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -14,8 +14,9 @@ class ClangProgram: prg = "#include \n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n" + prg # TODO: is there a way to not write this to disk? fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if platform.system() == 'Darwin' else 'so'}" + # NOTE: --rtlib=compiler-rt fixes float16 on Linux, it defines __gnu_h2f_ieee and __gnu_f2h_ieee if not os.path.exists(fn): - subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8')) + subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '--rtlib=compiler-rt', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8')) os.rename(fn+".tmp", fn) self.lib = ctypes.CDLL(fn) self.fxn = self.lib[name]