// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/FHEToTFHECrt/Pass.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include "concretelang/Dialect/RT/IR/RTDialect.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/RT/IR/RTTypes.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" namespace FHE = mlir::concretelang::FHE; namespace TFHE = mlir::concretelang::TFHE; namespace concretelang = mlir::concretelang; namespace fhe_to_tfhe_crt_conversion { namespace typing { /// Converts `FHE::EncryptedInteger` into `Tensor`. mlir::RankedTensorType convertEint(mlir::MLIRContext *context, FHE::EncryptedIntegerType eint, uint64_t crtLength) { return mlir::RankedTensorType::get( mlir::ArrayRef((int64_t)crtLength), TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth())); } /// Converts `Tensor` into a /// `Tensor` if the element type is appropriate. Otherwise /// return the input type. mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context, mlir::RankedTensorType maybeEintTensor, uint64_t crtLength) { if (!maybeEintTensor.getElementType().isa()) { return (mlir::Type)(maybeEintTensor); } auto eint = maybeEintTensor.getElementType().cast(); auto currentShape = maybeEintTensor.getShape(); mlir::SmallVector newShape = mlir::SmallVector(currentShape.begin(), currentShape.end()); newShape.push_back((int64_t)crtLength); return mlir::RankedTensorType::get( llvm::ArrayRef(newShape), TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth())); } /// Converts the type `FHE::EncryptedInteger` to `Tensor` /// if the input type is appropriate. Otherwise return the input type. mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t, uint64_t crtLength) { if (auto eint = t.dyn_cast()) return convertEint(context, eint, crtLength); return t; } /// The type converter used to convert `FHE` to `TFHE` types using the crt /// strategy. class TypeConverter : public mlir::TypeConverter { public: TypeConverter(concretelang::CrtLoweringParameters loweringParameters) { size_t nMods = loweringParameters.nMods; addConversion([](mlir::Type type) { return type; }); addConversion([=](FHE::EncryptedIntegerType type) { return convertEint(type.getContext(), type, nMods); }); addConversion([=](mlir::RankedTensorType type) { return maybeConvertEintTensor(type.getContext(), type, nMods); }); addConversion([&](concretelang::RT::FutureType type) { return concretelang::RT::FutureType::get(this->convertType( type.dyn_cast().getElementType())); }); addConversion([&](concretelang::RT::PointerType type) { return concretelang::RT::PointerType::get(this->convertType( type.dyn_cast().getElementType())); }); } /// Returns a lambda that uses this converter to turn one type into another. std::function getConversionLambda() { return [&](mlir::MLIRContext *, mlir::Type t) { return convertType(t); }; } }; } // namespace typing namespace lowering { /// A pattern rewriter superclass used by most op rewriters during the /// conversion. template struct CrtOpPattern : public mlir::OpRewritePattern { /// The lowering parameters are bound to the op rewriter. concretelang::CrtLoweringParameters loweringParameters; CrtOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit), loweringParameters(params) {} /// Writes an `scf::for` that loops over the crt dimension of one tensor and /// execute the input lambda to write the loop body. Returns the first result /// of the op. /// /// Note: /// ----- /// /// + The type of `firstArgTensor` type is used as output type. mlir::Value writeUnaryTensorLoop( mlir::Location location, mlir::Type returnType, mlir::PatternRewriter &rewriter, mlir::function_ref body) const { mlir::Value tensor = rewriter.create( location, returnType.cast(), mlir::ValueRange{}); // Create the loop mlir::arith::ConstantOp zeroConstantOp = rewriter.create(location, 0); mlir::arith::ConstantOp oneConstantOp = rewriter.create(location, 1); mlir::arith::ConstantOp crtSizeConstantOp = rewriter.create(location, loweringParameters.nMods); mlir::scf::ForOp newOp = rewriter.create( location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor, body); // Convert the types of the new operation typing::TypeConverter converter(loweringParameters); concretelang::convertOperandAndResultTypes(rewriter, newOp, converter.getConversionLambda()); return newOp.getResult(0); } /// Writes the crt encoding of a plaintext of arbitrary precision. mlir::Value writePlaintextCrtEncoding(mlir::Location location, mlir::Value rawPlaintext, mlir::PatternRewriter &rewriter) const { mlir::Value castedPlaintext = rewriter.create( location, rewriter.getI64Type(), rawPlaintext); return rewriter.create( location, mlir::RankedTensorType::get( mlir::ArrayRef(loweringParameters.nMods), rewriter.getI64Type()), castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods), rewriter.getI64IntegerAttr(loweringParameters.modsProd)); } }; /// Rewriter for the `FHE::add_eint_int` operation. struct AddEintIntOpPattern : public CrtOpPattern { AddEintIntOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::AddEintIntOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value eintOperand = op.a(); mlir::Value intOperand = op.b(); // Convert operand type to glwe tensor. typing::TypeConverter converter(loweringParameters); intOperand.setType(converter.convertType(intOperand.getType())); eintOperand.setType(converter.convertType(eintOperand.getType())); // Write plaintext encoding mlir::Value encodedPlaintextTensor = writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter); // Write add loop. mlir::Type ciphertextScalarType = converter.convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = builder.create(loc, eintOperand, iter); mlir::Value extractedInt = builder.create( loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedEint, extractedInt); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::sub_int_eint` operation. struct SubIntEintOpPattern : public CrtOpPattern { SubIntEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::SubIntEintOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value intOperand = op.a(); mlir::Value eintOperand = op.b(); // Convert operand type to glwe tensor. typing::TypeConverter converter(loweringParameters); intOperand.setType(converter.convertType(intOperand.getType())); eintOperand.setType(converter.convertType(eintOperand.getType())); // Write plaintext encoding mlir::Value encodedPlaintextTensor = writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter); // Write add loop. mlir::Type ciphertextScalarType = converter.convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = builder.create(loc, eintOperand, iter); mlir::Value extractedInt = builder.create( loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedInt, extractedEint); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::sub_eint_int` operation. struct SubEintIntOpPattern : public CrtOpPattern { SubEintIntOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::SubEintIntOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value eintOperand = op.a(); mlir::Value intOperand = op.b(); // Convert operand type to glwe tensor. typing::TypeConverter converter(loweringParameters); intOperand.setType(converter.convertType(intOperand.getType())); eintOperand.setType(converter.convertType(eintOperand.getType())); // Write plaintext negation mlir::Type intType = intOperand.getType(); mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(intType, -1); mlir::Value minusOne = rewriter.create(location, minusOneAttr) .getResult(); mlir::Value negative = rewriter.create(location, intOperand, minusOne) .getResult(); // Write plaintext encoding mlir::Value encodedPlaintextTensor = writePlaintextCrtEncoding(op.getLoc(), negative, rewriter); // Write add loop. mlir::Type ciphertextScalarType = converter.convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = builder.create(loc, eintOperand, iter); mlir::Value extractedInt = builder.create( loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedEint, extractedInt); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::add_eint` operation. struct AddEintOpPattern : CrtOpPattern { AddEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::AddEintOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value lhsOperand = op.a(); mlir::Value rhsOperand = op.b(); // Convert operand type to glwe tensor. typing::TypeConverter converter(loweringParameters); lhsOperand.setType(converter.convertType(lhsOperand.getType())); rhsOperand.setType(converter.convertType(rhsOperand.getType())); // Write add loop. mlir::Type ciphertextScalarType = converter.convertType(lhsOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, lhsOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedLhs = builder.create(loc, lhsOperand, iter); mlir::Value extractedRhs = builder.create(loc, rhsOperand, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedLhs, extractedRhs); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::sub_eint` operation. struct SubEintOpPattern : CrtOpPattern { SubEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::SubEintOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value lhsOperand = op.a(); mlir::Value rhsOperand = op.b(); // Convert operand type to glwe tensor. typing::TypeConverter converter(loweringParameters); lhsOperand.setType(converter.convertType(lhsOperand.getType())); rhsOperand.setType(converter.convertType(rhsOperand.getType())); // Write sub loop. mlir::Type ciphertextScalarType = converter.convertType(lhsOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, lhsOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedLhs = builder.create(loc, lhsOperand, iter); mlir::Value extractedRhs = builder.create(loc, rhsOperand, iter); mlir::Value negatedRhs = builder.create( loc, ciphertextScalarType, extractedRhs); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedLhs, negatedRhs); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::neg_eint` operation. struct NegEintOpPattern : CrtOpPattern { NegEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::NegEintOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value operand = op.a(); // Convert operand type to glwe tensor. typing::TypeConverter converter{loweringParameters}; operand.setType(converter.convertType(operand.getType())); // Write the loop nest. mlir::Type ciphertextScalarType = converter.convertType(operand.getType()) .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( location, operand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedCiphertext = builder.create(loc, operand, iter); mlir::Value negatedCiphertext = builder.create( loc, ciphertextScalarType, extractedCiphertext); mlir::Value newTensor = builder.create( loc, negatedCiphertext, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, loopRes); return mlir::success(); } }; /// Rewriter for the `FHE::mul_eint_int` operation. struct MulEintIntOpPattern : CrtOpPattern { MulEintIntOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::MulEintIntOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value eintOperand = op.a(); mlir::Value intOperand = op.b(); // Convert operand type to glwe tensor. typing::TypeConverter converter{loweringParameters}; eintOperand.setType(converter.convertType(eintOperand.getType())); // Write cleartext "encoding" mlir::Value encodedCleartext = rewriter.create( location, rewriter.getI64Type(), intOperand); // Write the loop nest. mlir::Type ciphertextScalarType = converter.convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedCiphertext = builder.create(loc, eintOperand, iter); mlir::Value negatedCiphertext = builder.create( loc, ciphertextScalarType, extractedCiphertext, encodedCleartext); mlir::Value newTensor = builder.create( loc, negatedCiphertext, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, loopRes); return mlir::success(); } }; /// Rewriter for the `FHE::apply_lookup_table` operation. struct ApplyLookupTableEintOpPattern : public CrtOpPattern { ApplyLookupTableEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::ApplyLookupTableEintOp op, mlir::PatternRewriter &rewriter) const override { typing::TypeConverter converter(loweringParameters); mlir::Value newLut = rewriter .create( op.getLoc(), mlir::RankedTensorType::get( mlir::ArrayRef(loweringParameters.lutSize), rewriter.getI64Type()), op.lut(), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.mods)), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.bits)), rewriter.getI32IntegerAttr(loweringParameters.polynomialSize), rewriter.getI32IntegerAttr(loweringParameters.modsProd)) .getResult(); // Replace the lut with an encoded / expanded one. auto wopPBS = rewriter.create( op.getLoc(), op.getType(), op.a(), newLut, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({})); concretelang::convertOperandAndResultTypes(rewriter, wopPBS, converter.getConversionLambda()); rewriter.replaceOp(op, {wopPBS.getResult()}); return ::mlir::success(); }; }; /// Rewriter for the `tensor::extract` operation. struct TensorExtractOpPattern : public CrtOpPattern { TensorExtractOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractOp op, mlir::PatternRewriter &rewriter) const override { if (!op.getTensor() .getType() .cast() .getElementType() .isa() && !op.getTensor() .getType() .cast() .getElementType() .isa()) { return mlir::success(); } typing::TypeConverter converter{loweringParameters}; mlir::SmallVector offsets; mlir::SmallVector sizes; mlir::SmallVector strides; for (auto index : op.getIndices()) { offsets.push_back(index); sizes.push_back(rewriter.getI64IntegerAttr(1)); strides.push_back(rewriter.getI64IntegerAttr(1)); } offsets.push_back( rewriter.create(op.getLoc(), 0) .getResult()); sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods)); strides.push_back(rewriter.getI64IntegerAttr(1)); auto newOp = rewriter.create( op.getLoc(), converter.convertType(op.getResult().getType()) .cast(), op.getTensor(), offsets, sizes, strides); concretelang::convertOperandAndResultTypes(rewriter, newOp, converter.getConversionLambda()); rewriter.replaceOp(op, {newOp.getResult()}); return mlir::success(); } }; /// Rewriter for the `tensor::extract` operation. struct TensorInsertOpPattern : public CrtOpPattern { TensorInsertOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertOp op, mlir::PatternRewriter &rewriter) const override { if (!op.getDest() .getType() .cast() .getElementType() .isa() && !op.getDest() .getType() .cast() .getElementType() .isa()) { return mlir::success(); } typing::TypeConverter converter{loweringParameters}; mlir::SmallVector offsets; mlir::SmallVector sizes; mlir::SmallVector strides; for (auto index : op.getIndices()) { offsets.push_back(index); sizes.push_back(rewriter.getI64IntegerAttr(1)); strides.push_back(rewriter.getI64IntegerAttr(1)); } offsets.push_back( rewriter.create(op.getLoc(), 0) .getResult()); sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods)); strides.push_back(rewriter.getI64IntegerAttr(1)); auto newOp = rewriter.create( op.getLoc(), op.getScalar(), op.getDest(), offsets, sizes, strides); concretelang::convertOperandAndResultTypes(rewriter, newOp, converter.getConversionLambda()); rewriter.replaceOp(op, {newOp}); return mlir::success(); } }; /// Rewriter for the `tensor::from_elements` operation. struct TensorFromElementsOpPattern : public CrtOpPattern { TensorFromElementsOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::FromElementsOp op, mlir::PatternRewriter &rewriter) const override { if (!op.getResult() .getType() .cast() .getElementType() .isa() && !op.getResult() .getType() .cast() .getElementType() .isa()) { return mlir::success(); } typing::TypeConverter converter{loweringParameters}; // Create dest tensor allocation op mlir::Value outputTensor = rewriter.create( op.getLoc(), converter.convertType(op.getResult().getType()) .cast(), mlir::ValueRange{}); // Create insert_slice ops to insert the different pieces. auto outputShape = outputTensor.getType().cast().getShape(); mlir::SmallVector offsets{ rewriter.getI64IntegerAttr(0)}; mlir::SmallVector sizes{rewriter.getI64IntegerAttr(1)}; mlir::SmallVector strides{ rewriter.getI64IntegerAttr(1)}; for (size_t dimIndex = 1; dimIndex < outputShape.size(); ++dimIndex) { sizes.push_back(rewriter.getI64IntegerAttr(outputShape[dimIndex])); strides.push_back(rewriter.getI64IntegerAttr(1)); offsets.push_back(rewriter.getI64IntegerAttr(0)); } for (size_t insertionIndex = 0; insertionIndex < op.getElements().size(); ++insertionIndex) { offsets[0] = rewriter.getI64IntegerAttr(insertionIndex); mlir::tensor::InsertSliceOp insertOp = rewriter.create( op.getLoc(), op.getElements()[insertionIndex], outputTensor, offsets, sizes, strides); concretelang::convertOperandAndResultTypes( rewriter, insertOp, converter.getConversionLambda()); outputTensor = insertOp.getResult(); } rewriter.replaceOp(op, {outputTensor}); return mlir::success(); } }; } // namespace lowering struct FHEToTFHECrtPass : public FHEToTFHECrtBase { FHEToTFHECrtPass(concretelang::CrtLoweringParameters params) : loweringParameters(params) {} void runOnOperation() override { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); typing::TypeConverter converter(loweringParameters); //------------------------------------------- Marking legal/illegal dialects target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return ( converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); }); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return (converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes())); }); target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); target.addDynamicallyLegalOp( [&](mlir::func::ConstantOp op) { return FunctionConstantOpConversion::isLegal( op, converter); }); target.addLegalOp(); target.addLegalOp(); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::MakeReadyFutureOp>(target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::CreateAsyncTaskOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::WorkFunctionReturnOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); //---------------------------------------------------------- Adding patterns mlir::RewritePatternSet patterns(&getContext()); // Patterns for the `FHE` dialect operations patterns.add< // |_ `FHE::zero_eint` concretelang::GenericTypeAndOpConverterPattern, // |_ `FHE::zero_tensor` concretelang::GenericTypeAndOpConverterPattern>( &getContext(), converter); // |_ `FHE::add_eint_int` patterns.add(&getContext(), loweringParameters); // Patterns for the relics of the `FHELinalg` dialect operations. // |_ `linalg::generic` turned to nested `scf::for` patterns.add>( patterns.getContext(), converter); patterns.add>( patterns.getContext(), converter); patterns.add< RegionOpTypeConverterPattern>( &getContext(), converter); patterns.add(&getContext(), loweringParameters); patterns.add(&getContext(), loweringParameters); patterns.add>(patterns.getContext(), converter); patterns.add< concretelang::GenericTypeConverterPattern>( patterns.getContext(), converter); patterns.add>(patterns.getContext(), converter); patterns.add< concretelang::GenericTypeConverterPattern>( patterns.getContext(), converter); patterns.add>( &getContext(), converter); // Patterns for `func` dialect operations. mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); patterns .add>( patterns.getContext(), converter); patterns.add>( &getContext(), converter); // Pattern for the `tensor::from_element` op. patterns.add(patterns.getContext(), loweringParameters); // Patterns for the `RT` dialect operations. patterns .add, concretelang::GenericTypeConverterPattern, concretelang::GenericTypeConverterPattern< concretelang::RT::MakeReadyFutureOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::AwaitFutureOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::CreateAsyncTaskOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::BuildReturnPtrPlaceholderOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::DerefReturnPtrPlaceholderOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::WorkFunctionReturnOp>, concretelang::GenericTypeConverterPattern< concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), converter); //--------------------------------------------------------- Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } private: concretelang::CrtLoweringParameters loweringParameters; }; } // namespace fhe_to_tfhe_crt_conversion namespace mlir { namespace concretelang { std::unique_ptr> createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering) { return std::make_unique( lowering); } } // namespace concretelang } // namespace mlir