mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
llvmir support for bool <-> float casting (#1492)
This commit is contained in:
@@ -143,5 +143,9 @@ class TestInt32Dtype(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
|
||||
def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
|
||||
|
||||
class TestBoolDtype(unittest.TestCase):
|
||||
def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32])
|
||||
def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -150,10 +150,6 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.where(x > 0.5, 4, 2).permute((1, 0)),
|
||||
lambda x: (x > 0.5).where(4, 2).permute((1, 0)), forward_only=True)
|
||||
|
||||
def test_where_bool(self): # Fixes #1479.
|
||||
helper_test_op([(1,), (1,)], lambda x,y: torch.where(x==y, torch.tensor([1,1], dtype=torch.bool), 0),
|
||||
lambda x,y: (x==y).where(Tensor([1,1], dtype=dtypes.bool), 0), forward_only=True, vals=[[0,1],[1,1]])
|
||||
|
||||
def _test_cmp(self, fxn, reverse=True):
|
||||
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
|
||||
helper_test_op(shps, fxn, fxn, forward_only=True)
|
||||
|
||||
@@ -116,8 +116,8 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
idx = args.idx.render(render_llvm, bb[-1])
|
||||
element = lvars[vin[0]]
|
||||
if args.memory_dtype != vin[0].dtype:
|
||||
if dtypes.is_int(args.memory_dtype):
|
||||
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
|
||||
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
element = bb[-1].bitcast(element, ir.IntType(32))
|
||||
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
|
||||
|
||||
Reference in New Issue
Block a user