[FRONTEND][BACKEND] Lower tl.abs to math::Abs{I,F}Op (#1401)

This generates identical PTX for floating point, but for integer types
the resulting PTX is much better. For example `tl.abs` for int16
currently generates

```mlir
  cvt.s32.s16 %r1, %rs2;
  neg.s16     %rs4, %rs2;
  setp.lt.s32 %p4, %r1, 0;
  selp.b16    %rs3, %rs4, %rs2, %p4;
```

After, it becomes a single `abs.s16` instruction.

This also improves LLVM's ability to optimize floats. e.g. `abs(t) *
abs(t)` is optimized to `t * t` now which didn't happen before.

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
peterbell10
2023-03-25 04:58:24 +00:00
committed by GitHub
parent a9c87245b4
commit 6063fccd0b
7 changed files with 102 additions and 13 deletions

View File

@@ -79,6 +79,7 @@ Math Ops
:toctree: generated
:nosignatures:
abs
exp
log
cos

View File

@@ -1011,6 +1011,48 @@ struct ExpOpConversionApprox
}
};
struct AbsIOpConversion
: ElementwiseOpConversionBase<mlir::math::AbsIOp, AbsIOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::math::AbsIOp, AbsIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::math::AbsIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto boolFalse = rewriter.getBoolAttr(false);
auto constFalse = rewriter.create<LLVM::ConstantOp>(loc, boolFalse);
return rewriter.create<LLVM::AbsOp>(loc, elemTy, operands[0],
/*is_int_min_poison=*/constFalse);
}
};
struct AbsFOpConversion
: ElementwiseOpConversionBase<mlir::math::AbsFOp, AbsFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::math::AbsFOp, AbsFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::math::AbsFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
if (llvm::isa<IntegerType>(elemTy)) {
// Mask out the sign bit
auto num_bits =
getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
assert(num_bits <= 16);
auto mask = (1u << (num_bits - 1u)) - 1u;
auto maskAttr = rewriter.getIntegerAttr(elemTy, mask);
auto maskConst = rewriter.create<LLVM::ConstantOp>(loc, maskAttr);
return and_(operands[0], maskConst);
}
return rewriter.create<LLVM::FAbsOp>(loc, elemTy, operands[0]);
}
};
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
@@ -1056,6 +1098,8 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<AbsIOpConversion>(typeConverter, benefit);
patterns.add<AbsFOpConversion>(typeConverter, benefit);
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);

View File

@@ -174,6 +174,7 @@ void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
// Rewrite rule
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
GenericOpPattern<math::AbsFOp>, GenericOpPattern<math::AbsIOp>,
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
}

View File

@@ -1311,6 +1311,16 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::SqrtOp>(loc, val);
})
.def("create_fabs",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::AbsFOp>(loc, val);
})
.def("create_iabs",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::AbsIOp>(loc, val);
})
.def("create_reduce",
[](mlir::OpBuilder &self, mlir::Value &operand,
mlir::triton::RedOp redOp, int axis) -> mlir::Value {

View File

@@ -532,6 +532,33 @@ def test_math_op(expr, device='cuda'):
def test_abs(dtype_x, device='cuda'):
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
def test_abs_f8(in_dtype):
@triton.jit
def abs_kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
z = tl.abs(x)
tl.store(Z + off, z)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
f8_tensor[all_exp_ones] = 0
f8 = triton.reinterpret(f8_tensor, in_dtype)
n_elements = f8_tensor.numel()
out_f8 = torch.empty_like(f8_tensor)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
abs_kernel[(1,)](f8, triton.reinterpret(out_f8, in_dtype), n_elements)
f32_tensor = convert_float_to_float32(f8_tensor, in_dtype)
expect = f32_tensor.abs()
actual_f8 = convert_float_to_float32(out_f8, in_dtype)
torch.testing.assert_allclose(expect, actual_f8)
# ----------------
# test indexing
# ----------------

View File

@@ -1107,6 +1107,12 @@ def sqrt(x, _builder=None):
return semantic.sqrt(x, _builder)
@builtin
@_add_math_1arg_docstr("absolute value")
def abs(x, _builder=None):
return semantic.abs(x, _builder)
# -----------------------
# Reductions
# -----------------------
@@ -1214,19 +1220,6 @@ def max_contiguous(input, values, _builder=None):
# Standard library
# -----------------------
@triton.jit
def abs(x):
x_dtype = x.dtype
if x_dtype.is_floating():
num_bits: constexpr = x.dtype.primitive_bitwidth
int_dtype = dtype(f'int{num_bits}')
mask = 2 ** (num_bits - 1) - 1
ret = x.to(int_dtype, bitcast=True) & mask.to(int_dtype)
ret = ret.to(x_dtype, bitcast=True)
else:
ret = where(x >= 0, x, -x)
return ret
@triton.jit
def cdiv(x, div):

View File

@@ -1223,8 +1223,21 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_sqrt(x.handle), x.type)
def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:
dtype = x.dtype
if dtype.is_floating():
return tl.tensor(builder.create_fabs(x.handle), x.type)
elif dtype.is_int_signed():
return tl.tensor(builder.create_iabs(x.handle), x.type)
elif dtype.is_int_unsigned():
return x # no-op
else:
assert False, f"Unexpected dtype {dtype}"
##
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")