mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Add support for scf.forall and associated ops in FHE to TFHE passes
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
@@ -882,35 +883,34 @@ struct ExtractSliceOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
|
||||
template <typename OpTy>
|
||||
struct InsertSliceOpPattern : public CrtOpPattern<OpTy> {
|
||||
InsertSliceOpPattern(mlir::MLIRContext *context,
|
||||
mlir::concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<mlir::tensor::InsertSliceOp>(context, params, benefit) {}
|
||||
: CrtOpPattern<OpTy>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::InsertSliceOp op,
|
||||
mlir::tensor::InsertSliceOp::Adaptor adaptor,
|
||||
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||
::mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// add 0 to offsets
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets = getMixedValues(
|
||||
adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter);
|
||||
offsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
// add modulus of the CRT decomposition to sizes
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes =
|
||||
getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter);
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
|
||||
|
||||
mlir::SmallVector<int64_t> newStaticOffsets{op.static_offsets()};
|
||||
mlir::SmallVector<int64_t> newStaticSizes{op.static_sizes()};
|
||||
mlir::SmallVector<int64_t> newStaticStrides{op.static_strides()};
|
||||
newStaticOffsets.push_back(0);
|
||||
newStaticSizes.push_back(this->loweringParameters.nMods);
|
||||
newStaticStrides.push_back(1);
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides = getMixedValues(
|
||||
adaptor.getStaticStrides(), adaptor.getStrides(), rewriter);
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
mlir::RankedTensorType newType =
|
||||
converter->convertType(op.getResult().getType())
|
||||
.template cast<mlir::RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
op, newType, adaptor.getSource(), adaptor.getDest(),
|
||||
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(newStaticStrides));
|
||||
// replace insert slice-like operation with the new one
|
||||
rewriter.replaceOpWithNewOp<OpTy>(
|
||||
op, adaptor.getSource(), adaptor.getDest(), offsets, sizes, strides);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -980,6 +980,7 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
op, converter);
|
||||
});
|
||||
target.addLegalOp<mlir::func::CallOp>();
|
||||
target.addLegalOp<mlir::scf::InParallelOp>();
|
||||
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::bufferization::AllocTensorOp>(target, converter);
|
||||
@@ -1001,6 +1002,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
mlir::concretelang::Optimizer::PartitionFrontierOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::tensor::ParallelInsertSliceOp>(target, converter);
|
||||
|
||||
//---------------------------------------------------------- Adding patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
@@ -1047,6 +1050,9 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
mlir::scf::YieldOp>>(patterns.getContext(), converter);
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForOp>>(&getContext(), converter);
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForallOp>>(&getContext(), converter);
|
||||
|
||||
patterns.add<lowering::TensorExtractOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
|
||||
@@ -1059,8 +1065,10 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
loweringParameters);
|
||||
patterns.add<lowering::ExtractSliceOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::InsertSliceOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<
|
||||
lowering::InsertSliceOpPattern<mlir::tensor::InsertSliceOp>,
|
||||
lowering::InsertSliceOpPattern<mlir::tensor::ParallelInsertSliceOp>>(
|
||||
patterns.getContext(), loweringParameters);
|
||||
patterns.add<lowering::TraceCiphertextOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
@@ -1090,6 +1098,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
|
||||
mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForallOp>>(&getContext(), converter);
|
||||
|
||||
//--------------------------------------------------------- Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -793,6 +793,7 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
op, converter);
|
||||
});
|
||||
target.addLegalOp<mlir::func::CallOp>();
|
||||
target.addLegalOp<mlir::scf::InParallelOp>();
|
||||
|
||||
//---------------------------------------------------------- Adding patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
@@ -854,7 +855,14 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::GenerateOp, true>,
|
||||
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForOp>>(&getContext(), converter);
|
||||
mlir::scf::ForOp>,
|
||||
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForallOp>,
|
||||
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::EmptyOp>,
|
||||
mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::ParallelInsertSliceOp, true>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(
|
||||
patterns, target, converter);
|
||||
|
||||
@@ -870,6 +878,13 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForallOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::EmptyOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::tensor::ParallelInsertSliceOp>(target, converter);
|
||||
|
||||
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user