feat: implement basic sum operation

This commit is contained in:
Umut
2022-02-01 11:17:43 +03:00
parent a7c63a5494
commit 4203e86998
10 changed files with 457 additions and 2 deletions

View File

@@ -516,4 +516,29 @@ def ZeroOp : FHELinalg_Op<"zero", []> {
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$aggregate);
}
def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
let summary = "Returns the sum of all elements of a tensor of encrypted integers.";
let description = [{
Performs a sum to a tensor of encrypted integers.
Examples:
```mlir
// Returns the sum of all elements of `%a0`
"FHELinalg.sum"(%a0) : (tensor<3x3x!FHE.eint<4>>) -> !FHE.eint<4>
//
// ( [1,2,3] )
// sum ( [4,5,6] ) = 45
// ( [7,8,9] )
//
```
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor
);
let results = (outs EncryptedIntegerType:$out);
}
#endif

View File

@@ -961,6 +961,128 @@ struct FHELinalgZeroToLinalgGenerate
};
};
// This rewrite pattern transforms any instance of operators
// `FHELinalg.zero` to an instance of `linalg.generate` with an
// appropriate region yielding a zero value.
//
// Example:
//
// %result = "FHELinalg.sum"(%input) :
// tensor<d0xd1x...xdNx!FHE.eint<p>>() -> !FHE.eint<p>
//
// becomes:
//
// #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
// #map1 = affine_map<(i0, i1, ..., iN) -> (0)>
//
// %zero = "FHE.zero"() : () -> !FHE.eint<7>
// %accumulator = tensor.from_elements %zero : tensor<1x!FHE.eint<7>>
//
// %accumulation = linalg.generic
// { indexing_maps = [#map0, #map1], iterator_types = ["reduction",
// "reduction", ..., "reduction"] } ins(%input :
// tensor<d0xd1x...xdNx!FHE.eint<7>>) outs(%accumulator :
// tensor<1x!FHE.eint<7>>)
// {
// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
// %c = "FHE.add_eint"(%a, %b) :
// (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
// linalg.yield %c : !FHE.eint<7>
// } -> tensor<1x!FHE.eint<7>>
//
// %index = arith.constant 0 : index
// %result = tensor.extract %index : tensor<1x!FHE.eint<7>>
//
struct SumToLinalgGeneric
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::SumOp> {
SumToLinalgGeneric(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::SumOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(::mlir::concretelang::FHELinalg::SumOp sumOp,
::mlir::PatternRewriter &rewriter) const override {
namespace arith = mlir::arith;
namespace linalg = mlir::linalg;
namespace tensor = mlir::tensor;
namespace FHE = mlir::concretelang::FHE;
mlir::Location location = sumOp.getLoc();
mlir::Value input = sumOp.getOperand();
mlir::Value output = sumOp.getResult();
auto inputType = input.getType().dyn_cast_or_null<mlir::TensorType>();
assert(inputType != nullptr);
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
size_t inputDimensions = inputShape.size();
mlir::Value zero =
rewriter.create<FHE::ZeroEintOp>(location, output.getType())
.getResult();
for (size_t i = 0; i < inputDimensions; i++) {
if (inputShape[i] == 0) {
rewriter.replaceOp(sumOp, {zero});
return mlir::success();
}
}
mlir::Value accumulator =
rewriter.create<tensor::FromElementsOp>(location, zero).getResult();
auto ins = llvm::SmallVector<mlir::Value, 1>{input};
auto outs = llvm::SmallVector<mlir::Value, 1>{accumulator};
mlir::AffineMap inputMap = mlir::AffineMap::getMultiDimIdentityMap(
inputDimensions, this->getContext());
mlir::AffineMap outputMap = mlir::AffineMap::get(
inputDimensions, 0, {rewriter.getAffineConstantExpr(0)},
rewriter.getContext());
auto maps = llvm::SmallVector<mlir::AffineMap, 2>{inputMap, outputMap};
auto iteratorTypes = llvm::SmallVector<llvm::StringRef, 3>{};
for (size_t i = 0; i < inputDimensions; i++) {
iteratorTypes.push_back("reduction");
}
auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::Value lhs = blockArgs[0];
mlir::Value rhs = blockArgs[1];
mlir::Value addition =
nestedBuilder.create<FHE::AddEintOp>(location, lhs, rhs).getResult();
nestedBuilder.create<linalg::YieldOp>(location, addition);
};
auto resultTypes = llvm::SmallVector<mlir::Type, 1>{accumulator.getType()};
mlir::Value accumulation =
rewriter
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
iteratorTypes, regionBuilder)
.getResult(0);
mlir::Value index =
rewriter.create<arith::ConstantIndexOp>(location, 0).getResult();
auto indices = llvm::SmallVector<mlir::Value, 1>{index};
mlir::Value result =
rewriter.create<tensor::ExtractOp>(location, accumulation, indices)
.getResult();
rewriter.replaceOp(sumOp, {result});
return mlir::success();
};
};
namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
@@ -1020,6 +1142,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
patterns.insert<FHELinalgApplyMappedLookupTableToLinalgGeneric>(
&getContext());
patterns.insert<FHELinalgZeroToLinalgGenerate>(&getContext());
patterns.insert<SumToLinalgGeneric>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -837,6 +837,29 @@ static llvm::APInt getSqMANP(
return operandMANPs[0]->getValue().getMANP().getValue();
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SumOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
auto type = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
uint64_t numberOfElements = type.getNumElements();
if (numberOfElements == 0) {
return llvm::APInt{1, 1, false};
}
assert(operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().getValue();
unsigned int multiplierBits = ceilLog2(numberOfElements + 1);
auto multiplier = llvm::APInt{multiplierBits, numberOfElements, false};
return APIntWidthExtendUMul(multiplier, operandMANP);
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
@@ -909,6 +932,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(
op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};
} else if (auto sumOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::SumOp>(op)) {
norm2SqEquiv = getSqMANP(sumOp, operands);
}
// Tensor Operators
// ExtractOp

View File

@@ -0,0 +1,21 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)>
// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>>
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: return %[[v3]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}

View File

@@ -0,0 +1,21 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)>
// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>>
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: return %[[v3]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}

View File

@@ -0,0 +1,21 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)>
// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7>
// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>>
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>>
// CHECK-NEXT: return %[[v3]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}

View File

@@ -227,7 +227,7 @@ func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 1 = 65
// ceil(sqrt(65)) = 9
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -310,7 +310,7 @@ func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 1 = 65
// ceil(sqrt(65)) = 9
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -390,3 +390,25 @@ func @zero() -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @sum() -> !FHE.eint<7> {
%0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>>
// CHECK: "FHELinalg.sum"(%0) {MANP = 2 : ui{{[0-9]+}}} : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
%1 = "FHELinalg.sum"(%0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
%2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>>
// CHECK: "FHELinalg.sum"(%2) {MANP = 3 : ui{{[0-9]+}}} : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7>
%3 = "FHELinalg.sum"(%2) : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7>
%4 = "FHELinalg.zero"() : () -> tensor<9x!FHE.eint<7>>
// CHECK: "FHELinalg.sum"(%4) {MANP = 3 : ui{{[0-9]+}}} : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7>
%5 = "FHELinalg.sum"(%4) : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7>
%6 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>>
// CHECK: "FHELinalg.sum"(%6) {MANP = 4 : ui{{[0-9]+}}} : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7>
%7 = "FHELinalg.sum"(%6) : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7>
return %7 : !FHE.eint<7>
}

View File

@@ -367,3 +367,43 @@ func @zero_2D() -> tensor<4x9x!FHE.eint<2>> {
%0 = "FHELinalg.zero"() : () -> tensor<4x9x!FHE.eint<2>>
return %0 : tensor<4x9x!FHE.eint<2>>
}
/////////////////////////////////////////////////
// FHELinalg.sum
/////////////////////////////////////////////////
// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7>
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7>
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7>
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
// CHECK-NEXT: }
func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}

View File

@@ -1579,6 +1579,152 @@ func @main() -> tensor<2x2x4x!FHE.eint<6>> {
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg sum /////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, sum_empty) {
using llvm::ArrayRef;
using llvm::Expected;
using mlir::concretelang::IntLambdaArgument;
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::TensorLambdaArgument;
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%x) : (tensor<0x!FHE.eint<7>>) -> (!FHE.eint<7>)
return %0 : !FHE.eint<7>
}
)XXX");
const uint8_t expected = 0;
ArrayRef<uint8_t> xRef(nullptr, (size_t)0);
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {0});
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
ASSERT_EXPECTED_SUCCESS(result);
ASSERT_EQ(*result, expected);
}
TEST(End2EndJit_FHELinalg, sum_1D) {
using llvm::ArrayRef;
using llvm::Expected;
using mlir::concretelang::IntLambdaArgument;
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::TensorLambdaArgument;
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%x) : (tensor<4x!FHE.eint<7>>) -> (!FHE.eint<7>)
return %0 : !FHE.eint<7>
}
)XXX");
const uint8_t x[4]{0, 1, 2, 3};
const uint8_t expected = 6;
ArrayRef<uint8_t> xRef((const uint8_t *)x, 4);
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {4});
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
ASSERT_EXPECTED_SUCCESS(result);
ASSERT_EQ(*result, expected);
}
TEST(End2EndJit_FHELinalg, sum_2D) {
using llvm::ArrayRef;
using llvm::Expected;
using mlir::concretelang::IntLambdaArgument;
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::TensorLambdaArgument;
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%x) : (tensor<3x4x!FHE.eint<7>>) -> (!FHE.eint<7>)
return %0 : !FHE.eint<7>
}
)XXX");
const uint8_t x[3][4]{
{0, 1, 2, 3},
{4, 5, 6, 7},
{8, 9, 0, 1},
};
const uint8_t expected = 46;
ArrayRef<uint8_t> xRef((const uint8_t *)x, 3 * 4);
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {3, 4});
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
ASSERT_EXPECTED_SUCCESS(result);
ASSERT_EQ(*result, expected);
}
TEST(End2EndJit_FHELinalg, sum_3D) {
using llvm::ArrayRef;
using llvm::Expected;
using mlir::concretelang::IntLambdaArgument;
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::TensorLambdaArgument;
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.sum"(%x) : (tensor<3x4x2x!FHE.eint<7>>) -> (!FHE.eint<7>)
return %0 : !FHE.eint<7>
}
)XXX");
const uint8_t x[3][4][2]{
{
{0, 1},
{2, 3},
{4, 5},
{6, 7},
},
{
{8, 9},
{0, 1},
{2, 3},
{4, 5},
},
{
{6, 7},
{8, 9},
{0, 1},
{2, 3},
},
};
const uint8_t expected = 96;
ArrayRef<uint8_t> xRef((const uint8_t *)x, 3 * 4 * 2);
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {3, 4, 2});
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
ASSERT_EXPECTED_SUCCESS(result);
ASSERT_EQ(*result, expected);
}
class TiledMatMulParametric
: public ::testing::TestWithParam<std::vector<int64_t>> {};