// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/TFHEToConcrete/Patterns.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" namespace { struct TFHEToConcretePass : public TFHEToConcreteBase { void runOnOperation() final; }; } // namespace using mlir::concretelang::Concrete::LweCiphertextType; using mlir::concretelang::TFHE::GLWECipherTextType; /// TFHEToConcreteTypeConverter is a TypeConverter that transform /// `TFHE.glwe<{_,_,_}{p}>` to Concrete.lwe_ciphertext class TFHEToConcreteTypeConverter : public mlir::TypeConverter { public: TFHEToConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { return mlir::concretelang::convertTypeToLWE(type.getContext(), type); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { return (mlir::Type)(type); } mlir::Type r = mlir::RankedTensorType::get( type.getShape(), mlir::concretelang::convertTypeToLWE(glwe.getContext(), glwe)); return r; }); } }; void TFHEToConcretePass::runOnOperation() { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); TFHEToConcreteTypeConverter converter; // Mark ops from the target dialect as legal operations target.addLegalDialect(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); // Make sure that no ops `linalg.generic` that have illegal types target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return (converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); }); // Make sure that func has legal signature target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && converter.isLegal(&funcOp.getBody()); }); // Add all patterns required to lower all ops from `TFHE` to // `Concrete` mlir::OwningRewritePatternList patterns(&getContext()); populateWithGeneratedTFHEToConcrete(patterns); patterns.add>( &getContext(), converter); patterns.add>( &getContext(), converter); patterns.add>( &getContext(), converter); mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target, converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); // Conversion of RT Dialect Ops patterns.add>(patterns.getContext(), converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DataflowTaskOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } } namespace mlir { namespace concretelang { std::unique_ptr> createConvertTFHEToConcretePass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir