Files
concrete/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp

519 lines
22 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RTOpConverter.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Support/Constants.h"
namespace TFHE = mlir::concretelang::TFHE;
namespace Tracing = mlir::concretelang::Tracing;
using TFHE::GLWECipherTextType;
/// Converts ciphertexts to plaintext integer types
class SimulateTFHETypeConverter : public mlir::TypeConverter {
public:
SimulateTFHETypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
addConversion([&](mlir::RankedTensorType type) {
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
if (glwe == nullptr) {
return (mlir::Type)(type);
}
return (mlir::Type)mlir::RankedTensorType::get(
type.getShape(), mlir::IntegerType::get(type.getContext(), 64));
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
};
namespace {
mlir::RankedTensorType toDynamicTensorType(mlir::TensorType staticSizedTensor) {
std::vector<int64_t> dynSizedShape(staticSizedTensor.getShape().size(),
mlir::ShapedType::kDynamic);
return mlir::RankedTensorType::get(dynSizedShape,
staticSizedTensor.getElementType());
}
struct NegOpPattern : public mlir::OpConversionPattern<TFHE::NegGLWEOp> {
NegOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::NegGLWEOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::NegGLWEOp negOp, TFHE::NegGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_neg_lwe_u64";
if (insertForwardDeclaration(
negOp, rewriter, funcName,
rewriter.getFunctionType({rewriter.getIntegerType(64)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
negOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)},
mlir::ValueRange({adaptor.getA()}));
return mlir::success();
}
};
struct SubIntGLWEOpPattern : public mlir::OpRewritePattern<TFHE::SubGLWEIntOp> {
SubIntGLWEOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<TFHE::SubGLWEIntOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::SubGLWEIntOp subOp,
mlir::PatternRewriter &rewriter) const override {
mlir::Value negated = rewriter.create<TFHE::NegGLWEOp>(
subOp.getLoc(), subOp.getB().getType(), subOp.getB());
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(subOp, subOp.getType(),
negated, subOp.getA());
return mlir::success();
}
};
struct EncodeExpandLutForBootstrapOpPattern
: public mlir::OpConversionPattern<TFHE::EncodeExpandLutForBootstrapOp> {
EncodeExpandLutForBootstrapOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::EncodeExpandLutForBootstrapOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::EncodeExpandLutForBootstrapOp eeOp,
TFHE::EncodeExpandLutForBootstrapOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_encode_expand_lut_for_boostrap";
mlir::Value polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
eeOp.getLoc(), eeOp.getPolySize(), 32);
mlir::Value outputBitsCst = rewriter.create<mlir::arith::ConstantIntOp>(
eeOp.getLoc(), eeOp.getOutputBits(), 32);
mlir::Value isSignedCst = rewriter.create<mlir::arith::ConstantIntOp>(
eeOp.getLoc(), eeOp.getIsSigned(), 1);
mlir::Value outputBuffer =
rewriter.create<mlir::bufferization::AllocTensorOp>(
eeOp.getLoc(),
eeOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
auto dynamicResultType = toDynamicTensorType(eeOp.getResult().getType());
auto dynamicLutType =
toDynamicTensorType(eeOp.getInputLookupTable().getType());
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
eeOp.getLoc(), dynamicResultType, outputBuffer);
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
eeOp.getLoc(),
toDynamicTensorType(eeOp.getInputLookupTable().getType()),
adaptor.getInputLookupTable());
// sim_encode_expand_lut_for_boostrap(uint64_t *out_allocated, uint64_t
// *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t
// out_stride, uint64_t *in_allocated, uint64_t *in_aligned, uint64_t
// in_offset, uint64_t in_size, uint64_t in_stride, uint32_t poly_size,
// uint32_t output_bits, bool is_signed)
if (insertForwardDeclaration(
eeOp, rewriter, funcName,
rewriter.getFunctionType(
{dynamicResultType, dynamicLutType, rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
{}))
.failed()) {
return mlir::failure();
}
rewriter.create<mlir::func::CallOp>(
eeOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({castedOutputBuffer, castedLUT, polySizeCst,
outputBitsCst, isSignedCst}));
rewriter.replaceOp(eeOp, outputBuffer);
return mlir::success();
}
};
struct EncodePlaintextWithCrtOpPattern
: public mlir::OpConversionPattern<TFHE::EncodePlaintextWithCrtOp> {
EncodePlaintextWithCrtOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::EncodePlaintextWithCrtOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::EncodePlaintextWithCrtOp epOp,
TFHE::EncodePlaintextWithCrtOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_encode_plaintext_with_crt";
mlir::Value modsProductCst = rewriter.create<mlir::arith::ConstantIntOp>(
epOp.getLoc(), epOp.getModsProd(), 64);
mlir::Value outputBuffer =
rewriter.create<mlir::bufferization::AllocTensorOp>(
epOp.getLoc(),
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
// TODO: add mods
if (insertForwardDeclaration(
epOp, rewriter, funcName,
rewriter.getFunctionType({epOp.getResult().getType(),
epOp.getInput().getType() /*, mods here*/,
rewriter.getI64Type()},
{}))
.failed()) {
return mlir::failure();
}
rewriter.create<mlir::func::CallOp>(
epOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInput(), modsProductCst}));
rewriter.replaceOp(epOp, outputBuffer);
return mlir::success();
}
};
struct BootstrapGLWEOpPattern
: public mlir::OpConversionPattern<TFHE::BootstrapGLWEOp> {
BootstrapGLWEOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::BootstrapGLWEOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
TFHE::BootstrapGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_bootstrap_lwe_u64";
TFHE::GLWECipherTextType resultType =
bsOp.getType().cast<TFHE::GLWECipherTextType>();
TFHE::GLWECipherTextType inputType =
bsOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
auto polySize = adaptor.getKey().getPolySize();
auto glweDimension = adaptor.getKey().getGlweDim();
auto levels = adaptor.getKey().getLevels();
auto baseLog = adaptor.getKey().getBaseLog();
auto inputLweDimension =
inputType.getKey().getNormalized().value().dimension;
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
bsOp.getLoc(), polySize, 32);
auto glweDimensionCst = rewriter.create<mlir::arith::ConstantIntOp>(
bsOp.getLoc(), glweDimension, 32);
auto levelsCst =
rewriter.create<mlir::arith::ConstantIntOp>(bsOp.getLoc(), levels, 32);
auto baseLogCst =
rewriter.create<mlir::arith::ConstantIntOp>(bsOp.getLoc(), baseLog, 32);
auto inputLweDimensionCst = rewriter.create<mlir::arith::ConstantIntOp>(
bsOp.getLoc(), inputLweDimension, 32);
auto dynamicLutType = toDynamicTensorType(bsOp.getLookupTable().getType());
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable());
// uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t
// *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t
// tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t
// poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim)
if (insertForwardDeclaration(
bsOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), dynamicLutType,
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
bsOp, funcName, this->getTypeConverter()->convertType(resultType),
mlir::ValueRange({adaptor.getCiphertext(), castedLUT,
inputLweDimensionCst, polySizeCst, levelsCst,
baseLogCst, glweDimensionCst}));
return mlir::success();
}
};
struct KeySwitchGLWEOpPattern
: public mlir::OpConversionPattern<TFHE::KeySwitchGLWEOp> {
KeySwitchGLWEOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::KeySwitchGLWEOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
TFHE::KeySwitchGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const std::string funcName = "sim_keyswitch_lwe_u64";
TFHE::GLWECipherTextType resultType =
ksOp.getType().cast<TFHE::GLWECipherTextType>();
TFHE::GLWECipherTextType inputType =
ksOp.getCiphertext().getType().cast<TFHE::GLWECipherTextType>();
auto levels = adaptor.getKey().getLevels();
auto baseLog = adaptor.getKey().getBaseLog();
auto inputDim = inputType.getKey().getNormalized().value().dimension;
auto outputDim = resultType.getKey().getNormalized().value().dimension;
mlir::Value levelCst =
rewriter.create<mlir::arith::ConstantIntOp>(ksOp.getLoc(), levels, 32);
mlir::Value baseLogCst =
rewriter.create<mlir::arith::ConstantIntOp>(ksOp.getLoc(), baseLog, 32);
mlir::Value inputDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
ksOp.getLoc(), inputDim, 32);
mlir::Value outputDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
ksOp.getLoc(), outputDim, 32);
// uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level,
// uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim)
if (insertForwardDeclaration(
ksOp, rewriter, funcName,
rewriter.getFunctionType(
{rewriter.getIntegerType(64), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32)},
{rewriter.getIntegerType(64)}))
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
ksOp, funcName, this->getTypeConverter()->convertType(resultType),
mlir::ValueRange({adaptor.getCiphertext(), levelCst, baseLogCst,
inputDimCst, outputDimCst}));
return mlir::success();
}
};
struct ZeroOpPattern : public mlir::OpConversionPattern<TFHE::ZeroGLWEOp> {
ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::ZeroGLWEOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::ZeroGLWEOp zeroOp, TFHE::ZeroGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType());
rewriter.replaceOpWithNewOp<mlir::arith::ConstantIntOp>(zeroOp, 0,
newResultTy);
return ::mlir::success();
};
};
struct ZeroTensorOpPattern
: public mlir::OpConversionPattern<TFHE::ZeroTensorGLWEOp> {
ZeroTensorOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: mlir::OpConversionPattern<TFHE::ZeroTensorGLWEOp>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(TFHE::ZeroTensorGLWEOp zeroTensorOp,
TFHE::ZeroTensorGLWEOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto newResultTy =
this->getTypeConverter()->convertType(zeroTensorOp.getType());
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
zeroTensorOp,
mlir::DenseElementsAttr::get(newResultTy, {mlir::APInt::getZero(64)}),
newResultTy);
return ::mlir::success();
};
};
struct SimulateTFHEPass : public SimulateTFHEBase<SimulateTFHEPass> {
void runOnOperation() final;
};
void SimulateTFHEPass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
SimulateTFHETypeConverter converter;
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();
mlir::RewritePatternSet patterns(&getContext());
// Replace ops and convert operand and result types
patterns.insert<mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEIntOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::AddGLWEOp, mlir::arith::AddIOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(),
converter);
// Convert operand and result types
patterns.insert<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::YieldOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::FromElementsOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExtractOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExtractSliceOp, true>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::InsertOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::InsertSliceOp, true>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExpandShapeOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::CollapseShapeOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::YieldOp>>(&getContext(), converter);
// legalize ops only if operand and result types are legal
target.addDynamicallyLegalOp<
mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp,
mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp,
mlir::tensor::InsertOp, mlir::tensor::InsertSliceOp,
mlir::tensor::FromElementsOp, mlir::tensor::ExpandShapeOp,
mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp>(
[&](mlir::Operation *op) {
return converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getOperandTypes());
});
// Make sure that no ops `linalg.generic` that have illegal types
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
// Update scf::ForOp region with converted types
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
SimulateTFHETypeConverter>>(
&getContext(), converter);
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp forOp) {
return converter.isLegal(forOp.getInitArgs().getTypes()) &&
converter.isLegal(forOp.getResults().getTypes());
});
patterns.insert<ZeroOpPattern, ZeroTensorOpPattern, KeySwitchGLWEOpPattern,
BootstrapGLWEOpPattern, EncodeExpandLutForBootstrapOpPattern,
EncodePlaintextWithCrtOpPattern, NegOpPattern>(&getContext(),
converter);
patterns.insert<SubIntGLWEOpPattern>(&getContext());
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::YieldOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>>(&getContext(),
converter);
mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target,
converter);
// Make sure that functions no longer operate on ciphertexts
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<SimulateTFHETypeConverter>::isLegal(
op, converter);
});
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
target, converter);
patterns.insert<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>>(&getContext(), converter);
patterns.add<FunctionConstantOpConversion<SimulateTFHETypeConverter>>(
&getContext(), converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createSimulateTFHEPass() {
return std::make_unique<SimulateTFHEPass>();
}
} // namespace concretelang
} // namespace mlir