mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement concat operation
This commit is contained in:
@@ -601,4 +601,48 @@ def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def ConcatOp : FHELinalg_Op<"concat"> {
|
||||
let summary = "Concatenates a sequence of tensors along an existing axis.";
|
||||
|
||||
let description = [{
|
||||
Concatenates several tensors along a given axis.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir
|
||||
"FHELinalg.concat"(%a, %b) { axis = 0 } : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> tensor<6x3x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2,3] [1,2,3] ) [1,2,3]
|
||||
// concat ( [4,5,6], [4,5,6] ) = [4,5,6]
|
||||
// ( [7,8,9] [7,8,9] ) [7,8,9]
|
||||
// [1,2,3]
|
||||
// [4,5,6]
|
||||
// [7,8,9]
|
||||
//
|
||||
```
|
||||
|
||||
```mlir
|
||||
"FHELinalg.concat"(%a, %b) { axis = 1 } : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> tensor<3x6x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2,3] [1,2,3] ) [1,2,3,1,2,3]
|
||||
// concat ( [4,5,6], [4,5,6] ) = [4,5,6,4,5,6]
|
||||
// ( [7,8,9] [7,8,9] ) [7,8,9,7,8,9]
|
||||
//
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>>:$ins,
|
||||
DefaultValuedAttr<I64Attr, "0">:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$out
|
||||
);
|
||||
|
||||
let verifier = [{
|
||||
return mlir::concretelang::FHELinalg::verifyConcat(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1126,6 +1126,121 @@ struct SumToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.concat` to instances of `tensor.insert_slice`
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "FHELinalg.concat"(%x, %y) { axis = 1 } :
|
||||
// (tensor<2x3x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>)
|
||||
// -> tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %empty = "FHELinalg.zero"() : () -> tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
// %x_copied = tensor.insert_slice %x into %empty[0, 0] [2, 3] [1, 1]
|
||||
// : tensor<2x3x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
// %y_copied = tensor.insert_slice %y into %x_copied[0, 3] [2, 4] [1, 1]
|
||||
// : tensor<2x4x!FHE.eint<4>> into tensor<2x7x!FHE.eint<4>>
|
||||
//
|
||||
struct ConcatRewritePattern
|
||||
: public mlir::OpRewritePattern<FHELinalg::ConcatOp> {
|
||||
ConcatRewritePattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<FHELinalg::ConcatOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHELinalg::ConcatOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
size_t axis = op.axis();
|
||||
|
||||
mlir::Value output = op.getResult();
|
||||
auto outputType = output.getType().dyn_cast<mlir::TensorType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
size_t outputDimensions = outputShape.size();
|
||||
|
||||
mlir::Value result =
|
||||
rewriter.create<FHELinalg::ZeroOp>(location, outputType).getResult();
|
||||
|
||||
auto offsets = llvm::SmallVector<int64_t, 3>{};
|
||||
auto sizes = llvm::SmallVector<int64_t, 3>{};
|
||||
auto strides = llvm::SmallVector<int64_t, 3>{};
|
||||
|
||||
// set up the initial values of offsets, sizes, and strides
|
||||
// each one has exactly `outputDimensions` number of elements
|
||||
// - offsets will be [0, 0, 0, ..., 0, 0, 0]
|
||||
// - strides will be [1, 1, 1, ..., 1, 1, 1]
|
||||
// - sizes will be the output shape except at the 'axis' which will be 0
|
||||
for (size_t i = 0; i < outputDimensions; i++) {
|
||||
offsets.push_back(0);
|
||||
if (i == axis) {
|
||||
sizes.push_back(0);
|
||||
} else {
|
||||
sizes.push_back(outputShape[i]);
|
||||
}
|
||||
strides.push_back(1);
|
||||
}
|
||||
|
||||
// these are not used, but they are required
|
||||
// for the creation of InsertSliceOp operation
|
||||
auto dynamicOffsets = llvm::ArrayRef<mlir::Value>{};
|
||||
auto dynamicSizes = llvm::ArrayRef<mlir::Value>{};
|
||||
auto dynamicStrides = llvm::ArrayRef<mlir::Value>{};
|
||||
|
||||
for (mlir::Value input : op.getOperands()) {
|
||||
auto inputType = input.getType().dyn_cast<mlir::TensorType>();
|
||||
int64_t axisSize = inputType.getShape()[axis];
|
||||
|
||||
// offsets and sizes will be modified for each input tensor
|
||||
// if we have:
|
||||
// "FHELinalg.concat"(%x, %y, %z) :
|
||||
// (
|
||||
// tensor<3x!FHE.eint<7>>,
|
||||
// tensor<4x!FHE.eint<7>>,
|
||||
// tensor<2x!FHE.eint<7>>,
|
||||
// )
|
||||
// -> tensor<9x!FHE.eint<7>>
|
||||
//
|
||||
// for the first copy:
|
||||
// offsets = [0], sizes = [3], strides = [1]
|
||||
//
|
||||
// for the second copy:
|
||||
// offsets = [3], sizes = [4], strides = [1]
|
||||
//
|
||||
// for the third copy:
|
||||
// offsets = [7], sizes = [2], strides = [1]
|
||||
//
|
||||
// so in each iteration:
|
||||
// - the size is set to the axis size of the input
|
||||
// - the offset is increased by the size of the previous input
|
||||
|
||||
sizes[axis] = axisSize;
|
||||
|
||||
// these arrays are copied, so it's fine to modify and use them again
|
||||
mlir::ArrayAttr offsetsAttr = rewriter.getI64ArrayAttr(offsets);
|
||||
mlir::ArrayAttr sizesAttr = rewriter.getI64ArrayAttr(sizes);
|
||||
mlir::ArrayAttr stridesAttr = rewriter.getI64ArrayAttr(strides);
|
||||
|
||||
offsets[axis] += axisSize;
|
||||
|
||||
result = rewriter
|
||||
.create<mlir::tensor::InsertSliceOp>(
|
||||
location, outputType, input, result, dynamicOffsets,
|
||||
dynamicSizes, dynamicStrides, offsetsAttr, sizesAttr,
|
||||
stridesAttr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct FHETensorOpsToLinalg
|
||||
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
|
||||
@@ -1186,6 +1301,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
|
||||
&getContext());
|
||||
patterns.insert<FHELinalgZeroToLinalgGenerate>(&getContext());
|
||||
patterns.insert<SumToLinalgGeneric>(&getContext());
|
||||
patterns.insert<ConcatRewritePattern>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
@@ -880,6 +880,20 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(noiseMultiplier, operandMANP);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::ConcatOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
llvm::APInt result = llvm::APInt{1, 0, false};
|
||||
for (mlir::LatticeElement<MANPLatticeValue> *operandMANP : operandMANPs) {
|
||||
llvm::APInt candidate = operandMANP->getValue().getMANP().getValue();
|
||||
if (candidate.getLimitedValue() >= result.getLimitedValue()) {
|
||||
result = candidate;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
@@ -955,6 +969,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto sumOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::SumOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(sumOp, operands);
|
||||
} else if (auto concatOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::ConcatOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(concatOp, operands);
|
||||
}
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
|
||||
@@ -490,6 +490,105 @@ mlir::LogicalResult verifySum(SumOp &op) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static bool sameShapeExceptAxis(llvm::ArrayRef<int64_t> shape1,
|
||||
llvm::ArrayRef<int64_t> shape2, size_t axis) {
|
||||
if (shape1.size() != shape2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < shape1.size(); i++) {
|
||||
if (i != axis && shape1[i] != shape2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyConcat(ConcatOp &op) {
|
||||
unsigned numOperands = op.getNumOperands();
|
||||
if (numOperands < 2) {
|
||||
op->emitOpError() << "should have at least 2 inputs";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
int64_t axis = op.axis();
|
||||
mlir::Value out = op.out();
|
||||
|
||||
auto outVectorType = out.getType().dyn_cast<mlir::TensorType>();
|
||||
auto outElementType =
|
||||
outVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
|
||||
|
||||
llvm::ArrayRef<int64_t> outShape = outVectorType.getShape();
|
||||
size_t outDims = outShape.size();
|
||||
|
||||
if (axis < 0 || (size_t)axis >= outDims) {
|
||||
op->emitOpError() << "has invalid axis attribute";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
int64_t expectedOutputElementsInAxis = 0;
|
||||
|
||||
size_t index = 0;
|
||||
for (mlir::Value in : op.ins()) {
|
||||
auto inVectorType = in.getType().dyn_cast<mlir::TensorType>();
|
||||
auto inElementType =
|
||||
inVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
|
||||
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, inElementType,
|
||||
outElementType)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
|
||||
llvm::ArrayRef<int64_t> inShape = inVectorType.getShape();
|
||||
if (!sameShapeExceptAxis(inShape, outShape, (size_t)axis)) {
|
||||
auto stream = op->emitOpError();
|
||||
|
||||
stream << "does not have the proper shape of <";
|
||||
if (axis == 0) {
|
||||
stream << "?";
|
||||
} else {
|
||||
stream << outShape[0];
|
||||
}
|
||||
for (size_t i = 1; i < outDims; i++) {
|
||||
stream << "x";
|
||||
if (i == (size_t)axis) {
|
||||
stream << "?";
|
||||
} else {
|
||||
stream << outShape[i];
|
||||
}
|
||||
}
|
||||
stream << "> for input #" << index;
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
expectedOutputElementsInAxis += inShape[axis];
|
||||
|
||||
index += 1;
|
||||
}
|
||||
|
||||
if (outShape[axis] != expectedOutputElementsInAxis) {
|
||||
auto stream = op->emitOpError();
|
||||
|
||||
stream << "does not have the proper output shape of <";
|
||||
if (axis == 0) {
|
||||
stream << expectedOutputElementsInAxis;
|
||||
} else {
|
||||
stream << outShape[0];
|
||||
}
|
||||
for (size_t i = 1; i < outDims; i++) {
|
||||
stream << "x";
|
||||
if (i == (size_t)axis) {
|
||||
stream << expectedOutputElementsInAxis;
|
||||
} else {
|
||||
stream << outShape[i];
|
||||
}
|
||||
}
|
||||
stream << ">";
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Verify the matmul shapes, the type of tensor elements should be checked by
|
||||
/// something else
|
||||
template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
// RUN: concretecompiler %s --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0] [3] [1] : tensor<3x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3] [4] [1] : tensor<4x!FHE.eint<7>> into tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
|
||||
return %0 : tensor<7x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [3, 4] [1, 1] : tensor<3x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][3, 0] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
|
||||
return %0 : tensor<7x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<4x7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0] [4, 3] [1, 1] : tensor<4x3x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3] [4, 4] [1, 1] : tensor<4x4x!FHE.eint<7>> into tensor<4x7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<4x7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>>
|
||||
return %0 : tensor<4x7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<4x3x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][2, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<4x3x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<2x6x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 3, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x6x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<2x6x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>>
|
||||
return %0 : tensor<2x6x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = tensor.generate {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: index):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: tensor.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } : tensor<2x3x8x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.insert_slice %[[a0]] into %[[v0]][0, 0, 0] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = tensor.insert_slice %[[a1]] into %[[v1]][0, 0, 4] [2, 3, 4] [1, 1, 1] : tensor<2x3x4x!FHE.eint<7>> into tensor<2x3x8x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v2]] : tensor<2x3x8x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>>
|
||||
return %0 : tensor<2x3x8x!FHE.eint<7>>
|
||||
}
|
||||
@@ -546,3 +546,30 @@ func @sum() -> !FHE.eint<7> {
|
||||
|
||||
return %1 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat() -> tensor<3x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>>
|
||||
// CHECK: MANP = 2 : ui{{[0-9]+}}
|
||||
%1 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>>
|
||||
|
||||
%2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>>
|
||||
// CHECK: MANP = 3 : ui{{[0-9]+}}
|
||||
%3 = "FHELinalg.sum"(%2) { keep_dims = true } : (tensor<5x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>>
|
||||
|
||||
%4 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>>
|
||||
// CHECK: MANP = 4 : ui{{[0-9]+}}
|
||||
%5 = "FHELinalg.sum"(%4) { keep_dims = true } : (tensor<10x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>>
|
||||
|
||||
// CHECK: MANP = 3 : ui{{[0-9]+}}
|
||||
%6 = "FHELinalg.concat"(%1, %3) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>>
|
||||
// CHECK: MANP = 4 : ui{{[0-9]+}}
|
||||
%7 = "FHELinalg.concat"(%1, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>>
|
||||
// CHECK: MANP = 4 : ui{{[0-9]+}}
|
||||
%8 = "FHELinalg.concat"(%3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>>
|
||||
// CHECK: MANP = 4 : ui{{[0-9]+}}
|
||||
%9 = "FHELinalg.concat"(%1, %3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>>
|
||||
|
||||
return %9 : tensor<3x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
121
compiler/tests/Dialect/FHELinalg/FHELinalg/concat.invalid.mlir
Normal file
121
compiler/tests/Dialect/FHELinalg/FHELinalg/concat.invalid.mlir
Normal file
@@ -0,0 +1,121 @@
|
||||
// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s
|
||||
|
||||
// -----
|
||||
|
||||
func @main() -> tensor<0x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}}
|
||||
%0 = "FHELinalg.concat"() : () -> tensor<0x!FHE.eint<7>>
|
||||
return %0 : tensor<0x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op should have at least 2 inputs}}
|
||||
%0 = "FHELinalg.concat"(%x) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>>
|
||||
return %0 : tensor<4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>>
|
||||
return %0 : tensor<7x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op has invalid axis attribute}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 3 } : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op has invalid axis attribute}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = -3 } : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<10x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <7>}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<10x!FHE.eint<7>>
|
||||
return %0 : tensor<10x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<5x4x!FHE.eint<7>>) -> tensor<10x4x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <8x4>}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<5x4x!FHE.eint<7>>) -> tensor<10x4x!FHE.eint<7>>
|
||||
return %0 : tensor<10x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x10x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <3x9>}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x10x!FHE.eint<7>>
|
||||
return %0 : tensor<3x10x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <?x4x10> for input #0}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x10x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x?x10> for input #0}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x10x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x4x?> for input #0}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<3x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x10x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x4x!FHE.eint<7>>, %y: tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper shape of <3x4x?> for input #1}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<3x4x4x!FHE.eint<7>>, tensor<3x5x!FHE.eint<7>>) -> tensor<3x4x10x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x10x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @main(%x: tensor<3x4x4x!FHE.eint<7>>, %y: tensor<3x5x4x!FHE.eint<7>>) -> tensor<3x10x4x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.concat' op does not have the proper output shape of <3x9x4>}}
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x4x4x!FHE.eint<7>>, tensor<3x5x4x!FHE.eint<7>>) -> tensor<3x10x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x10x4x!FHE.eint<7>>
|
||||
}
|
||||
100
compiler/tests/Dialect/FHELinalg/FHELinalg/concat.mlir
Normal file
100
compiler/tests/Dialect/FHELinalg/FHELinalg/concat.mlir
Normal file
@@ -0,0 +1,100 @@
|
||||
// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
|
||||
return %0 : tensor<7x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<7x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x4x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<7x4x!FHE.eint<7>>
|
||||
return %0 : tensor<7x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<4x3x!FHE.eint<7>>, %[[a1:.*]]: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x7x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<4x3x!FHE.eint<7>>, %y: tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<4x3x!FHE.eint<7>>, tensor<4x4x!FHE.eint<7>>) -> tensor<4x7x!FHE.eint<7>>
|
||||
return %0 : tensor<4x7x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<4x3x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<4x3x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<4x3x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<2x6x4x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x6x4x!FHE.eint<7>>
|
||||
return %0 : tensor<2x6x4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>, %[[a1:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.concat"(%[[a0]], %[[a1]]) : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<2x3x8x!FHE.eint<7>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x3x4x!FHE.eint<7>>, %y: tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x3x4x!FHE.eint<7>>, tensor<2x3x4x!FHE.eint<7>>) -> tensor<2x3x8x!FHE.eint<7>>
|
||||
return %0 : tensor<2x3x8x!FHE.eint<7>>
|
||||
}
|
||||
@@ -3059,6 +3059,445 @@ func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, concat_1D_axis_0) {
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%x: tensor<3x!FHE.eint<7>>, %y: tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<3x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
|
||||
return %0 : tensor<7x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[3]{0, 1, 2};
|
||||
const uint8_t y[4]{3, 4, 5, 6};
|
||||
|
||||
const uint8_t expected[7]{0, 1, 2, 3, 4, 5, 6};
|
||||
|
||||
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
xArg(xRef, {3});
|
||||
|
||||
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 4);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
yArg(yRef, {4});
|
||||
|
||||
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
|
||||
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
|
||||
{&xArg, &yArg});
|
||||
ASSERT_EXPECTED_SUCCESS(call);
|
||||
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
|
||||
(*call)
|
||||
->cast<concretelang::TensorLambdaArgument<
|
||||
concretelang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(res.getDimensions().size(), (size_t)1);
|
||||
ASSERT_EQ(res.getDimensions().at(0), 7);
|
||||
ASSERT_EXPECTED_VALUE(res.getNumElements(), 7);
|
||||
|
||||
for (size_t i = 0; i < 7; i++) {
|
||||
EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, concat_2D_axis_0) {
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%x: tensor<2x3x!FHE.eint<7>>, %y: tensor<3x3x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x3x!FHE.eint<7>>, tensor<3x3x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>>
|
||||
return %0 : tensor<5x3x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[2][3]{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
};
|
||||
const uint8_t y[3][3]{
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
{2, 3, 4},
|
||||
};
|
||||
|
||||
const uint8_t expected[5][3]{
|
||||
{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 0, 1}, {2, 3, 4},
|
||||
};
|
||||
|
||||
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
xArg(xRef, {2, 3});
|
||||
|
||||
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 3 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
yArg(yRef, {3, 3});
|
||||
|
||||
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
|
||||
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
|
||||
{&xArg, &yArg});
|
||||
ASSERT_EXPECTED_SUCCESS(call);
|
||||
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
|
||||
(*call)
|
||||
->cast<concretelang::TensorLambdaArgument<
|
||||
concretelang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(res.getDimensions().size(), (size_t)2);
|
||||
ASSERT_EQ(res.getDimensions().at(0), 5);
|
||||
ASSERT_EQ(res.getDimensions().at(1), 3);
|
||||
ASSERT_EXPECTED_VALUE(res.getNumElements(), 15);
|
||||
|
||||
for (size_t i = 0; i < 5; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ(res.getValue()[(i * 3) + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, concat_2D_axis_1) {
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%x: tensor<3x2x!FHE.eint<7>>, %y: tensor<3x3x!FHE.eint<7>>) -> tensor<3x5x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<3x2x!FHE.eint<7>>, tensor<3x3x!FHE.eint<7>>) -> tensor<3x5x!FHE.eint<7>>
|
||||
return %0 : tensor<3x5x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[3][2]{
|
||||
{0, 1},
|
||||
{2, 3},
|
||||
{4, 5},
|
||||
};
|
||||
const uint8_t y[3][3]{
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
{2, 3, 4},
|
||||
};
|
||||
|
||||
const uint8_t expected[3][5]{
|
||||
{0, 1, 6, 7, 8},
|
||||
{2, 3, 9, 0, 1},
|
||||
{4, 5, 2, 3, 4},
|
||||
};
|
||||
|
||||
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 3 * 2);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
xArg(xRef, {3, 2});
|
||||
|
||||
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 3 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
yArg(yRef, {3, 3});
|
||||
|
||||
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
|
||||
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
|
||||
{&xArg, &yArg});
|
||||
ASSERT_EXPECTED_SUCCESS(call);
|
||||
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
|
||||
(*call)
|
||||
->cast<concretelang::TensorLambdaArgument<
|
||||
concretelang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(res.getDimensions().size(), (size_t)2);
|
||||
ASSERT_EQ(res.getDimensions().at(0), 3);
|
||||
ASSERT_EQ(res.getDimensions().at(1), 5);
|
||||
ASSERT_EXPECTED_VALUE(res.getNumElements(), 15);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
EXPECT_EQ(res.getValue()[(i * 5) + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, concat_3D_axis_0) {
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<4x4x3x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 0 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<4x4x3x!FHE.eint<7>>
|
||||
return %0 : tensor<4x4x3x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[2][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
const uint8_t y[2][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
|
||||
const uint8_t expected[4][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
|
||||
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 4 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
xArg(xRef, {2, 4, 3});
|
||||
|
||||
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 2 * 4 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
yArg(yRef, {2, 4, 3});
|
||||
|
||||
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
|
||||
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
|
||||
{&xArg, &yArg});
|
||||
ASSERT_EXPECTED_SUCCESS(call);
|
||||
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
|
||||
(*call)
|
||||
->cast<concretelang::TensorLambdaArgument<
|
||||
concretelang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(res.getDimensions().size(), (size_t)3);
|
||||
ASSERT_EQ(res.getDimensions().at(0), 4);
|
||||
ASSERT_EQ(res.getDimensions().at(1), 4);
|
||||
ASSERT_EQ(res.getDimensions().at(2), 3);
|
||||
ASSERT_EXPECTED_VALUE(res.getNumElements(), 48);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
for (size_t j = 0; j < 4; j++) {
|
||||
for (size_t k = 0; k < 3; k++) {
|
||||
EXPECT_EQ(res.getValue()[(i * 4 * 3) + (j * 3) + k], expected[i][j][k])
|
||||
<< ", at pos(" << i << "," << j << "," << k << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, concat_3D_axis_1) {
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x8x3x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 1 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x8x3x!FHE.eint<7>>
|
||||
return %0 : tensor<2x8x3x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[2][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
const uint8_t y[2][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
|
||||
const uint8_t expected[2][8][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
|
||||
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 4 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
xArg(xRef, {2, 4, 3});
|
||||
|
||||
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 2 * 4 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
yArg(yRef, {2, 4, 3});
|
||||
|
||||
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
|
||||
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
|
||||
{&xArg, &yArg});
|
||||
ASSERT_EXPECTED_SUCCESS(call);
|
||||
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
|
||||
(*call)
|
||||
->cast<concretelang::TensorLambdaArgument<
|
||||
concretelang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(res.getDimensions().size(), (size_t)3);
|
||||
ASSERT_EQ(res.getDimensions().at(0), 2);
|
||||
ASSERT_EQ(res.getDimensions().at(1), 8);
|
||||
ASSERT_EQ(res.getDimensions().at(2), 3);
|
||||
ASSERT_EXPECTED_VALUE(res.getNumElements(), 48);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 8; j++) {
|
||||
for (size_t k = 0; k < 3; k++) {
|
||||
EXPECT_EQ(res.getValue()[(i * 8 * 3) + (j * 3) + k], expected[i][j][k])
|
||||
<< ", at pos(" << i << "," << j << "," << k << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, concat_3D_axis_2) {
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%x: tensor<2x4x3x!FHE.eint<7>>, %y: tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x4x6x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.concat"(%x, %y) { axis = 2 } : (tensor<2x4x3x!FHE.eint<7>>, tensor<2x4x3x!FHE.eint<7>>) -> tensor<2x4x6x!FHE.eint<7>>
|
||||
return %0 : tensor<2x4x6x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[2][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
const uint8_t y[2][4][3]{
|
||||
{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
{9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4},
|
||||
{5, 6, 7},
|
||||
{8, 9, 0},
|
||||
{1, 2, 3},
|
||||
},
|
||||
};
|
||||
|
||||
const uint8_t expected[2][4][6]{
|
||||
{
|
||||
{0, 1, 2, 0, 1, 2},
|
||||
{3, 4, 5, 3, 4, 5},
|
||||
{6, 7, 8, 6, 7, 8},
|
||||
{9, 0, 1, 9, 0, 1},
|
||||
},
|
||||
{
|
||||
{2, 3, 4, 2, 3, 4},
|
||||
{5, 6, 7, 5, 6, 7},
|
||||
{8, 9, 0, 8, 9, 0},
|
||||
{1, 2, 3, 1, 2, 3},
|
||||
},
|
||||
};
|
||||
|
||||
llvm::ArrayRef<uint8_t> xRef((const uint8_t *)x, 2 * 4 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
xArg(xRef, {2, 4, 3});
|
||||
|
||||
llvm::ArrayRef<uint8_t> yRef((const uint8_t *)y, 2 * 4 * 3);
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<uint8_t>>
|
||||
yArg(yRef, {2, 4, 3});
|
||||
|
||||
llvm::Expected<std::unique_ptr<concretelang::LambdaArgument>> call =
|
||||
lambda.operator()<std::unique_ptr<concretelang::LambdaArgument>>(
|
||||
{&xArg, &yArg});
|
||||
ASSERT_EXPECTED_SUCCESS(call);
|
||||
|
||||
concretelang::TensorLambdaArgument<concretelang::IntLambdaArgument<>> &res =
|
||||
(*call)
|
||||
->cast<concretelang::TensorLambdaArgument<
|
||||
concretelang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(res.getDimensions().size(), (size_t)3);
|
||||
ASSERT_EQ(res.getDimensions().at(0), 2);
|
||||
ASSERT_EQ(res.getDimensions().at(1), 4);
|
||||
ASSERT_EQ(res.getDimensions().at(2), 6);
|
||||
ASSERT_EXPECTED_VALUE(res.getNumElements(), 48);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 4; j++) {
|
||||
for (size_t k = 0; k < 6; k++) {
|
||||
EXPECT_EQ(res.getValue()[(i * 4 * 6) + (j * 6) + k], expected[i][j][k])
|
||||
<< ", at pos(" << i << "," << j << "," << k << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TiledMatMulParametric
|
||||
: public ::testing::TestWithParam<std::vector<int64_t>> {};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user