[FRONTEND] Change libdevice to mathlib and fix abs (#1361)

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
rsanthanam-amd
2023-03-19 03:34:16 -05:00
committed by GitHub
parent 02caa8a652
commit c575911a01
9 changed files with 271 additions and 249 deletions

View File

@@ -136,16 +136,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
}
if (!funcs.empty()) {
static const std::string libdevice = "libdevice";
static const std::string mathlib = "mathlib";
// first search for environmental path
std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH");
std::string env_path = ::triton::tools::getenv("TRITON_MATHLIB_PATH");
if (!env_path.empty()) {
externLibs.try_emplace(libdevice, env_path);
externLibs.try_emplace(mathlib, env_path);
return externLibs;
}
namespace fs = std::filesystem;
// Search for libdevice relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and libdevice in
// Search for mathlib relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and mathlib in
// `triton/third_party/cuda/lib/libdevice.10.bc`
static const auto this_library_path = [] {
Dl_info fileinfo;
@@ -158,13 +158,13 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
"lib" / "libdevice.10.bc";
if (fs::exists(runtime_path)) {
externLibs.try_emplace(libdevice, runtime_path.string());
externLibs.try_emplace(mathlib, runtime_path.string());
} else {
// When using the Math Dialect, it is possible that some ops (e.g., log)
// are lowered to a function call. In this case, we need to link libdevice
// are lowered to a function call. In this case, we need to link mathlib
// using its default path:
// [triton root dir]/python/triton/language/libdevice.10.bc
// TODO(Keren): handle external linkage other than libdevice?
// TODO(Keren): handle external linkage other than mathlib?
static const auto this_file_path = std::filesystem::path(__FILE__);
static const auto compiletime_path = this_file_path.parent_path()
.parent_path()
@@ -178,16 +178,16 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
compiletime_path.string();
llvm::report_fatal_error(error_msg.c_str());
}
externLibs.try_emplace(libdevice, compiletime_path.string());
externLibs.try_emplace(mathlib, compiletime_path.string());
}
}
return externLibs;
}
static void linkLibdevice(llvm::Module &module) {
static void linkMathlib(llvm::Module &module) {
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// this will enable fast math path in mathlib
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
@@ -221,8 +221,8 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
return true;
}
if (name == "libdevice") {
linkLibdevice(module);
if (name == "mathlib") {
linkMathlib(module);
} else {
assert(false && "unknown extern lib: ");
}
@@ -261,7 +261,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
}
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/0, /*sizeLevel=*/0,
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
if (auto err = optPipeline(llvmModule.get())) {

View File

@@ -56,7 +56,7 @@ matmul_data = {
'a100': {
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.59, 'float32': 0.57, 'int8': 0.34},
(2048, 2048, 2048): {'float16': 0.62, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
# tall-skinny

View File

@@ -520,6 +520,17 @@ def test_unary_op(dtype_x, expr, device='cuda'):
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# ----------------
# test abs
# ----------------
@pytest.mark.parametrize("dtype_x", [
(dtype_x)
for dtype_x in dtypes_with_bfloat16
])
def test_abs(dtype_x, device='cuda'):
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
# ----------------
# test indexing
@@ -1791,11 +1802,11 @@ def test_num_warps_pow2():
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.log2', ''),
('float32', 'libdevice.pow', tl.libdevice.LIBDEVICE_PATH),
('float64', 'libdevice.norm4d', '')])
def test_libdevice_tensor(dtype_str, expr, lib_path):
[('int32', 'mathlib.ffs', ''),
('float32', 'mathlib.log2', ''),
('float32', 'mathlib.pow', tl.mathlib.MATHLIB_PATH),
('float64', 'mathlib.norm4d', '')])
def test_mathlib_tensor(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
@@ -1808,37 +1819,37 @@ def test_libdevice_tensor(dtype_str, expr, lib_path):
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
if expr == 'libdevice.log2':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.libdevice.log2(5.0), x.shape)'})
if expr == 'mathlib.log2':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.broadcast_to(tl.mathlib.log2(5.0), x.shape)'})
y_ref = np.log2(5.0)
elif expr == 'libdevice.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
elif expr == 'mathlib.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.ffs(x)'})
y_ref = np.zeros(shape, dtype=x.dtype)
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
elif expr == 'libdevice.pow':
elif expr == 'mathlib.pow':
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
y_ref = np.power(x, x)
elif expr == 'libdevice.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
elif expr == 'mathlib.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.norm4d(x, x, x, x)'})
y_ref = np.sqrt(4 * np.power(x, 2))
x_tri = to_triton(x)
# triton result
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
# compare
if expr == 'libdevice.ffs':
if expr == 'mathlib.ffs':
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('float32', 'libdevice.pow', '')])
def test_libdevice_scalar(dtype_str, expr, lib_path):
[('float32', 'mathlib.pow', '')])
def test_mathlib_scalar(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
@@ -1854,13 +1865,13 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.mathlib.pow(x, x)'})
y_ref[:] = np.power(x, x)
# triton result
x_tri = to_triton(x)[0].item()
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'mathlib': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -341,11 +341,13 @@ class CodeGenerator(ast.NodeVisitor):
names = [names]
if not isinstance(values, tuple):
values = [values]
native_nontensor_types = (triton.language.dtype, )
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if not isinstance(value, triton.language.tensor):
if not isinstance(value, triton.language.tensor) and \
not isinstance(value, native_nontensor_types):
value = triton.language.core._to_tensor(value, self.builder)
self.set_value(name, value)

View File

@@ -5,7 +5,7 @@ from ..impl import (
ir,
builtin,
)
from . import libdevice
from . import mathlib
from .core import (
abs,
arange,
@@ -141,7 +141,7 @@ __all__ = [
"int64",
"int8",
"ir",
"libdevice",
"mathlib",
"load",
"log",
"max",

View File

@@ -1215,7 +1215,16 @@ def max_contiguous(input, values, _builder=None):
@triton.jit
def abs(x):
return where(x >= 0, x, -x)
x_dtype = x.dtype
if x_dtype.is_floating():
num_bits: constexpr = x.dtype.primitive_bitwidth
int_dtype = dtype(f'int{num_bits}')
mask = 2 ** (num_bits - 1) - 1
ret = x.to(int_dtype, bitcast=True) & mask.to(int_dtype)
ret = ret.to(x_dtype, bitcast=True)
else:
ret = where(x >= 0, x, -x)
return ret
@triton.jit

View File

@@ -1188,14 +1188,14 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.mulhi(x, y, _builder=builder)
from . import mathlib
return mathlib.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.floor(x, _builder=builder)
from . import mathlib
return mathlib.floor(x, _builder=builder)
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:

View File

@@ -152,9 +152,9 @@ class Libdevice(ExternLibrary):
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
:param path: path of the libdevice library
:param path: path of the mathlib library
'''
super().__init__("libdevice", path)
super().__init__("mathlib", path)
self._symbol_groups = {}
@staticmethod
@@ -286,11 +286,11 @@ class Libdevice(ExternLibrary):
# @extern.extern
# def <op_name>(<args>, _builder=None):
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
# return extern.dispatch("mathlib", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str += "import os\n"
header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n"
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n"
header_str += "MATHLIB_PATH = os.getenv(\"TRITON_MATHLIB_PATH\", LOCAL_PATH)\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"
@@ -299,7 +299,7 @@ class Libdevice(ExternLibrary):
func_name_str += f"{arg_name}, "
func_name_str += "_builder=None):\n"
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
return_str = f"\treturn extern.elementwise(\"{self._name}\", MATHLIB_PATH, ["
for arg_name in symbols[0].arg_names:
return_str += f"{arg_name}, "
return_str += "], \n"
@@ -347,7 +347,7 @@ class LLVMDisassembler:
return self._path
extern_libs = ["libdevice"]
extern_libs = ["mathlib"]
def build(
@@ -363,7 +363,7 @@ def build(
:param lib_name: name of the library
:param output_dir: path to the output directory
'''
if lib_name == "libdevice":
if lib_name == "mathlib":
extern_lib = Libdevice(lib_path)
else:
raise Exception(f"Unknown extern library: {lib_name}")