enhance(compiler): Introduce MidLFHE dag parametrization

This commit is contained in:
Quentin Bourgerie
2021-08-16 17:10:21 +02:00
parent 70fb5fcd8e
commit 67f0fc0f45
9 changed files with 253 additions and 3 deletions

View File

@@ -0,0 +1,18 @@
#ifndef ZAMALANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
#define ZAMALANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
#include "mlir/Pass/Pass.h"
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
namespace mlir {
namespace zamalang {
/// Create a pass to inject fhe parameters to the MidLFHE types and operators.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertMidLFHEGlobalParametrizationPass(
mlir::zamalang::V0FHEContext &fheContext);
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -10,9 +10,11 @@
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h"
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#define GEN_PASS_CLASSES
#include "zamalang/Conversion/Passes.h.inc"

View File

@@ -17,6 +17,13 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def MidLFHEGlobalParametrization : Pass<"midlfhe-global-parametrization", "mlir::ModuleOp"> {
let summary = "Inject global fhe parameters to the MidLFHE dialect";
let constructor = "mlir::zamalang::createConvertMidLFHEToLowLFHEPass()";
let options = [];
let dependentDialects = ["mlir::zamalang::MidLFHE::MidLFHEDialect"];
}
def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
let summary = "Lowers operations from the MidLFHE dialect to LowLFHE";
let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }];

View File

@@ -1,4 +1,5 @@
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(MidLFHEGlobalParametrization)
add_subdirectory(MidLFHEToLowLFHE)
add_subdirectory(HLFHETensorOpsToLinalg)
add_subdirectory(LowLFHEToConcreteCAPI)

View File

@@ -0,0 +1,15 @@
add_mlir_dialect_library(MidLFHEGlobalParametrization
MidLFHEGlobalParametrization.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/MidLFHE
DEPENDS
MidLFHEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(MidLFHEGlobalParametrization PUBLIC MLIRIR)

View File

@@ -0,0 +1,198 @@
#include <iostream>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
namespace {
struct MidLFHEGlobalParametrizationPass
: public MidLFHEGlobalParametrizationBase<
MidLFHEGlobalParametrizationPass> {
MidLFHEGlobalParametrizationPass(mlir::zamalang::V0FHEContext &fheContext)
: fheContext(fheContext){};
void runOnOperation() final;
mlir::zamalang::V0FHEContext &fheContext;
};
} // namespace
using mlir::zamalang::MidLFHE::GLWECipherTextType;
/// MidLFHEGlobalParametrizationTypeConverter is a TypeConverter that transform
/// `MidLFHE.gwle<{_,_,_}{p}>` to
/// `MidLFHE.gwle<{glweSize,polynomialSize,bits}{p'}>`
class MidLFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
public:
MidLFHEGlobalParametrizationTypeConverter(
mlir::zamalang::V0FHEContext &fheContext) {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
auto glweSize = fheContext.parameter.getNBigGlweSize();
auto p = fheContext.constraint.p;
if (type.getDimension() == glweSize && type.getP() == p) {
return type;
}
return GLWECipherTextType::get(type.getContext(), glweSize,
1 /*for the v0, is always lwe ciphertext*/,
64 /*for the v0 we handle only q=64*/, p);
});
}
};
template <typename Op>
struct MidLFHEOpTypeConversionPattern : public mlir::OpRewritePattern<Op> {
MidLFHEOpTypeConversionPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<Op>(context, benefit),
typeConverter(typeConverter) {}
mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Type, 1> newResultTypes;
if (typeConverter.convertTypes(op->getResultTypes(), newResultTypes)
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<Op>(op, newResultTypes, op->getOperands());
return mlir::success();
};
private:
mlir::TypeConverter &typeConverter;
};
struct MidLFHEApplyLookupTableParametrizationPattern
: public mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable> {
MidLFHEApplyLookupTableParametrizationPattern(
mlir::MLIRContext *context, mlir::TypeConverter &typeConverter,
mlir::zamalang::V0Parameter &v0Parameter,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
context, benefit),
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::MidLFHE::ApplyLookupTable op,
mlir::PatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Type, 1> newResultTypes;
if (typeConverter.convertTypes(op->getResultTypes(), newResultTypes)
.failed()) {
return mlir::failure();
}
mlir::SmallVector<mlir::NamedAttribute, 6> newAttributes{
mlir::NamedAttribute(rewriter.getIdentifier("k"),
rewriter.getI32IntegerAttr(v0Parameter.k)),
mlir::NamedAttribute(
rewriter.getIdentifier("polynomialSize"),
rewriter.getI32IntegerAttr(v0Parameter.polynomialSize)),
mlir::NamedAttribute(rewriter.getIdentifier("levelKS"),
rewriter.getI32IntegerAttr(v0Parameter.ksLevel)),
mlir::NamedAttribute(rewriter.getIdentifier("baseLogKS"),
rewriter.getI32IntegerAttr(v0Parameter.ksLogBase)),
mlir::NamedAttribute(rewriter.getIdentifier("levelBS"),
rewriter.getI32IntegerAttr(v0Parameter.brLevel)),
mlir::NamedAttribute(rewriter.getIdentifier("baseLogBS"),
rewriter.getI32IntegerAttr(v0Parameter.brLogBase)),
};
rewriter.replaceOpWithNewOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
op, newResultTypes, op->getOperands(), newAttributes);
return mlir::success();
};
private:
mlir::TypeConverter &typeConverter;
mlir::zamalang::V0Parameter &v0Parameter;
};
template <typename Op>
void populateWithMidLFHEOpTypeConversionPattern(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
patterns.add<MidLFHEOpTypeConversionPattern<Op>>(patterns.getContext(),
typeConverter);
target.addDynamicallyLegalOp<Op>(
[&](Op op) { return typeConverter.isLegal(op->getResultTypes()); });
}
void populateWithMidLFHEApplyLookupTableParametrizationPattern(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter,
mlir::zamalang::V0Parameter &v0Parameter) {
patterns.add<MidLFHEApplyLookupTableParametrizationPattern>(
patterns.getContext(), typeConverter, v0Parameter);
target.addDynamicallyLegalOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
[&](mlir::zamalang::MidLFHE::ApplyLookupTable op) {
if (op.k() != v0Parameter.k ||
op.polynomialSize() != v0Parameter.polynomialSize ||
op.levelKS() != v0Parameter.ksLevel ||
op.baseLogKS() != v0Parameter.ksLogBase ||
op.levelBS() != v0Parameter.brLevel ||
op.baseLogBS() != v0Parameter.brLogBase) {
return false;
}
return typeConverter.isLegal(op->getResultTypes());
});
}
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
/// operators to the corresponding function call to the `Concrete C API`.
void populateWithMidLFHEOpTypeConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter,
mlir::zamalang::V0Parameter &v0Parameter) {
populateWithMidLFHEOpTypeConversionPattern<
mlir::zamalang::MidLFHE::AddGLWEIntOp>(patterns, target, typeConverter);
populateWithMidLFHEOpTypeConversionPattern<
mlir::zamalang::MidLFHE::AddGLWEOp>(patterns, target, typeConverter);
populateWithMidLFHEOpTypeConversionPattern<
mlir::zamalang::MidLFHE::SubIntGLWEOp>(patterns, target, typeConverter);
populateWithMidLFHEOpTypeConversionPattern<
mlir::zamalang::MidLFHE::MulGLWEIntOp>(patterns, target, typeConverter);
populateWithMidLFHEApplyLookupTableParametrizationPattern(
patterns, target, typeConverter, v0Parameter);
}
void MidLFHEGlobalParametrizationPass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
MidLFHEGlobalParametrizationTypeConverter converter(fheContext);
// Make sure that no ops from `MidLFHE` remain after the lowering
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
// Make sure func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
// Add all patterns required to lower all ops from `MidLFHE` to
// `LowLFHE`
mlir::OwningRewritePatternList patterns(&getContext());
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
fheContext.parameter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertMidLFHEGlobalParametrizationPass(
mlir::zamalang::V0FHEContext &fheContext) {
return std::make_unique<MidLFHEGlobalParametrizationPass>(fheContext);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -44,7 +44,9 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
pm,
mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext),
enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
addFilteredPassToPassManager(

View File

@@ -0,0 +1,7 @@
// RUN: zamacompiler %s --run-jit --jit-args 10 --jit-args 54 2>&1| FileCheck %s
// CHECK-LABEL: 64
func @main(%arg0: !HLFHE.eint<6>, %arg1: i7) -> !HLFHE.eint<6> {
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<6>, i7) -> (!HLFHE.eint<6>)
return %1: !HLFHE.eint<6>
}

View File

@@ -1,6 +1,6 @@
// RUN: zamacompiler %s --run-jit --jit-args 12 --jit-args 30 2>&1| FileCheck %s
// RUN: zamacompiler %s --run-jit --jit-args 100 --jit-args 27 2>&1| FileCheck %s
// CHECK-LABEL: 42
// CHECK-LABEL: 127
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>