mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
fix(compiler): Replace in-place updated with conversion patterns when lowering TFHE to Concrete
The current scheme with in-place updates of the types of values may result in operations recognized as legal and thus preventing them from being converted when the operations producing their operands have been converted earlier, as their types have been updated and legality is solely based on types. For example, the conversion pattern for an `tensor.insert_slice` operation working on tensors of encrypted values may not trigger if the operations producing its operands have been converted, leaving the operation with updated operand types with the extra dimension added by the type conversion from TFHE to Concrete, but with unmodified sizes, strides and offsets, not taking into account the extra dimension. This causes the verifier of the affected operation to fail and the compilation to abort. By using op conversion patterns, the original types of each operation are preserved during the actual rewrite, correctly triggering all conversion patterns based on the legality of data types.
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
@@ -320,16 +321,15 @@ struct TracePlaintextOpPattern
|
||||
};
|
||||
|
||||
template <typename ZeroOp>
|
||||
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
ZeroOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<ZeroOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
struct ZeroOpPattern : public mlir::OpConversionPattern<ZeroOp> {
|
||||
ZeroOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &converter)
|
||||
: mlir::OpConversionPattern<ZeroOp>(
|
||||
converter, context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(ZeroOp zeroOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
TFHEToConcreteTypeConverter converter;
|
||||
auto newResultTy = converter.convertType(zeroOp.getType());
|
||||
matchAndRewrite(ZeroOp zeroOp, typename ZeroOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
auto newResultTy = this->getTypeConverter()->convertType(zeroOp.getType());
|
||||
|
||||
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
@@ -451,54 +451,43 @@ struct ExtractOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// Pattern that rewrites the InsertSlice operation, taking into account the
|
||||
/// additional LWE dimension introduced during type conversion
|
||||
struct InsertSliceOpPattern
|
||||
: public mlir::OpConversionPattern<mlir::tensor::InsertSliceOp> {
|
||||
/// Pattern that rewrites the InsertSlice-like operation, taking into
|
||||
/// account the additional LWE dimension introduced during type
|
||||
/// conversion
|
||||
template <typename OpTy>
|
||||
struct InsertSliceOpPattern : public mlir::OpConversionPattern<OpTy> {
|
||||
InsertSliceOpPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &typeConverter)
|
||||
: ::mlir::OpConversionPattern<mlir::tensor::InsertSliceOp>(
|
||||
: ::mlir::OpConversionPattern<OpTy>(
|
||||
typeConverter, context,
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp,
|
||||
mlir::tensor::InsertSliceOp::Adaptor adaptor,
|
||||
matchAndRewrite(OpTy insertSliceOp, typename OpTy::Adaptor adaptor,
|
||||
::mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// is not a tensor of GLWEs that need to be extended with the LWE dimension
|
||||
if (this->getTypeConverter()->isLegal(insertSliceOp.getType())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
mlir::RankedTensorType newDestTy = ((mlir::Type)adaptor.getDest().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
|
||||
auto newResultTy = this->getTypeConverter()
|
||||
->convertType(insertSliceOp.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
// add 0 to offsets
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets = getMixedValues(
|
||||
adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter);
|
||||
offsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add 0 to static offsets
|
||||
mlir::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.append(adaptor.getStaticOffsets().begin(),
|
||||
adaptor.getStaticOffsets().end());
|
||||
staticOffsets.push_back(0);
|
||||
|
||||
// add lweDimension+1 to static_sizes
|
||||
mlir::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.append(adaptor.getStaticSizes().begin(),
|
||||
adaptor.getStaticSizes().end());
|
||||
staticSizes.push_back(newResultTy.getDimSize(newResultTy.getRank() - 1));
|
||||
// add lweDimension+1 to sizes
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes =
|
||||
getMixedValues(adaptor.getStaticSizes(), adaptor.getSizes(), rewriter);
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newDestTy.getDimSize(newDestTy.getRank() - 1)));
|
||||
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.append(adaptor.getStaticStrides().begin(),
|
||||
adaptor.getStaticStrides().end());
|
||||
staticStrides.push_back(1);
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides = getMixedValues(
|
||||
adaptor.getStaticStrides(), adaptor.getStrides(), rewriter);
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
insertSliceOp, newResultTy, adaptor.getSource(), adaptor.getDest(),
|
||||
adaptor.getOffsets(), adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getDenseI64ArrayAttr(staticOffsets),
|
||||
rewriter.getDenseI64ArrayAttr(staticSizes),
|
||||
rewriter.getDenseI64ArrayAttr(staticStrides));
|
||||
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, adaptor.getSource(),
|
||||
adaptor.getDest(), offsets, sizes,
|
||||
strides);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
@@ -798,17 +787,18 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>(
|
||||
&getContext(), converter);
|
||||
// pattern of remaining TFHE ops
|
||||
|
||||
patterns.insert<ZeroOpPattern<mlir::concretelang::TFHE::ZeroGLWEOp>,
|
||||
ZeroOpPattern<mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(
|
||||
&getContext());
|
||||
patterns.insert<SubIntGLWEOpPattern, BootstrapGLWEOpPattern,
|
||||
ZeroOpPattern<mlir::concretelang::TFHE::ZeroTensorGLWEOp>,
|
||||
SubIntGLWEOpPattern, BootstrapGLWEOpPattern,
|
||||
BatchedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern,
|
||||
BatchedKeySwitchGLWEOpPattern, WopPBSGLWEOpPattern>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on tensors of TFHE GLWE
|
||||
// types
|
||||
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern, InsertSliceOpPattern,
|
||||
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
|
||||
InsertSliceOpPattern<mlir::tensor::InsertSliceOp>,
|
||||
InsertOpPattern, FromElementsOpPattern>(&getContext(),
|
||||
converter);
|
||||
// Add patterns to rewrite some of tensor ops that were introduced by the
|
||||
@@ -833,13 +823,11 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
});
|
||||
|
||||
// rewrite scf for loops if working on illegal types
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
TFHEToConcreteTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp forOp) {
|
||||
return converter.isLegal(forOp.getInitArgs().getTypes()) &&
|
||||
converter.isLegal(forOp.getResults().getTypes());
|
||||
});
|
||||
patterns.add<mlir::concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::scf::ForOp>>(&getContext(), converter);
|
||||
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
|
||||
converter);
|
||||
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
|
||||
target, converter);
|
||||
|
||||
Reference in New Issue
Block a user