mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fixed issue for fp64 literals and added tests (#1698)
fixes #1686
This commit is contained in:
@@ -1969,6 +1969,23 @@ def test_full(dtype_str):
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("literal, dtype_str",
|
||||
[(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"),
|
||||
('float("inf")', "f32"), ('float("-inf")', "f32"),
|
||||
('float("nan")', "f32"), ('float("-nan")', "f32"),
|
||||
(0., "f32"),
|
||||
(5, "i32"), (2**40, "i64"),])
|
||||
def test_constexpr(literal, dtype_str):
|
||||
@triton.jit
|
||||
def kernel(out_ptr):
|
||||
val = GENERATE_TEST_HERE
|
||||
tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val)
|
||||
|
||||
kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"})
|
||||
out = torch.zeros((1,), dtype=torch.float32, device="cuda")
|
||||
h = kernel_patched[(1,)](out)
|
||||
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
|
||||
@@ -55,7 +55,17 @@ def _to_tensor(x, builder):
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
return tensor(builder.get_fp32(x), float32)
|
||||
min_float32 = 2 ** -126
|
||||
max_float32 = (2 - 2**-23) * 2**127
|
||||
abs_x = __builtins__['abs'](x)
|
||||
if abs_x == float("inf") or\
|
||||
abs_x == 0.0 or \
|
||||
x != x or \
|
||||
min_float32 <= abs_x <= max_float32:
|
||||
return tensor(builder.get_fp32(x), float32)
|
||||
else:
|
||||
return tensor(builder.get_fp64(x), float64)
|
||||
|
||||
elif isinstance(x, constexpr):
|
||||
return _to_tensor(x.value, builder)
|
||||
elif isinstance(x, tensor):
|
||||
|
||||
Reference in New Issue
Block a user