mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Change libdevice to mathlib and fix abs (#1361)
Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -136,16 +136,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
}
|
||||
|
||||
if (!funcs.empty()) {
|
||||
static const std::string libdevice = "libdevice";
|
||||
static const std::string mathlib = "mathlib";
|
||||
// first search for environmental path
|
||||
std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH");
|
||||
std::string env_path = ::triton::tools::getenv("TRITON_MATHLIB_PATH");
|
||||
if (!env_path.empty()) {
|
||||
externLibs.try_emplace(libdevice, env_path);
|
||||
externLibs.try_emplace(mathlib, env_path);
|
||||
return externLibs;
|
||||
}
|
||||
namespace fs = std::filesystem;
|
||||
// Search for libdevice relative to its library path if used from Python
|
||||
// Then native code is in `triton/_C/libtriton.so` and libdevice in
|
||||
// Search for mathlib relative to its library path if used from Python
|
||||
// Then native code is in `triton/_C/libtriton.so` and mathlib in
|
||||
// `triton/third_party/cuda/lib/libdevice.10.bc`
|
||||
static const auto this_library_path = [] {
|
||||
Dl_info fileinfo;
|
||||
@@ -158,13 +158,13 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
|
||||
"lib" / "libdevice.10.bc";
|
||||
if (fs::exists(runtime_path)) {
|
||||
externLibs.try_emplace(libdevice, runtime_path.string());
|
||||
externLibs.try_emplace(mathlib, runtime_path.string());
|
||||
} else {
|
||||
// When using the Math Dialect, it is possible that some ops (e.g., log)
|
||||
// are lowered to a function call. In this case, we need to link libdevice
|
||||
// are lowered to a function call. In this case, we need to link mathlib
|
||||
// using its default path:
|
||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||
// TODO(Keren): handle external linkage other than libdevice?
|
||||
// TODO(Keren): handle external linkage other than mathlib?
|
||||
static const auto this_file_path = std::filesystem::path(__FILE__);
|
||||
static const auto compiletime_path = this_file_path.parent_path()
|
||||
.parent_path()
|
||||
@@ -178,16 +178,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
compiletime_path.string();
|
||||
llvm::report_fatal_error(error_msg.c_str());
|
||||
}
|
||||
externLibs.try_emplace(libdevice, compiletime_path.string());
|
||||
externLibs.try_emplace(mathlib, compiletime_path.string());
|
||||
}
|
||||
}
|
||||
|
||||
return externLibs;
|
||||
}
|
||||
|
||||
static void linkLibdevice(llvm::Module &module) {
|
||||
static void linkMathlib(llvm::Module &module) {
|
||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||
// this will enable fast math path in libdevice
|
||||
// this will enable fast math path in mathlib
|
||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||
// sqrt.approx.ftz.f32
|
||||
auto &ctx = module.getContext();
|
||||
@@ -221,8 +221,8 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
return true;
|
||||
}
|
||||
|
||||
if (name == "libdevice") {
|
||||
linkLibdevice(module);
|
||||
if (name == "mathlib") {
|
||||
linkMathlib(module);
|
||||
} else {
|
||||
assert(false && "unknown extern lib: ");
|
||||
}
|
||||
@@ -261,7 +261,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
}
|
||||
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/*optLevel=*/0, /*sizeLevel=*/0,
|
||||
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||
/*targetMachine=*/nullptr);
|
||||
|
||||
if (auto err = optPipeline(llvmModule.get())) {
|
||||
|
||||
@@ -56,7 +56,7 @@ matmul_data = {
|
||||
'a100': {
|
||||
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
|
||||
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.59, 'float32': 0.57, 'int8': 0.34},
|
||||
(2048, 2048, 2048): {'float16': 0.62, 'float32': 0.57, 'int8': 0.34},
|
||||
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
|
||||
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
|
||||
# tall-skinny
|
||||
|
||||
@@ -520,6 +520,17 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test abs
|
||||
# ----------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [
|
||||
(dtype_x)
|
||||
for dtype_x in dtypes_with_bfloat16
|
||||
])
|
||||
def test_abs(dtype_x, device='cuda'):
|
||||
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
@@ -1791,11 +1802,11 @@ def test_num_warps_pow2():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.log2', ''),
|
||||
('float32', 'libdevice.pow', tl.libdevice.LIBDEVICE_PATH),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
[('int32', 'mathlib.ffs', ''),
|
||||
('float32', 'mathlib.log2', ''),
|
||||
('float32', 'mathlib.pow', tl.mathlib.MATHLIB_PATH),
|
||||
('float64', 'mathlib.norm4d', '')])
|
||||
def test_mathlib_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
@@ -1808,37 +1819,37 @@ def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
|
||||
if expr == 'libdevice.log2':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.libdevice.log2(5.0), x.shape)'})
|
||||
if expr == 'mathlib.log2':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.mathlib.log2(5.0), x.shape)'})
|
||||
y_ref = np.log2(5.0)
|
||||
elif expr == 'libdevice.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||
elif expr == 'mathlib.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.ffs(x)'})
|
||||
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
for i in range(shape[0]):
|
||||
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
elif expr == 'libdevice.pow':
|
||||
elif expr == 'mathlib.pow':
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
|
||||
y_ref = np.power(x, x)
|
||||
elif expr == 'libdevice.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||
elif expr == 'mathlib.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.norm4d(x, x, x, x)'})
|
||||
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
|
||||
x_tri = to_triton(x)
|
||||
# triton result
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
|
||||
# compare
|
||||
if expr == 'libdevice.ffs':
|
||||
if expr == 'mathlib.ffs':
|
||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
else:
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('float32', 'libdevice.pow', '')])
|
||||
def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
[('float32', 'mathlib.pow', '')])
|
||||
def test_mathlib_scalar(dtype_str, expr, lib_path):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
@@ -1854,13 +1865,13 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
|
||||
y_ref[:] = np.power(x, x)
|
||||
|
||||
# triton result
|
||||
x_tri = to_triton(x)[0].item()
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
|
||||
# compare
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
@@ -341,11 +341,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
native_nontensor_types = (triton.language.dtype, )
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.tensor):
|
||||
if not isinstance(value, triton.language.tensor) and \
|
||||
not isinstance(value, native_nontensor_types):
|
||||
value = triton.language.core._to_tensor(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import libdevice
|
||||
from . import mathlib
|
||||
from .core import (
|
||||
abs,
|
||||
arange,
|
||||
@@ -141,7 +141,7 @@ __all__ = [
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"libdevice",
|
||||
"mathlib",
|
||||
"load",
|
||||
"log",
|
||||
"max",
|
||||
|
||||
@@ -1215,7 +1215,16 @@ def max_contiguous(input, values, _builder=None):
|
||||
|
||||
@triton.jit
|
||||
def abs(x):
|
||||
return where(x >= 0, x, -x)
|
||||
x_dtype = x.dtype
|
||||
if x_dtype.is_floating():
|
||||
num_bits: constexpr = x.dtype.primitive_bitwidth
|
||||
int_dtype = dtype(f'int{num_bits}')
|
||||
mask = 2 ** (num_bits - 1) - 1
|
||||
ret = x.to(int_dtype, bitcast=True) & mask.to(int_dtype)
|
||||
ret = ret.to(x_dtype, bitcast=True)
|
||||
else:
|
||||
ret = where(x >= 0, x, -x)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1188,14 +1188,14 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||
# FIXME(Keren): not portable, should be fixed
|
||||
from . import libdevice
|
||||
return libdevice.mulhi(x, y, _builder=builder)
|
||||
from . import mathlib
|
||||
return mathlib.mulhi(x, y, _builder=builder)
|
||||
|
||||
|
||||
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# FIXME(Keren): not portable, should be fixed
|
||||
from . import libdevice
|
||||
return libdevice.floor(x, _builder=builder)
|
||||
from . import mathlib
|
||||
return mathlib.floor(x, _builder=builder)
|
||||
|
||||
|
||||
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
@@ -152,9 +152,9 @@ class Libdevice(ExternLibrary):
|
||||
def __init__(self, path) -> None:
|
||||
'''
|
||||
Constructor for Libdevice.
|
||||
:param path: path of the libdevice library
|
||||
:param path: path of the mathlib library
|
||||
'''
|
||||
super().__init__("libdevice", path)
|
||||
super().__init__("mathlib", path)
|
||||
self._symbol_groups = {}
|
||||
|
||||
@staticmethod
|
||||
@@ -286,11 +286,11 @@ class Libdevice(ExternLibrary):
|
||||
# @extern.extern
|
||||
# def <op_name>(<args>, _builder=None):
|
||||
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
|
||||
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
# return extern.dispatch("mathlib", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
import_str = "from . import core, extern\n"
|
||||
import_str += "import os\n"
|
||||
header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n"
|
||||
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n"
|
||||
header_str += "MATHLIB_PATH = os.getenv(\"TRITON_MATHLIB_PATH\", LOCAL_PATH)\n"
|
||||
func_str = ""
|
||||
for symbols in self._symbol_groups.values():
|
||||
func_str += "@extern.extern\n"
|
||||
@@ -299,7 +299,7 @@ class Libdevice(ExternLibrary):
|
||||
func_name_str += f"{arg_name}, "
|
||||
func_name_str += "_builder=None):\n"
|
||||
|
||||
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
|
||||
return_str = f"\treturn extern.elementwise(\"{self._name}\", MATHLIB_PATH, ["
|
||||
for arg_name in symbols[0].arg_names:
|
||||
return_str += f"{arg_name}, "
|
||||
return_str += "], \n"
|
||||
@@ -347,7 +347,7 @@ class LLVMDisassembler:
|
||||
return self._path
|
||||
|
||||
|
||||
extern_libs = ["libdevice"]
|
||||
extern_libs = ["mathlib"]
|
||||
|
||||
|
||||
def build(
|
||||
@@ -363,7 +363,7 @@ def build(
|
||||
:param lib_name: name of the library
|
||||
:param output_dir: path to the output directory
|
||||
'''
|
||||
if lib_name == "libdevice":
|
||||
if lib_name == "mathlib":
|
||||
extern_lib = Libdevice(lib_path)
|
||||
else:
|
||||
raise Exception(f"Unknown extern library: {lib_name}")
|
||||
|
||||
Reference in New Issue
Block a user