mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Refactor min/max to unify tl.maximum and tl.math.max (#2091)
maximum used to generate a cmp/sel even for floating point types. Always using max op allows better code quality and avoids having different behavior than tl.math.max
This commit is contained in:
@@ -1066,6 +1066,36 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::ShRSIOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_minsi",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MinSIOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_minui",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MinUIOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_minf",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MinFOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_maxsi",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MaxSIOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_maxui",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MaxUIOp>(lhs, rhs));
|
||||
})
|
||||
.def("create_maxf",
|
||||
[](TritonOpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::MaxFOp>(lhs, rhs));
|
||||
})
|
||||
// AddPtr (similar to GEP)
|
||||
.def("create_addptr",
|
||||
[](TritonOpBuilder &self, mlir::Value &ptr,
|
||||
|
||||
@@ -1382,7 +1382,7 @@ def minimum(x, y):
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return where(x < y, x, y)
|
||||
return math.min(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -1395,7 +1395,7 @@ def maximum(x, y):
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return where(x > y, x, y)
|
||||
return math.max(x, y)
|
||||
|
||||
# max and argmax
|
||||
|
||||
@@ -1422,11 +1422,6 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmax_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
def _fast_max(x, y):
|
||||
return math.max(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
@@ -1445,7 +1440,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(int32)
|
||||
return reduce(input, axis, _fast_max)
|
||||
return reduce(input, axis, maximum)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -1479,11 +1474,6 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmin_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
def _fast_min(x, y):
|
||||
return math.min(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
@@ -1502,7 +1492,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(int32)
|
||||
return reduce(input, axis, _fast_min)
|
||||
return reduce(input, axis, minimum)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -1926,6 +1916,16 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
|
||||
|
||||
|
||||
def binary_op_type_legalization(lhs, rhs, builder):
|
||||
'''
|
||||
Convert both operands to a single common type
|
||||
:param lhs: the left operand
|
||||
:param rhs: the right operand
|
||||
:param builder: the builder
|
||||
'''
|
||||
return semantic.binary_op_type_checking_impl(lhs, rhs, builder)
|
||||
|
||||
|
||||
def extern(fn):
|
||||
"""A decorator for external functions."""
|
||||
return builtin(fn)
|
||||
|
||||
@@ -40,26 +40,34 @@ def byte_perm(arg0, arg1, arg2, _builder=None):
|
||||
|
||||
@core.extern
|
||||
def min(arg0, arg1, _builder=None):
|
||||
return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ],
|
||||
{(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")),
|
||||
(core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")),
|
||||
(core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")),
|
||||
(core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")),
|
||||
(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")),
|
||||
(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")),
|
||||
}, is_pure=True, _builder=_builder)
|
||||
arg0 = core._to_tensor(arg0, _builder)
|
||||
arg1 = core._to_tensor(arg1, _builder)
|
||||
arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder)
|
||||
dtype = arg0.dtype
|
||||
if dtype.is_floating():
|
||||
return core.tensor(_builder.create_minf(arg0.handle, arg1.handle), arg0.type)
|
||||
elif dtype.is_int_signed():
|
||||
return core.tensor(_builder.create_minsi(arg0.handle, arg1.handle), arg0.type)
|
||||
elif dtype.is_int_unsigned():
|
||||
return core.tensor(_builder.create_minui(arg0.handle, arg1.handle), arg0.dtype)
|
||||
else:
|
||||
assert False, f"Unexpected dtype {dtype}"
|
||||
|
||||
|
||||
@core.extern
|
||||
def max(arg0, arg1, _builder=None):
|
||||
return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ],
|
||||
{(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")),
|
||||
(core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")),
|
||||
(core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")),
|
||||
(core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")),
|
||||
(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")),
|
||||
(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")),
|
||||
}, is_pure=True, _builder=_builder)
|
||||
arg0 = core._to_tensor(arg0, _builder)
|
||||
arg1 = core._to_tensor(arg1, _builder)
|
||||
arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder)
|
||||
dtype = arg0.dtype
|
||||
if dtype.is_floating():
|
||||
return core.tensor(_builder.create_maxf(arg0.handle, arg1.handle), arg0.type)
|
||||
elif dtype.is_int_signed():
|
||||
return core.tensor(_builder.create_maxsi(arg0.handle, arg1.handle), arg0.type)
|
||||
elif dtype.is_int_unsigned():
|
||||
return core.tensor(_builder.create_maxui(arg0.handle, arg1.handle), arg0.dtype)
|
||||
else:
|
||||
assert False, f"Unexpected dtype {dtype}"
|
||||
|
||||
|
||||
@core.extern
|
||||
|
||||
Reference in New Issue
Block a user