diff --git a/test/test_dtype.py b/test/test_dtype.py index 1df71aeba2..8849f835d8 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -143,5 +143,9 @@ class TestInt32Dtype(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64") def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64) +class TestBoolDtype(unittest.TestCase): + def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32]) + def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool) + if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index c4b9cb1ab4..d048adc4b0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -150,10 +150,6 @@ class TestOps(unittest.TestCase): lambda x: torch.where(x > 0.5, 4, 2).permute((1, 0)), lambda x: (x > 0.5).where(4, 2).permute((1, 0)), forward_only=True) - def test_where_bool(self): # Fixes #1479. - helper_test_op([(1,), (1,)], lambda x,y: torch.where(x==y, torch.tensor([1,1], dtype=torch.bool), 0), - lambda x,y: (x==y).where(Tensor([1,1], dtype=dtypes.bool), 0), forward_only=True, vals=[[0,1],[1,1]]) - def _test_cmp(self, fxn, reverse=True): for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: helper_test_op(shps, fxn, fxn, forward_only=True) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 8927e4212f..cf017b6949 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -116,8 +116,8 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li idx = args.idx.render(render_llvm, bb[-1]) element = lvars[vin[0]] if args.memory_dtype != vin[0].dtype: - if dtypes.is_int(args.memory_dtype): - element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype]) + if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool: + element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype]) elif args.memory_dtype == dtypes.bfloat16: element = bb[-1].bitcast(element, ir.IntType(32)) element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))