// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/TFHEToConcrete/Patterns.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" namespace TFHE = mlir::concretelang::TFHE; namespace Concrete = mlir::concretelang::Concrete; namespace { struct TFHEToConcretePass : public TFHEToConcreteBase { void runOnOperation() final; }; } // namespace using mlir::concretelang::Concrete::LweCiphertextType; using mlir::concretelang::TFHE::GLWECipherTextType; /// TFHEToConcreteTypeConverter is a TypeConverter that transform /// `TFHE.glwe<{_,_,_}{p}>` to Concrete.lwe_ciphertext class TFHEToConcreteTypeConverter : public mlir::TypeConverter { public: TFHEToConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { return mlir::concretelang::convertTypeToLWE(type.getContext(), type); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { return (mlir::Type)(type); } mlir::Type r = mlir::RankedTensorType::get( type.getShape(), mlir::concretelang::convertTypeToLWE(glwe.getContext(), glwe)); return r; }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( this->convertType(type.dyn_cast() .getElementType())); }); addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( this->convertType(type.dyn_cast() .getElementType())); }); } }; namespace { struct BootstrapGLWEOpPattern : public mlir::OpRewritePattern { BootstrapGLWEOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &converter, mlir::PatternBenefit benefit = 100) : mlir::OpRewritePattern(context, benefit), converter(converter) {} mlir::LogicalResult matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, mlir::PatternRewriter &rewriter) const override { mlir::Type resultType = converter.convertType(bsOp.getType()); auto newOp = rewriter.replaceOpWithNewOp( bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(), bsOp.level(), bsOp.baseLog(), bsOp.polySize(), bsOp.glweDimension()); rewriter.startRootUpdate(newOp); newOp.input_ciphertext().setType( converter.convertType(bsOp.ciphertext().getType())); rewriter.finalizeRootUpdate(newOp); return ::mlir::success(); } private: mlir::TypeConverter &converter; }; struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern { WopPBSGLWEOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &converter, mlir::PatternBenefit benefit = 100) : mlir::OpRewritePattern(context, benefit), converter(converter) {} mlir::LogicalResult matchAndRewrite(TFHE::WopPBSGLWEOp wopOp, mlir::PatternRewriter &rewriter) const override { mlir::Type resultType = converter.convertType(wopOp.getType()); auto newOp = rewriter.replaceOpWithNewOp( wopOp, resultType, wopOp.ciphertext(), wopOp.lookupTable(), // Bootstrap parameters wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(), // Keyswitch parameters wopOp.keyswitchLevel(), wopOp.keyswitchBaseLog(), // Packing keyswitch key parameters wopOp.packingKeySwitchInputLweDimension(), wopOp.packingKeySwitchoutputPolynomialSize(), wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(), // Circuit bootstrap parameters wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog()); rewriter.startRootUpdate(newOp); newOp.ciphertext().setType( converter.convertType(wopOp.ciphertext().getType())); rewriter.finalizeRootUpdate(newOp); return ::mlir::success(); } private: mlir::TypeConverter &converter; }; void TFHEToConcretePass::runOnOperation() { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); TFHEToConcreteTypeConverter converter; // Mark ops from the target dialect as legal operations target.addLegalDialect(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); // Legalize arith.constant operations introduced by some patterns target.addLegalOp(); // Make sure that no ops `linalg.generic` that have illegal types target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return (converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); }); // Make sure that func has legal signature target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); target.addDynamicallyLegalOp( [&](mlir::func::ConstantOp op) { return FunctionConstantOpConversion< TFHEToConcreteTypeConverter>::isLegal(op, converter); }); // Add all patterns required to lower all ops from `TFHE` to // `Concrete` mlir::RewritePatternSet patterns(&getContext()); patterns.add>( &getContext(), converter); populateWithGeneratedTFHEToConcrete(patterns); patterns.add>(&getContext(), converter); patterns.add(&getContext(), converter); patterns.add(&getContext(), converter); target.addDynamicallyLegalOp( [&](Concrete::BootstrapLweOp op) { return (converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes())); }); patterns.add>(&getContext(), converter); patterns.add>( &getContext(), converter); patterns.add< mlir::concretelang::GenericTypeConverterPattern>( patterns.getContext(), converter); patterns.add< mlir::concretelang::GenericTypeConverterPattern>( patterns.getContext(), converter); patterns.add>( &getContext(), converter); patterns.add>( &getContext(), converter); mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); // Conversion of RT Dialect Ops patterns.add< mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::MakeReadyFutureOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::AwaitFutureOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::CreateAsyncTaskOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::WorkFunctionReturnOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::AwaitFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } } } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createConvertTFHEToConcretePass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir