[FRONTENT][BACKEND] Add support for elementwise inline assembly (#2136)

Add a new operation to be able to implement packed inline assembly for
elementwise operations. This way inline assembly can be used to control
elementwise operations. It also allows to pack elements to be able to
manually vectorize operations.
This commit is contained in:
Thomas
2023-08-18 12:57:52 -07:00
committed by GitHub
parent c736ea8492
commit bf351b9ba2
10 changed files with 286 additions and 21 deletions

View File

@@ -510,6 +510,28 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// ElementwiseInlineAsm Op
//
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise,
SameOperandsAndResultEncoding,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "inline assembly applying elementwise operation to a group of packed element.";
let description = [{
This will apply the given in inline assembly to `packed_element` number of
elements of the inputs. The elements packed together is unknown and will
depend on the backend implementation.
}];
let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic<AnyTypeOf<[TT_Type]>>:$args);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{
$asm_string attr-dict ($args^ `:` type($args))? `->` type($result)
}];
}
//
// Print Op
//

View File

@@ -811,6 +811,86 @@ private:
}
};
struct ElementwiseInlineAsmOpConversion
: public ElementwiseOpConversionBase<ElementwiseInlineAsmOp,
ElementwiseInlineAsmOpConversion> {
using Base = ElementwiseOpConversionBase<ElementwiseInlineAsmOp,
ElementwiseInlineAsmOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
typedef typename Base::OpAdaptor OpAdaptor;
// If operand size is smaller than 32bits pack by groups of 32bits.
// Otherwise have separate inputs.
SmallVector<Value> packOperands(ElementwiseInlineAsmOp op,
MultipleOperandsRange operands,
ConversionPatternRewriter &rewriter,
Location loc) const {
SmallVector<Value> packedOperands;
unsigned numPackedElements = op.getPackedElement();
for (int i = 0, e = op.getNumOperands(); i < e; i++) {
unsigned bitWidth =
getElementType(op.getOperand(i)).getIntOrFloatBitWidth();
unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1;
numElementPerReg = std::min(numElementPerReg, numPackedElements);
for (int j = 0; j < numPackedElements; j += numElementPerReg) {
if (numElementPerReg == 1) {
packedOperands.push_back(operands[j][i]);
continue;
}
Type t = vec_ty(
getTypeConverter()->convertType(getElementType(op.getOperand(i))),
numElementPerReg);
Value packed = undef(t);
for (int k = 0; k < numElementPerReg; k++) {
packed = insert_element(packed, operands[j + k][i], i32_val(k));
}
packedOperands.push_back(packed);
}
}
return packedOperands;
}
SmallVector<Value> createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
int numPackedElements = op.getPackedElement();
if (operands.size() % numPackedElements != 0)
llvm::report_fatal_error("Inline asm op has more packed elements than "
"number of elements per thread.");
SmallVector<Value> packedOperands =
packOperands(op, operands, rewriter, loc);
Type dstType =
getTypeConverter()->convertType(getElementType(op.getResult()));
Type retType = dstType;
if (numPackedElements > 1)
retType = vec_ty(retType, numPackedElements);
Value result = rewriter
.create<LLVM::InlineAsmOp>(
loc, retType,
packedOperands, // operands
op.getAsmString(), // asm_string
op.getConstraints(), // constraints
!op.getPure(), // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(
rewriter.getContext(),
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr() // operand_attrs
)
->getResult(0);
SmallVector<Value> results;
if (numPackedElements > 1) {
for (int i = 0; i < numPackedElements; i++)
results.push_back(extract_element(result, i32_val(i)));
} else {
results = {result};
}
return results;
}
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
using Base =
@@ -1213,6 +1293,7 @@ void populateElementwiseOpToLLVMPatterns(
patterns
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
typeConverter, benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call

View File

@@ -695,26 +695,26 @@ public:
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns, unsigned numCTAs) {
MLIRContext *context = patterns.getContext();
patterns
.insert< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::AdvanceOp>,
TritonGenericPattern<triton::MakeTensorPtrOp>,
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::FpToFpOp>,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonReducePattern, TritonReduceReturnPattern, TritonScanPattern,
TritonScanReturnPattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern,
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
typeConverter, context);
patterns.insert< // TODO: view should have custom pattern that views the
// layout
TritonGenericPattern<triton::AdvanceOp>,
TritonGenericPattern<triton::MakeTensorPtrOp>,
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::FpToFpOp>,
TritonGenericPattern<triton::IntToPtrOp>,
TritonGenericPattern<triton::PtrToIntOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
TritonGenericPattern<triton::ElementwiseInlineAsmOp>, TritonReducePattern,
TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern,
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
typeConverter, context);
}
//

View File

@@ -870,5 +870,17 @@ LogicalResult triton::ReturnOp::verify() {
return success();
}
// -- ElementwiseInlineAsmOp --
void ElementwiseInlineAsmOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getPure())
return;
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(),
SideEffects::DefaultResource::get());
}
} // namespace triton
} // namespace mlir

View File

@@ -1524,6 +1524,14 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::SelectOp>(condition, trueValue,
falseValue);
})
.def("create_inline_asm",
[](TritonOpBuilder &self, const std::string &inlineAsm,
const std::string &constraints,
const std::vector<mlir::Value> &values, mlir::Type &type,
bool isPure, int pack) -> mlir::Value {
return self.create<mlir::triton::ElementwiseInlineAsmOp>(
type, inlineAsm, constraints, isPure, pack, values);
})
.def("create_print",
[](TritonOpBuilder &self, const std::string &prefix,
const std::vector<mlir::Value> &values) -> void {

View File

@@ -3034,6 +3034,60 @@ def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device):
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
# -----------------------
# test inline asm
# -----------------------
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm(num_ctas, device):
check_cuda_only(device)
@triton.jit
def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.load(Y + tl.arange(0, BLOCK))
s = tl.full([BLOCK], n, tl.int32)
z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, is_pure=True, pack=1)
tl.store(Z + tl.arange(0, BLOCK), z)
shape = (128, )
rs = RandomState(17)
x = numpy_random(shape, dtype_str='uint32', rs=rs)
y = numpy_random(shape, dtype_str='uint32', rs=rs)
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
n = 17
z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device)
kernel[(1,)](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas)
y_ref = (y << n) | (x >> (32 - n))
# compare
np.testing.assert_equal(y_ref, to_numpy(z_tri))
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm_packed(num_ctas, device):
check_cuda_only(device)
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
# shift 4x8bits values together.
y = tl.inline_asm_elementwise("and.b32 $0, $1, 0x1F1F1F1F; \
shl.b32 $0, $0, 3;",
"=r,r", [x,], dtype=tl.int8, is_pure=True, pack=4)
tl.store(Y + tl.arange(0, BLOCK), y)
shape = (512, )
rs = RandomState(17)
x = numpy_random(shape, dtype_str='uint8', rs=rs)
x_tri = to_triton(x, device=device)
y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device)
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas)
y_ref = x << 3
# compare
np.testing.assert_equal(y_ref, to_numpy(y_tri))
# -----------------------
# test control flow
# -----------------------

View File

@@ -54,6 +54,7 @@ from .core import (
float8e4,
float8e5,
function_type,
inline_asm_elementwise,
int1,
int16,
int32,
@@ -154,6 +155,7 @@ __all__ = [
"float8e5",
"full",
"function_type",
"inline_asm_elementwise",
"int1",
"int16",
"int32",

View File

@@ -1788,6 +1788,46 @@ def device_assert(cond, msg="", _builder=None):
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
@builtin
def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pure: bool, pack: int, _builder=None):
'''
Execute the inline assembly to a packed of elements of the tensor
:param asm: assembly to be inlined, it has to match the target assembly format
:param constraints: string representing the mapping of operands to register
:param args: the arguments of the operation
:param dtype: the element type of the returned variable
:param is_pure: whether the operation is pure
:param pack: the number of elements to be processed by one instance of inline assembly
:param _builder: the builder
:return: the return value of the function
'''
dispatch_args = args.copy()
asm = _constexpr_to_value(asm)
constraints = _constexpr_to_value(constraints)
pack = _constexpr_to_value(pack)
is_pure = _constexpr_to_value(is_pure)
ret_shape = None
arg_types = []
for i in range(len(dispatch_args)):
dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
arg_types.append(dispatch_args[i].dtype)
if len(arg_types) > 0:
arg_types = tuple(arg_types)
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i, item in enumerate(dispatch_args):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
item, broadcast_arg, _builder, arithmetic_check=False)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
ret_shape = broadcast_arg.shape
res_ty = block_type(dtype, ret_shape).to_ir(_builder)
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty, is_pure, pack)
return tensor(call, block_type(dtype, ret_shape))
# -----------------------
# Iterators
# -----------------------

View File

@@ -201,5 +201,12 @@ tt.func @scan_op(%ptr: tensor<1x2x4x!tt.ptr<f32>>, %v : tensor<1x2x4xf32>) {
}) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
tt.store %ptr, %a : tensor<1x2x4xf32>
tt.return
}
// CHECK-LABEL: inline_asm
// CHECK: tt.elementwise_inline_asm "shl.b32 $0, $0, 3;"
tt.func @inline_asm(%0: tensor<512xi8>) {
%1 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;"
{constraints = "=r,r", packed_element = 4 : i32, pure = true} %0 : tensor<512xi8> -> tensor<512xi8>
tt.return
}

View File

@@ -1396,3 +1396,42 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
tt.return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: inline_asm
tt.func public @inline_asm(%arg0: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
%2 = tt.addptr %1, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8, #blocked>
// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b32 $0, $0, 3;", "=r,r" %{{.*}} : (vector<4xi8>) -> vector<4xi8>
%4 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {constraints = "=r,r", packed_element = 4 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked>
%5 = tt.splat %arg1 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
%6 = tt.addptr %5, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<512xi8, #blocked>
tt.return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: inline_asm_pack_16bit
tt.func public @inline_asm_pack_16bit(%arg0: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
%2 = tt.addptr %1, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8, #blocked>
// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b16 $0, $0, 3;", "=h,h" %{{.*}} : (vector<2xi8>) -> vector<2xi8>
%4 = tt.elementwise_inline_asm "shl.b16 $0, $0, 3;" {constraints = "=h,h", packed_element = 2 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked>
%5 = tt.splat %arg1 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
%6 = tt.addptr %5, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<512xi8, #blocked>
tt.return
}
}