// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h" #include "concretelang/Dialect/TFHE/IR/TFHEParameters.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/RTOpConverter.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" #include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "concretelang/Support/Constants.h" #include "concretelang/Support/TFHECircuitKeys.h" #include #include #include namespace TFHE = mlir::concretelang::TFHE; using mlir::concretelang::TFHE::GLWECipherTextType; namespace conversion { class KeyConverter { public: KeyConverter(mlir::concretelang::TFHE::TFHECircuitKeys &circuitKeys) : circuitKeys(circuitKeys){}; TFHE::GLWESecretKey convertSecretKey(TFHE::GLWESecretKey sk) { auto parameterizedKey = sk.getParameterized().value(); return TFHE::GLWESecretKey::newNormalized( parameterizedKey.dimension, parameterizedKey.polySize, circuitKeys.getSecretKeyIndex(sk).value()); } TFHE::GLWEBootstrapKeyAttr convertBootstrapKey(TFHE::GLWEBootstrapKeyAttr bsk) { return TFHE::GLWEBootstrapKeyAttr::get( bsk.getContext(), convertSecretKey(bsk.getInputKey()), convertSecretKey(bsk.getOutputKey()), bsk.getPolySize(), bsk.getGlweDim(), bsk.getLevels(), bsk.getBaseLog(), circuitKeys.getBootstrapKeyIndex(bsk).value()); } TFHE::GLWEKeyswitchKeyAttr convertKeyswitchKey(TFHE::GLWEKeyswitchKeyAttr ksk) { return TFHE::GLWEKeyswitchKeyAttr::get( ksk.getContext(), convertSecretKey(ksk.getInputKey()), convertSecretKey(ksk.getOutputKey()), ksk.getLevels(), ksk.getBaseLog(), circuitKeys.getKeyswitchKeyIndex(ksk).value()); } TFHE::GLWEPackingKeyswitchKeyAttr convertPackingKeyswitchKey(TFHE::GLWEPackingKeyswitchKeyAttr pksk) { return TFHE::GLWEPackingKeyswitchKeyAttr::get( pksk.getContext(), convertSecretKey(pksk.getInputKey()), convertSecretKey(pksk.getOutputKey()), pksk.getOutputPolySize(), pksk.getInputLweDim(), pksk.getGlweDim(), pksk.getLevels(), pksk.getBaseLog(), circuitKeys.getPackingKeyswitchKeyIndex(pksk).value()); } private: mlir::concretelang::TFHE::TFHECircuitKeys circuitKeys; }; class TypeConverter : public mlir::TypeConverter { public: TypeConverter(KeyConverter &keyConverter) : keyConverter(keyConverter) { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { auto key = type.getKey(); if (key.isParameterized()) { return GLWECipherTextType::get(type.getContext(), keyConverter.convertSecretKey(key)); } else { return type; } }); addConversion([&](mlir::RankedTensorType type) { mlir::Type r = mlir::RankedTensorType::get( type.getShape(), this->convertType(type.getElementType())); return r; }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( this->convertType(type.dyn_cast() .getElementType())); }); addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( this->convertType(type.dyn_cast() .getElementType())); }); } private: KeyConverter keyConverter; }; } // namespace conversion namespace patterns { struct KeySwitchGLWEOpPattern : public mlir::OpRewritePattern { KeySwitchGLWEOpPattern(mlir::MLIRContext *context, conversion::TypeConverter &typeConverter, conversion::KeyConverter &keyConverter, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), keyConverter(keyConverter), typeConverter(typeConverter) {} mlir::LogicalResult matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp, mlir::PatternRewriter &rewriter) const override { auto newInputTy = typeConverter.convertType(ksOp.getCiphertext().getType()) .cast(); auto newOutputTy = typeConverter.convertType(ksOp.getResult().getType()); auto newKeyswitchKey = keyConverter.convertKeyswitchKey(ksOp.getKeyAttr()); auto newOp = rewriter.replaceOpWithNewOp( ksOp, newOutputTy, ksOp.getCiphertext(), newKeyswitchKey); rewriter.startRootUpdate(newOp); newOp.getCiphertext().setType(newInputTy); rewriter.finalizeRootUpdate(newOp); return mlir::success(); }; private: conversion::KeyConverter &keyConverter; conversion::TypeConverter &typeConverter; }; struct BootstrapGLWEOpPattern : public mlir::OpRewritePattern { BootstrapGLWEOpPattern(mlir::MLIRContext *context, conversion::TypeConverter &typeConverter, conversion::KeyConverter &keyConverter, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), keyConverter(keyConverter), typeConverter(typeConverter) {} mlir::LogicalResult matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, mlir::PatternRewriter &rewriter) const override { auto newInputTy = typeConverter.convertType(bsOp.getCiphertext().getType()) .cast(); auto newOutputTy = typeConverter.convertType(bsOp.getResult().getType()); auto newBootstrapKey = keyConverter.convertBootstrapKey(bsOp.getKeyAttr()); auto newOp = rewriter.replaceOpWithNewOp( bsOp, newOutputTy, bsOp.getCiphertext(), bsOp.getLookupTable(), newBootstrapKey); rewriter.startRootUpdate(newOp); newOp.getCiphertext().setType(newInputTy.cast()); rewriter.finalizeRootUpdate(newOp); return mlir::success(); }; private: conversion::KeyConverter &keyConverter; conversion::TypeConverter &typeConverter; }; struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern { WopPBSGLWEOpPattern(mlir::MLIRContext *context, conversion::TypeConverter &typeConverter, conversion::KeyConverter &keyConverter, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), keyConverter(keyConverter), typeConverter(typeConverter) {} mlir::LogicalResult matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp, mlir::PatternRewriter &rewriter) const override { auto newInputTy = typeConverter.convertType(wopPBSOp.getCiphertexts().getType()) .cast(); auto newOutputType = typeConverter.convertType(wopPBSOp.getType()); auto newKeyswitchKey = keyConverter.convertKeyswitchKey(wopPBSOp.getKskAttr()); auto newBootstrapKey = keyConverter.convertBootstrapKey(wopPBSOp.getBskAttr()); auto newPackingKeyswitchKey = keyConverter.convertPackingKeyswitchKey(wopPBSOp.getPkskAttr()); auto newOp = rewriter.replaceOpWithNewOp( wopPBSOp, newOutputType, wopPBSOp.getCiphertexts(), wopPBSOp.getLookupTable(), newKeyswitchKey, newBootstrapKey, newPackingKeyswitchKey, wopPBSOp.getCrtDecompositionAttr(), wopPBSOp.getCbsLevelsAttr(), wopPBSOp.getCbsBaseLogAttr()); rewriter.startRootUpdate(newOp); newOp.getCiphertexts().setType(newInputTy); rewriter.finalizeRootUpdate(newOp); return mlir::success(); }; private: conversion::KeyConverter &keyConverter; conversion::TypeConverter &typeConverter; }; } // namespace patterns namespace { struct TFHEKeyNormalizationPass : public TFHEKeyNormalizationBase { void runOnOperation() final; }; template void populateWithTFHEOpTypeConversionPattern( mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, mlir::TypeConverter &typeConverter) { patterns.add>( patterns.getContext(), typeConverter); target.addDynamicallyLegalOp( [&](Op op) { return typeConverter.isLegal(op->getResultTypes()); }); } /// Populate the RewritePatternSet with all patterns that rewrite Concrete /// operators to the corresponding function call to the `Concrete C API`. void populateWithTFHEOpTypeConversionPatterns( mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, mlir::TypeConverter &typeConverter) { populateWithTFHEOpTypeConversionPattern( patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern< mlir::concretelang::TFHE::ZeroTensorGLWEOp>(patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern< mlir::concretelang::TFHE::AddGLWEIntOp>(patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern( patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern< mlir::concretelang::TFHE::SubGLWEIntOp>(patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern( patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern< mlir::concretelang::TFHE::MulGLWEIntOp>(patterns, target, typeConverter); } } // namespace void TFHEKeyNormalizationPass::runOnOperation() { auto op = this->getOperation(); auto circuitKeys = TFHE::extractCircuitKeys(op); auto keyConverter = conversion::KeyConverter(circuitKeys); auto typeConverter = conversion::TypeConverter(keyConverter); // Parametrize { mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); // function signature target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return typeConverter.isSignatureLegal(funcOp.getFunctionType()) && typeConverter.isLegal(&funcOp.getBody()); }); target.addDynamicallyLegalOp( [&](mlir::func::ConstantOp op) { return FunctionConstantOpConversion< conversion::TypeConverter>::isLegal(op, typeConverter); }); patterns.add>( &getContext(), typeConverter); mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter); // Parametrize keyswitch target.addLegalOp(); patterns.add(&getContext(), typeConverter, keyConverter); target.addDynamicallyLegalOp( [&](TFHE::KeySwitchGLWEOp op) { return op.getKeyAttr().getInputKey().isNormalized() && op.getKeyAttr().getOutputKey().isNormalized() && op.getKeyAttr().getIndex() != -1; }); // Parametrize bootstrap patterns.add(&getContext(), typeConverter, keyConverter); target.addDynamicallyLegalOp( [&](TFHE::BootstrapGLWEOp op) { return op.getKeyAttr().getInputKey().isNormalized() && op.getKeyAttr().getOutputKey().isNormalized() && op.getKeyAttr().getIndex() != -1; }); // Parametrize wop pbs patterns.add(&getContext(), typeConverter, keyConverter); target.addDynamicallyLegalOp( [&](TFHE::WopPBSGLWEOp op) { return op.getKskAttr().getInputKey().isNormalized() && op.getKskAttr().getOutputKey().isNormalized() && op.getKskAttr().getIndex() != -1 && op.getBskAttr().getInputKey().isNormalized() && op.getBskAttr().getOutputKey().isNormalized() && op.getBskAttr().getIndex() != -1 && op.getPkskAttr().getInputKey().isNormalized() && op.getPkskAttr().getOutputKey().isNormalized() && op.getPkskAttr().getIndex() != -1; }); // Add all patterns to convert TFHE types populateWithTFHEOpTypeConversionPatterns(patterns, target, typeConverter); patterns.add>(&getContext(), typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::bufferization::AllocTensorOp>(target, typeConverter); patterns.add>( &getContext(), typeConverter); patterns.add>( &getContext(), typeConverter); patterns.add>( &getContext(), typeConverter); patterns.add>( &getContext(), typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); patterns.add>( &getContext(), typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp( target, typeConverter); mlir::concretelang::populateWithTensorTypeConverterPatterns( patterns, target, typeConverter); patterns.add< mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::Tracing::TraceCiphertextOp>, mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern>( &getContext(), typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::Tracing::TraceCiphertextOp>(target, typeConverter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, typeConverter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } } namespace mlir { namespace concretelang { std::unique_ptr> createTFHEKeyNormalizationPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir