mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler): Introduce MidLFHE dag parametrization
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
/// Create a pass to inject fhe parameters to the MidLFHE types and operators.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertMidLFHEGlobalParametrizationPass(
|
||||
mlir::zamalang::V0FHEContext &fheContext);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -10,9 +10,11 @@
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
|
||||
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h"
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "zamalang/Conversion/Passes.h.inc"
|
||||
|
||||
@@ -17,6 +17,13 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def MidLFHEGlobalParametrization : Pass<"midlfhe-global-parametrization", "mlir::ModuleOp"> {
|
||||
let summary = "Inject global fhe parameters to the MidLFHE dialect";
|
||||
let constructor = "mlir::zamalang::createConvertMidLFHEToLowLFHEPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::zamalang::MidLFHE::MidLFHEDialect"];
|
||||
}
|
||||
|
||||
def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the MidLFHE dialect to LowLFHE";
|
||||
let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }];
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(MidLFHEGlobalParametrization)
|
||||
add_subdirectory(MidLFHEToLowLFHE)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(LowLFHEToConcreteCAPI)
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
add_mlir_dialect_library(MidLFHEGlobalParametrization
|
||||
MidLFHEGlobalParametrization.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/MidLFHE
|
||||
|
||||
DEPENDS
|
||||
MidLFHEDialect
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(MidLFHEGlobalParametrization PUBLIC MLIRIR)
|
||||
@@ -0,0 +1,198 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
|
||||
namespace {
|
||||
struct MidLFHEGlobalParametrizationPass
|
||||
: public MidLFHEGlobalParametrizationBase<
|
||||
MidLFHEGlobalParametrizationPass> {
|
||||
MidLFHEGlobalParametrizationPass(mlir::zamalang::V0FHEContext &fheContext)
|
||||
: fheContext(fheContext){};
|
||||
void runOnOperation() final;
|
||||
mlir::zamalang::V0FHEContext &fheContext;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
using mlir::zamalang::MidLFHE::GLWECipherTextType;
|
||||
|
||||
/// MidLFHEGlobalParametrizationTypeConverter is a TypeConverter that transform
|
||||
/// `MidLFHE.gwle<{_,_,_}{p}>` to
|
||||
/// `MidLFHE.gwle<{glweSize,polynomialSize,bits}{p'}>`
|
||||
class MidLFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
MidLFHEGlobalParametrizationTypeConverter(
|
||||
mlir::zamalang::V0FHEContext &fheContext) {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
auto glweSize = fheContext.parameter.getNBigGlweSize();
|
||||
auto p = fheContext.constraint.p;
|
||||
if (type.getDimension() == glweSize && type.getP() == p) {
|
||||
return type;
|
||||
}
|
||||
return GLWECipherTextType::get(type.getContext(), glweSize,
|
||||
1 /*for the v0, is always lwe ciphertext*/,
|
||||
64 /*for the v0 we handle only q=64*/, p);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct MidLFHEOpTypeConversionPattern : public mlir::OpRewritePattern<Op> {
|
||||
MidLFHEOpTypeConversionPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &typeConverter,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit),
|
||||
typeConverter(typeConverter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::SmallVector<mlir::Type, 1> newResultTypes;
|
||||
if (typeConverter.convertTypes(op->getResultTypes(), newResultTypes)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Op>(op, newResultTypes, op->getOperands());
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
struct MidLFHEApplyLookupTableParametrizationPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable> {
|
||||
MidLFHEApplyLookupTableParametrizationPattern(
|
||||
mlir::MLIRContext *context, mlir::TypeConverter &typeConverter,
|
||||
mlir::zamalang::V0Parameter &v0Parameter,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
context, benefit),
|
||||
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::MidLFHE::ApplyLookupTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::SmallVector<mlir::Type, 1> newResultTypes;
|
||||
if (typeConverter.convertTypes(op->getResultTypes(), newResultTypes)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
mlir::SmallVector<mlir::NamedAttribute, 6> newAttributes{
|
||||
mlir::NamedAttribute(rewriter.getIdentifier("k"),
|
||||
rewriter.getI32IntegerAttr(v0Parameter.k)),
|
||||
mlir::NamedAttribute(
|
||||
rewriter.getIdentifier("polynomialSize"),
|
||||
rewriter.getI32IntegerAttr(v0Parameter.polynomialSize)),
|
||||
mlir::NamedAttribute(rewriter.getIdentifier("levelKS"),
|
||||
rewriter.getI32IntegerAttr(v0Parameter.ksLevel)),
|
||||
mlir::NamedAttribute(rewriter.getIdentifier("baseLogKS"),
|
||||
rewriter.getI32IntegerAttr(v0Parameter.ksLogBase)),
|
||||
mlir::NamedAttribute(rewriter.getIdentifier("levelBS"),
|
||||
rewriter.getI32IntegerAttr(v0Parameter.brLevel)),
|
||||
mlir::NamedAttribute(rewriter.getIdentifier("baseLogBS"),
|
||||
rewriter.getI32IntegerAttr(v0Parameter.brLogBase)),
|
||||
};
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
op, newResultTypes, op->getOperands(), newAttributes);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &typeConverter;
|
||||
mlir::zamalang::V0Parameter &v0Parameter;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
void populateWithMidLFHEOpTypeConversionPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
patterns.add<MidLFHEOpTypeConversionPattern<Op>>(patterns.getContext(),
|
||||
typeConverter);
|
||||
target.addDynamicallyLegalOp<Op>(
|
||||
[&](Op op) { return typeConverter.isLegal(op->getResultTypes()); });
|
||||
}
|
||||
|
||||
void populateWithMidLFHEApplyLookupTableParametrizationPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter,
|
||||
mlir::zamalang::V0Parameter &v0Parameter) {
|
||||
patterns.add<MidLFHEApplyLookupTableParametrizationPattern>(
|
||||
patterns.getContext(), typeConverter, v0Parameter);
|
||||
target.addDynamicallyLegalOp<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
[&](mlir::zamalang::MidLFHE::ApplyLookupTable op) {
|
||||
if (op.k() != v0Parameter.k ||
|
||||
op.polynomialSize() != v0Parameter.polynomialSize ||
|
||||
op.levelKS() != v0Parameter.ksLevel ||
|
||||
op.baseLogKS() != v0Parameter.ksLogBase ||
|
||||
op.levelBS() != v0Parameter.brLevel ||
|
||||
op.baseLogBS() != v0Parameter.brLogBase) {
|
||||
return false;
|
||||
}
|
||||
return typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
}
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateWithMidLFHEOpTypeConversionPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter,
|
||||
mlir::zamalang::V0Parameter &v0Parameter) {
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::AddGLWEIntOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::AddGLWEOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::SubIntGLWEOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEOpTypeConversionPattern<
|
||||
mlir::zamalang::MidLFHE::MulGLWEIntOp>(patterns, target, typeConverter);
|
||||
populateWithMidLFHEApplyLookupTableParametrizationPattern(
|
||||
patterns, target, typeConverter, v0Parameter);
|
||||
}
|
||||
|
||||
void MidLFHEGlobalParametrizationPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
MidLFHEGlobalParametrizationTypeConverter converter(fheContext);
|
||||
|
||||
// Make sure that no ops from `MidLFHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
|
||||
// Make sure func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
// Add all patterns required to lower all ops from `MidLFHE` to
|
||||
// `LowLFHE`
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter,
|
||||
fheContext.parameter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertMidLFHEGlobalParametrizationPass(
|
||||
mlir::zamalang::V0FHEContext &fheContext) {
|
||||
return std::make_unique<MidLFHEGlobalParametrizationPass>(fheContext);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -44,7 +44,9 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
pm,
|
||||
mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext),
|
||||
enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
|
||||
7
compiler/tests/RunJit/hlfhe_add_eint_int_6.mlir
Normal file
7
compiler/tests/RunJit/hlfhe_add_eint_int_6.mlir
Normal file
@@ -0,0 +1,7 @@
|
||||
// RUN: zamacompiler %s --run-jit --jit-args 10 --jit-args 54 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: 64
|
||||
func @main(%arg0: !HLFHE.eint<6>, %arg1: i7) -> !HLFHE.eint<6> {
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<6>, i7) -> (!HLFHE.eint<6>)
|
||||
return %1: !HLFHE.eint<6>
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: zamacompiler %s --run-jit --jit-args 12 --jit-args 30 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --run-jit --jit-args 100 --jit-args 27 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: 42
|
||||
// CHECK-LABEL: 127
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
Reference in New Issue
Block a user