fix(compiler): Pad constant tabulated lambda when the input precision of apply_lookup_table has been changed

This commit is contained in:
Quentin Bourgerie
2021-09-27 09:32:55 +02:00
parent 6204f93878
commit 8f4da14bdb
2 changed files with 136 additions and 21 deletions

View File

@@ -129,6 +129,70 @@ private:
mlir::zamalang::V0Parameter &v0Parameter;
};
struct MidLFHEApplyLookupTablePaddingPattern
: public mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable> {
MidLFHEApplyLookupTablePaddingPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
context, benefit),
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::MidLFHE::ApplyLookupTable op,
mlir::PatternRewriter &rewriter) const override {
auto glweInType = op.getOperandTypes()[0]
.cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
auto tabulatedLambdaType =
op.l_cst().getType().cast<mlir::RankedTensorType>();
auto glweOutType =
op.getType().cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
auto expectedSize = 1 << glweInType.getP();
if (tabulatedLambdaType.getShape()[0] < expectedSize) {
auto constantOp =
mlir::dyn_cast_or_null<mlir::ConstantOp>(op.l_cst().getDefiningOp());
if (constantOp == nullptr) {
op.emitError() << "padding for non-constant operator is NYI";
return mlir::failure();
}
mlir::DenseIntElementsAttr denseVals =
constantOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
if (denseVals == nullptr) {
op.emitError() << "value should be dense";
return mlir::failure();
}
// Create the new constant dense op with padding
auto integerSize = 64;
llvm::SmallVector<llvm::APInt> rawNewDenseVals(
expectedSize, llvm::APInt(integerSize, 0));
for (auto i = 0; i < expectedSize; i++) {
rawNewDenseVals[i] = llvm::APInt(
integerSize,
denseVals.getFlatValue<llvm::APInt>(i % denseVals.size())
.getZExtValue());
}
auto newDenseValsType = mlir::RankedTensorType::get(
{expectedSize}, rewriter.getIntegerType(integerSize));
auto newDenseVals =
mlir::DenseIntElementsAttr::get(newDenseValsType, rawNewDenseVals);
auto newConstantOp =
rewriter.create<mlir::ConstantOp>(constantOp.getLoc(), newDenseVals);
// Replace the apply_lookup_table with the new constant
mlir::SmallVector<mlir::Type> newResultTypes{op.getType()};
llvm::SmallVector<mlir::Value> newOperands{op.ct(), newConstantOp};
llvm::ArrayRef<mlir::NamedAttribute> newAttrs = op->getAttrs();
rewriter.replaceOpWithNewOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
op, newResultTypes, newOperands, newAttrs);
return mlir::success();
}
return mlir::success();
};
private:
mlir::TypeConverter &typeConverter;
mlir::zamalang::V0Parameter &v0Parameter;
};
template <typename Op>
void populateWithMidLFHEOpTypeConversionPattern(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
@@ -160,6 +224,24 @@ void populateWithMidLFHEApplyLookupTableParametrizationPattern(
});
}
void populateWithMidLFHEApplyLookupTablePaddingPattern(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) {
patterns.add<MidLFHEApplyLookupTablePaddingPattern>(patterns.getContext());
target.addLegalOp<mlir::ConstantOp>();
target.addDynamicallyLegalOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
[&](mlir::zamalang::MidLFHE::ApplyLookupTable op) {
auto glweInType =
op.getOperandTypes()[0]
.cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
auto tabulatedLambdaType =
op.getOperandTypes()[1].cast<mlir::RankedTensorType>();
auto glweOutType =
op.getType().cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
return tabulatedLambdaType.getShape()[0] == 1 << glweInType.getP();
});
}
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
/// operators to the corresponding function call to the `Concrete C API`.
void populateWithMidLFHEOpTypeConversionPatterns(
@@ -183,31 +265,47 @@ void populateWithMidLFHEOpTypeConversionPatterns(
void MidLFHEGlobalParametrizationPass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
MidLFHEGlobalParametrizationTypeConverter converter(fheContext);
// Make sure that no ops from `MidLFHE` remain after the lowering
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
// Parametrize
{
mlir::ConversionTarget target(getContext());
mlir::OwningRewritePatternList patterns(&getContext());
// Make sure func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
// Add all patterns required to lower all ops from `MidLFHE` to
// `LowLFHE`
mlir::OwningRewritePatternList patterns(&getContext());
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
fheContext.parameter);
patterns.add<LinalgGenericTypeConverterPattern<
MidLFHEGlobalParametrizationTypeConverter>>(&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// function signature
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
// Add all patterns to convert MidLFHE types
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
fheContext.parameter);
patterns.add<LinalgGenericTypeConverterPattern<
MidLFHEGlobalParametrizationTypeConverter>>(&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
// Pad lookup table
{
mlir::ConversionTarget target(getContext());
mlir::OwningRewritePatternList patterns(&getContext());
populateWithMidLFHEApplyLookupTablePaddingPattern(patterns, target);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}

View File

@@ -414,6 +414,23 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
ASSERT_EQ(res, 14);
}
TEST(CompileAndRunTLU, identity_func_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%arg0: !HLFHE.eint<5>) -> !HLFHE.eint<5> {
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64>
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<5>, tensor<32xi64>) -> (!HLFHE.eint<5>)
return %1: !HLFHE.eint<5>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
uint64_t expected = 3;
auto maybeResult = engine.run({expected});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, expected);
}
TEST(CompileAndRunTLU, identity_func) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(