[FRONTEND][BACKEND] tl.mathlib -> tl.math; internally reverted to mathlib -> libdevice (#1368)

This commit is contained in:
Philippe Tillet
2023-03-19 02:14:57 -07:00
committed by GitHub
parent c575911a01
commit 39139258c8
6 changed files with 245 additions and 245 deletions

View File

@@ -136,16 +136,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
}
if (!funcs.empty()) {
static const std::string mathlib = "mathlib";
static const std::string libdevice = "libdevice";
// first search for environmental path
std::string env_path = ::triton::tools::getenv("TRITON_MATHLIB_PATH");
std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH");
if (!env_path.empty()) {
externLibs.try_emplace(mathlib, env_path);
externLibs.try_emplace(libdevice, env_path);
return externLibs;
}
namespace fs = std::filesystem;
// Search for mathlib relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and mathlib in
// Search for libdevice relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and libdevice in
// `triton/third_party/cuda/lib/libdevice.10.bc`
static const auto this_library_path = [] {
Dl_info fileinfo;
@@ -158,13 +158,13 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
"lib" / "libdevice.10.bc";
if (fs::exists(runtime_path)) {
externLibs.try_emplace(mathlib, runtime_path.string());
externLibs.try_emplace(libdevice, runtime_path.string());
} else {
// When using the Math Dialect, it is possible that some ops (e.g., log)
// are lowered to a function call. In this case, we need to link mathlib
// are lowered to a function call. In this case, we need to link libdevice
// using its default path:
// [triton root dir]/python/triton/language/libdevice.10.bc
// TODO(Keren): handle external linkage other than mathlib?
// TODO(Keren): handle external linkage other than libdevice?
static const auto this_file_path = std::filesystem::path(__FILE__);
static const auto compiletime_path = this_file_path.parent_path()
.parent_path()
@@ -178,16 +178,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
compiletime_path.string();
llvm::report_fatal_error(error_msg.c_str());
}
externLibs.try_emplace(mathlib, compiletime_path.string());
externLibs.try_emplace(libdevice, compiletime_path.string());
}
}
return externLibs;
}
static void linkMathlib(llvm::Module &module) {
static void linkLibdevice(llvm::Module &module) {
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in mathlib
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
@@ -221,8 +221,8 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
return true;
}
if (name == "mathlib") {
linkMathlib(module);
if (name == "libdevice") {
linkLibdevice(module);
} else {
assert(false && "unknown extern lib: ");
}

View File

@@ -1802,11 +1802,11 @@ def test_num_warps_pow2():
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'mathlib.ffs', ''),
('float32', 'mathlib.log2', ''),
('float32', 'mathlib.pow', tl.mathlib.MATHLIB_PATH),
('float64', 'mathlib.norm4d', '')])
def test_mathlib_tensor(dtype_str, expr, lib_path):
[('int32', 'math.ffs', ''),
('float32', 'math.log2', ''),
('float32', 'math.pow', tl.math.LIBDEVICE_PATH),
('float64', 'math.norm4d', '')])
def test_math_tensor(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
@@ -1819,37 +1819,37 @@ def test_mathlib_tensor(dtype_str, expr, lib_path):
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
if expr == 'mathlib.log2':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.mathlib.log2(5.0), x.shape)'})
if expr == 'math.log2':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.math.log2(5.0), x.shape)'})
y_ref = np.log2(5.0)
elif expr == 'mathlib.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.ffs(x)'})
elif expr == 'math.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.ffs(x)'})
y_ref = np.zeros(shape, dtype=x.dtype)
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
elif expr == 'mathlib.pow':
elif expr == 'math.pow':
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, x)'})
y_ref = np.power(x, x)
elif expr == 'mathlib.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.norm4d(x, x, x, x)'})
elif expr == 'math.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.norm4d(x, x, x, x)'})
y_ref = np.sqrt(4 * np.power(x, 2))
x_tri = to_triton(x)
# triton result
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
if expr == 'mathlib.ffs':
if expr == 'math.ffs':
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'mathlib.pow', '')])
def test_mathlib_scalar(dtype_str, expr, lib_path):
[('float32', 'math.pow', '')])
def test_math_scalar(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
@@ -1865,13 +1865,13 @@ def test_mathlib_scalar(dtype_str, expr, lib_path):
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, x)'})
y_ref[:] = np.power(x, x)
# triton result
x_tri = to_triton(x)[0].item()
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'math': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -5,7 +5,7 @@ from ..impl import (
ir,
builtin,
)
from . import mathlib
from . import math
from .core import (
abs,
arange,
@@ -141,7 +141,7 @@ __all__ = [
"int64",
"int8",
"ir",
"mathlib",
"math",
"load",
"log",
"max",

View File

@@ -1188,14 +1188,14 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
# FIXME(Keren): not portable, should be fixed
from . import mathlib
return mathlib.mulhi(x, y, _builder=builder)
from . import math
return math.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
# FIXME(Keren): not portable, should be fixed
from . import mathlib
return mathlib.floor(x, _builder=builder)
from . import math
return math.floor(x, _builder=builder)
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:

View File

@@ -152,9 +152,9 @@ class Libdevice(ExternLibrary):
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
:param path: path of the mathlib library
:param path: path of the libdevice library
'''
super().__init__("mathlib", path)
super().__init__("libdevice", path)
self._symbol_groups = {}
@staticmethod
@@ -286,11 +286,11 @@ class Libdevice(ExternLibrary):
# @extern.extern
# def <op_name>(<args>, _builder=None):
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
# return extern.dispatch("mathlib", <path>, <args>, <arg_type_symbol_dict>, _builder)
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str += "import os\n"
header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n"
header_str += "MATHLIB_PATH = os.getenv(\"TRITON_MATHLIB_PATH\", LOCAL_PATH)\n"
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"
@@ -299,7 +299,7 @@ class Libdevice(ExternLibrary):
func_name_str += f"{arg_name}, "
func_name_str += "_builder=None):\n"
return_str = f"\treturn extern.elementwise(\"{self._name}\", MATHLIB_PATH, ["
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
for arg_name in symbols[0].arg_names:
return_str += f"{arg_name}, "
return_str += "], \n"
@@ -347,7 +347,7 @@ class LLVMDisassembler:
return self._path
extern_libs = ["mathlib"]
extern_libs = ["libdevice"]
def build(
@@ -363,7 +363,7 @@ def build(
:param lib_name: name of the library
:param output_dir: path to the output directory
'''
if lib_name == "mathlib":
if lib_name == "libdevice":
extern_lib = Libdevice(lib_path)
else:
raise Exception(f"Unknown extern library: {lib_name}")