[FRONTEND] fixed issue for fp64 literals and added tests (#1698)

fixes #1686
This commit is contained in:
Philippe Tillet
2023-05-20 18:36:28 -07:00
committed by GitHub
parent fb30d84069
commit b5ba639bae
2 changed files with 28 additions and 1 deletions

View File

@@ -1969,6 +1969,23 @@ def test_full(dtype_str):
assert torch.all(out_dynamic == 2)
@pytest.mark.parametrize("literal, dtype_str",
[(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"),
('float("inf")', "f32"), ('float("-inf")', "f32"),
('float("nan")', "f32"), ('float("-nan")', "f32"),
(0., "f32"),
(5, "i32"), (2**40, "i64"),])
def test_constexpr(literal, dtype_str):
@triton.jit
def kernel(out_ptr):
val = GENERATE_TEST_HERE
tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val)
kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"})
out = torch.zeros((1,), dtype=torch.float32, device="cuda")
h = kernel_patched[(1,)](out)
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):

View File

@@ -55,7 +55,17 @@ def _to_tensor(x, builder):
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
return tensor(builder.get_fp32(x), float32)
min_float32 = 2 ** -126
max_float32 = (2 - 2**-23) * 2**127
abs_x = __builtins__['abs'](x)
if abs_x == float("inf") or\
abs_x == 0.0 or \
x != x or \
min_float32 <= abs_x <= max_float32:
return tensor(builder.get_fp32(x), float32)
else:
return tensor(builder.get_fp64(x), float64)
elif isinstance(x, constexpr):
return _to_tensor(x.value, builder)
elif isinstance(x, tensor):