mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
519 lines
22 KiB
C++
519 lines
22 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "concretelang/Conversion/Passes.h"
|
|
#include "concretelang/Conversion/Tools.h"
|
|
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
|
#include "concretelang/Conversion/Utils/RTOpConverter.h"
|
|
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
|
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
|
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
|
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
|
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
|
#include "concretelang/Support/Constants.h"
|
|
|
|
namespace TFHE = mlir::concretelang::TFHE;
|
|
namespace Tracing = mlir::concretelang::Tracing;
|
|
|
|
using TFHE::GLWECipherTextType;
|
|
|
|
/// Converts ciphertexts to plaintext integer types
|
|
class SimulateTFHETypeConverter : public mlir::TypeConverter {
|
|
|
|
public:
|
|
SimulateTFHETypeConverter() {
|
|
addConversion([](mlir::Type type) { return type; });
|
|
addConversion([&](GLWECipherTextType type) {
|
|
return mlir::IntegerType::get(type.getContext(), 64);
|
|
});
|
|
addConversion([&](mlir::RankedTensorType type) {
|
|
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
|
|
if (glwe == nullptr) {
|
|
return (mlir::Type)(type);
|
|
}
|
|
return (mlir::Type)mlir::RankedTensorType::get(
|
|
type.getShape(), mlir::IntegerType::get(type.getContext(), 64));
|
|
});
|
|
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
|
return mlir::concretelang::RT::FutureType::get(
|
|
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
|
.getElementType()));
|
|
});
|
|
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
|
return mlir::concretelang::RT::PointerType::get(
|
|
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
|
.getElementType()));
|
|
});
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
|
|
mlir::RankedTensorType toDynamicTensorType(mlir::TensorType staticSizedTensor) {
|
|
std::vector<int64_t> dynSizedShape(staticSizedTensor.getShape().size(),
|
|
mlir::ShapedType::kDynamic);
|
|
return mlir::RankedTensorType::get(dynSizedShape,
|
|
staticSizedTensor.getElementType());
|
|
}
|
|
|
|
struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
|
|
|
|
NegOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::NegGLWEOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::NegGLWEOp negOp, TFHE::NegGLWEOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
|
|
const std::string funcName = "sim_neg_lwe_u64";
|
|
|
|
if (insertForwardDeclaration(
|
|
negOp, rewriter, funcName,
|
|
rewriter.getFunctionType({rewriter.getIntegerType(64)},
|
|
{rewriter.getIntegerType(64)}))
|
|
.failed()) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
|
negOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
|
|
mlir::ValueRange({adaptor.getA()}));
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {
|
|
|
|
SubIntGLWEOpPattern(mlir::MLIRContext *context)
|
|
: mlir::OpRewritePattern<TFHE::SubGLWEIntOp>(
|
|
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::SubGLWEIntOp subOp,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
mlir::Value negated = rewriter.create<TFHE::NegGLWEOp>(
|
|
subOp.getLoc(), subOp.getB().getType(), subOp.getB());
|
|
|
|
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(subOp, subOp.getType(),
|
|
negated, subOp.getA());
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct EncodeExpandLutForBootstrapOpPattern
|
|
: public mlir::OpConversionPattern<TFHE::EncodeExpandLutForBootstrapOp> {
|
|
|
|
EncodeExpandLutForBootstrapOpPattern(mlir::MLIRContext *context,
|
|
mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::EncodeExpandLutForBootstrapOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::EncodeExpandLutForBootstrapOp eeOp,
|
|
TFHE::EncodeExpandLutForBootstrapOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
|
|
const std::string funcName = "sim_encode_expand_lut_for_boostrap";
|
|
|
|
mlir::Value polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
eeOp.getLoc(), eeOp.getPolySize(), 32);
|
|
mlir::Value outputBitsCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
eeOp.getLoc(), eeOp.getOutputBits(), 32);
|
|
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
eeOp.getLoc(), eeOp.getIsSigned(), 1);
|
|
|
|
mlir::Value outputBuffer =
|
|
rewriter.create<mlir::bufferization::AllocTensorOp>(
|
|
eeOp.getLoc(),
|
|
eeOp.getResult().getType().cast<mlir::RankedTensorType>(),
|
|
mlir::ValueRange{});
|
|
|
|
auto dynamicResultType = toDynamicTensorType(eeOp.getResult().getType());
|
|
auto dynamicLutType =
|
|
toDynamicTensorType(eeOp.getInputLookupTable().getType());
|
|
|
|
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
|
|
eeOp.getLoc(), dynamicResultType, outputBuffer);
|
|
|
|
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
|
|
eeOp.getLoc(),
|
|
toDynamicTensorType(eeOp.getInputLookupTable().getType()),
|
|
adaptor.getInputLookupTable());
|
|
|
|
// sim_encode_expand_lut_for_boostrap(uint64_t *out_allocated, uint64_t
|
|
// *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t
|
|
// out_stride, uint64_t *in_allocated, uint64_t *in_aligned, uint64_t
|
|
// in_offset, uint64_t in_size, uint64_t in_stride, uint32_t poly_size,
|
|
// uint32_t output_bits, bool is_signed)
|
|
if (insertForwardDeclaration(
|
|
eeOp, rewriter, funcName,
|
|
rewriter.getFunctionType(
|
|
{dynamicResultType, dynamicLutType, rewriter.getIntegerType(32),
|
|
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
|
|
{}))
|
|
.failed()) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.create<mlir::func::CallOp>(
|
|
eeOp.getLoc(), funcName, mlir::TypeRange{},
|
|
mlir::ValueRange({castedOutputBuffer, castedLUT, polySizeCst,
|
|
outputBitsCst, isSignedCst}));
|
|
|
|
rewriter.replaceOp(eeOp, outputBuffer);
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct EncodePlaintextWithCrtOpPattern
|
|
: public mlir::OpConversionPattern<TFHE::EncodePlaintextWithCrtOp> {
|
|
|
|
EncodePlaintextWithCrtOpPattern(mlir::MLIRContext *context,
|
|
mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::EncodePlaintextWithCrtOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::EncodePlaintextWithCrtOp epOp,
|
|
TFHE::EncodePlaintextWithCrtOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
|
|
const std::string funcName = "sim_encode_plaintext_with_crt";
|
|
|
|
mlir::Value modsProductCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
epOp.getLoc(), epOp.getModsProd(), 64);
|
|
|
|
mlir::Value outputBuffer =
|
|
rewriter.create<mlir::bufferization::AllocTensorOp>(
|
|
epOp.getLoc(),
|
|
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
|
|
mlir::ValueRange{});
|
|
|
|
// TODO: add mods
|
|
if (insertForwardDeclaration(
|
|
epOp, rewriter, funcName,
|
|
rewriter.getFunctionType({epOp.getResult().getType(),
|
|
epOp.getInput().getType() /*, mods here*/,
|
|
rewriter.getI64Type()},
|
|
{}))
|
|
.failed()) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.create<mlir::func::CallOp>(
|
|
epOp.getLoc(), funcName, mlir::TypeRange{},
|
|
mlir::ValueRange({outputBuffer, adaptor.getInput(), modsProductCst}));
|
|
|
|
rewriter.replaceOp(epOp, outputBuffer);
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct BootstrapGLWEOpPattern
|
|
: public mlir::OpConversionPattern<TFHE::BootstrapGLWEOp> {
|
|
|
|
BootstrapGLWEOpPattern(mlir::MLIRContext *context,
|
|
mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::BootstrapGLWEOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
|
|
TFHE::BootstrapGLWEOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
|
|
const std::string funcName = "sim_bootstrap_lwe_u64";
|
|
|
|
TFHE::GLWECipherTextType resultType =
|
|
bsOp.getType().cast<TFHE::GLWECipherTextType>();
|
|
TFHE::GLWECipherTextType inputType =
|
|
bsOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
|
|
|
|
auto polySize = adaptor.getKey().getPolySize();
|
|
auto glweDimension = adaptor.getKey().getGlweDim();
|
|
auto levels = adaptor.getKey().getLevels();
|
|
auto baseLog = adaptor.getKey().getBaseLog();
|
|
auto inputLweDimension =
|
|
inputType.getKey().getNormalized().value().dimension;
|
|
|
|
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
bsOp.getLoc(), polySize, 32);
|
|
auto glweDimensionCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
bsOp.getLoc(), glweDimension, 32);
|
|
auto levelsCst =
|
|
rewriter.create<mlir::arith::ConstantIntOp>(bsOp.getLoc(), levels, 32);
|
|
auto baseLogCst =
|
|
rewriter.create<mlir::arith::ConstantIntOp>(bsOp.getLoc(), baseLog, 32);
|
|
auto inputLweDimensionCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
bsOp.getLoc(), inputLweDimension, 32);
|
|
|
|
auto dynamicLutType = toDynamicTensorType(bsOp.getLookupTable().getType());
|
|
|
|
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
|
|
bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable());
|
|
|
|
// uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t
|
|
// *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t
|
|
// tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t
|
|
// poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim)
|
|
if (insertForwardDeclaration(
|
|
bsOp, rewriter, funcName,
|
|
rewriter.getFunctionType(
|
|
{rewriter.getIntegerType(64), dynamicLutType,
|
|
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
|
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
|
rewriter.getIntegerType(32)},
|
|
{rewriter.getIntegerType(64)}))
|
|
.failed()) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
|
bsOp, funcName, this->getTypeConverter()->convertType(resultType),
|
|
mlir::ValueRange({adaptor.getCiphertext(), castedLUT,
|
|
inputLweDimensionCst, polySizeCst, levelsCst,
|
|
baseLogCst, glweDimensionCst}));
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct KeySwitchGLWEOpPattern
|
|
: public mlir::OpConversionPattern<TFHE::KeySwitchGLWEOp> {
|
|
|
|
KeySwitchGLWEOpPattern(mlir::MLIRContext *context,
|
|
mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::KeySwitchGLWEOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
|
|
TFHE::KeySwitchGLWEOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
|
|
const std::string funcName = "sim_keyswitch_lwe_u64";
|
|
|
|
TFHE::GLWECipherTextType resultType =
|
|
ksOp.getType().cast<TFHE::GLWECipherTextType>();
|
|
TFHE::GLWECipherTextType inputType =
|
|
ksOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
|
|
|
|
auto levels = adaptor.getKey().getLevels();
|
|
auto baseLog = adaptor.getKey().getBaseLog();
|
|
auto inputDim = inputType.getKey().getNormalized().value().dimension;
|
|
auto outputDim = resultType.getKey().getNormalized().value().dimension;
|
|
|
|
mlir::Value levelCst =
|
|
rewriter.create<mlir::arith::ConstantIntOp>(ksOp.getLoc(), levels, 32);
|
|
mlir::Value baseLogCst =
|
|
rewriter.create<mlir::arith::ConstantIntOp>(ksOp.getLoc(), baseLog, 32);
|
|
mlir::Value inputDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
ksOp.getLoc(), inputDim, 32);
|
|
mlir::Value outputDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
|
ksOp.getLoc(), outputDim, 32);
|
|
|
|
// uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level,
|
|
// uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim)
|
|
if (insertForwardDeclaration(
|
|
ksOp, rewriter, funcName,
|
|
rewriter.getFunctionType(
|
|
{rewriter.getIntegerType(64), rewriter.getIntegerType(32),
|
|
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
|
rewriter.getIntegerType(32)},
|
|
{rewriter.getIntegerType(64)}))
|
|
.failed()) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
|
ksOp, funcName, this->getTypeConverter()->convertType(resultType),
|
|
mlir::ValueRange({adaptor.getCiphertext(), levelCst, baseLogCst,
|
|
inputDimCst, outputDimCst}));
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct ZeroOpPattern : public mlir::OpConversionPattern<TFHE::ZeroGLWEOp> {
|
|
ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::ZeroGLWEOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::ZeroGLWEOp zeroOp, TFHE::ZeroGLWEOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType());
|
|
rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(zeroOp, 0,
|
|
newResultTy);
|
|
return ::mlir::success();
|
|
};
|
|
};
|
|
|
|
struct ZeroTensorOpPattern
|
|
: public mlir::OpConversionPattern<TFHE::ZeroTensorGLWEOp> {
|
|
ZeroTensorOpPattern(mlir::MLIRContext *context,
|
|
mlir::TypeConverter &typeConverter)
|
|
: mlir::OpConversionPattern<TFHE::ZeroTensorGLWEOp>(
|
|
typeConverter, context,
|
|
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
|
|
|
::mlir::LogicalResult
|
|
matchAndRewrite(TFHE::ZeroTensorGLWEOp zeroTensorOp,
|
|
TFHE::ZeroTensorGLWEOp::Adaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
auto newResultTy =
|
|
this->getTypeConverter()->convertType(zeroTensorOp.getType());
|
|
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
|
|
zeroTensorOp,
|
|
mlir::DenseElementsAttr::get(newResultTy, {mlir::APInt::getZero(64)}),
|
|
newResultTy);
|
|
return ::mlir::success();
|
|
};
|
|
};
|
|
|
|
struct SimulateTFHEPass : public SimulateTFHEBase<SimulateTFHEPass> {
|
|
void runOnOperation() final;
|
|
};
|
|
|
|
void SimulateTFHEPass::runOnOperation() {
|
|
auto op = this->getOperation();
|
|
|
|
mlir::ConversionTarget target(getContext());
|
|
SimulateTFHETypeConverter converter;
|
|
|
|
target.addLegalDialect<mlir::arith::ArithDialect>();
|
|
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp,
|
|
mlir::tensor::CastOp>();
|
|
// Make sure that no ops from `TFHE` remain after the lowering
|
|
target.addIllegalDialect<TFHE::TFHEDialect>();
|
|
|
|
mlir::RewritePatternSet patterns(&getContext());
|
|
|
|
// Replace ops and convert operand and result types
|
|
patterns.insert<mlir::concretelang::GenericOneToOneOpConversionPattern<
|
|
TFHE::AddGLWEIntOp, mlir::arith::AddIOp>,
|
|
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
|
TFHE::AddGLWEOp, mlir::arith::AddIOp>,
|
|
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
|
TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(),
|
|
converter);
|
|
// Convert operand and result types
|
|
patterns.insert<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::bufferization::AllocTensorOp, true>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::scf::YieldOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::FromElementsOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::ExtractOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::ExtractSliceOp, true>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::InsertOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::InsertSliceOp, true>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::ExpandShapeOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::CollapseShapeOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::tensor::YieldOp>>(&getContext(), converter);
|
|
// legalize ops only if operand and result types are legal
|
|
target.addDynamicallyLegalOp<
|
|
mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp,
|
|
mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp,
|
|
mlir::tensor::InsertOp, mlir::tensor::InsertSliceOp,
|
|
mlir::tensor::FromElementsOp, mlir::tensor::ExpandShapeOp,
|
|
mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp>(
|
|
[&](mlir::Operation *op) {
|
|
return converter.isLegal(op->getResultTypes()) &&
|
|
converter.isLegal(op->getOperandTypes());
|
|
});
|
|
// Make sure that no ops `linalg.generic` that have illegal types
|
|
target
|
|
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
|
[&](mlir::Operation *op) {
|
|
return (
|
|
converter.isLegal(op->getOperandTypes()) &&
|
|
converter.isLegal(op->getResultTypes()) &&
|
|
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
|
});
|
|
// Update scf::ForOp region with converted types
|
|
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
|
SimulateTFHETypeConverter>>(
|
|
&getContext(), converter);
|
|
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp forOp) {
|
|
return converter.isLegal(forOp.getInitArgs().getTypes()) &&
|
|
converter.isLegal(forOp.getResults().getTypes());
|
|
});
|
|
|
|
patterns.insert<ZeroOpPattern, ZeroTensorOpPattern, KeySwitchGLWEOpPattern,
|
|
BootstrapGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern,
|
|
EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(),
|
|
converter);
|
|
patterns.insert<SubIntGLWEOpPattern>(&getContext());
|
|
|
|
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::func::ReturnOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::scf::YieldOp>,
|
|
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::bufferization::AllocTensorOp, true>>(&getContext(),
|
|
converter);
|
|
|
|
mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target,
|
|
converter);
|
|
|
|
// Make sure that functions no longer operate on ciphertexts
|
|
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
|
[&](mlir::func::FuncOp funcOp) {
|
|
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
|
converter.isLegal(&funcOp.getBody());
|
|
});
|
|
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
|
[&](mlir::func::ConstantOp op) {
|
|
return FunctionConstantOpConversion<SimulateTFHETypeConverter>::isLegal(
|
|
op, converter);
|
|
});
|
|
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
|
patterns, converter);
|
|
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
|
|
target, converter);
|
|
patterns.insert<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
|
mlir::func::ReturnOp>>(&getContext(), converter);
|
|
|
|
patterns.add<FunctionConstantOpConversion<SimulateTFHETypeConverter>>(
|
|
&getContext(), converter);
|
|
|
|
// Apply conversion
|
|
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
|
this->signalPassFailure();
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
std::unique_ptr<OperationPass<ModuleOp>> createSimulateTFHEPass() {
|
|
return std::make_unique<SimulateTFHEPass>();
|
|
}
|
|
} // namespace concretelang
|
|
} // namespace mlir
|