From d94993ede24bbcee2c30dfd0fe3d5bcc3ad86e18 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 5 Apr 2023 16:47:34 +0200 Subject: [PATCH] fix(compiler): Replace in-place updated with conversion patterns when lowering TFHE to Concrete The current scheme with in-place updates of the types of values may result in operations recognized as legal and thus preventing them from being converted when the operations producing their operands have been converted earlier, as their types have been updated and legality is solely based on types. For example, the conversion pattern for an `tensor.insert_slice` operation working on tensors of encrypted values may not trigger if the operations producing its operands have been converted, leaving the operation with updated operand types with the extra dimension added by the type conversion from TFHE to Concrete, but with unmodified sizes, strides and offsets, not taking into account the extra dimension. This causes the verifier of the affected operation to fail and the compilation to abort. By using op conversion patterns, the original types of each operation are preserved during the actual rewrite, correctly triggering all conversion patterns based on the legality of data types. --- .../TFHEToConcrete/TFHEToConcrete.cpp | 96 ++++++++----------- 1 file changed, 42 insertions(+), 54 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index b0781e5e7..2e607ead3 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -10,6 +10,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/Dialects/SCF.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" @@ -320,16 +321,15 @@ struct TracePlaintextOpPattern }; template -struct ZeroOpPattern : public mlir::OpRewritePattern { - ZeroOpPattern(mlir::MLIRContext *context) - : mlir::OpRewritePattern( - context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} +struct ZeroOpPattern : public mlir::OpConversionPattern { + ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &converter) + : mlir::OpConversionPattern( + converter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult - matchAndRewrite(ZeroOp zeroOp, - mlir::PatternRewriter &rewriter) const override { - TFHEToConcreteTypeConverter converter; - auto newResultTy = converter.convertType(zeroOp.getType()); + matchAndRewrite(ZeroOp zeroOp, typename ZeroOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType()); auto generateBody = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, @@ -451,54 +451,43 @@ struct ExtractOpPattern }; }; -/// Pattern that rewrites the InsertSlice operation, taking into account the -/// additional LWE dimension introduced during type conversion -struct InsertSliceOpPattern - : public mlir::OpConversionPattern { +/// Pattern that rewrites the InsertSlice-like operation, taking into +/// account the additional LWE dimension introduced during type +/// conversion +template +struct InsertSliceOpPattern : public mlir::OpConversionPattern { InsertSliceOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) - : ::mlir::OpConversionPattern( + : ::mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, - mlir::tensor::InsertSliceOp::Adaptor adaptor, + matchAndRewrite(OpTy insertSliceOp, typename OpTy::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { - // is not a tensor of GLWEs that need to be extended with the LWE dimension - if (this->getTypeConverter()->isLegal(insertSliceOp.getType())) { - return mlir::failure(); - } + mlir::RankedTensorType newDestTy = ((mlir::Type)adaptor.getDest().getType()) + .cast(); - auto newResultTy = this->getTypeConverter() - ->convertType(insertSliceOp.getResult().getType()) - .cast(); + // add 0 to offsets + mlir::SmallVector offsets = getMixedValues( + adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter); + offsets.push_back(rewriter.getI64IntegerAttr(0)); - // add 0 to static offsets - mlir::SmallVector staticOffsets; - staticOffsets.append(adaptor.getStaticOffsets().begin(), - adaptor.getStaticOffsets().end()); - staticOffsets.push_back(0); - - // add lweDimension+1 to static_sizes - mlir::SmallVector staticSizes; - staticSizes.append(adaptor.getStaticSizes().begin(), - adaptor.getStaticSizes().end()); - staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1)); + // add lweDimension+1 to sizes + mlir::SmallVector sizes = + getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter); + sizes.push_back(rewriter.getI64IntegerAttr( + newDestTy.getDimSize(newDestTy.getRank() - 1))); // add 1 to the strides - mlir::SmallVector staticStrides; - staticStrides.append(adaptor.getStaticStrides().begin(), - adaptor.getStaticStrides().end()); - staticStrides.push_back(1); + mlir::SmallVector strides = getMixedValues( + adaptor.getStaticStrides(), adaptor.getStrides(), rewriter); + strides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.insert_slice with the new one - rewriter.replaceOpWithNewOp( - insertSliceOp, newResultTy, adaptor.getSource(), adaptor.getDest(), - adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(), - rewriter.getDenseI64ArrayAttr(staticOffsets), - rewriter.getDenseI64ArrayAttr(staticSizes), - rewriter.getDenseI64ArrayAttr(staticStrides)); + rewriter.replaceOpWithNewOp(insertSliceOp, adaptor.getSource(), + adaptor.getDest(), offsets, sizes, + strides); return ::mlir::success(); }; @@ -798,17 +787,18 @@ void TFHEToConcretePass::runOnOperation() { mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>( &getContext(), converter); // pattern of remaining TFHE ops + patterns.insert, - ZeroOpPattern>( - &getContext()); - patterns.insert, + SubIntGLWEOpPattern, BootstrapGLWEOpPattern, BatchedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern, BatchedKeySwitchGLWEOpPattern, WopPBSGLWEOpPattern>( &getContext(), converter); // Add patterns to rewrite tensor operators that works on tensors of TFHE GLWE // types - patterns.insert, InsertOpPattern, FromElementsOpPattern>(&getContext(), converter); // Add patterns to rewrite some of tensor ops that were introduced by the @@ -833,13 +823,11 @@ void TFHEToConcretePass::runOnOperation() { }); // rewrite scf for loops if working on illegal types - patterns.add>( - &getContext(), converter); - target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { - return converter.isLegal(forOp.getInitArgs().getTypes()) && - converter.isLegal(forOp.getResults().getTypes()); - }); + patterns.add>(&getContext(), converter); + + mlir::concretelang::addDynamicallyLegalTypeOp(target, + converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter);