feat: implement concat operation

This commit is contained in:
Umut
2022-02-15 16:22:09 +03:00
parent d41a7f0b68
commit 20a89b7b42
9 changed files with 1119 additions and 0 deletions

View File

@@ -601,4 +601,48 @@ def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
}];
}
def ConcatOp : FHELinalg_Op<"concat"> {
let summary = "Concatenates a sequence of tensors along an existing axis.";
let description = [{
Concatenates several tensors along a given axis.
Examples:
```mlir
"FHELinalg.concat"(%a, %b) { axis = 0 } : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> tensor<6x3x!FHE.eint<4>>
//
// ( [1,2,3] [1,2,3] ) [1,2,3]
// concat ( [4,5,6], [4,5,6] ) = [4,5,6]
// ( [7,8,9] [7,8,9] ) [7,8,9]
// [1,2,3]
// [4,5,6]
// [7,8,9]
//
```
```mlir
"FHELinalg.concat"(%a, %b) { axis = 1 } : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> tensor<3x6x!FHE.eint<4>>
//
// ( [1,2,3] [1,2,3] ) [1,2,3,1,2,3]
// concat ( [4,5,6], [4,5,6] ) = [4,5,6,4,5,6]
// ( [7,8,9] [7,8,9] ) [7,8,9,7,8,9]
//
```
}];
let arguments = (ins
Variadic<Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>>:$ins,
DefaultValuedAttr<I64Attr, "0">:$axis
);
let results = (outs
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$out
);
let verifier = [{
return mlir::concretelang::FHELinalg::verifyConcat(*this);
}];
}
#endif

View File

@@ -1126,6 +1126,121 @@ struct SumToLinalgGeneric
};
};
// This rewrite pattern transforms any instance of operators
// `FHELinalg.concat` to instances of `tensor.insert_slice`
//
// Example:
//
// %result = "FHELinalg.concat"(%x, %y) { axis = 1 } :
// (tensor<2x3x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>)
// -> tensor<2x7x!FHE.eint<4>>
//
// becomes:
//
// %empty = "FHELinalg.zero"() : () -> tensor<2x7x!FHE.eint<4>>
//
// %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1]
// : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
//
// %y_copied = tensor.insert_slice %y into %x_copied[0, 3] [2, 4] [1, 1]
// : tensor<2x4x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
//
struct ConcatRewritePattern
: public mlir::OpRewritePattern<FHELinalg::ConcatOp> {
ConcatRewritePattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<FHELinalg::ConcatOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
mlir::LogicalResult
matchAndRewrite(FHELinalg::ConcatOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
size_t axis = op.axis();
mlir::Value output = op.getResult();
auto outputType = output.getType().dyn_cast<mlir::TensorType>();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
size_t outputDimensions = outputShape.size();
mlir::Value result =
rewriter.create<FHELinalg::ZeroOp>(location, outputType).getResult();
auto offsets = llvm::SmallVector<int64_t, 3>{};
auto sizes = llvm::SmallVector<int64_t, 3>{};
auto strides = llvm::SmallVector<int64_t, 3>{};
// set up the initial values of offsets, sizes, and strides
// each one has exactly `outputDimensions` number of elements
// - offsets will be [0, 0, 0, ..., 0, 0, 0]
// - strides will be [1, 1, 1, ..., 1, 1, 1]
// - sizes will be the output shape except at the 'axis' which will be 0
for (size_t i = 0; i < outputDimensions; i++) {
offsets.push_back(0);
if (i == axis) {
sizes.push_back(0);
} else {
sizes.push_back(outputShape[i]);
}
strides.push_back(1);
}
// these are not used, but they are required
// for the creation of InsertSliceOp operation
auto dynamicOffsets = llvm::ArrayRef<mlir::Value>{};
auto dynamicSizes = llvm::ArrayRef<mlir::Value>{};
auto dynamicStrides = llvm::ArrayRef<mlir::Value>{};
for (mlir::Value input : op.getOperands()) {
auto inputType = input.getType().dyn_cast<mlir::TensorType>();
int64_t axisSize = inputType.getShape()[axis];
// offsets and sizes will be modified for each input tensor
// if we have:
// "FHELinalg.concat"(%x, %y, %z) :
// (
// tensor<3x!FHE.eint<7>>,
// tensor<4x!FHE.eint<7>>,
// tensor<2x!FHE.eint<7>>,
// )
// -> tensor<9x!FHE.eint<7>>
//
// for the first copy:
// offsets = [0], sizes = [3], strides = [1]
//
// for the second copy:
// offsets = [3], sizes = [4], strides = [1]
//
// for the third copy:
// offsets = [7], sizes = [2], strides = [1]
//
// so in each iteration:
// - the size is set to the axis size of the input
// - the offset is increased by the size of the previous input
sizes[axis] = axisSize;
// these arrays are copied, so it's fine to modify and use them again
mlir::ArrayAttr offsetsAttr = rewriter.getI64ArrayAttr(offsets);
mlir::ArrayAttr sizesAttr = rewriter.getI64ArrayAttr(sizes);
mlir::ArrayAttr stridesAttr = rewriter.getI64ArrayAttr(strides);
offsets[axis] += axisSize;
result = rewriter
.create<mlir::tensor::InsertSliceOp>(
location, outputType, input, result, dynamicOffsets,
dynamicSizes, dynamicStrides, offsetsAttr, sizesAttr,
stridesAttr)
.getResult();
}
rewriter.replaceOp(op, {result});
return mlir::success();
};
};
namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
@@ -1186,6 +1301,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
&getContext());
patterns.insert<FHELinalgZeroToLinalgGenerate>(&getContext());
patterns.insert<SumToLinalgGeneric>(&getContext());
patterns.insert<ConcatRewritePattern>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -880,6 +880,20 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUMul(noiseMultiplier, operandMANP);
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::ConcatOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
llvm::APInt result = llvm::APInt{1, 0, false};
for (mlir::LatticeElement<MANPLatticeValue> *operandMANP : operandMANPs) {
llvm::APInt candidate = operandMANP->getValue().getMANP().getValue();
if (candidate.getLimitedValue() >= result.getLimitedValue()) {
result = candidate;
}
}
return result;
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
@@ -955,6 +969,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto sumOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::SumOp>(op)) {
norm2SqEquiv = getSqMANP(sumOp, operands);
} else if (auto concatOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::ConcatOp>(
op)) {
norm2SqEquiv = getSqMANP(concatOp, operands);
}
// Tensor Operators
// ExtractOp

View File

@@ -490,6 +490,105 @@ mlir::LogicalResult verifySum(SumOp &op) {
return mlir::success();
}
static bool sameShapeExceptAxis(llvm::ArrayRef<int64_t> shape1,
llvm::ArrayRef<int64_t> shape2, size_t axis) {
if (shape1.size() != shape2.size()) {
return false;
}
for (size_t i = 0; i < shape1.size(); i++) {
if (i != axis && shape1[i] != shape2[i]) {
return false;
}
}
return true;
}
mlir::LogicalResult verifyConcat(ConcatOp &op) {
unsigned numOperands = op.getNumOperands();
if (numOperands < 2) {
op->emitOpError() << "should have at least 2 inputs";
return mlir::failure();
}
int64_t axis = op.axis();
mlir::Value out = op.out();
auto outVectorType = out.getType().dyn_cast<mlir::TensorType>();
auto outElementType =
outVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
llvm::ArrayRef<int64_t> outShape = outVectorType.getShape();
size_t outDims = outShape.size();
if (axis < 0 || (size_t)axis >= outDims) {
op->emitOpError() << "has invalid axis attribute";
return mlir::failure();
}
int64_t expectedOutputElementsInAxis = 0;
size_t index = 0;
for (mlir::Value in : op.ins()) {
auto inVectorType = in.getType().dyn_cast<mlir::TensorType>();
auto inElementType =
inVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, inElementType,
outElementType)) {
return ::mlir::failure();
}
llvm::ArrayRef<int64_t> inShape = inVectorType.getShape();
if (!sameShapeExceptAxis(inShape, outShape, (size_t)axis)) {
auto stream = op->emitOpError();
stream << "does not have the proper shape of <";
if (axis == 0) {
stream << "?";
} else {
stream << outShape[0];
}
for (size_t i = 1; i < outDims; i++) {
stream << "x";
if (i == (size_t)axis) {
stream << "?";
} else {
stream << outShape[i];
}
}
stream << "> for input #" << index;
return mlir::failure();
}
expectedOutputElementsInAxis += inShape[axis];
index += 1;
}
if (outShape[axis] != expectedOutputElementsInAxis) {
auto stream = op->emitOpError();
stream << "does not have the proper output shape of <";
if (axis == 0) {
stream << expectedOutputElementsInAxis;
} else {
stream << outShape[0];
}
for (size_t i = 1; i < outDims; i++) {
stream << "x";
if (i == (size_t)axis) {
stream << expectedOutputElementsInAxis;
} else {
stream << outShape[i];
}
}
stream << ">";
return mlir::failure();
}
return mlir::success();
}
/// Verify the matmul shapes, the type of tensor elements should be checked by
/// something else
template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {

View File

@@ -0,0 +1,155 @@
// RUN: concretecompiler %s --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<7x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<7x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
return %0 : tensor<7x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
return %0 : tensor<7x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<4x7x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [4, 3] [1, 1] : tensor<4x3x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<4x7x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>>
return %0 : tensor<4x7x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
return %0 : tensor<4x3x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
return %0 : tensor<4x3x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<2x6x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<2x6x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>>
return %0 : tensor<2x6x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } : tensor<2x3x8x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 0, 4] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>>
// CHECK-NEXT: return %[[v2]] : tensor<2x3x8x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>>
return %0 : tensor<2x3x8x!FHE.eint<7>>
}

View File

@@ -546,3 +546,30 @@ func @sum() -> !FHE.eint<7> {
return %1 : !FHE.eint<7>
}
// -----
func @concat() -> tensor<3x!FHE.eint<7>> {
%0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>>
// CHECK: MANP = 2 : ui{{[0-9]+}}
%1 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>>
%2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>>
// CHECK: MANP = 3 : ui{{[0-9]+}}
%3 = "FHELinalg.sum"(%2) { keep_dims = true } : (tensor<5x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>>
%4 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>>
// CHECK: MANP = 4 : ui{{[0-9]+}}
%5 = "FHELinalg.sum"(%4) { keep_dims = true } : (tensor<10x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>>
// CHECK: MANP = 3 : ui{{[0-9]+}}
%6 = "FHELinalg.concat"(%1, %3) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>>
// CHECK: MANP = 4 : ui{{[0-9]+}}
%7 = "FHELinalg.concat"(%1, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>>
// CHECK: MANP = 4 : ui{{[0-9]+}}
%8 = "FHELinalg.concat"(%3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>>
// CHECK: MANP = 4 : ui{{[0-9]+}}
%9 = "FHELinalg.concat"(%1, %3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>>
return %9 : tensor<3x!FHE.eint<7>>
}

View File

@@ -0,0 +1,121 @@
// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s
// -----
func @main() -> tensor<0x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}}
%0 = "FHELinalg.concat"() : () -> tensor<0x!FHE.eint<7>>
return %0 : tensor<0x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}}
%0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>>
return %0 : tensor<4x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>>
return %0 : tensor<7x!FHE.eint<6>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op has invalid axis attribute}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 3 } : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op has invalid axis attribute}}
%0 = "FHELinalg.concat"(%x, %y) { axis = -3 } : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<10x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <7>}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<10x!FHE.eint<7>>
return %0 : tensor<10x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<5x4x!FHE.eint<7>>) -> tensor<10x4x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <8x4>}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<5x4x!FHE.eint<7>>) -> tensor<10x4x!FHE.eint<7>>
return %0 : tensor<10x4x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x10x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <3x9>}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x10x!FHE.eint<7>>
return %0 : tensor<3x10x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <?x4x10> for input #0}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
return %0 : tensor<3x4x10x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x?x10> for input #0}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
return %0 : tensor<3x4x10x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x4x?> for input #0}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
return %0 : tensor<3x4x10x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x4x?> for input #1}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<3x4x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
return %0 : tensor<3x4x10x!FHE.eint<7>>
}
// -----
func @main(%x: tensor<3x4x4x!FHE.eint<7>>, %y: tensor<3x5x4x!FHE.eint<7>>) -> tensor<3x10x4x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <3x9x4>}}
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x4x!FHE.eint<7>>, tensor<3x5x4x!FHE.eint<7>>) -> tensor<3x10x4x!FHE.eint<7>>
return %0 : tensor<3x10x4x!FHE.eint<7>>
}

View File

@@ -0,0 +1,100 @@
// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<7x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<7x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
return %0 : tensor<7x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<7x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
return %0 : tensor<7x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<4x7x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>>
return %0 : tensor<4x7x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
return %0 : tensor<4x3x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<4x3x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
return %0 : tensor<4x3x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<2x6x4x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>>
return %0 : tensor<2x6x4x!FHE.eint<7>>
}
// -----
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>>
// CHECK-NEXT: return %[[v0]] : tensor<2x3x8x!FHE.eint<7>>
// CHECK-NEXT: }
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>>
return %0 : tensor<2x3x8x!FHE.eint<7>>
}

View File

@@ -3059,6 +3059,445 @@ func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> {
}
}
TEST(End2EndJit_FHELinalg, concat_1D_axis_0) {
namespace concretelang = mlir::concretelang;
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
)XXX");
const uint8_t x[3]{0, 1, 2};
const uint8_t y[4]{3, 4, 5, 6};
const uint8_t expected[7]{0, 1, 2, 3, 4, 5, 6};
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
xArg(xRef, {3});
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 4);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
yArg(yRef, {4});
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
{&xArg, &yArg});
ASSERT_EXPECTED_SUCCESS(call);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
(*call)
->cast<concretelang::TensorLambdaArgument<
concretelang::IntLambdaArgument<>>>();
ASSERT_EQ(res.getDimensions().size(), (size_t)1);
ASSERT_EQ(res.getDimensions().at(0), 7);
ASSERT_EXPECTED_VALUE(res.getNumElements(), 7);
for (size_t i = 0; i < 7; i++) {
EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")";
}
}
TEST(End2EndJit_FHELinalg, concat_2D_axis_0) {
namespace concretelang = mlir::concretelang;
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<2x3x!FHE.eint<7>>, %y: tensor<3x3x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x!FHE.eint<7>>, tensor<3x3x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>>
return %0 : tensor<5x3x!FHE.eint<7>>
}
)XXX");
const uint8_t x[2][3]{
{0, 1, 2},
{3, 4, 5},
};
const uint8_t y[3][3]{
{6, 7, 8},
{9, 0, 1},
{2, 3, 4},
};
const uint8_t expected[5][3]{
{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 0, 1}, {2, 3, 4},
};
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
xArg(xRef, {2, 3});
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 3 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
yArg(yRef, {3, 3});
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
{&xArg, &yArg});
ASSERT_EXPECTED_SUCCESS(call);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
(*call)
->cast<concretelang::TensorLambdaArgument<
concretelang::IntLambdaArgument<>>>();
ASSERT_EQ(res.getDimensions().size(), (size_t)2);
ASSERT_EQ(res.getDimensions().at(0), 5);
ASSERT_EQ(res.getDimensions().at(1), 3);
ASSERT_EXPECTED_VALUE(res.getNumElements(), 15);
for (size_t i = 0; i < 5; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ(res.getValue()[(i * 3) + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_FHELinalg, concat_2D_axis_1) {
namespace concretelang = mlir::concretelang;
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<3x2x!FHE.eint<7>>, %y: tensor<3x3x!FHE.eint<7>>) -> tensor<3x5x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x2x!FHE.eint<7>>, tensor<3x3x!FHE.eint<7>>) -> tensor<3x5x!FHE.eint<7>>
return %0 : tensor<3x5x!FHE.eint<7>>
}
)XXX");
const uint8_t x[3][2]{
{0, 1},
{2, 3},
{4, 5},
};
const uint8_t y[3][3]{
{6, 7, 8},
{9, 0, 1},
{2, 3, 4},
};
const uint8_t expected[3][5]{
{0, 1, 6, 7, 8},
{2, 3, 9, 0, 1},
{4, 5, 2, 3, 4},
};
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 3 * 2);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
xArg(xRef, {3, 2});
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 3 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
yArg(yRef, {3, 3});
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
{&xArg, &yArg});
ASSERT_EXPECTED_SUCCESS(call);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
(*call)
->cast<concretelang::TensorLambdaArgument<
concretelang::IntLambdaArgument<>>>();
ASSERT_EQ(res.getDimensions().size(), (size_t)2);
ASSERT_EQ(res.getDimensions().at(0), 3);
ASSERT_EQ(res.getDimensions().at(1), 5);
ASSERT_EXPECTED_VALUE(res.getNumElements(), 15);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 5; j++) {
EXPECT_EQ(res.getValue()[(i * 5) + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_FHELinalg, concat_3D_axis_0) {
namespace concretelang = mlir::concretelang;
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<4x4x3x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<4x4x3x!FHE.eint<7>>
return %0 : tensor<4x4x3x!FHE.eint<7>>
}
)XXX");
const uint8_t x[2][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
const uint8_t y[2][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
const uint8_t expected[4][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 4 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
xArg(xRef, {2, 4, 3});
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 2 * 4 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
yArg(yRef, {2, 4, 3});
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
{&xArg, &yArg});
ASSERT_EXPECTED_SUCCESS(call);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
(*call)
->cast<concretelang::TensorLambdaArgument<
concretelang::IntLambdaArgument<>>>();
ASSERT_EQ(res.getDimensions().size(), (size_t)3);
ASSERT_EQ(res.getDimensions().at(0), 4);
ASSERT_EQ(res.getDimensions().at(1), 4);
ASSERT_EQ(res.getDimensions().at(2), 3);
ASSERT_EXPECTED_VALUE(res.getNumElements(), 48);
for (size_t i = 0; i < 4; i++) {
for (size_t j = 0; j < 4; j++) {
for (size_t k = 0; k < 3; k++) {
EXPECT_EQ(res.getValue()[(i * 4 * 3) + (j * 3) + k], expected[i][j][k])
<< ", at pos(" << i << "," << j << "," << k << ")";
}
}
}
}
TEST(End2EndJit_FHELinalg, concat_3D_axis_1) {
namespace concretelang = mlir::concretelang;
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x8x3x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x8x3x!FHE.eint<7>>
return %0 : tensor<2x8x3x!FHE.eint<7>>
}
)XXX");
const uint8_t x[2][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
const uint8_t y[2][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
const uint8_t expected[2][8][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 4 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
xArg(xRef, {2, 4, 3});
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 2 * 4 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
yArg(yRef, {2, 4, 3});
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
{&xArg, &yArg});
ASSERT_EXPECTED_SUCCESS(call);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
(*call)
->cast<concretelang::TensorLambdaArgument<
concretelang::IntLambdaArgument<>>>();
ASSERT_EQ(res.getDimensions().size(), (size_t)3);
ASSERT_EQ(res.getDimensions().at(0), 2);
ASSERT_EQ(res.getDimensions().at(1), 8);
ASSERT_EQ(res.getDimensions().at(2), 3);
ASSERT_EXPECTED_VALUE(res.getNumElements(), 48);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 8; j++) {
for (size_t k = 0; k < 3; k++) {
EXPECT_EQ(res.getValue()[(i * 8 * 3) + (j * 3) + k], expected[i][j][k])
<< ", at pos(" << i << "," << j << "," << k << ")";
}
}
}
}
TEST(End2EndJit_FHELinalg, concat_3D_axis_2) {
namespace concretelang = mlir::concretelang;
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x4x6x!FHE.eint<7>> {
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x4x6x!FHE.eint<7>>
return %0 : tensor<2x4x6x!FHE.eint<7>>
}
)XXX");
const uint8_t x[2][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
const uint8_t y[2][4][3]{
{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
{9, 0, 1},
},
{
{2, 3, 4},
{5, 6, 7},
{8, 9, 0},
{1, 2, 3},
},
};
const uint8_t expected[2][4][6]{
{
{0, 1, 2, 0, 1, 2},
{3, 4, 5, 3, 4, 5},
{6, 7, 8, 6, 7, 8},
{9, 0, 1, 9, 0, 1},
},
{
{2, 3, 4, 2, 3, 4},
{5, 6, 7, 5, 6, 7},
{8, 9, 0, 8, 9, 0},
{1, 2, 3, 1, 2, 3},
},
};
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 4 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
xArg(xRef, {2, 4, 3});
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 2 * 4 * 3);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
yArg(yRef, {2, 4, 3});
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
{&xArg, &yArg});
ASSERT_EXPECTED_SUCCESS(call);
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
(*call)
->cast<concretelang::TensorLambdaArgument<
concretelang::IntLambdaArgument<>>>();
ASSERT_EQ(res.getDimensions().size(), (size_t)3);
ASSERT_EQ(res.getDimensions().at(0), 2);
ASSERT_EQ(res.getDimensions().at(1), 4);
ASSERT_EQ(res.getDimensions().at(2), 6);
ASSERT_EXPECTED_VALUE(res.getNumElements(), 48);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 4; j++) {
for (size_t k = 0; k < 6; k++) {
EXPECT_EQ(res.getValue()[(i * 4 * 6) + (j * 6) + k], expected[i][j][k])
<< ", at pos(" << i << "," << j << "," << k << ")";
}
}
}
}
class TiledMatMulParametric
: public ::testing::TestWithParam<std::vector<int64_t>> {};