From a7439af786faff265f2f245aa3b7698073b76d97 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:28:23 -0400 Subject: [PATCH] Fix llvm int->bool cast (#2164) * add to ir * add test case * minimize diff * todo * enable fast math * added both False and True case --- test/test_dtype.py | 2 ++ tinygrad/renderer/llvmir.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index feb9221723..61fedfc6eb 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -105,6 +105,8 @@ class TestInt8Dtype(unittest.TestCase): def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4]) def test_casts_to_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float32, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64]) + def test_casts_to_bool_1(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.bool], target_contents=[True, True, True, True]) + def test_casts_to_bool_2(self): _test_casts_from([1,0,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.bool], target_contents=[True, False, True, True]) def test_casts_from_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.float32, dtypes.uint8, dtypes.int32, dtypes.int64]) def test_casts_from_uint8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.uint8, target_dtypes=[dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64]) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index b0ffe8f41a..f01fcaeb8c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -44,7 +44,9 @@ def cast(bb, val, input_type, output_type): if input_type == dtypes.float32: if dtypes.is_int(output_type) or output_type == dtypes.bool: - val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_unsigned(output_type): val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) + elif output_type == dtypes.bool: val = bb[-1].fcmp_ordered("!=", val, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) + else: val = bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) elif output_type == dtypes.bfloat16: val = bb[-1].bitcast(val, ir.IntType(32)) val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))