mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Change libdevice to mathlib and fix abs (#1361)
Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -56,7 +56,7 @@ matmul_data = {
|
||||
'a100': {
|
||||
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
|
||||
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.59, 'float32': 0.57, 'int8': 0.34},
|
||||
(2048, 2048, 2048): {'float16': 0.62, 'float32': 0.57, 'int8': 0.34},
|
||||
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
|
||||
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
|
||||
# tall-skinny
|
||||
|
||||
@@ -520,6 +520,17 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test abs
|
||||
# ----------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [
|
||||
(dtype_x)
|
||||
for dtype_x in dtypes_with_bfloat16
|
||||
])
|
||||
def test_abs(dtype_x, device='cuda'):
|
||||
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
@@ -1791,11 +1802,11 @@ def test_num_warps_pow2():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.log2', ''),
|
||||
('float32', 'libdevice.pow', tl.libdevice.LIBDEVICE_PATH),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
[('int32', 'mathlib.ffs', ''),
|
||||
('float32', 'mathlib.log2', ''),
|
||||
('float32', 'mathlib.pow', tl.mathlib.MATHLIB_PATH),
|
||||
('float64', 'mathlib.norm4d', '')])
|
||||
def test_mathlib_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
@@ -1808,37 +1819,37 @@ def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
|
||||
if expr == 'libdevice.log2':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.libdevice.log2(5.0), x.shape)'})
|
||||
if expr == 'mathlib.log2':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.mathlib.log2(5.0), x.shape)'})
|
||||
y_ref = np.log2(5.0)
|
||||
elif expr == 'libdevice.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||
elif expr == 'mathlib.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.ffs(x)'})
|
||||
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
for i in range(shape[0]):
|
||||
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
elif expr == 'libdevice.pow':
|
||||
elif expr == 'mathlib.pow':
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
|
||||
y_ref = np.power(x, x)
|
||||
elif expr == 'libdevice.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||
elif expr == 'mathlib.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.norm4d(x, x, x, x)'})
|
||||
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
|
||||
x_tri = to_triton(x)
|
||||
# triton result
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
|
||||
# compare
|
||||
if expr == 'libdevice.ffs':
|
||||
if expr == 'mathlib.ffs':
|
||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
else:
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('float32', 'libdevice.pow', '')])
|
||||
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
[('float32', 'mathlib.pow', '')])
|
||||
def test_mathlib_scalar(dtype_str, expr, lib_path):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
@@ -1854,13 +1865,13 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
|
||||
y_ref[:] = np.power(x, x)
|
||||
|
||||
# triton result
|
||||
x_tri = to_triton(x)[0].item()
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
|
||||
# compare
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user