[FRONTEND] Refactor min/max to unify tl.maximum and tl.math.max (#2091)

maximum used to generate a cmp/sel even for floating point types. Always
using max op allows better code quality and avoids having different
behavior than tl.math.max
This commit is contained in:
Thomas
2023-08-11 17:46:20 -07:00
committed by GitHub
parent 5162871c6c
commit 421ce18988
5 changed files with 80 additions and 41 deletions

View File

@@ -1066,6 +1066,36 @@ void init_triton_ir(py::module &&m) {
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::ShRSIOp>(lhs, rhs));
})
.def("create_minsi",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MinSIOp>(lhs, rhs));
})
.def("create_minui",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MinUIOp>(lhs, rhs));
})
.def("create_minf",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MinFOp>(lhs, rhs));
})
.def("create_maxsi",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MaxSIOp>(lhs, rhs));
})
.def("create_maxui",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MaxUIOp>(lhs, rhs));
})
.def("create_maxf",
[](TritonOpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::MaxFOp>(lhs, rhs));
})
// AddPtr (similar to GEP)
.def("create_addptr",
[](TritonOpBuilder &self, mlir::Value &ptr,

View File

@@ -1382,7 +1382,7 @@ def minimum(x, y):
:param other: the second input tensor
:type other: Block
"""
return where(x < y, x, y)
return math.min(x, y)
@jit
@@ -1395,7 +1395,7 @@ def maximum(x, y):
:param other: the second input tensor
:type other: Block
"""
return where(x > y, x, y)
return math.max(x, y)
# max and argmax
@@ -1422,11 +1422,6 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
return _argmax_combine(value1, index1, value2, index2, False)
@jit
def _fast_max(x, y):
return math.max(x, y)
@jit
@_add_reduction_docstr("maximum",
return_indices_arg="return_indices",
@@ -1445,7 +1440,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
else:
assert input.dtype.is_integer_type()
input = input.to(int32)
return reduce(input, axis, _fast_max)
return reduce(input, axis, maximum)
@jit
@@ -1479,11 +1474,6 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
return _argmin_combine(value1, index1, value2, index2, False)
@jit
def _fast_min(x, y):
return math.min(x, y)
@jit
@_add_reduction_docstr("minimum",
return_indices_arg="return_indices",
@@ -1502,7 +1492,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
else:
assert input.dtype.is_integer_type()
input = input.to(int32)
return reduce(input, axis, _fast_min)
return reduce(input, axis, minimum)
@jit
@@ -1926,6 +1916,16 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
def binary_op_type_legalization(lhs, rhs, builder):
'''
Convert both operands to a single common type
:param lhs: the left operand
:param rhs: the right operand
:param builder: the builder
'''
return semantic.binary_op_type_checking_impl(lhs, rhs, builder)
def extern(fn):
"""A decorator for external functions."""
return builtin(fn)

View File

@@ -40,26 +40,34 @@ def byte_perm(arg0, arg1, arg2, _builder=None):
@core.extern
def min(arg0, arg1, _builder=None):
return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ],
{(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")),
(core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")),
(core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")),
(core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")),
(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")),
(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")),
}, is_pure=True, _builder=_builder)
arg0 = core._to_tensor(arg0, _builder)
arg1 = core._to_tensor(arg1, _builder)
arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder)
dtype = arg0.dtype
if dtype.is_floating():
return core.tensor(_builder.create_minf(arg0.handle, arg1.handle), arg0.type)
elif dtype.is_int_signed():
return core.tensor(_builder.create_minsi(arg0.handle, arg1.handle), arg0.type)
elif dtype.is_int_unsigned():
return core.tensor(_builder.create_minui(arg0.handle, arg1.handle), arg0.dtype)
else:
assert False, f"Unexpected dtype {dtype}"
@core.extern
def max(arg0, arg1, _builder=None):
return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ],
{(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")),
(core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")),
(core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")),
(core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")),
(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")),
(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")),
}, is_pure=True, _builder=_builder)
arg0 = core._to_tensor(arg0, _builder)
arg1 = core._to_tensor(arg1, _builder)
arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder)
dtype = arg0.dtype
if dtype.is_floating():
return core.tensor(_builder.create_maxf(arg0.handle, arg1.handle), arg0.type)
elif dtype.is_int_signed():
return core.tensor(_builder.create_maxsi(arg0.handle, arg1.handle), arg0.type)
elif dtype.is_int_unsigned():
return core.tensor(_builder.create_maxui(arg0.handle, arg1.handle), arg0.dtype)
else:
assert False, f"Unexpected dtype {dtype}"
@core.extern