diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 95858873f..f83617429 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -601,4 +601,48 @@ def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { }]; } +def ConcatOp : FHELinalg_Op<"concat"> { + let summary = "Concatenates a sequence of tensors along an existing axis."; + + let description = [{ + Concatenates several tensors along a given axis. + + Examples: + + ```mlir + "FHELinalg.concat"(%a, %b) { axis = 0 } : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> tensor<6x3x!FHE.eint<4>> + // + // ( [1,2,3] [1,2,3] ) [1,2,3] + // concat ( [4,5,6], [4,5,6] ) = [4,5,6] + // ( [7,8,9] [7,8,9] ) [7,8,9] + // [1,2,3] + // [4,5,6] + // [7,8,9] + // + ``` + + ```mlir + "FHELinalg.concat"(%a, %b) { axis = 1 } : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> tensor<3x6x!FHE.eint<4>> + // + // ( [1,2,3] [1,2,3] ) [1,2,3,1,2,3] + // concat ( [4,5,6], [4,5,6] ) = [4,5,6,4,5,6] + // ( [7,8,9] [7,8,9] ) [7,8,9,7,8,9] + // + ``` + }]; + + let arguments = (ins + Variadic.predicate, HasStaticShapePred]>>>:$ins, + DefaultValuedAttr:$axis + ); + + let results = (outs + Type.predicate, HasStaticShapePred]>>:$out + ); + + let verifier = [{ + return mlir::concretelang::FHELinalg::verifyConcat(*this); + }]; +} + #endif diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 0995b9a0c..13cefce59 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1126,6 +1126,121 @@ struct SumToLinalgGeneric }; }; +// This rewrite pattern transforms any instance of operators +// `FHELinalg.concat` to instances of `tensor.insert_slice` +// +// Example: +// +// %result = "FHELinalg.concat"(%x, %y) { axis = 1 } : +// (tensor<2x3x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) +// -> tensor<2x7x!FHE.eint<4>> +// +// becomes: +// +// %empty = "FHELinalg.zero"() : () -> tensor<2x7x!FHE.eint<4>> +// +// %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1] +// : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>> +// +// %y_copied = tensor.insert_slice %y into %x_copied[0, 3] [2, 4] [1, 1] +// : tensor<2x4x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>> +// +struct ConcatRewritePattern + : public mlir::OpRewritePattern { + ConcatRewritePattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(FHELinalg::ConcatOp op, + mlir::PatternRewriter &rewriter) const override { + + mlir::Location location = op.getLoc(); + size_t axis = op.axis(); + + mlir::Value output = op.getResult(); + auto outputType = output.getType().dyn_cast(); + + llvm::ArrayRef outputShape = outputType.getShape(); + size_t outputDimensions = outputShape.size(); + + mlir::Value result = + rewriter.create(location, outputType).getResult(); + + auto offsets = llvm::SmallVector{}; + auto sizes = llvm::SmallVector{}; + auto strides = llvm::SmallVector{}; + + // set up the initial values of offsets, sizes, and strides + // each one has exactly `outputDimensions` number of elements + // - offsets will be [0, 0, 0, ..., 0, 0, 0] + // - strides will be [1, 1, 1, ..., 1, 1, 1] + // - sizes will be the output shape except at the 'axis' which will be 0 + for (size_t i = 0; i < outputDimensions; i++) { + offsets.push_back(0); + if (i == axis) { + sizes.push_back(0); + } else { + sizes.push_back(outputShape[i]); + } + strides.push_back(1); + } + + // these are not used, but they are required + // for the creation of InsertSliceOp operation + auto dynamicOffsets = llvm::ArrayRef{}; + auto dynamicSizes = llvm::ArrayRef{}; + auto dynamicStrides = llvm::ArrayRef{}; + + for (mlir::Value input : op.getOperands()) { + auto inputType = input.getType().dyn_cast(); + int64_t axisSize = inputType.getShape()[axis]; + + // offsets and sizes will be modified for each input tensor + // if we have: + // "FHELinalg.concat"(%x, %y, %z) : + // ( + // tensor<3x!FHE.eint<7>>, + // tensor<4x!FHE.eint<7>>, + // tensor<2x!FHE.eint<7>>, + // ) + // -> tensor<9x!FHE.eint<7>> + // + // for the first copy: + // offsets = [0], sizes = [3], strides = [1] + // + // for the second copy: + // offsets = [3], sizes = [4], strides = [1] + // + // for the third copy: + // offsets = [7], sizes = [2], strides = [1] + // + // so in each iteration: + // - the size is set to the axis size of the input + // - the offset is increased by the size of the previous input + + sizes[axis] = axisSize; + + // these arrays are copied, so it's fine to modify and use them again + mlir::ArrayAttr offsetsAttr = rewriter.getI64ArrayAttr(offsets); + mlir::ArrayAttr sizesAttr = rewriter.getI64ArrayAttr(sizes); + mlir::ArrayAttr stridesAttr = rewriter.getI64ArrayAttr(strides); + + offsets[axis] += axisSize; + + result = rewriter + .create( + location, outputType, input, result, dynamicOffsets, + dynamicSizes, dynamicStrides, offsetsAttr, sizesAttr, + stridesAttr) + .getResult(); + } + + rewriter.replaceOp(op, {result}); + return mlir::success(); + }; +}; + namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { @@ -1186,6 +1301,7 @@ void FHETensorOpsToLinalg::runOnFunction() { &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 8331cd716..4e9c8d07f 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -880,6 +880,20 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUMul(noiseMultiplier, operandMANP); } +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::ConcatOp op, + llvm::ArrayRef *> operandMANPs) { + + llvm::APInt result = llvm::APInt{1, 0, false}; + for (mlir::LatticeElement *operandMANP : operandMANPs) { + llvm::APInt candidate = operandMANP->getValue().getMANP().getValue(); + if (candidate.getLimitedValue() >= result.getLimitedValue()) { + result = candidate; + } + } + return result; +} + struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; MANPAnalysis(mlir::MLIRContext *ctx, bool debug) @@ -955,6 +969,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto sumOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(sumOp, operands); + } else if (auto concatOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(concatOp, operands); } // Tensor Operators // ExtractOp diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 574195c1a..847a825d2 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -490,6 +490,105 @@ mlir::LogicalResult verifySum(SumOp &op) { return mlir::success(); } +static bool sameShapeExceptAxis(llvm::ArrayRef shape1, + llvm::ArrayRef shape2, size_t axis) { + if (shape1.size() != shape2.size()) { + return false; + } + for (size_t i = 0; i < shape1.size(); i++) { + if (i != axis && shape1[i] != shape2[i]) { + return false; + } + } + return true; +} + +mlir::LogicalResult verifyConcat(ConcatOp &op) { + unsigned numOperands = op.getNumOperands(); + if (numOperands < 2) { + op->emitOpError() << "should have at least 2 inputs"; + return mlir::failure(); + } + + int64_t axis = op.axis(); + mlir::Value out = op.out(); + + auto outVectorType = out.getType().dyn_cast(); + auto outElementType = + outVectorType.getElementType().dyn_cast(); + + llvm::ArrayRef outShape = outVectorType.getShape(); + size_t outDims = outShape.size(); + + if (axis < 0 || (size_t)axis >= outDims) { + op->emitOpError() << "has invalid axis attribute"; + return mlir::failure(); + } + + int64_t expectedOutputElementsInAxis = 0; + + size_t index = 0; + for (mlir::Value in : op.ins()) { + auto inVectorType = in.getType().dyn_cast(); + auto inElementType = + inVectorType.getElementType().dyn_cast(); + if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, inElementType, + outElementType)) { + return ::mlir::failure(); + } + + llvm::ArrayRef inShape = inVectorType.getShape(); + if (!sameShapeExceptAxis(inShape, outShape, (size_t)axis)) { + auto stream = op->emitOpError(); + + stream << "does not have the proper shape of <"; + if (axis == 0) { + stream << "?"; + } else { + stream << outShape[0]; + } + for (size_t i = 1; i < outDims; i++) { + stream << "x"; + if (i == (size_t)axis) { + stream << "?"; + } else { + stream << outShape[i]; + } + } + stream << "> for input #" << index; + + return mlir::failure(); + } + expectedOutputElementsInAxis += inShape[axis]; + + index += 1; + } + + if (outShape[axis] != expectedOutputElementsInAxis) { + auto stream = op->emitOpError(); + + stream << "does not have the proper output shape of <"; + if (axis == 0) { + stream << expectedOutputElementsInAxis; + } else { + stream << outShape[0]; + } + for (size_t i = 1; i < outDims; i++) { + stream << "x"; + if (i == (size_t)axis) { + stream << expectedOutputElementsInAxis; + } else { + stream << outShape[i]; + } + } + stream << ">"; + + return mlir::failure(); + } + + return mlir::success(); +} + /// Verify the matmul shapes, the type of tensor elements should be checked by /// something else template mlir::LogicalResult verifyMatmul(MatMulOp &op) { diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir new file mode 100644 index 000000000..fa6f9e5c9 --- /dev/null +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/concat.mlir @@ -0,0 +1,155 @@ +// RUN: concretecompiler %s --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> + return %0 : tensor<7x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> + return %0 : tensor<7x4x!FHE.eint<7>> +} + + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [4, 3] [1, 1] : tensor<4x3x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> + return %0 : tensor<4x7x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> + return %0 : tensor<4x3x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> + return %0 : tensor<4x3x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> + return %0 : tensor<2x6x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 0, 4] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> + return %0 : tensor<2x3x8x!FHE.eint<7>> +} diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index a3bdd86d0..814a062e2 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -546,3 +546,30 @@ func @sum() -> !FHE.eint<7> { return %1 : !FHE.eint<7> } + +// ----- + +func @concat() -> tensor<3x!FHE.eint<7>> { + %0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>> + // CHECK: MANP = 2 : ui{{[0-9]+}} + %1 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + + %2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>> + // CHECK: MANP = 3 : ui{{[0-9]+}} + %3 = "FHELinalg.sum"(%2) { keep_dims = true } : (tensor<5x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + + %4 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %5 = "FHELinalg.sum"(%4) { keep_dims = true } : (tensor<10x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %6 = "FHELinalg.concat"(%1, %3) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %7 = "FHELinalg.concat"(%1, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %8 = "FHELinalg.concat"(%3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %9 = "FHELinalg.concat"(%1, %3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + + return %9 : tensor<3x!FHE.eint<7>> +} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/concat.invalid.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/concat.invalid.mlir new file mode 100644 index 000000000..50b1fabde --- /dev/null +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/concat.invalid.mlir @@ -0,0 +1,121 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s + +// ----- + +func @main() -> tensor<0x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}} + %0 = "FHELinalg.concat"() : () -> tensor<0x!FHE.eint<7>> + return %0 : tensor<0x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}} + %0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}} + %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> + return %0 : tensor<7x!FHE.eint<6>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}} + %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}} + %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op has invalid axis attribute}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 3 } : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op has invalid axis attribute}} + %0 = "FHELinalg.concat"(%x, %y) { axis = -3 } : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<10x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <7>}} + %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<10x!FHE.eint<7>> + return %0 : tensor<10x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<5x4x!FHE.eint<7>>) -> tensor<10x4x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <8x4>}} + %0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<5x4x!FHE.eint<7>>) -> tensor<10x4x!FHE.eint<7>> + return %0 : tensor<10x4x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x10x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <3x9>}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x10x!FHE.eint<7>> + return %0 : tensor<3x10x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of for input #0}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> + return %0 : tensor<3x4x10x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x?x10> for input #0}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> + return %0 : tensor<3x4x10x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x4x?> for input #0}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> + return %0 : tensor<3x4x10x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x4x?> for input #1}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<3x4x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> + return %0 : tensor<3x4x10x!FHE.eint<7>> +} + +// ----- + +func @main(%x: tensor<3x4x4x!FHE.eint<7>>, %y: tensor<3x5x4x!FHE.eint<7>>) -> tensor<3x10x4x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <3x9x4>}} + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x4x!FHE.eint<7>>, tensor<3x5x4x!FHE.eint<7>>) -> tensor<3x10x4x!FHE.eint<7>> + return %0 : tensor<3x10x4x!FHE.eint<7>> +} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/concat.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/concat.mlir new file mode 100644 index 000000000..95b88e9d5 --- /dev/null +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/concat.mlir @@ -0,0 +1,100 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<7x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> + return %0 : tensor<7x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<7x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> + return %0 : tensor<7x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<4x7x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> + return %0 : tensor<4x7x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> + return %0 : tensor<4x3x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<4x3x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> + return %0 : tensor<4x3x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<2x6x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> + return %0 : tensor<2x6x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<2x3x8x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> + return %0 : tensor<2x3x8x!FHE.eint<7>> +} diff --git a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc index 6c8f17d84..2df3c8c4c 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc @@ -3059,6 +3059,445 @@ func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { } } +TEST(End2EndJit_FHELinalg, concat_1D_axis_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> + return %0 : tensor<7x!FHE.eint<7>> +} +)XXX"); + + const uint8_t x[3]{0, 1, 2}; + const uint8_t y[4]{3, 4, 5, 6}; + + const uint8_t expected[7]{0, 1, 2, 3, 4, 5, 6}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3); + concretelang::TensorLambdaArgument> + xArg(xRef, {3}); + + llvm::ArrayRef yRef((const uint8_t *)y, 4); + concretelang::TensorLambdaArgument> + yArg(yRef, {4}); + + llvm::Expected> call = + lambda.operator()>( + {&xArg, &yArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 7); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 7); + + for (size_t i = 0; i < 7; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, concat_2D_axis_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%x: tensor<2x3x!FHE.eint<7>>, %y: tensor<3x3x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x!FHE.eint<7>>, tensor<3x3x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> + return %0 : tensor<5x3x!FHE.eint<7>> +} +)XXX"); + + const uint8_t x[2][3]{ + {0, 1, 2}, + {3, 4, 5}, + }; + const uint8_t y[3][3]{ + {6, 7, 8}, + {9, 0, 1}, + {2, 3, 4}, + }; + + const uint8_t expected[5][3]{ + {0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 0, 1}, {2, 3, 4}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 2 * 3); + concretelang::TensorLambdaArgument> + xArg(xRef, {2, 3}); + + llvm::ArrayRef yRef((const uint8_t *)y, 3 * 3); + concretelang::TensorLambdaArgument> + yArg(yRef, {3, 3}); + + llvm::Expected> call = + lambda.operator()>( + {&xArg, &yArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 5); + ASSERT_EQ(res.getDimensions().at(1), 3); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 15); + + for (size_t i = 0; i < 5; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ(res.getValue()[(i * 3) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, concat_2D_axis_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%x: tensor<3x2x!FHE.eint<7>>, %y: tensor<3x3x!FHE.eint<7>>) -> tensor<3x5x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x2x!FHE.eint<7>>, tensor<3x3x!FHE.eint<7>>) -> tensor<3x5x!FHE.eint<7>> + return %0 : tensor<3x5x!FHE.eint<7>> +} +)XXX"); + + const uint8_t x[3][2]{ + {0, 1}, + {2, 3}, + {4, 5}, + }; + const uint8_t y[3][3]{ + {6, 7, 8}, + {9, 0, 1}, + {2, 3, 4}, + }; + + const uint8_t expected[3][5]{ + {0, 1, 6, 7, 8}, + {2, 3, 9, 0, 1}, + {4, 5, 2, 3, 4}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 2}); + + llvm::ArrayRef yRef((const uint8_t *)y, 3 * 3); + concretelang::TensorLambdaArgument> + yArg(yRef, {3, 3}); + + llvm::Expected> call = + lambda.operator()>( + {&xArg, &yArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 5); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 15); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 5; j++) { + EXPECT_EQ(res.getValue()[(i * 5) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, concat_3D_axis_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<4x4x3x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<4x4x3x!FHE.eint<7>> + return %0 : tensor<4x4x3x!FHE.eint<7>> +} +)XXX"); + + const uint8_t x[2][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + const uint8_t y[2][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + + const uint8_t expected[4][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 2 * 4 * 3); + concretelang::TensorLambdaArgument> + xArg(xRef, {2, 4, 3}); + + llvm::ArrayRef yRef((const uint8_t *)y, 2 * 4 * 3); + concretelang::TensorLambdaArgument> + yArg(yRef, {2, 4, 3}); + + llvm::Expected> call = + lambda.operator()>( + {&xArg, &yArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 4); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EQ(res.getDimensions().at(2), 3); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 48); + + for (size_t i = 0; i < 4; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 3; k++) { + EXPECT_EQ(res.getValue()[(i * 4 * 3) + (j * 3) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, concat_3D_axis_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x8x3x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x8x3x!FHE.eint<7>> + return %0 : tensor<2x8x3x!FHE.eint<7>> +} +)XXX"); + + const uint8_t x[2][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + const uint8_t y[2][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + + const uint8_t expected[2][8][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 2 * 4 * 3); + concretelang::TensorLambdaArgument> + xArg(xRef, {2, 4, 3}); + + llvm::ArrayRef yRef((const uint8_t *)y, 2 * 4 * 3); + concretelang::TensorLambdaArgument> + yArg(yRef, {2, 4, 3}); + + llvm::Expected> call = + lambda.operator()>( + {&xArg, &yArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 2); + ASSERT_EQ(res.getDimensions().at(1), 8); + ASSERT_EQ(res.getDimensions().at(2), 3); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 48); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 8; j++) { + for (size_t k = 0; k < 3; k++) { + EXPECT_EQ(res.getValue()[(i * 8 * 3) + (j * 3) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, concat_3D_axis_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x4x6x!FHE.eint<7>> { + %0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x4x6x!FHE.eint<7>> + return %0 : tensor<2x4x6x!FHE.eint<7>> +} +)XXX"); + + const uint8_t x[2][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + const uint8_t y[2][4][3]{ + { + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + }, + { + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {1, 2, 3}, + }, + }; + + const uint8_t expected[2][4][6]{ + { + {0, 1, 2, 0, 1, 2}, + {3, 4, 5, 3, 4, 5}, + {6, 7, 8, 6, 7, 8}, + {9, 0, 1, 9, 0, 1}, + }, + { + {2, 3, 4, 2, 3, 4}, + {5, 6, 7, 5, 6, 7}, + {8, 9, 0, 8, 9, 0}, + {1, 2, 3, 1, 2, 3}, + }, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 2 * 4 * 3); + concretelang::TensorLambdaArgument> + xArg(xRef, {2, 4, 3}); + + llvm::ArrayRef yRef((const uint8_t *)y, 2 * 4 * 3); + concretelang::TensorLambdaArgument> + yArg(yRef, {2, 4, 3}); + + llvm::Expected> call = + lambda.operator()>( + {&xArg, &yArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 2); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EQ(res.getDimensions().at(2), 6); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 48); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 6; k++) { + EXPECT_EQ(res.getValue()[(i * 4 * 6) + (j * 6) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + class TiledMatMulParametric : public ::testing::TestWithParam> {};