mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTENT][BACKEND] Add support for elementwise inline assembly (#2136)
Add a new operation to be able to implement packed inline assembly for elementwise operations. This way inline assembly can be used to control elementwise operations. It also allows to pack elements to be able to manually vectorize operations.
This commit is contained in:
@@ -510,6 +510,28 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// ElementwiseInlineAsm Op
|
||||
//
|
||||
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "inline assembly applying elementwise operation to a group of packed element.";
|
||||
let description = [{
|
||||
This will apply the given in inline assembly to `packed_element` number of
|
||||
elements of the inputs. The elements packed together is unknown and will
|
||||
depend on the backend implementation.
|
||||
}];
|
||||
|
||||
let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic<AnyTypeOf<[TT_Type]>>:$args);
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
$asm_string attr-dict ($args^ `:` type($args))? `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
// Print Op
|
||||
//
|
||||
|
||||
@@ -811,6 +811,86 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
struct ElementwiseInlineAsmOpConversion
|
||||
: public ElementwiseOpConversionBase<ElementwiseInlineAsmOp,
|
||||
ElementwiseInlineAsmOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<ElementwiseInlineAsmOp,
|
||||
ElementwiseInlineAsmOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
typedef typename Base::OpAdaptor OpAdaptor;
|
||||
|
||||
// If operand size is smaller than 32bits pack by groups of 32bits.
|
||||
// Otherwise have separate inputs.
|
||||
SmallVector<Value> packOperands(ElementwiseInlineAsmOp op,
|
||||
MultipleOperandsRange operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
SmallVector<Value> packedOperands;
|
||||
unsigned numPackedElements = op.getPackedElement();
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; i++) {
|
||||
unsigned bitWidth =
|
||||
getElementType(op.getOperand(i)).getIntOrFloatBitWidth();
|
||||
unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1;
|
||||
numElementPerReg = std::min(numElementPerReg, numPackedElements);
|
||||
for (int j = 0; j < numPackedElements; j += numElementPerReg) {
|
||||
if (numElementPerReg == 1) {
|
||||
packedOperands.push_back(operands[j][i]);
|
||||
continue;
|
||||
}
|
||||
Type t = vec_ty(
|
||||
getTypeConverter()->convertType(getElementType(op.getOperand(i))),
|
||||
numElementPerReg);
|
||||
Value packed = undef(t);
|
||||
for (int k = 0; k < numElementPerReg; k++) {
|
||||
packed = insert_element(packed, operands[j + k][i], i32_val(k));
|
||||
}
|
||||
packedOperands.push_back(packed);
|
||||
}
|
||||
}
|
||||
return packedOperands;
|
||||
}
|
||||
|
||||
SmallVector<Value> createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
int numPackedElements = op.getPackedElement();
|
||||
if (operands.size() % numPackedElements != 0)
|
||||
llvm::report_fatal_error("Inline asm op has more packed elements than "
|
||||
"number of elements per thread.");
|
||||
SmallVector<Value> packedOperands =
|
||||
packOperands(op, operands, rewriter, loc);
|
||||
Type dstType =
|
||||
getTypeConverter()->convertType(getElementType(op.getResult()));
|
||||
Type retType = dstType;
|
||||
if (numPackedElements > 1)
|
||||
retType = vec_ty(retType, numPackedElements);
|
||||
Value result = rewriter
|
||||
.create<LLVM::InlineAsmOp>(
|
||||
loc, retType,
|
||||
packedOperands, // operands
|
||||
op.getAsmString(), // asm_string
|
||||
op.getConstraints(), // constraints
|
||||
!op.getPure(), // has_side_effects
|
||||
false, // is_align_stack
|
||||
LLVM::AsmDialectAttr::get(
|
||||
rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||
ArrayAttr() // operand_attrs
|
||||
)
|
||||
->getResult(0);
|
||||
SmallVector<Value> results;
|
||||
if (numPackedElements > 1) {
|
||||
for (int i = 0; i < numPackedElements; i++)
|
||||
results.push_back(extract_element(result, i32_val(i)));
|
||||
} else {
|
||||
results = {result};
|
||||
}
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
struct FDivOpConversion
|
||||
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
|
||||
using Base =
|
||||
@@ -1213,6 +1293,7 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
patterns
|
||||
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
|
||||
typeConverter, benefit);
|
||||
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is
|
||||
// FP32. For other input types, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
|
||||
@@ -695,26 +695,26 @@ public:
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, unsigned numCTAs) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns
|
||||
.insert< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::AdvanceOp>,
|
||||
TritonGenericPattern<triton::MakeTensorPtrOp>,
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonReduceReturnPattern, TritonScanPattern,
|
||||
TritonScanReturnPattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern,
|
||||
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
|
||||
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
|
||||
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
|
||||
typeConverter, context);
|
||||
patterns.insert< // TODO: view should have custom pattern that views the
|
||||
// layout
|
||||
TritonGenericPattern<triton::AdvanceOp>,
|
||||
TritonGenericPattern<triton::MakeTensorPtrOp>,
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonGenericPattern<triton::ElementwiseInlineAsmOp>, TritonReducePattern,
|
||||
TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
|
||||
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
|
||||
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -870,5 +870,17 @@ LogicalResult triton::ReturnOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
// -- ElementwiseInlineAsmOp --
|
||||
void ElementwiseInlineAsmOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (getPure())
|
||||
return;
|
||||
effects.emplace_back(MemoryEffects::Write::get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Read::get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1524,6 +1524,14 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::SelectOp>(condition, trueValue,
|
||||
falseValue);
|
||||
})
|
||||
.def("create_inline_asm",
|
||||
[](TritonOpBuilder &self, const std::string &inlineAsm,
|
||||
const std::string &constraints,
|
||||
const std::vector<mlir::Value> &values, mlir::Type &type,
|
||||
bool isPure, int pack) -> mlir::Value {
|
||||
return self.create<mlir::triton::ElementwiseInlineAsmOp>(
|
||||
type, inlineAsm, constraints, isPure, pack, values);
|
||||
})
|
||||
.def("create_print",
|
||||
[](TritonOpBuilder &self, const std::string &prefix,
|
||||
const std::vector<mlir::Value> &values) -> void {
|
||||
|
||||
@@ -3034,6 +3034,60 @@ def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device):
|
||||
# compare
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# test inline asm
|
||||
# -----------------------
|
||||
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_inline_asm(num_ctas, device):
|
||||
check_cuda_only(device)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.load(Y + tl.arange(0, BLOCK))
|
||||
s = tl.full([BLOCK], n, tl.int32)
|
||||
z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, is_pure=True, pack=1)
|
||||
tl.store(Z + tl.arange(0, BLOCK), z)
|
||||
|
||||
shape = (128, )
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(shape, dtype_str='uint32', rs=rs)
|
||||
y = numpy_random(shape, dtype_str='uint32', rs=rs)
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
n = 17
|
||||
z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device)
|
||||
kernel[(1,)](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas)
|
||||
y_ref = (y << n) | (x >> (32 - n))
|
||||
# compare
|
||||
np.testing.assert_equal(y_ref, to_numpy(z_tri))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_inline_asm_packed(num_ctas, device):
|
||||
check_cuda_only(device)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# shift 4x8bits values together.
|
||||
y = tl.inline_asm_elementwise("and.b32 $0, $1, 0x1F1F1F1F; \
|
||||
shl.b32 $0, $0, 3;",
|
||||
"=r,r", [x,], dtype=tl.int8, is_pure=True, pack=4)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
|
||||
shape = (512, )
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(shape, dtype_str='uint8', rs=rs)
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device)
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas)
|
||||
y_ref = x << 3
|
||||
# compare
|
||||
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||
|
||||
# -----------------------
|
||||
# test control flow
|
||||
# -----------------------
|
||||
|
||||
@@ -54,6 +54,7 @@ from .core import (
|
||||
float8e4,
|
||||
float8e5,
|
||||
function_type,
|
||||
inline_asm_elementwise,
|
||||
int1,
|
||||
int16,
|
||||
int32,
|
||||
@@ -154,6 +155,7 @@ __all__ = [
|
||||
"float8e5",
|
||||
"full",
|
||||
"function_type",
|
||||
"inline_asm_elementwise",
|
||||
"int1",
|
||||
"int16",
|
||||
"int32",
|
||||
|
||||
@@ -1788,6 +1788,46 @@ def device_assert(cond, msg="", _builder=None):
|
||||
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pure: bool, pack: int, _builder=None):
|
||||
'''
|
||||
Execute the inline assembly to a packed of elements of the tensor
|
||||
:param asm: assembly to be inlined, it has to match the target assembly format
|
||||
:param constraints: string representing the mapping of operands to register
|
||||
:param args: the arguments of the operation
|
||||
:param dtype: the element type of the returned variable
|
||||
:param is_pure: whether the operation is pure
|
||||
:param pack: the number of elements to be processed by one instance of inline assembly
|
||||
:param _builder: the builder
|
||||
:return: the return value of the function
|
||||
'''
|
||||
dispatch_args = args.copy()
|
||||
asm = _constexpr_to_value(asm)
|
||||
constraints = _constexpr_to_value(constraints)
|
||||
pack = _constexpr_to_value(pack)
|
||||
is_pure = _constexpr_to_value(is_pure)
|
||||
ret_shape = None
|
||||
arg_types = []
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
|
||||
arg_types.append(dispatch_args[i].dtype)
|
||||
if len(arg_types) > 0:
|
||||
arg_types = tuple(arg_types)
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder, arithmetic_check=False)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape).to_ir(_builder)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty, is_pure, pack)
|
||||
return tensor(call, block_type(dtype, ret_shape))
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Iterators
|
||||
# -----------------------
|
||||
|
||||
@@ -201,5 +201,12 @@ tt.func @scan_op(%ptr: tensor<1x2x4x!tt.ptr<f32>>, %v : tensor<1x2x4xf32>) {
|
||||
}) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
|
||||
tt.store %ptr, %a : tensor<1x2x4xf32>
|
||||
tt.return
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: inline_asm
|
||||
// CHECK: tt.elementwise_inline_asm "shl.b32 $0, $0, 3;"
|
||||
tt.func @inline_asm(%0: tensor<512xi8>) {
|
||||
%1 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;"
|
||||
{constraints = "=r,r", packed_element = 4 : i32, pure = true} %0 : tensor<512xi8> -> tensor<512xi8>
|
||||
tt.return
|
||||
}
|
||||
|
||||
@@ -1396,3 +1396,42 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
// CHECK-LABEL: inline_asm
|
||||
tt.func public @inline_asm(%arg0: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
|
||||
%0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
|
||||
%2 = tt.addptr %1, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8, #blocked>
|
||||
// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b32 $0, $0, 3;", "=r,r" %{{.*}} : (vector<4xi8>) -> vector<4xi8>
|
||||
%4 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {constraints = "=r,r", packed_element = 4 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked>
|
||||
%5 = tt.splat %arg1 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
|
||||
%6 = tt.addptr %5, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
|
||||
tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<512xi8, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
// CHECK-LABEL: inline_asm_pack_16bit
|
||||
tt.func public @inline_asm_pack_16bit(%arg0: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8, 1> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
|
||||
%0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
|
||||
%2 = tt.addptr %1, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8, #blocked>
|
||||
// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b16 $0, $0, 3;", "=h,h" %{{.*}} : (vector<2xi8>) -> vector<2xi8>
|
||||
%4 = tt.elementwise_inline_asm "shl.b16 $0, $0, 3;" {constraints = "=h,h", packed_element = 2 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked>
|
||||
%5 = tt.splat %arg1 : (!tt.ptr<i8, 1>) -> tensor<512x!tt.ptr<i8, 1>, #blocked>
|
||||
%6 = tt.addptr %5, %0 : tensor<512x!tt.ptr<i8, 1>, #blocked>, tensor<512xi32, #blocked>
|
||||
tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<512xi8, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user