mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Lower tl.abs to math::Abs{I,F}Op (#1401)
This generates identical PTX for floating point, but for integer types the resulting PTX is much better. For example `tl.abs` for int16 currently generates ```mlir cvt.s32.s16 %r1, %rs2; neg.s16 %rs4, %rs2; setp.lt.s32 %p4, %r1, 0; selp.b16 %rs3, %rs4, %rs2, %p4; ``` After, it becomes a single `abs.s16` instruction. This also improves LLVM's ability to optimize floats. e.g. `abs(t) * abs(t)` is optimized to `t * t` now which didn't happen before. --------- Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -79,6 +79,7 @@ Math Ops
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
abs
|
||||
exp
|
||||
log
|
||||
cos
|
||||
|
||||
@@ -1011,6 +1011,48 @@ struct ExpOpConversionApprox
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsIOpConversion
|
||||
: ElementwiseOpConversionBase<mlir::math::AbsIOp, AbsIOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<mlir::math::AbsIOp, AbsIOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
Value createDestOp(mlir::math::AbsIOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
auto boolFalse = rewriter.getBoolAttr(false);
|
||||
auto constFalse = rewriter.create<LLVM::ConstantOp>(loc, boolFalse);
|
||||
return rewriter.create<LLVM::AbsOp>(loc, elemTy, operands[0],
|
||||
/*is_int_min_poison=*/constFalse);
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsFOpConversion
|
||||
: ElementwiseOpConversionBase<mlir::math::AbsFOp, AbsFOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<mlir::math::AbsFOp, AbsFOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
Value createDestOp(mlir::math::AbsFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
if (llvm::isa<IntegerType>(elemTy)) {
|
||||
// Mask out the sign bit
|
||||
auto num_bits =
|
||||
getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
|
||||
assert(num_bits <= 16);
|
||||
auto mask = (1u << (num_bits - 1u)) - 1u;
|
||||
auto maskAttr = rewriter.getIntegerAttr(elemTy, mask);
|
||||
auto maskConst = rewriter.create<LLVM::ConstantOp>(loc, maskAttr);
|
||||
return and_(operands[0], maskConst);
|
||||
}
|
||||
|
||||
return rewriter.create<LLVM::FAbsOp>(loc, elemTy, operands[0]);
|
||||
}
|
||||
};
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -1056,6 +1098,8 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||
#undef POPULATE_UNARY_OP
|
||||
|
||||
patterns.add<AbsIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AbsFOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
// Rewrite rule
|
||||
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
||||
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
||||
GenericOpPattern<math::AbsFOp>, GenericOpPattern<math::AbsIOp>,
|
||||
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
@@ -1311,6 +1311,16 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::SqrtOp>(loc, val);
|
||||
})
|
||||
.def("create_fabs",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::AbsFOp>(loc, val);
|
||||
})
|
||||
.def("create_iabs",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::AbsIOp>(loc, val);
|
||||
})
|
||||
.def("create_reduce",
|
||||
[](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
|
||||
@@ -532,6 +532,33 @@ def test_math_op(expr, device='cuda'):
|
||||
def test_abs(dtype_x, device='cuda'):
|
||||
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
def test_abs_f8(in_dtype):
|
||||
|
||||
@triton.jit
|
||||
def abs_kernel(Z, X, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
x = tl.load(X + off)
|
||||
z = tl.abs(x)
|
||||
tl.store(Z + off, z)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
out_f8 = torch.empty_like(f8_tensor)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
abs_kernel[(1,)](f8, triton.reinterpret(out_f8, in_dtype), n_elements)
|
||||
|
||||
f32_tensor = convert_float_to_float32(f8_tensor, in_dtype)
|
||||
expect = f32_tensor.abs()
|
||||
actual_f8 = convert_float_to_float32(out_f8, in_dtype)
|
||||
torch.testing.assert_allclose(expect, actual_f8)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
# ----------------
|
||||
|
||||
@@ -1107,6 +1107,12 @@ def sqrt(x, _builder=None):
|
||||
return semantic.sqrt(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("absolute value")
|
||||
def abs(x, _builder=None):
|
||||
return semantic.abs(x, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Reductions
|
||||
# -----------------------
|
||||
@@ -1214,19 +1220,6 @@ def max_contiguous(input, values, _builder=None):
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
@triton.jit
|
||||
def abs(x):
|
||||
x_dtype = x.dtype
|
||||
if x_dtype.is_floating():
|
||||
num_bits: constexpr = x.dtype.primitive_bitwidth
|
||||
int_dtype = dtype(f'int{num_bits}')
|
||||
mask = 2 ** (num_bits - 1) - 1
|
||||
ret = x.to(int_dtype, bitcast=True) & mask.to(int_dtype)
|
||||
ret = ret.to(x_dtype, bitcast=True)
|
||||
else:
|
||||
ret = where(x >= 0, x, -x)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv(x, div):
|
||||
|
||||
@@ -1223,8 +1223,21 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_sqrt(x.handle), x.type)
|
||||
|
||||
|
||||
def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
dtype = x.dtype
|
||||
if dtype.is_floating():
|
||||
return tl.tensor(builder.create_fabs(x.handle), x.type)
|
||||
elif dtype.is_int_signed():
|
||||
return tl.tensor(builder.create_iabs(x.handle), x.type)
|
||||
elif dtype.is_int_unsigned():
|
||||
return x # no-op
|
||||
else:
|
||||
assert False, f"Unexpected dtype {dtype}"
|
||||
|
||||
|
||||
##
|
||||
|
||||
|
||||
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
|
||||
Reference in New Issue
Block a user