fix(compiler): Replace in-place updated with conversion patterns when lowering TFHE to Concrete

The current scheme with in-place updates of the types of values may
result in operations recognized as legal and thus preventing them from
being converted when the operations producing their operands have been
converted earlier, as their types have been updated and legality is
solely based on types.

For example, the conversion pattern for an `tensor.insert_slice`
operation working on tensors of encrypted values may not trigger if
the operations producing its operands have been converted, leaving the
operation with updated operand types with the extra dimension added by
the type conversion from TFHE to Concrete, but with unmodified sizes,
strides and offsets, not taking into account the extra dimension. This
causes the verifier of the affected operation to fail and the
compilation to abort.

By using op conversion patterns, the original types of each operation
are preserved during the actual rewrite, correctly triggering all
conversion patterns based on the legality of data types.
This commit is contained in:
Andi Drebes
2023-04-05 16:47:34 +02:00
committed by Antoniu Pop
parent 3ec17a74b6
commit d94993ede2

View File

@@ -10,6 +10,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
@@ -320,16 +321,15 @@ struct TracePlaintextOpPattern
};
template <typename ZeroOp>
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
ZeroOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<ZeroOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
struct ZeroOpPattern : public mlir::OpConversionPattern<ZeroOp> {
ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &converter)
: mlir::OpConversionPattern<ZeroOp>(
converter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(ZeroOp zeroOp,
mlir::PatternRewriter &rewriter) const override {
TFHEToConcreteTypeConverter converter;
auto newResultTy = converter.convertType(zeroOp.getType());
matchAndRewrite(ZeroOp zeroOp, typename ZeroOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType());
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
@@ -451,54 +451,43 @@ struct ExtractOpPattern
};
};
/// Pattern that rewrites the InsertSlice operation, taking into account the
/// additional LWE dimension introduced during type conversion
struct InsertSliceOpPattern
: public mlir::OpConversionPattern<mlir::tensor::InsertSliceOp> {
/// Pattern that rewrites the InsertSlice-like operation, taking into
/// account the additional LWE dimension introduced during type
/// conversion
template <typename OpTy>
struct InsertSliceOpPattern : public mlir::OpConversionPattern<OpTy> {
InsertSliceOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter)
: ::mlir::OpConversionPattern<mlir::tensor::InsertSliceOp>(
: ::mlir::OpConversionPattern<OpTy>(
typeConverter, context,
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp,
mlir::tensor::InsertSliceOp::Adaptor adaptor,
matchAndRewrite(OpTy insertSliceOp, typename OpTy::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {
// is not a tensor of GLWEs that need to be extended with the LWE dimension
if (this->getTypeConverter()->isLegal(insertSliceOp.getType())) {
return mlir::failure();
}
mlir::RankedTensorType newDestTy = ((mlir::Type)adaptor.getDest().getType())
.cast<mlir::RankedTensorType>();
auto newResultTy = this->getTypeConverter()
->convertType(insertSliceOp.getResult().getType())
.cast<mlir::RankedTensorType>();
// add 0 to offsets
mlir::SmallVector<mlir::OpFoldResult> offsets = getMixedValues(
adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter);
offsets.push_back(rewriter.getI64IntegerAttr(0));
// add 0 to static offsets
mlir::SmallVector<int64_t> staticOffsets;
staticOffsets.append(adaptor.getStaticOffsets().begin(),
adaptor.getStaticOffsets().end());
staticOffsets.push_back(0);
// add lweDimension+1 to static_sizes
mlir::SmallVector<int64_t> staticSizes;
staticSizes.append(adaptor.getStaticSizes().begin(),
adaptor.getStaticSizes().end());
staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1));
// add lweDimension+1 to sizes
mlir::SmallVector<mlir::OpFoldResult> sizes =
getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter);
sizes.push_back(rewriter.getI64IntegerAttr(
newDestTy.getDimSize(newDestTy.getRank() - 1)));
// add 1 to the strides
mlir::SmallVector<int64_t> staticStrides;
staticStrides.append(adaptor.getStaticStrides().begin(),
adaptor.getStaticStrides().end());
staticStrides.push_back(1);
mlir::SmallVector<mlir::OpFoldResult> strides = getMixedValues(
adaptor.getStaticStrides(), adaptor.getStrides(), rewriter);
strides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.insert_slice with the new one
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
insertSliceOp, newResultTy, adaptor.getSource(), adaptor.getDest(),
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
rewriter.getDenseI64ArrayAttr(staticOffsets),
rewriter.getDenseI64ArrayAttr(staticSizes),
rewriter.getDenseI64ArrayAttr(staticStrides));
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, adaptor.getSource(),
adaptor.getDest(), offsets, sizes,
strides);
return ::mlir::success();
};
@@ -798,17 +787,18 @@ void TFHEToConcretePass::runOnOperation() {
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>(
&getContext(), converter);
// pattern of remaining TFHE ops
patterns.insert<ZeroOpPattern<mlir::concretelang::TFHE::ZeroGLWEOp>,
ZeroOpPattern<mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(
&getContext());
patterns.insert<SubIntGLWEOpPattern, BootstrapGLWEOpPattern,
ZeroOpPattern<mlir::concretelang::TFHE::ZeroTensorGLWEOp>,
SubIntGLWEOpPattern, BootstrapGLWEOpPattern,
BatchedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern,
BatchedKeySwitchGLWEOpPattern, WopPBSGLWEOpPattern>(
&getContext(), converter);
// Add patterns to rewrite tensor operators that works on tensors of TFHE GLWE
// types
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern, InsertSliceOpPattern,
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
InsertSliceOpPattern<mlir::tensor::InsertSliceOp>,
InsertOpPattern, FromElementsOpPattern>(&getContext(),
converter);
// Add patterns to rewrite some of tensor ops that were introduced by the
@@ -833,13 +823,11 @@ void TFHEToConcretePass::runOnOperation() {
});
// rewrite scf for loops if working on illegal types
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
TFHEToConcreteTypeConverter>>(
&getContext(), converter);
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp forOp) {
return converter.isLegal(forOp.getInitArgs().getTypes()) &&
converter.isLegal(forOp.getResults().getTypes());
});
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
mlir::scf::ForOp>>(&getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
target, converter);