mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): Pad constant tabulated lambda when the input precision of apply_lookup_table has been changed
This commit is contained in:
@@ -129,6 +129,70 @@ private:
|
||||
mlir::zamalang::V0Parameter &v0Parameter;
|
||||
};
|
||||
|
||||
struct MidLFHEApplyLookupTablePaddingPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable> {
|
||||
MidLFHEApplyLookupTablePaddingPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
context, benefit),
|
||||
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::MidLFHE::ApplyLookupTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto glweInType = op.getOperandTypes()[0]
|
||||
.cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
|
||||
auto tabulatedLambdaType =
|
||||
op.l_cst().getType().cast<mlir::RankedTensorType>();
|
||||
auto glweOutType =
|
||||
op.getType().cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
|
||||
auto expectedSize = 1 << glweInType.getP();
|
||||
if (tabulatedLambdaType.getShape()[0] < expectedSize) {
|
||||
auto constantOp =
|
||||
mlir::dyn_cast_or_null<mlir::ConstantOp>(op.l_cst().getDefiningOp());
|
||||
if (constantOp == nullptr) {
|
||||
op.emitError() << "padding for non-constant operator is NYI";
|
||||
return mlir::failure();
|
||||
}
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
constantOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
||||
if (denseVals == nullptr) {
|
||||
op.emitError() << "value should be dense";
|
||||
return mlir::failure();
|
||||
}
|
||||
// Create the new constant dense op with padding
|
||||
auto integerSize = 64;
|
||||
llvm::SmallVector<llvm::APInt> rawNewDenseVals(
|
||||
expectedSize, llvm::APInt(integerSize, 0));
|
||||
for (auto i = 0; i < expectedSize; i++) {
|
||||
rawNewDenseVals[i] = llvm::APInt(
|
||||
integerSize,
|
||||
denseVals.getFlatValue<llvm::APInt>(i % denseVals.size())
|
||||
.getZExtValue());
|
||||
}
|
||||
auto newDenseValsType = mlir::RankedTensorType::get(
|
||||
{expectedSize}, rewriter.getIntegerType(integerSize));
|
||||
auto newDenseVals =
|
||||
mlir::DenseIntElementsAttr::get(newDenseValsType, rawNewDenseVals);
|
||||
auto newConstantOp =
|
||||
rewriter.create<mlir::ConstantOp>(constantOp.getLoc(), newDenseVals);
|
||||
// Replace the apply_lookup_table with the new constant
|
||||
mlir::SmallVector<mlir::Type> newResultTypes{op.getType()};
|
||||
llvm::SmallVector<mlir::Value> newOperands{op.ct(), newConstantOp};
|
||||
llvm::ArrayRef<mlir::NamedAttribute> newAttrs = op->getAttrs();
|
||||
rewriter.replaceOpWithNewOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
op, newResultTypes, newOperands, newAttrs);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &typeConverter;
|
||||
mlir::zamalang::V0Parameter &v0Parameter;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
void populateWithMidLFHEOpTypeConversionPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
@@ -160,6 +224,24 @@ void populateWithMidLFHEApplyLookupTableParametrizationPattern(
|
||||
});
|
||||
}
|
||||
|
||||
void populateWithMidLFHEApplyLookupTablePaddingPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) {
|
||||
patterns.add<MidLFHEApplyLookupTablePaddingPattern>(patterns.getContext());
|
||||
target.addLegalOp<mlir::ConstantOp>();
|
||||
target.addDynamicallyLegalOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
[&](mlir::zamalang::MidLFHE::ApplyLookupTable op) {
|
||||
auto glweInType =
|
||||
op.getOperandTypes()[0]
|
||||
.cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
|
||||
auto tabulatedLambdaType =
|
||||
op.getOperandTypes()[1].cast<mlir::RankedTensorType>();
|
||||
auto glweOutType =
|
||||
op.getType().cast<mlir::zamalang::MidLFHE::GLWECipherTextType>();
|
||||
|
||||
return tabulatedLambdaType.getShape()[0] == 1 << glweInType.getP();
|
||||
});
|
||||
}
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateWithMidLFHEOpTypeConversionPatterns(
|
||||
@@ -183,31 +265,47 @@ void populateWithMidLFHEOpTypeConversionPatterns(
|
||||
void MidLFHEGlobalParametrizationPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
MidLFHEGlobalParametrizationTypeConverter converter(fheContext);
|
||||
|
||||
// Make sure that no ops from `MidLFHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
// Parametrize
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
// Make sure func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
// Add all patterns required to lower all ops from `MidLFHE` to
|
||||
// `LowLFHE`
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
|
||||
fheContext.parameter);
|
||||
patterns.add<LinalgGenericTypeConverterPattern<
|
||||
MidLFHEGlobalParametrizationTypeConverter>>(&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
// function signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
// Add all patterns to convert MidLFHE types
|
||||
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
|
||||
fheContext.parameter);
|
||||
patterns.add<LinalgGenericTypeConverterPattern<
|
||||
MidLFHEGlobalParametrizationTypeConverter>>(&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// Pad lookup table
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
populateWithMidLFHEApplyLookupTablePaddingPattern(patterns, target);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -414,6 +414,23 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
|
||||
ASSERT_EQ(res, 14);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTLU, identity_func_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<5>) -> !HLFHE.eint<5> {
|
||||
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64>
|
||||
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<5>, tensor<32xi64>) -> (!HLFHE.eint<5>)
|
||||
return %1: !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
uint64_t expected = 3;
|
||||
auto maybeResult = engine.run({expected});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, expected);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTLU, identity_func) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
Reference in New Issue
Block a user