feat(compiler): Add support for scf.forall and associated ops in FHE to TFHE passes

This commit is contained in:
Andi Drebes
2023-04-05 17:06:58 +02:00
parent 2cd06580ee
commit d10c1ca576
2 changed files with 48 additions and 23 deletions

View File

@@ -7,6 +7,7 @@
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Operation.h>
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
@@ -882,35 +883,34 @@ struct ExtractSliceOpPattern
};
};
struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
template <typename OpTy>
struct InsertSliceOpPattern : public CrtOpPattern<OpTy> {
InsertSliceOpPattern(mlir::MLIRContext *context,
mlir::concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<mlir::tensor::InsertSliceOp>(context, params, benefit) {}
: CrtOpPattern<OpTy>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::InsertSliceOp op,
mlir::tensor::InsertSliceOp::Adaptor adaptor,
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {
// add 0 to offsets
mlir::SmallVector<mlir::OpFoldResult> offsets = getMixedValues(
adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter);
offsets.push_back(rewriter.getI64IntegerAttr(0));
mlir::TypeConverter *converter = this->getTypeConverter();
// add modulus of the CRT decomposition to sizes
mlir::SmallVector<mlir::OpFoldResult> sizes =
getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter);
sizes.push_back(rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
mlir::SmallVector<int64_t> newStaticOffsets{op.static_offsets()};
mlir::SmallVector<int64_t> newStaticSizes{op.static_sizes()};
mlir::SmallVector<int64_t> newStaticStrides{op.static_strides()};
newStaticOffsets.push_back(0);
newStaticSizes.push_back(this->loweringParameters.nMods);
newStaticStrides.push_back(1);
// add 1 to the strides
mlir::SmallVector<mlir::OpFoldResult> strides = getMixedValues(
adaptor.getStaticStrides(), adaptor.getStrides(), rewriter);
strides.push_back(rewriter.getI64IntegerAttr(1));
mlir::RankedTensorType newType =
converter->convertType(op.getResult().getType())
.template cast<mlir::RankedTensorType>();
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
op, newType, adaptor.getSource(), adaptor.getDest(),
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
rewriter.getDenseI64ArrayAttr(newStaticOffsets),
rewriter.getDenseI64ArrayAttr(newStaticSizes),
rewriter.getDenseI64ArrayAttr(newStaticStrides));
// replace insert slice-like operation with the new one
rewriter.replaceOpWithNewOp<OpTy>(
op, adaptor.getSource(), adaptor.getDest(), offsets, sizes, strides);
return mlir::success();
};
@@ -980,6 +980,7 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
op, converter);
});
target.addLegalOp<mlir::func::CallOp>();
target.addLegalOp<mlir::scf::InParallelOp>();
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::bufferization::AllocTensorOp>(target, converter);
@@ -1001,6 +1002,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
mlir::concretelang::Optimizer::PartitionFrontierOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::tensor::ParallelInsertSliceOp>(target, converter);
//---------------------------------------------------------- Adding patterns
mlir::RewritePatternSet patterns(&getContext());
@@ -1047,6 +1050,9 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
mlir::scf::YieldOp>>(patterns.getContext(), converter);
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForOp>>(&getContext(), converter);
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForallOp>>(&getContext(), converter);
patterns.add<lowering::TensorExtractOpPattern>(&getContext(),
loweringParameters);
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
@@ -1059,8 +1065,10 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
loweringParameters);
patterns.add<lowering::ExtractSliceOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<lowering::InsertSliceOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<
lowering::InsertSliceOpPattern<mlir::tensor::InsertSliceOp>,
lowering::InsertSliceOpPattern<mlir::tensor::ParallelInsertSliceOp>>(
patterns.getContext(), loweringParameters);
patterns.add<lowering::TraceCiphertextOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
@@ -1090,6 +1098,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target,
converter);
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForallOp>>(&getContext(), converter);
//--------------------------------------------------------- Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -793,6 +793,7 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
op, converter);
});
target.addLegalOp<mlir::func::CallOp>();
target.addLegalOp<mlir::scf::InParallelOp>();
//---------------------------------------------------------- Adding patterns
mlir::RewritePatternSet patterns(&getContext());
@@ -854,7 +855,14 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::GenerateOp, true>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForOp>>(&getContext(), converter);
mlir::scf::ForOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForallOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::EmptyOp>,
mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ParallelInsertSliceOp, true>>(&getContext(),
converter);
mlir::concretelang::populateWithTensorTypeConverterPatterns(
patterns, target, converter);
@@ -870,6 +878,13 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForallOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::tensor::ParallelInsertSliceOp>(target, converter);
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
&getContext(), converter);