Files
concrete/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp
Quentin Bourgerie 8cd3a3a599 feat(compiler): First draft to support FHE.eint up to 16bits
For now what it works are only levelled ops with user parameters. (take a look to the tests)

Done:
- Add parameters to the fhe parameters to support CRT-based large integers
- Add command line options and tests options to allows the user to give those new parameters
- Update the dialects and pipeline to handle new fhe parameters for CRT-based large integers
- Update the client parameters and the client library to handle the CRT-based large integers

Todo:
- Plug the optimizer to compute the CRT-based large interger parameters
- Plug the pbs for the CRT-based large integer
2022-08-12 16:35:11 +02:00

562 lines
23 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
namespace arith = mlir::arith;
namespace tensor = mlir::tensor;
namespace bufferization = mlir::bufferization;
namespace scf = mlir::scf;
namespace BConcrete = mlir::concretelang::BConcrete;
namespace crt = concretelang::clientlib::crt;
namespace {
char encode_crt[] = "encode_crt";
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
// (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %tmp = "BConcreteOp"(%blockArg)
// : (tensor<lweSizexi64>) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
template <typename BConcreteCRTOp, typename BConcreteOp>
struct BConcreteCRTUnaryOpPattern
: public mlir::OpRewritePattern<BConcreteCRTOp> {
BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(BConcreteCRTOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.ciphertext(), offsets, sizes, strides);
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
auto tmp = builder.create<BConcreteOp>(loc, blockTy, blockArg);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
// (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
template <typename BConcreteCRTOp, typename BConcreteOp>
struct BConcreteCRTBinaryOpPattern
: public mlir::OpRewritePattern<BConcreteCRTOp> {
BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(BConcreteCRTOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg1 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.rhs(), offsets, sizes, strides);
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
auto tmp =
builder.create<BConcreteOp>(loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block with the crt decomposition of the cleartext.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// // Build the decomposition of the plaintext
// %x0_a = arith.constant 64/d0 : f64
// %x0_b = arith.mulf %x, %x0_a : i64
// %x0 = arith.fptoui %x0_b : f64 to i64
// ...
// %xn_a = arith.constant 64/dn : f64
// %xn_b = arith.mulf %x, %xn_a : i64
// %xn = arith.fptoui %xn_b : f64 to i64
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
// // Loop on blocks
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct AddPlaintextCRTLweBufferOpPattern
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp> {
AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
auto rhs = op.rhs();
mlir::SmallVector<mlir::Value, 5> plaintextElements;
uint64_t moduliProduct = 1;
for (mlir::Attribute di : op.crtDecomposition()) {
moduliProduct *= di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
}
if (auto cst =
mlir::dyn_cast_or_null<arith::ConstantIntOp>(rhs.getDefiningOp())) {
auto apCst = cst.getValue().cast<mlir::IntegerAttr>().getValue();
auto value = apCst.getSExtValue();
// constant value, encode at compile time
for (mlir::Attribute di : op.crtDecomposition()) {
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
auto encoded = crt::encode(value, modulus, moduliProduct);
plaintextElements.push_back(
rewriter.create<arith::ConstantIntOp>(loc, encoded, 64));
}
} else {
// dynamic value, encode at runtime
if (insertForwardDeclaration(
op, rewriter, encode_crt,
mlir::FunctionType::get(rewriter.getContext(),
{rewriter.getI64Type(),
rewriter.getI64Type(),
rewriter.getI64Type()},
{rewriter.getI64Type()}))
.failed()) {
return mlir::failure();
}
auto extOp =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), rhs);
auto moduliProductOp =
rewriter.create<arith::ConstantIntOp>(loc, moduliProduct, 64);
for (mlir::Attribute di : op.crtDecomposition()) {
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
auto modulusOp =
rewriter.create<arith::ConstantIntOp>(loc, modulus, 64);
plaintextElements.push_back(
rewriter
.create<mlir::func::CallOp>(
loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()},
mlir::ValueRange{extOp, modulusOp, moduliProductOp})
.getResult(0));
}
}
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
auto x_decomp =
rewriter.create<tensor::FromElementsOp>(loc, plaintextElements);
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::AddPlaintextLweBufferOp>(
loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block with the crt decomposition of the cleartext.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// // Build the decomposition of the plaintext
// %x0_a = arith.constant 64/d0 : f64
// %x0_b = arith.mulf %x, %x0_a : i64
// %x0 = arith.fptoui %x0_b : f64 to i64
// ...
// %xn_a = arith.constant 64/dn : f64
// %xn_b = arith.mulf %x, %xn_a : i64
// %xn = arith.fptoui %xn_b : f64 to i64
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
// // Loop on blocks
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct MulCleartextCRTLweBufferOpPattern
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp> {
MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
auto rhs = rewriter.create<arith::ExtUIOp>(op.getLoc(),
rewriter.getI64Type(), op.rhs());
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::MulCleartextLweBufferOp>(
loc, blockTy, blockArg0, rhs);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
struct EliminateCRTOpsPass : public EliminateCRTOpsBase<EliminateCRTOpsPass> {
void runOnOperation() final;
};
void EliminateCRTOpsPass::runOnOperation() {
auto op = getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// add_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddCRTLweBuffersOp>();
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweBuffersOp,
BConcrete::AddLweBuffersOp>>(
&getContext());
// add_plaintext_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddPlaintextCRTLweBufferOp>();
patterns.add<AddPlaintextCRTLweBufferOpPattern>(&getContext());
// mul_cleartext_crt_lwe_buffer
target.addIllegalOp<BConcrete::MulCleartextCRTLweBufferOp>();
patterns.add<MulCleartextCRTLweBufferOpPattern>(&getContext());
target.addIllegalOp<BConcrete::NegateCRTLweBufferOp>();
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweBufferOp,
BConcrete::NegateLweBufferOp>>(
&getContext());
// This dialect are used to transforms crt ops to bconcrete ops
target
.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
scf::SCFDialect, bufferization::BufferizationDialect,
mlir::func::FuncDialect, BConcrete::BConcreteDialect>();
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
return;
}
}
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps() {
return std::make_unique<EliminateCRTOpsPass>();
}
} // namespace concretelang
} // namespace mlir