Files
concrete/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp
youben11 1e07733257 fix(lower-LUT): keep output size of KS during lowering
output size of keyswiting wasn't set properly. As this information must
come from the selected parameters, it should goes down from the MidLFHE
to the appropriate call to ciphertext allocation
2021-09-09 20:35:28 +02:00

727 lines
30 KiB
C++

#include "mlir//IR/BuiltinTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
class LowLFHEToConcreteCAPITypeConverter : public mlir::TypeConverter {
public:
LowLFHEToConcreteCAPITypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::zamalang::LowLFHE::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
addConversion([&](mlir::zamalang::LowLFHE::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
}
};
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::PatternRewriter &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
mlir::SymbolTable::lookupSymbolIn(module, funcName));
if (!opFunc) {
// Insert the forward declaration of the funcName
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcType);
opFunc.setPrivate();
} else {
// Check if the `funcName` is well a private function
if (!opFunc.isPrivate()) {
op->emitError() << "the function \"" << funcName
<< "\" conflicts with the concrete C API, please rename";
return mlir::failure();
}
}
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
->template hasTrait<mlir::OpTrait::FunctionLike>());
return mlir::success();
}
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
/// replace with a call to `funcName`, the funcName should be an external
/// function that was linked later. It insert the forward declaration of the
/// private `funcName` if it not already in the symbol table.
/// The C signature of the function should be `void funcName(int *err, out,
/// arg0, arg1)`, the pattern rewrite:
/// ```
/// out = op(arg0, arg1)
/// ```
/// to
/// ```
/// err = constant 0 : i64
/// call_op(err, out, arg0, arg1);
/// ```
template <typename Op>
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
mlir::StringRef funcName,
mlir::StringRef allocName,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
allocName(allocName) {}
mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of the operator function
{
mlir::SmallVector<mlir::Type, 4> operands{errType,
op->getResultTypes().front()};
for (auto ty : op->getOperandTypes()) {
operands.push_back(typeConverter.convertType(ty));
}
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), operands, {});
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {errType, rewriter.getIndexType()},
{op->getResultTypes().front()});
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
.failed()) {
return mlir::failure();
}
}
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
// Replace the operation with a call to the `funcName`
{
// Create the err value
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// Add the call to the allocation
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
// Add err and allocated value to operands
mlir::SmallVector<mlir::Value, 4> newOperands{errOp, alloc.getResult(0)};
for (auto operand : op->getOperands()) {
newOperands.push_back(operand);
}
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
newOperands);
}
return mlir::success();
};
private:
std::string funcName;
std::string allocName;
};
struct LowLFHEZeroOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp> {
LowLFHEZeroOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op,
mlir::PatternRewriter &rewriter) const override {
auto allocName = "allocate_lwe_ciphertext_u64";
auto errType = mlir::IndexType::get(rewriter.getContext());
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {errType, rewriter.getIndexType()},
{op->getResultTypes().front()});
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
.failed()) {
return mlir::failure();
}
}
// Replace the operation with a call to the `funcName`
{
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
// Create the err value
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// Add the call to the allocation
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
}
return mlir::success();
};
};
struct LowLFHEEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp> {
LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::EncodeIntOp op,
mlir::PatternRewriter &rewriter) const override {
{
mlir::Value castedInt = rewriter.create<mlir::ZeroExtendIOp>(
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
mlir::Value constantShiftOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
mlir::Type resultType = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<mlir::ShiftLeftOp>(op, resultType, castedInt,
constantShiftOp);
}
return mlir::success();
};
};
struct LowLFHEIntToCleartextOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::IntToCleartextOp> {
LowLFHEIntToCleartextOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::IntToCleartextOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::IntToCleartextOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Value castedInt = rewriter.replaceOpWithNewOp<mlir::ZeroExtendIOp>(
op, rewriter.getIntegerType(64), op->getOperands().front());
return mlir::success();
};
};
struct GlweFromTableOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable> {
GlweFromTableOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::GlweFromTable op,
mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of the alloc_glwe function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, mlir::IntegerType::get(rewriter.getContext(), 32)},
{mlir::zamalang::LowLFHE::PlaintextListType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the foregin_pt_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType,
// mlir::UnrankedTensorType::get(
// mlir::IntegerType::get(rewriter.getContext(), 64)),
op->getOperandTypes().front(),
mlir::IntegerType::get(rewriter.getContext(), 64)},
{mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext())});
if (insertForwardDeclaration(
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the fill_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType,
mlir::zamalang::LowLFHE::PlaintextListType::get(
rewriter.getContext()),
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext())},
{});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_list_glwe function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType,
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext()),
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext()),
mlir::zamalang::LowLFHE::PlaintextListType::get(
rewriter.getContext())},
{});
if (insertForwardDeclaration(
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate two glwe to build accumulator
auto glweSizeOp =
rewriter.create<mlir::ConstantOp>(op.getLoc(), op->getAttr("k"));
auto polySizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), op->getAttr("polynomialSize"));
mlir::SmallVector<mlir::Value> allocGlweOperands{errOp, glweSizeOp,
polySizeOp};
// first accumulator would replace the op since it's the returned value
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_glwe_ciphertext_u64",
mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
allocGlweOperands);
// second accumulator is just needed to build the actual accumulator
auto _accumulatorOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_glwe_ciphertext_u64",
mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
allocGlweOperands);
// allocate plaintext list
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{errOp,
polySizeOp};
auto plaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_plaintext_list_u64",
mlir::zamalang::LowLFHE::PlaintextListType::get(rewriter.getContext()),
allocPlaintextListOperands);
// create foreign plaintext
auto rankedTensorType =
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
if (rankedTensorType.getRank() != 1) {
llvm::errs() << "table lookup must be of a single dimension";
return mlir::failure();
}
auto sizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(
mlir::IntegerType::get(rewriter.getContext(), 64),
rankedTensorType.getDimSize(0)));
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
errOp, op->getOperand(0), sizeOp};
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "runtime_foreign_plaintext_list_u64",
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext()),
ForeignPlaintextListOperands);
// fill plaintext list
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
errOp, plaintextListOp.getResult(0),
foreignPlaintextListOp.getResult(0)};
rewriter.create<mlir::CallOp>(
op.getLoc(), "fill_plaintext_list_with_expansion_u64",
mlir::TypeRange({}), FillPlaintextListOperands);
// add plaintext list and glwe to build final accumulator for pbs
mlir::SmallVector<mlir::Value> AddPlaintextListGlweOperands{
errOp, accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
plaintextListOp.getResult(0)};
rewriter.create<mlir::CallOp>(
op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64",
mlir::TypeRange({}), AddPlaintextListGlweOperands);
return mlir::success();
};
};
struct LowLFHEBootstrapLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errType = mlir::IndexType::get(rewriter.getContext());
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_bsk_key function
// {
// auto funcType = mlir::FunctionType::get(
// rewriter.getContext(),
// {
// errType,
// // level
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // baselog
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // glwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // polynomial size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// },
// {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
// rewriter.getContext())});
// if (insertForwardDeclaration(op, rewriter,
// "allocate_lwe_bootstrap_key_u64",
// funcType)
// .failed()) {
// return mlir::failure();
// }
// }
// Insert forward declaration of the getBsk function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {},
{mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the allocate_lwe_ct function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{lweOperandType});
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the bootstrap function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext()),
lweOperandType,
lweOperandType,
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext()),
},
{});
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate the result lwe ciphertext
// TODO: use right value for output lwe size
// LweSize output_lwe_size = { (glwe_size._0 -1) * poly_size._0 + 1}
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("k").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands);
// allocate bsk
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
// auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32), -1));
// auto polySizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
// mlir::SmallVector<mlir::Value> allocBskOperands{
// errOp, decompLevelCountOp, decompBaseLogOp,
// glweSizeOp, lweSizeOp, polySizeOp};
// auto allocateBskOp = rewriter.create<mlir::CallOp>(
// op.getLoc(), "allocate_lwe_bootstrap_key_u64",
// mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
// rewriter.getContext()),
// allocBskOperands);
// get bsk
mlir::SmallVector<mlir::Value> getBskOperands{};
auto getBskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalBootstrapKey",
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext()),
getBskOperands);
// bootstrap
mlir::SmallVector<mlir::Value> bootstrapOperands{
errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0), op->getOperand(1)};
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
mlir::TypeRange({}), bootstrapOperands);
return mlir::success();
};
};
struct LowLFHEKeySwitchLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errType = mlir::IndexType::get(rewriter.getContext());
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_ksk_key function
// {
// auto funcType = mlir::FunctionType::get(
// rewriter.getContext(),
// {
// errType,
// // level
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // baselog
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // input lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // output lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// },
// {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
// rewriter.getContext())});
// if (insertForwardDeclaration(op, rewriter,
// "allocate_lwe_keyswitch_key_u64",
// funcType)
// .failed()) {
// return mlir::failure();
// }
// }
// Insert forward declaration of the getKsk function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {},
{mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the allocate_lwe_ct function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{lweOperandType});
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// TODO: build the right type here
auto lweOutputType = lweOperandType;
// Insert forward declaration of the keyswitch function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
// ksk
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext()),
// output ct
lweOutputType,
// input ct
lweOperandType,
},
{});
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate the result lwe ciphertext
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
// allocate ksk
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
// auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
// auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
// mlir::SmallVector<mlir::Value> allockskOperands{
// errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
// outputLweSizeOp};
// auto allocateKskOp = rewriter.create<mlir::CallOp>(
// op.getLoc(), "allocate_lwe_keyswitch_key_u64",
// mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
// rewriter.getContext()),
// allockskOperands);
// get ksk
mlir::SmallVector<mlir::Value> getkskOperands{};
auto getKskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalKeyswitchKey",
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext()),
getkskOperands);
// keyswitch
mlir::SmallVector<mlir::Value> keyswitchOperands{
errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0)};
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
mlir::TypeRange({}), keyswitchOperands);
return mlir::success();
};
};
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
/// operators to the corresponding function call to the `Concrete C API`.
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::zamalang::LowLFHE::AddLweCiphertextsOp>>(
patterns.getContext(), "add_lwe_ciphertexts_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::zamalang::LowLFHE::AddPlaintextLweCiphertextOp>>(
patterns.getContext(), "add_plaintext_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::zamalang::LowLFHE::MulCleartextLweCiphertextOp>>(
patterns.getContext(), "mul_cleartext_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::zamalang::LowLFHE::NegateLweCiphertextOp>>(
patterns.getContext(), "negate_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEEncodeIntOpPattern>(patterns.getContext());
patterns.add<LowLFHEIntToCleartextOpPattern>(patterns.getContext());
patterns.add<LowLFHEZeroOpPattern>(patterns.getContext());
patterns.add<GlweFromTableOpPattern>(patterns.getContext());
patterns.add<LowLFHEKeySwitchLweOpPattern>(patterns.getContext());
patterns.add<LowLFHEBootstrapLweOpPattern>(patterns.getContext());
}
namespace {
struct LowLFHEToConcreteCAPIPass
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
void runOnOperation() final;
};
} // namespace
void LowLFHEToConcreteCAPIPass::runOnOperation() {
// Setup the conversion target.
mlir::ConversionTarget target(getContext());
target.addIllegalDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::memref::MemRefDialect>();
// Setup rewrite patterns
mlir::RewritePatternSet patterns(&getContext());
populateLowLFHEToConcreteCAPICall(patterns);
// Apply the conversion
mlir::ModuleOp op = getOperation();
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass() {
return std::make_unique<LowLFHEToConcreteCAPIPass>();
}
} // namespace zamalang
} // namespace mlir