diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index b0781e5e7..2e607ead3 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -10,6 +10,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/Dialects/SCF.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" @@ -320,16 +321,15 @@ struct TracePlaintextOpPattern }; template -struct ZeroOpPattern : public mlir::OpRewritePattern { - ZeroOpPattern(mlir::MLIRContext *context) - : mlir::OpRewritePattern( - context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} +struct ZeroOpPattern : public mlir::OpConversionPattern { + ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &converter) + : mlir::OpConversionPattern( + converter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult - matchAndRewrite(ZeroOp zeroOp, - mlir::PatternRewriter &rewriter) const override { - TFHEToConcreteTypeConverter converter; - auto newResultTy = converter.convertType(zeroOp.getType()); + matchAndRewrite(ZeroOp zeroOp, typename ZeroOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType()); auto generateBody = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, @@ -451,54 +451,43 @@ struct ExtractOpPattern }; }; -/// Pattern that rewrites the InsertSlice operation, taking into account the -/// additional LWE dimension introduced during type conversion -struct InsertSliceOpPattern - : public mlir::OpConversionPattern { +/// Pattern that rewrites the InsertSlice-like operation, taking into +/// account the additional LWE dimension introduced during type +/// conversion +template +struct InsertSliceOpPattern : public mlir::OpConversionPattern { InsertSliceOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) - : ::mlir::OpConversionPattern( + : ::mlir::OpConversionPattern( typeConverter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} ::mlir::LogicalResult - matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, - mlir::tensor::InsertSliceOp::Adaptor adaptor, + matchAndRewrite(OpTy insertSliceOp, typename OpTy::Adaptor adaptor, ::mlir::ConversionPatternRewriter &rewriter) const override { - // is not a tensor of GLWEs that need to be extended with the LWE dimension - if (this->getTypeConverter()->isLegal(insertSliceOp.getType())) { - return mlir::failure(); - } + mlir::RankedTensorType newDestTy = ((mlir::Type)adaptor.getDest().getType()) + .cast(); - auto newResultTy = this->getTypeConverter() - ->convertType(insertSliceOp.getResult().getType()) - .cast(); + // add 0 to offsets + mlir::SmallVector offsets = getMixedValues( + adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter); + offsets.push_back(rewriter.getI64IntegerAttr(0)); - // add 0 to static offsets - mlir::SmallVector staticOffsets; - staticOffsets.append(adaptor.getStaticOffsets().begin(), - adaptor.getStaticOffsets().end()); - staticOffsets.push_back(0); - - // add lweDimension+1 to static_sizes - mlir::SmallVector staticSizes; - staticSizes.append(adaptor.getStaticSizes().begin(), - adaptor.getStaticSizes().end()); - staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1)); + // add lweDimension+1 to sizes + mlir::SmallVector sizes = + getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter); + sizes.push_back(rewriter.getI64IntegerAttr( + newDestTy.getDimSize(newDestTy.getRank() - 1))); // add 1 to the strides - mlir::SmallVector staticStrides; - staticStrides.append(adaptor.getStaticStrides().begin(), - adaptor.getStaticStrides().end()); - staticStrides.push_back(1); + mlir::SmallVector strides = getMixedValues( + adaptor.getStaticStrides(), adaptor.getStrides(), rewriter); + strides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.insert_slice with the new one - rewriter.replaceOpWithNewOp( - insertSliceOp, newResultTy, adaptor.getSource(), adaptor.getDest(), - adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(), - rewriter.getDenseI64ArrayAttr(staticOffsets), - rewriter.getDenseI64ArrayAttr(staticSizes), - rewriter.getDenseI64ArrayAttr(staticStrides)); + rewriter.replaceOpWithNewOp(insertSliceOp, adaptor.getSource(), + adaptor.getDest(), offsets, sizes, + strides); return ::mlir::success(); }; @@ -798,17 +787,18 @@ void TFHEToConcretePass::runOnOperation() { mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>( &getContext(), converter); // pattern of remaining TFHE ops + patterns.insert, - ZeroOpPattern>( - &getContext()); - patterns.insert, + SubIntGLWEOpPattern, BootstrapGLWEOpPattern, BatchedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern, BatchedKeySwitchGLWEOpPattern, WopPBSGLWEOpPattern>( &getContext(), converter); // Add patterns to rewrite tensor operators that works on tensors of TFHE GLWE // types - patterns.insert, InsertOpPattern, FromElementsOpPattern>(&getContext(), converter); // Add patterns to rewrite some of tensor ops that were introduced by the @@ -833,13 +823,11 @@ void TFHEToConcretePass::runOnOperation() { }); // rewrite scf for loops if working on illegal types - patterns.add>( - &getContext(), converter); - target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { - return converter.isLegal(forOp.getInitArgs().getTypes()) && - converter.isLegal(forOp.getResults().getTypes()); - }); + patterns.add>(&getContext(), converter); + + mlir::concretelang::addDynamicallyLegalTypeOp(target, + converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter);