mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] tl.mathlib -> tl.math; internally reverted to mathlib -> libdevice (#1368)
This commit is contained in:
@@ -136,16 +136,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
}
|
||||
|
||||
if (!funcs.empty()) {
|
||||
static const std::string mathlib = "mathlib";
|
||||
static const std::string libdevice = "libdevice";
|
||||
// first search for environmental path
|
||||
std::string env_path = ::triton::tools::getenv("TRITON_MATHLIB_PATH");
|
||||
std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH");
|
||||
if (!env_path.empty()) {
|
||||
externLibs.try_emplace(mathlib, env_path);
|
||||
externLibs.try_emplace(libdevice, env_path);
|
||||
return externLibs;
|
||||
}
|
||||
namespace fs = std::filesystem;
|
||||
// Search for mathlib relative to its library path if used from Python
|
||||
// Then native code is in `triton/_C/libtriton.so` and mathlib in
|
||||
// Search for libdevice relative to its library path if used from Python
|
||||
// Then native code is in `triton/_C/libtriton.so` and libdevice in
|
||||
// `triton/third_party/cuda/lib/libdevice.10.bc`
|
||||
static const auto this_library_path = [] {
|
||||
Dl_info fileinfo;
|
||||
@@ -158,13 +158,13 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
|
||||
"lib" / "libdevice.10.bc";
|
||||
if (fs::exists(runtime_path)) {
|
||||
externLibs.try_emplace(mathlib, runtime_path.string());
|
||||
externLibs.try_emplace(libdevice, runtime_path.string());
|
||||
} else {
|
||||
// When using the Math Dialect, it is possible that some ops (e.g., log)
|
||||
// are lowered to a function call. In this case, we need to link mathlib
|
||||
// are lowered to a function call. In this case, we need to link libdevice
|
||||
// using its default path:
|
||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||
// TODO(Keren): handle external linkage other than mathlib?
|
||||
// TODO(Keren): handle external linkage other than libdevice?
|
||||
static const auto this_file_path = std::filesystem::path(__FILE__);
|
||||
static const auto compiletime_path = this_file_path.parent_path()
|
||||
.parent_path()
|
||||
@@ -178,16 +178,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
compiletime_path.string();
|
||||
llvm::report_fatal_error(error_msg.c_str());
|
||||
}
|
||||
externLibs.try_emplace(mathlib, compiletime_path.string());
|
||||
externLibs.try_emplace(libdevice, compiletime_path.string());
|
||||
}
|
||||
}
|
||||
|
||||
return externLibs;
|
||||
}
|
||||
|
||||
static void linkMathlib(llvm::Module &module) {
|
||||
static void linkLibdevice(llvm::Module &module) {
|
||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||
// this will enable fast math path in mathlib
|
||||
// this will enable fast math path in libdevice
|
||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||
// sqrt.approx.ftz.f32
|
||||
auto &ctx = module.getContext();
|
||||
@@ -221,8 +221,8 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
return true;
|
||||
}
|
||||
|
||||
if (name == "mathlib") {
|
||||
linkMathlib(module);
|
||||
if (name == "libdevice") {
|
||||
linkLibdevice(module);
|
||||
} else {
|
||||
assert(false && "unknown extern lib: ");
|
||||
}
|
||||
|
||||
@@ -1802,11 +1802,11 @@ def test_num_warps_pow2():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'mathlib.ffs', ''),
|
||||
('float32', 'mathlib.log2', ''),
|
||||
('float32', 'mathlib.pow', tl.mathlib.MATHLIB_PATH),
|
||||
('float64', 'mathlib.norm4d', '')])
|
||||
def test_mathlib_tensor(dtype_str, expr, lib_path):
|
||||
[('int32', 'math.ffs', ''),
|
||||
('float32', 'math.log2', ''),
|
||||
('float32', 'math.pow', tl.math.LIBDEVICE_PATH),
|
||||
('float64', 'math.norm4d', '')])
|
||||
def test_math_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
@@ -1819,37 +1819,37 @@ def test_mathlib_tensor(dtype_str, expr, lib_path):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
|
||||
if expr == 'mathlib.log2':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.mathlib.log2(5.0), x.shape)'})
|
||||
if expr == 'math.log2':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.math.log2(5.0), x.shape)'})
|
||||
y_ref = np.log2(5.0)
|
||||
elif expr == 'mathlib.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.ffs(x)'})
|
||||
elif expr == 'math.ffs':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.ffs(x)'})
|
||||
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||
for i in range(shape[0]):
|
||||
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||
elif expr == 'mathlib.pow':
|
||||
elif expr == 'math.pow':
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, x)'})
|
||||
y_ref = np.power(x, x)
|
||||
elif expr == 'mathlib.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.norm4d(x, x, x, x)'})
|
||||
elif expr == 'math.norm4d':
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.norm4d(x, x, x, x)'})
|
||||
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||
|
||||
x_tri = to_triton(x)
|
||||
# triton result
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# compare
|
||||
if expr == 'mathlib.ffs':
|
||||
if expr == 'math.ffs':
|
||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
else:
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('float32', 'mathlib.pow', '')])
|
||||
def test_mathlib_scalar(dtype_str, expr, lib_path):
|
||||
[('float32', 'math.pow', '')])
|
||||
def test_math_scalar(dtype_str, expr, lib_path):
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
@@ -1865,13 +1865,13 @@ def test_mathlib_scalar(dtype_str, expr, lib_path):
|
||||
|
||||
# numpy does not allow negative factors in power, so we use abs()
|
||||
x = np.abs(x)
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, x)'})
|
||||
y_ref[:] = np.power(x, x)
|
||||
|
||||
# triton result
|
||||
x_tri = to_triton(x)[0].item()
|
||||
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'math': lib_path})
|
||||
# compare
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import mathlib
|
||||
from . import math
|
||||
from .core import (
|
||||
abs,
|
||||
arange,
|
||||
@@ -141,7 +141,7 @@ __all__ = [
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"mathlib",
|
||||
"math",
|
||||
"load",
|
||||
"log",
|
||||
"max",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1188,14 +1188,14 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||
# FIXME(Keren): not portable, should be fixed
|
||||
from . import mathlib
|
||||
return mathlib.mulhi(x, y, _builder=builder)
|
||||
from . import math
|
||||
return math.mulhi(x, y, _builder=builder)
|
||||
|
||||
|
||||
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# FIXME(Keren): not portable, should be fixed
|
||||
from . import mathlib
|
||||
return mathlib.floor(x, _builder=builder)
|
||||
from . import math
|
||||
return math.floor(x, _builder=builder)
|
||||
|
||||
|
||||
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
@@ -152,9 +152,9 @@ class Libdevice(ExternLibrary):
|
||||
def __init__(self, path) -> None:
|
||||
'''
|
||||
Constructor for Libdevice.
|
||||
:param path: path of the mathlib library
|
||||
:param path: path of the libdevice library
|
||||
'''
|
||||
super().__init__("mathlib", path)
|
||||
super().__init__("libdevice", path)
|
||||
self._symbol_groups = {}
|
||||
|
||||
@staticmethod
|
||||
@@ -286,11 +286,11 @@ class Libdevice(ExternLibrary):
|
||||
# @extern.extern
|
||||
# def <op_name>(<args>, _builder=None):
|
||||
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
|
||||
# return extern.dispatch("mathlib", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
import_str = "from . import core, extern\n"
|
||||
import_str += "import os\n"
|
||||
header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n"
|
||||
header_str += "MATHLIB_PATH = os.getenv(\"TRITON_MATHLIB_PATH\", LOCAL_PATH)\n"
|
||||
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n"
|
||||
func_str = ""
|
||||
for symbols in self._symbol_groups.values():
|
||||
func_str += "@extern.extern\n"
|
||||
@@ -299,7 +299,7 @@ class Libdevice(ExternLibrary):
|
||||
func_name_str += f"{arg_name}, "
|
||||
func_name_str += "_builder=None):\n"
|
||||
|
||||
return_str = f"\treturn extern.elementwise(\"{self._name}\", MATHLIB_PATH, ["
|
||||
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
|
||||
for arg_name in symbols[0].arg_names:
|
||||
return_str += f"{arg_name}, "
|
||||
return_str += "], \n"
|
||||
@@ -347,7 +347,7 @@ class LLVMDisassembler:
|
||||
return self._path
|
||||
|
||||
|
||||
extern_libs = ["mathlib"]
|
||||
extern_libs = ["libdevice"]
|
||||
|
||||
|
||||
def build(
|
||||
@@ -363,7 +363,7 @@ def build(
|
||||
:param lib_name: name of the library
|
||||
:param output_dir: path to the output directory
|
||||
'''
|
||||
if lib_name == "mathlib":
|
||||
if lib_name == "libdevice":
|
||||
extern_lib = Libdevice(lib_path)
|
||||
else:
|
||||
raise Exception(f"Unknown extern library: {lib_name}")
|
||||
|
||||
Reference in New Issue
Block a user