// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/FHEToTFHECrt/Pass.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/Dialects/SCF.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include "concretelang/Dialect/RT/IR/RTDialect.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/RT/IR/RTTypes.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" namespace FHE = mlir::concretelang::FHE; namespace TFHE = mlir::concretelang::TFHE; namespace concretelang = mlir::concretelang; namespace fhe_to_tfhe_crt_conversion { namespace typing { /// Converts an encrypted integer into `TFHE::GlweCiphertext`. mlir::RankedTensorType convertEncrypted(mlir::MLIRContext *context, FHE::FheIntegerInterface enc, uint64_t crtLength) { return mlir::RankedTensorType::get( mlir::ArrayRef((int64_t)crtLength), TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth())); } /// Converts `Tensor` into a /// `Tensor` if the element type is appropriate. /// Otherwise return the input type. mlir::Type maybeConvertEncryptedTensor(mlir::MLIRContext *context, mlir::RankedTensorType maybeEncryptedTensor, uint64_t crtLength) { if (!maybeEncryptedTensor.getElementType().isa()) { return (mlir::Type)(maybeEncryptedTensor); } auto encType = maybeEncryptedTensor.getElementType().cast(); auto currentShape = maybeEncryptedTensor.getShape(); mlir::SmallVector newShape = mlir::SmallVector(currentShape.begin(), currentShape.end()); newShape.push_back((int64_t)crtLength); return mlir::RankedTensorType::get( llvm::ArrayRef(newShape), TFHE::GLWECipherTextType::get(context, -1, -1, -1, encType.getWidth())); } /// Converts any encrypted type to `TFHE::GlweCiphetext` if the /// input type is appropriate. Otherwise return the input type. mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t, uint64_t crtLength) { if (auto eint = t.dyn_cast()) return convertEncrypted(context, eint, crtLength); return t; } /// The type converter used to convert `FHE` to `TFHE` types using the crt /// strategy. class TypeConverter : public mlir::TypeConverter { public: TypeConverter(concretelang::CrtLoweringParameters loweringParameters) { size_t nMods = loweringParameters.nMods; addConversion([](mlir::Type type) { return type; }); addConversion([=](FHE::FheIntegerInterface type) { return convertEncrypted(type.getContext(), type, nMods); }); addConversion([=](mlir::RankedTensorType type) { return maybeConvertEncryptedTensor(type.getContext(), type, nMods); }); addConversion([&](concretelang::RT::FutureType type) { return concretelang::RT::FutureType::get(this->convertType( type.dyn_cast().getElementType())); }); addConversion([&](concretelang::RT::PointerType type) { return concretelang::RT::PointerType::get(this->convertType( type.dyn_cast().getElementType())); }); } /// Returns a lambda that uses this converter to turn one type into another. std::function getConversionLambda() { return [&](mlir::MLIRContext *, mlir::Type t) { return convertType(t); }; } }; } // namespace typing namespace lowering { /// A pattern rewriter superclass used by most op rewriters during the /// conversion. template struct CrtOpPattern : public mlir::OpConversionPattern { /// The lowering parameters are bound to the op rewriter. concretelang::CrtLoweringParameters loweringParameters; CrtOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : mlir::OpConversionPattern(typeConverter, context, benefit), loweringParameters(params), typeConverter(params) {} /// Writes an `scf::for` that loops over the crt dimension of one tensor and /// execute the input lambda to write the loop body. Returns the first result /// of the op. /// /// Note: /// ----- /// /// + The type of `firstArgTensor` type is used as output type. mlir::Value writeUnaryTensorLoop( mlir::Location location, mlir::Type returnType, mlir::PatternRewriter &rewriter, mlir::function_ref body) const { mlir::Value tensor = rewriter.create( location, returnType.cast(), mlir::ValueRange{}); // Create the loop mlir::arith::ConstantOp zeroConstantOp = rewriter.create(location, 0); mlir::arith::ConstantOp oneConstantOp = rewriter.create(location, 1); mlir::arith::ConstantOp crtSizeConstantOp = rewriter.create(location, loweringParameters.nMods); mlir::scf::ForOp newOp = rewriter.create( location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor, body); return newOp.getResult(0); } /// Writes the crt encoding of a plaintext of arbitrary precision. mlir::Value writePlaintextCrtEncoding(mlir::Location location, mlir::Value rawPlaintext, mlir::PatternRewriter &rewriter) const { mlir::Value castedPlaintext = rewriter.create( location, rewriter.getI64Type(), rawPlaintext); return rewriter.create( location, mlir::RankedTensorType::get( mlir::ArrayRef(loweringParameters.nMods), rewriter.getI64Type()), castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods), rewriter.getI64IntegerAttr(loweringParameters.modsProd)); } protected: typing::TypeConverter typeConverter; }; /// Rewriter for the `FHE::add_eint_int` operation. struct AddEintIntOpPattern : public CrtOpPattern { AddEintIntOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value eintOperand = adaptor.a(); mlir::Value intOperand = adaptor.b(); // Write plaintext encoding mlir::Value encodedPlaintextTensor = writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter); // Write add loop. mlir::Type ciphertextScalarType = converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = builder.create(loc, eintOperand, iter); mlir::Value extractedInt = builder.create( loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedEint, extractedInt); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::sub_int_eint` operation. struct SubIntEintOpPattern : public CrtOpPattern { SubIntEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value intOperand = adaptor.a(); mlir::Value eintOperand = adaptor.b(); // Write plaintext encoding mlir::Value encodedPlaintextTensor = writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter); // Write add loop. mlir::Type ciphertextScalarType = converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = builder.create(loc, eintOperand, iter); mlir::Value extractedInt = builder.create( loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedInt, extractedEint); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::sub_eint_int` operation. struct SubEintIntOpPattern : public CrtOpPattern { SubEintIntOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value eintOperand = adaptor.a(); mlir::Value intOperand = adaptor.b(); // Write plaintext negation mlir::Type intType = intOperand.getType(); mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(intType, -1); mlir::Value minusOne = rewriter.create(location, minusOneAttr) .getResult(); mlir::Value negative = rewriter.create(location, intOperand, minusOne) .getResult(); // Write plaintext encoding mlir::Value encodedPlaintextTensor = writePlaintextCrtEncoding(op.getLoc(), negative, rewriter); // Write add loop. mlir::Type ciphertextScalarType = converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = builder.create(loc, eintOperand, iter); mlir::Value extractedInt = builder.create( loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedEint, extractedInt); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::add_eint` operation. struct AddEintOpPattern : CrtOpPattern { AddEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value lhsOperand = adaptor.a(); mlir::Value rhsOperand = adaptor.b(); // Write add loop. mlir::Type ciphertextScalarType = converter->convertType(lhsOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, lhsOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedLhs = builder.create(loc, lhsOperand, iter); mlir::Value extractedRhs = builder.create(loc, rhsOperand, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedLhs, extractedRhs); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::sub_eint` operation. struct SubEintOpPattern : CrtOpPattern { SubEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value lhsOperand = adaptor.a(); mlir::Value rhsOperand = adaptor.b(); // Write sub loop. mlir::Type ciphertextScalarType = converter->convertType(lhsOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( location, lhsOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedLhs = builder.create(loc, lhsOperand, iter); mlir::Value extractedRhs = builder.create(loc, rhsOperand, iter); mlir::Value negatedRhs = builder.create( loc, ciphertextScalarType, extractedRhs); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedLhs, negatedRhs); mlir::Value newTensor = builder.create( loc, output, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, output); return mlir::success(); } }; /// Rewriter for the `FHE::neg_eint` operation. struct NegEintOpPattern : CrtOpPattern { NegEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::NegEintOp op, FHE::NegEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value operand = adaptor.a(); // Write the loop nest. mlir::Type ciphertextScalarType = converter->convertType(operand.getType()) .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( location, operand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedCiphertext = builder.create(loc, operand, iter); mlir::Value negatedCiphertext = builder.create( loc, ciphertextScalarType, extractedCiphertext); mlir::Value newTensor = builder.create( loc, negatedCiphertext, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, loopRes); return mlir::success(); } }; /// Rewriter for the `FHE::to_signed` operation. struct ToSignedOpPattern : public CrtOpPattern { ToSignedOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} mlir::LogicalResult matchAndRewrite(FHE::ToSignedOp op, FHE::ToSignedOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { typing::TypeConverter converter{loweringParameters}; rewriter.replaceOp(op, {adaptor.input()}); return mlir::success(); } }; /// Rewriter for the `FHE::to_unsigned` operation. struct ToUnsignedOpPattern : public CrtOpPattern { ToUnsignedOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} mlir::LogicalResult matchAndRewrite(FHE::ToUnsignedOp op, FHE::ToUnsignedOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { typing::TypeConverter converter{loweringParameters}; rewriter.replaceOp(op, {adaptor.input()}); return mlir::success(); } }; /// Rewriter for the `FHE::mul_eint_int` operation. struct MulEintIntOpPattern : CrtOpPattern { MulEintIntOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); mlir::Value eintOperand = adaptor.a(); mlir::Value intOperand = adaptor.b(); // Write cleartext "encoding" mlir::Value encodedCleartext = rewriter.create( location, rewriter.getI64Type(), intOperand); // Write the loop nest. mlir::Type ciphertextScalarType = converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedCiphertext = builder.create(loc, eintOperand, iter); mlir::Value negatedCiphertext = builder.create( loc, ciphertextScalarType, extractedCiphertext, encodedCleartext); mlir::Value newTensor = builder.create( loc, negatedCiphertext, args[0], iter); builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. rewriter.replaceOp(op, loopRes); return mlir::success(); } }; /// Rewriter for the `FHE::apply_lookup_table` operation. struct ApplyLookupTableEintOpPattern : public CrtOpPattern { ApplyLookupTableEintOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(FHE::ApplyLookupTableEintOp op, FHE::ApplyLookupTableEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); auto originalInputType = op.a().getType().cast(); mlir::Value newLut = rewriter .create( op.getLoc(), mlir::RankedTensorType::get( mlir::ArrayRef(loweringParameters.lutSize), rewriter.getI64Type()), adaptor.lut(), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.mods)), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.bits)), rewriter.getI32IntegerAttr(loweringParameters.polynomialSize), rewriter.getI32IntegerAttr(loweringParameters.modsProd), rewriter.getBoolAttr(originalInputType.isSigned())) .getResult(); // Replace the lut with an encoded / expanded one. auto wopPBS = rewriter.create( op.getLoc(), converter->convertType(op.getType()), adaptor.a(), newLut, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({})); rewriter.replaceOp(op, {wopPBS.getResult()}); return ::mlir::success(); }; }; /// Rewriter for the `tensor::extract` operation. struct TensorExtractOpPattern : public CrtOpPattern { TensorExtractOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractOp op, mlir::tensor::ExtractOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); if (!op.getTensor() .getType() .cast() .getElementType() .isa() && !op.getTensor() .getType() .cast() .getElementType() .isa()) { return mlir::success(); } mlir::SmallVector offsets; mlir::SmallVector sizes; mlir::SmallVector strides; for (auto index : op.getIndices()) { offsets.push_back(index); sizes.push_back(rewriter.getI64IntegerAttr(1)); strides.push_back(rewriter.getI64IntegerAttr(1)); } offsets.push_back( rewriter.create(op.getLoc(), 0) .getResult()); sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods)); strides.push_back(rewriter.getI64IntegerAttr(1)); auto newOp = rewriter.create( op.getLoc(), converter->convertType(op.getResult().getType()) .cast(), adaptor.getTensor(), offsets, sizes, strides); rewriter.replaceOp(op, {newOp.getResult()}); return mlir::success(); } }; /// Rewriter for the `tensor::extract` operation. struct TensorInsertOpPattern : public CrtOpPattern { TensorInsertOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertOp op, mlir::tensor::InsertOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { if (!op.getDest() .getType() .cast() .getElementType() .isa() && !op.getDest() .getType() .cast() .getElementType() .isa()) { return mlir::success(); } mlir::SmallVector offsets; mlir::SmallVector sizes; mlir::SmallVector strides; for (auto index : op.getIndices()) { offsets.push_back(index); sizes.push_back(rewriter.getI64IntegerAttr(1)); strides.push_back(rewriter.getI64IntegerAttr(1)); } offsets.push_back( rewriter.create(op.getLoc(), 0) .getResult()); sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods)); strides.push_back(rewriter.getI64IntegerAttr(1)); auto newOp = rewriter.create( op.getLoc(), adaptor.getScalar(), adaptor.getDest(), offsets, sizes, strides); rewriter.replaceOp(op, {newOp}); return mlir::success(); } }; /// Rewriter for the `tensor::from_elements` operation. struct TensorFromElementsOpPattern : public CrtOpPattern { TensorFromElementsOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::FromElementsOp op, mlir::tensor::FromElementsOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); if (!op.getResult() .getType() .cast() .getElementType() .isa() && !op.getResult() .getType() .cast() .getElementType() .isa()) { return mlir::success(); } // Create dest tensor allocation op mlir::Value outputTensor = rewriter.create( op.getLoc(), converter->convertType(op.getResult().getType()) .cast(), mlir::ValueRange{}); // Create insert_slice ops to insert the different pieces. auto oldOutputType = outputTensor.getType(); auto newOutputType = this->getTypeConverter()->convertType(oldOutputType); auto newOutputShape = newOutputType.cast().getShape(); mlir::SmallVector sizes(newOutputShape.size(), rewriter.getI64IntegerAttr(1)); sizes[sizes.size() - 1] = rewriter.getI64IntegerAttr(loweringParameters.nMods); mlir::SmallVector strides( newOutputShape.size(), rewriter.getI64IntegerAttr(1)); auto offsetGenerator = [&](size_t index) { mlir::SmallVector offsets( newOutputShape.size(), rewriter.getI64IntegerAttr(0)); size_t remainder = index * 5; for (int rankIndex = newOutputShape.size() - 1; rankIndex >= 0; --rankIndex) { offsets[rankIndex] = rewriter.getI64IntegerAttr(remainder % newOutputShape[rankIndex]); remainder = remainder / newOutputShape[rankIndex]; } return offsets; }; for (size_t insertionIndex = 0; insertionIndex < adaptor.getElements().size(); ++insertionIndex) { mlir::tensor::InsertSliceOp insertOp = rewriter.create( op.getLoc(), adaptor.getElements()[insertionIndex], outputTensor, offsetGenerator(insertionIndex), sizes, strides); outputTensor = insertOp.getResult(); } rewriter.replaceOp(op, {outputTensor}); return mlir::success(); } }; // Generic template for tensor operations that have reassociation map // attributes. template struct TensorReassociationOpPattern : public CrtOpPattern { TensorReassociationOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); auto reassocVal = (inRank ? adaptor.src() : op.result()); auto reassocTy = reassocVal.getType(); auto newReassocType = converter->convertType(reassocTy); mlir::SmallVector oldReassocs = op.getReassociationIndices(); mlir::SmallVector newReassocs{oldReassocs}; mlir::ReassociationIndices newReassocEnd; newReassocEnd.push_back( newReassocType.template cast().getRank() - 1); newReassocs.push_back(newReassocEnd); auto newOp = rewriter.create( op.getLoc(), converter->convertType(op.getResult().getType()), adaptor.src(), newReassocs); rewriter.replaceOp(op, {newOp}); return mlir::success(); }; }; struct ExtractSliceOpPattern : public CrtOpPattern { ExtractSliceOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp op, mlir::tensor::ExtractSliceOp::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::SmallVector newStaticOffsets{ op.static_offsets().template getAsRange()}; mlir::SmallVector newStaticSizes{ op.static_sizes().template getAsRange()}; mlir::SmallVector newStaticStrides{ op.static_strides().template getAsRange()}; newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0)); newStaticSizes.push_back( rewriter.getI64IntegerAttr(this->loweringParameters.nMods)); newStaticStrides.push_back(rewriter.getI64IntegerAttr(1)); mlir::RankedTensorType newType = converter->convertType(op.getResult().getType()) .template cast(); rewriter.replaceOpWithNewOp( op, newType, adaptor.source(), adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets), rewriter.getArrayAttr(newStaticSizes), rewriter.getArrayAttr(newStaticStrides)); return mlir::success(); }; }; struct InsertSliceOpPattern : public CrtOpPattern { InsertSliceOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertSliceOp op, mlir::tensor::InsertSliceOp::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); mlir::SmallVector newStaticOffsets{ op.static_offsets().template getAsRange()}; mlir::SmallVector newStaticSizes{ op.static_sizes().template getAsRange()}; mlir::SmallVector newStaticStrides{ op.static_strides().template getAsRange()}; newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0)); newStaticSizes.push_back( rewriter.getI64IntegerAttr(this->loweringParameters.nMods)); newStaticStrides.push_back(rewriter.getI64IntegerAttr(1)); mlir::RankedTensorType newType = converter->convertType(op.getResult().getType()) .template cast(); rewriter.replaceOpWithNewOp( op, newType, adaptor.source(), adaptor.dest(), adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets), rewriter.getArrayAttr(newStaticSizes), rewriter.getArrayAttr(newStaticStrides)); return mlir::success(); }; }; } // namespace lowering struct FHEToTFHECrtPass : public FHEToTFHECrtBase { FHEToTFHECrtPass(concretelang::CrtLoweringParameters params) : loweringParameters(params) {} void runOnOperation() override { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); typing::TypeConverter converter(loweringParameters); //------------------------------------------- Marking legal/illegal dialects target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return ( converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); }); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return (converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes())); }); target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); target.addDynamicallyLegalOp( [&](mlir::func::ConstantOp op) { return FunctionConstantOpConversion::isLegal( op, converter); }); target.addLegalOp(); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp(target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::MakeReadyFutureOp>(target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::CreateAsyncTaskOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::WorkFunctionReturnOp>(target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); //---------------------------------------------------------- Adding patterns mlir::RewritePatternSet patterns(&getContext()); // Patterns for `bufferization` dialect operations. patterns.add>(patterns.getContext(), converter); // Patterns for the `FHE` dialect operations patterns.add< // |_ `FHE::zero_eint` concretelang::GenericOneToOneOpConversionPattern, // |_ `FHE::zero_tensor` concretelang::GenericOneToOneOpConversionPattern< FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(), converter); // |_ `FHE::add_eint_int` patterns.add(&getContext(), loweringParameters); // Patterns for the relics of the `FHELinalg` dialect operations. // |_ `linalg::generic` turned to nested `scf::for` patterns.add< concretelang::TypeConvertingReinstantiationPattern>( patterns.getContext(), converter); patterns.add< concretelang::TypeConvertingReinstantiationPattern>( patterns.getContext(), converter); patterns.add< concretelang::TypeConvertingReinstantiationPattern>( &getContext(), converter); patterns.add(&getContext(), loweringParameters); patterns.add(&getContext(), loweringParameters); patterns.add>(patterns.getContext(), loweringParameters); patterns.add>(patterns.getContext(), loweringParameters); patterns.add(patterns.getContext(), loweringParameters); patterns.add(patterns.getContext(), loweringParameters); patterns.add>(&getContext(), converter); // Patterns for `func` dialect operations. mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); patterns.add>(patterns.getContext(), converter); patterns.add>( &getContext(), converter); // Pattern for the `tensor::from_element` op. patterns.add(patterns.getContext(), loweringParameters); // Patterns for the `RT` dialect operations. patterns.add< // concretelang::TypeConvertingReinstantiationPattern< // mlir::func::ReturnOp>, concretelang::TypeConvertingReinstantiationPattern, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::MakeReadyFutureOp>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::AwaitFutureOp>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::CreateAsyncTaskOp, true>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::BuildReturnPtrPlaceholderOp>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::DerefReturnPtrPlaceholderOp>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::WorkFunctionReturnOp>, concretelang::TypeConvertingReinstantiationPattern< concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), converter); //--------------------------------------------------------- Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } private: concretelang::CrtLoweringParameters loweringParameters; }; } // namespace fhe_to_tfhe_crt_conversion namespace mlir { namespace concretelang { std::unique_ptr> createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering) { return std::make_unique( lowering); } } // namespace concretelang } // namespace mlir