// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RTOpConverter.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Support/Constants.h" namespace TFHE = mlir::concretelang::TFHE; namespace Tracing = mlir::concretelang::Tracing; using TFHE::GLWECipherTextType; /// Converts ciphertexts to plaintext integer types class SimulateTFHETypeConverter : public mlir::TypeConverter { public: SimulateTFHETypeConverter() { addConversion([](mlir::Type type) { return type; }); addConversion([&](GLWECipherTextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { return (mlir::Type)(type); } return (mlir::Type)mlir::RankedTensorType::get( type.getShape(), mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( this->convertType(type.dyn_cast() .getElementType())); }); addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( this->convertType(type.dyn_cast() .getElementType())); }); } }; namespace { mlir::RankedTensorType toDynamicTensorType(mlir::TensorType staticSizedTensor) { std::vector dynSizedShape(staticSizedTensor.getShape().size(), mlir::ShapedType::kDynamic); return mlir::RankedTensorType::get(dynSizedShape, staticSizedTensor.getElementType()); } struct NegOpPattern : public mlir::OpConversionPattern { NegOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::NegGLWEOp negOp, TFHE::NegGLWEOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { const std::string funcName = "sim_neg_lwe_u64"; if (insertForwardDeclaration( negOp, rewriter, funcName, rewriter.getFunctionType({rewriter.getIntegerType(64)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); } rewriter.replaceOpWithNewOp( negOp, funcName, mlir::TypeRange{rewriter.getIntegerType(64)}, mlir::ValueRange({adaptor.getA()})); return mlir::success(); } }; struct SubIntGLWEOpPattern : public mlir::OpRewritePattern { SubIntGLWEOpPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern( context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::SubGLWEIntOp subOp, mlir::PatternRewriter &rewriter) const override { mlir::Value negated = rewriter.create( subOp.getLoc(), subOp.getB().getType(), subOp.getB()); rewriter.replaceOpWithNewOp(subOp, subOp.getType(), negated, subOp.getA()); return mlir::success(); } }; struct EncodeExpandLutForBootstrapOpPattern : public mlir::OpConversionPattern { EncodeExpandLutForBootstrapOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::EncodeExpandLutForBootstrapOp eeOp, TFHE::EncodeExpandLutForBootstrapOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { const std::string funcName = "sim_encode_expand_lut_for_boostrap"; mlir::Value polySizeCst = rewriter.create( eeOp.getLoc(), eeOp.getPolySize(), 32); mlir::Value outputBitsCst = rewriter.create( eeOp.getLoc(), eeOp.getOutputBits(), 32); mlir::Value isSignedCst = rewriter.create( eeOp.getLoc(), eeOp.getIsSigned(), 1); mlir::Value outputBuffer = rewriter.create( eeOp.getLoc(), eeOp.getResult().getType().cast(), mlir::ValueRange{}); auto dynamicResultType = toDynamicTensorType(eeOp.getResult().getType()); auto dynamicLutType = toDynamicTensorType(eeOp.getInputLookupTable().getType()); mlir::Value castedOutputBuffer = rewriter.create( eeOp.getLoc(), dynamicResultType, outputBuffer); mlir::Value castedLUT = rewriter.create( eeOp.getLoc(), toDynamicTensorType(eeOp.getInputLookupTable().getType()), adaptor.getInputLookupTable()); // sim_encode_expand_lut_for_boostrap(uint64_t *out_allocated, uint64_t // *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t // out_stride, uint64_t *in_allocated, uint64_t *in_aligned, uint64_t // in_offset, uint64_t in_size, uint64_t in_stride, uint32_t poly_size, // uint32_t output_bits, bool is_signed) if (insertForwardDeclaration( eeOp, rewriter, funcName, rewriter.getFunctionType( {dynamicResultType, dynamicLutType, rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(1)}, {})) .failed()) { return mlir::failure(); } rewriter.create( eeOp.getLoc(), funcName, mlir::TypeRange{}, mlir::ValueRange({castedOutputBuffer, castedLUT, polySizeCst, outputBitsCst, isSignedCst})); rewriter.replaceOp(eeOp, outputBuffer); return mlir::success(); } }; struct EncodePlaintextWithCrtOpPattern : public mlir::OpConversionPattern { EncodePlaintextWithCrtOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::EncodePlaintextWithCrtOp epOp, TFHE::EncodePlaintextWithCrtOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { const std::string funcName = "sim_encode_plaintext_with_crt"; mlir::Value modsProductCst = rewriter.create( epOp.getLoc(), epOp.getModsProd(), 64); mlir::Value outputBuffer = rewriter.create( epOp.getLoc(), epOp.getResult().getType().cast(), mlir::ValueRange{}); // TODO: add mods if (insertForwardDeclaration( epOp, rewriter, funcName, rewriter.getFunctionType({epOp.getResult().getType(), epOp.getInput().getType() /*, mods here*/, rewriter.getI64Type()}, {})) .failed()) { return mlir::failure(); } rewriter.create( epOp.getLoc(), funcName, mlir::TypeRange{}, mlir::ValueRange({outputBuffer, adaptor.getInput(), modsProductCst})); rewriter.replaceOp(epOp, outputBuffer); return mlir::success(); } }; struct BootstrapGLWEOpPattern : public mlir::OpConversionPattern { BootstrapGLWEOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, TFHE::BootstrapGLWEOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { const std::string funcName = "sim_bootstrap_lwe_u64"; TFHE::GLWECipherTextType resultType = bsOp.getType().cast(); TFHE::GLWECipherTextType inputType = bsOp.getCiphertext().getType().cast(); auto polySize = adaptor.getKey().getPolySize(); auto glweDimension = adaptor.getKey().getGlweDim(); auto levels = adaptor.getKey().getLevels(); auto baseLog = adaptor.getKey().getBaseLog(); auto inputLweDimension = inputType.getKey().getNormalized().value().dimension; auto polySizeCst = rewriter.create( bsOp.getLoc(), polySize, 32); auto glweDimensionCst = rewriter.create( bsOp.getLoc(), glweDimension, 32); auto levelsCst = rewriter.create(bsOp.getLoc(), levels, 32); auto baseLogCst = rewriter.create(bsOp.getLoc(), baseLog, 32); auto inputLweDimensionCst = rewriter.create( bsOp.getLoc(), inputLweDimension, 32); auto dynamicLutType = toDynamicTensorType(bsOp.getLookupTable().getType()); mlir::Value castedLUT = rewriter.create( bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable()); // uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t // *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t // tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t // poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim) if (insertForwardDeclaration( bsOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), dynamicLutType, rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); } rewriter.replaceOpWithNewOp( bsOp, funcName, this->getTypeConverter()->convertType(resultType), mlir::ValueRange({adaptor.getCiphertext(), castedLUT, inputLweDimensionCst, polySizeCst, levelsCst, baseLogCst, glweDimensionCst})); return mlir::success(); } }; struct KeySwitchGLWEOpPattern : public mlir::OpConversionPattern { KeySwitchGLWEOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp, TFHE::KeySwitchGLWEOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { const std::string funcName = "sim_keyswitch_lwe_u64"; TFHE::GLWECipherTextType resultType = ksOp.getType().cast(); TFHE::GLWECipherTextType inputType = ksOp.getCiphertext().getType().cast(); auto levels = adaptor.getKey().getLevels(); auto baseLog = adaptor.getKey().getBaseLog(); auto inputDim = inputType.getKey().getNormalized().value().dimension; auto outputDim = resultType.getKey().getNormalized().value().dimension; mlir::Value levelCst = rewriter.create(ksOp.getLoc(), levels, 32); mlir::Value baseLogCst = rewriter.create(ksOp.getLoc(), baseLog, 32); mlir::Value inputDimCst = rewriter.create( ksOp.getLoc(), inputDim, 32); mlir::Value outputDimCst = rewriter.create( ksOp.getLoc(), outputDim, 32); // uint64_t sim_keyswitch_lwe_u64(uint64_t plaintext, uint32_t level, // uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim) if (insertForwardDeclaration( ksOp, rewriter, funcName, rewriter.getFunctionType( {rewriter.getIntegerType(64), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32)}, {rewriter.getIntegerType(64)})) .failed()) { return mlir::failure(); } rewriter.replaceOpWithNewOp( ksOp, funcName, this->getTypeConverter()->convertType(resultType), mlir::ValueRange({adaptor.getCiphertext(), levelCst, baseLogCst, inputDimCst, outputDimCst})); return mlir::success(); } }; struct ZeroOpPattern : public mlir::OpConversionPattern { ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::ZeroGLWEOp zeroOp, TFHE::ZeroGLWEOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType()); rewriter.replaceOpWithNewOp(zeroOp, 0, newResultTy); return ::mlir::success(); }; }; struct ZeroTensorOpPattern : public mlir::OpConversionPattern { ZeroTensorOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) : mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult matchAndRewrite(TFHE::ZeroTensorGLWEOp zeroTensorOp, TFHE::ZeroTensorGLWEOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto newResultTy = this->getTypeConverter()->convertType(zeroTensorOp.getType()); rewriter.replaceOpWithNewOp( zeroTensorOp, mlir::DenseElementsAttr::get(newResultTy, {mlir::APInt::getZero(64)}), newResultTy); return ::mlir::success(); }; }; struct SimulateTFHEPass : public SimulateTFHEBase { void runOnOperation() final; }; void SimulateTFHEPass::runOnOperation() { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); SimulateTFHETypeConverter converter; target.addLegalDialect(); target.addLegalOp(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); mlir::RewritePatternSet patterns(&getContext()); // Replace ops and convert operand and result types patterns.insert, mlir::concretelang::GenericOneToOneOpConversionPattern< TFHE::AddGLWEOp, mlir::arith::AddIOp>, mlir::concretelang::GenericOneToOneOpConversionPattern< TFHE::MulGLWEIntOp, mlir::arith::MulIOp>>(&getContext(), converter); // Convert operand and result types patterns.insert, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::scf::YieldOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::FromElementsOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::ExtractOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::ExtractSliceOp, true>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::InsertOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::InsertSliceOp, true>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::ExpandShapeOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::CollapseShapeOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::tensor::YieldOp>>(&getContext(), converter); // legalize ops only if operand and result types are legal target.addDynamicallyLegalOp< mlir::tensor::YieldOp, mlir::scf::YieldOp, mlir::tensor::GenerateOp, mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, mlir::tensor::InsertOp, mlir::tensor::InsertSliceOp, mlir::tensor::FromElementsOp, mlir::tensor::ExpandShapeOp, mlir::tensor::CollapseShapeOp, mlir::bufferization::AllocTensorOp>( [&](mlir::Operation *op) { return converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getOperandTypes()); }); // Make sure that no ops `linalg.generic` that have illegal types target .addDynamicallyLegalOp( [&](mlir::Operation *op) { return ( converter.isLegal(op->getOperandTypes()) && converter.isLegal(op->getResultTypes()) && converter.isLegal(op->getRegion(0).front().getArgumentTypes())); }); // Update scf::ForOp region with converted types patterns.add>( &getContext(), converter); target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { return converter.isLegal(forOp.getInitArgs().getTypes()) && converter.isLegal(forOp.getResults().getTypes()); }); patterns.insert(&getContext(), converter); patterns.insert(&getContext()); patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::scf::YieldOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::bufferization::AllocTensorOp, true>>(&getContext(), converter); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, converter); // Make sure that functions no longer operate on ciphertexts target.addDynamicallyLegalOp( [&](mlir::func::FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); target.addDynamicallyLegalOp( [&](mlir::func::ConstantOp op) { return FunctionConstantOpConversion::isLegal( op, converter); }); mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); patterns.insert>(&getContext(), converter); patterns.add>( &getContext(), converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } } } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createSimulateTFHEPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir