feat(multiprecision): enable real multiple precision computation

This commit is contained in:
rudy
2022-08-10 00:11:05 +02:00
committed by Quentin Bourgerie
parent 80f36c14de
commit bd3d462384
4 changed files with 85 additions and 116 deletions

View File

@@ -99,8 +99,8 @@ struct ApplyLookupTableEintOpPattern
.cast<TFHE::GLWECipherTextType>();
auto resultTy = converter.convertType(lutOp.getType());
// %glwe_lut = "TFHE.glwe_from_table"(%lut)
auto glweLut = rewriter.create<TFHE::GLWEFromTableOp>(lutOp.getLoc(),
inputTy, lutOp.lut());
auto glweLut = rewriter.create<TFHE::GLWEFromTableOp>(
lutOp.getLoc(), resultTy, lutOp.lut());
// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
auto glweKs = rewriter.create<TFHE::KeySwitchGLWEOp>(
lutOp.getLoc(), inputTy, lutOp.a(), -1, -1);

View File

@@ -21,10 +21,11 @@ namespace TFHE = mlir::concretelang::TFHE;
namespace {
struct TFHEGlobalParametrizationPass
: public TFHEGlobalParametrizationBase<TFHEGlobalParametrizationPass> {
TFHEGlobalParametrizationPass(mlir::concretelang::V0FHEContext &fheContext)
: fheContext(fheContext){};
TFHEGlobalParametrizationPass(
mlir::concretelang::V0Parameter &cryptoParameters)
: cryptoParameters(cryptoParameters){};
void runOnOperation() final;
mlir::concretelang::V0FHEContext &fheContext;
mlir::concretelang::V0Parameter &cryptoParameters;
};
} // namespace
@@ -37,114 +38,120 @@ class TFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter {
public:
TFHEGlobalParametrizationTypeConverter(
mlir::concretelang::V0FHEContext &fheContext)
: fheContext(fheContext) {
auto convertGLWECiphertextType =
[&](GLWECipherTextType type,
mlir::concretelang::V0FHEContext &fheContext) {
auto newTy = this->glweInterPBSType(type.getContext(), fheContext);
if (newTy.getDimension() == type.getDimension() &&
newTy.getPolynomialSize() == type.getPolynomialSize() &&
newTy.getP() == type.getP())
return type;
return newTy;
};
mlir::concretelang::V0Parameter &cryptoParameters)
: cryptoParameters(cryptoParameters) {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
return convertGLWECiphertextType(type, fheContext);
});
addConversion(
[&](GLWECipherTextType type) { return this->glweInterPBSType(type); });
addConversion([&](mlir::RankedTensorType type) {
auto glwe = type.getElementType().dyn_cast_or_null<GLWECipherTextType>();
if (glwe == nullptr) {
return (mlir::Type)(type);
}
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(), convertGLWECiphertextType(glwe, fheContext));
mlir::Type r = mlir::RankedTensorType::get(type.getShape(),
this->glweInterPBSType(glwe));
return r;
});
}
TFHE::GLWECipherTextType
glweInterPBSType(mlir::MLIRContext *context,
mlir::concretelang::V0FHEContext fheContext) {
return TFHE::GLWECipherTextType::get(
context, fheContext.parameter.getNBigGlweDimension(), 1, 64,
fheContext.constraint.p);
TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) {
auto bits = 64;
auto dimension = cryptoParameters.getNBigGlweDimension();
auto polynomialSize = 1;
auto precision = (signed)type.getP();
if ((int)dimension == type.getDimension() &&
(int)polynomialSize == type.getPolynomialSize()) {
return type;
}
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision);
}
TFHE::GLWECipherTextType glweLookupTableType(mlir::MLIRContext *context) {
return TFHE::GLWECipherTextType::get(
context, fheContext.parameter.glweDimension,
fheContext.parameter.getPolynomialSize(), 64, fheContext.constraint.p);
TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) {
auto bits = 64;
auto dimension = cryptoParameters.glweDimension;
auto polynomialSize = cryptoParameters.getPolynomialSize();
auto precision = (signed)type.getP();
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision);
}
TFHE::GLWECipherTextType glweIntraPBSType(mlir::MLIRContext *context) {
return TFHE::GLWECipherTextType::get(context, fheContext.parameter.nSmall,
1, 64, fheContext.constraint.p);
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
auto bits = 64;
auto dimension = cryptoParameters.nSmall;
auto polynomialSize = 1;
auto precision = (signed)type.getP();
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision);
}
mlir::concretelang::V0FHEContext fheContext;
mlir::concretelang::V0Parameter cryptoParameters;
};
struct KeySwitchGLWEOpPattern
: public mlir::OpRewritePattern<TFHE::KeySwitchGLWEOp> {
KeySwitchGLWEOpPattern(mlir::MLIRContext *context,
TFHEGlobalParametrizationTypeConverter &converter,
mlir::concretelang::V0FHEContext &fheContext,
mlir::concretelang::V0Parameter &cryptoParameters,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::KeySwitchGLWEOp>(context, benefit),
converter(converter), fheContext(fheContext) {}
converter(converter), cryptoParameters(cryptoParameters) {}
mlir::LogicalResult
matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp,
mlir::PatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Type, 1> newResultTypes;
auto inputTy = ksOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
auto outputTy = converter.glweIntraPBSType(rewriter.getContext());
auto newInputTy = converter.convertType(inputTy);
auto outputTy = ksOp.result().getType().cast<TFHE::GLWECipherTextType>();
auto newOutputTy = converter.glweIntraPBSType(outputTy);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::KeySwitchGLWEOp>(
ksOp, outputTy, ksOp.ciphertext(), fheContext.parameter.ksLevel,
fheContext.parameter.ksLogBase);
ksOp, newOutputTy, ksOp.ciphertext(), cryptoParameters.ksLevel,
cryptoParameters.ksLogBase);
rewriter.startRootUpdate(newOp);
newOp.ciphertext().setType(converter.convertType(inputTy));
newOp.ciphertext().setType(newInputTy);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
private:
TFHEGlobalParametrizationTypeConverter &converter;
mlir::concretelang::V0FHEContext &fheContext;
mlir::concretelang::V0Parameter &cryptoParameters;
};
struct BootstrapGLWEOpPattern
: public mlir::OpRewritePattern<TFHE::BootstrapGLWEOp> {
BootstrapGLWEOpPattern(mlir::MLIRContext *context,
TFHEGlobalParametrizationTypeConverter &converter,
mlir::concretelang::V0FHEContext &fheContext,
mlir::concretelang::V0Parameter &cryptoParameters,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::BootstrapGLWEOp>(context, benefit),
converter(converter), fheContext(fheContext) {}
converter(converter), cryptoParameters(cryptoParameters) {}
mlir::LogicalResult
matchAndRewrite(TFHE::BootstrapGLWEOp bsOp,
mlir::PatternRewriter &rewriter) const override {
auto inputTy = bsOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
auto newInputTy = converter.glweIntraPBSType(inputTy);
auto outputTy = bsOp.result().getType().cast<TFHE::GLWECipherTextType>();
auto newOutputTy = converter.convertType(outputTy);
auto tableTy =
bsOp.lookup_table().getType().cast<TFHE::GLWECipherTextType>();
auto newTableTy = converter.glweLookupTableType(tableTy);
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
bsOp, converter.convertType(bsOp.result().getType()), bsOp.ciphertext(),
bsOp.lookup_table(), fheContext.parameter.brLevel,
fheContext.parameter.brLogBase);
bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(),
cryptoParameters.brLevel, cryptoParameters.brLogBase);
rewriter.startRootUpdate(newOp);
newOp.ciphertext().setType(
converter.glweIntraPBSType(rewriter.getContext()));
newOp.lookup_table().setType(
converter.glweLookupTableType(rewriter.getContext()));
newOp.ciphertext().setType(newInputTy);
newOp.lookup_table().setType(newTableTy);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
private:
TFHEGlobalParametrizationTypeConverter &converter;
mlir::concretelang::V0FHEContext &fheContext;
mlir::concretelang::V0Parameter &cryptoParameters;
};
/// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by
@@ -170,58 +177,24 @@ struct GLWEFromTablePattern
: public mlir::OpRewritePattern<TFHE::GLWEFromTableOp> {
GLWEFromTablePattern(mlir::MLIRContext *context,
TFHEGlobalParametrizationTypeConverter &converter,
mlir::concretelang::V0FHEContext &fheContext,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::GLWEFromTableOp>(context, benefit),
converter(converter), fheContext(fheContext) {}
converter(converter) {}
mlir::LogicalResult
matchAndRewrite(TFHE::GLWEFromTableOp glweOp,
mlir::PatternRewriter &rewriter) const override {
auto newTy = converter.glweLookupTableType(glweOp.getContext());
auto lutOp = glweOp.table();
auto tableTy = lutOp.getType().cast<mlir::RankedTensorType>();
auto expectedSize = 1 << newTy.getP();
if (tableTy.getShape()[0] < expectedSize) {
// Create a new padded lookup table
auto constantOp = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
lutOp.getDefiningOp());
if (constantOp == nullptr) {
glweOp.emitError() << "padding for non-constant operator is NYI";
return mlir::failure();
}
mlir::DenseIntElementsAttr denseVals =
constantOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
if (denseVals == nullptr) {
constantOp.emitError() << "value should be dense";
return mlir::failure();
}
auto integerSize = 64;
llvm::SmallVector<llvm::APInt> rawNewDenseVals(
expectedSize, llvm::APInt(integerSize, 0));
auto denseValsAP = denseVals.getValues<llvm::APInt>();
for (auto i = 0; i < expectedSize; i++) {
rawNewDenseVals[i] = llvm::APInt(
integerSize, denseValsAP[i % denseVals.size()].getZExtValue());
}
auto newDenseValsType = mlir::RankedTensorType::get(
{expectedSize}, rewriter.getIntegerType(integerSize));
auto newDenseVals =
mlir::DenseIntElementsAttr::get(newDenseValsType, rawNewDenseVals);
// Replace the lutOp by the new padded lookup table
lutOp = rewriter.create<mlir::arith::ConstantOp>(constantOp.getLoc(),
newDenseVals);
}
rewriter.replaceOpWithNewOp<TFHE::GLWEFromTableOp>(glweOp, newTy, lutOp);
auto outputTy = glweOp.result().getType().cast<TFHE::GLWECipherTextType>();
auto newOutputTy = converter.glweLookupTableType(outputTy);
auto tableOp = glweOp.table();
rewriter.replaceOpWithNewOp<TFHE::GLWEFromTableOp>(glweOp, newOutputTy,
tableOp);
return mlir::success();
};
private:
TFHEGlobalParametrizationTypeConverter &converter;
mlir::concretelang::V0FHEContext &fheContext;
};
template <typename Op>
@@ -239,8 +212,7 @@ void populateWithTFHEOpTypeConversionPattern(
/// operators to the corresponding function call to the `Concrete C API`.
void populateWithTFHEOpTypeConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter,
mlir::concretelang::V0Parameter &v0Parameter) {
mlir::TypeConverter &typeConverter) {
populateWithTFHEOpTypeConversionPattern<mlir::concretelang::TFHE::ZeroGLWEOp>(
patterns, target, typeConverter);
populateWithTFHEOpTypeConversionPattern<
@@ -261,7 +233,7 @@ void populateWithTFHEOpTypeConversionPatterns(
void TFHEGlobalParametrizationPass::runOnOperation() {
auto op = this->getOperation();
TFHEGlobalParametrizationTypeConverter converter(fheContext);
TFHEGlobalParametrizationTypeConverter converter(cryptoParameters);
// Parametrize
{
@@ -278,7 +250,7 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
patterns, converter);
// Parametrize keyswitch bootstrap
patterns.add<GLWEFromTablePattern>(&getContext(), converter, fheContext);
patterns.add<GLWEFromTablePattern>(&getContext(), converter);
target.addDynamicallyLegalOp<TFHE::GLWEFromTableOp>(
[&](TFHE::GLWEFromTableOp op) {
return !op.getType()
@@ -286,20 +258,21 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
.hasUnparametrizedParameters();
});
target.addLegalOp<mlir::arith::ConstantOp>();
patterns.add<KeySwitchGLWEOpPattern>(&getContext(), converter, fheContext);
patterns.add<KeySwitchGLWEOpPattern>(&getContext(), converter,
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::KeySwitchGLWEOp>(
[&](TFHE::KeySwitchGLWEOp op) {
return op.level() != (uint32_t)-1 && op.baseLog() != (uint32_t)-1;
});
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter, fheContext);
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter,
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::BootstrapGLWEOp>(
[&](TFHE::BootstrapGLWEOp op) {
return converter.isLegal(op->getResultTypes());
});
// Add all patterns to convert TFHE types
populateWithTFHEOpTypeConversionPatterns(patterns, target, converter,
fheContext.parameter);
populateWithTFHEOpTypeConversionPatterns(patterns, target, converter);
patterns.add<RegionOpTypeConverterPattern<
mlir::linalg::GenericOp, TFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);
@@ -348,7 +321,7 @@ namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTFHEGlobalParametrizationPass(
mlir::concretelang::V0FHEContext &fheContext) {
return std::make_unique<TFHEGlobalParametrizationPass>(fheContext);
return std::make_unique<TFHEGlobalParametrizationPass>(fheContext.parameter);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -32,7 +32,6 @@ const auto v0Curve = getV0Curves(securityLevel, keyFormat);
/// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
Precision precision,
Variance variance,
mlir::Type type) {
if (type.isIntOrIndex()) {
@@ -53,9 +52,11 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
},
};
}
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
if (auto lweType = type.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>()) {
// TODO - Get the width from the LWECiphertextType instead of global
// precision (could be possible after merge concrete-ciphertext-parameter)
size_t precision = (size_t)lweType.getP();
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
@@ -75,8 +76,8 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate = gateFromMLIRType(secretKeyID, precision, variance,
tensor.getElementType());
auto gate =
gateFromMLIRType(secretKeyID, variance, tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
@@ -142,9 +143,6 @@ createClientParametersForV0(V0FHEContext fheContext,
llvm::inconvertibleErrorCode());
}
// For the v0 the precision is global
auto precision = fheContext.constraint.p;
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getFunctionType();
@@ -157,16 +155,14 @@ createClientParametersForV0(V0FHEContext fheContext,
for (auto inType = funcType.getInputs().begin();
inType < funcType.getInputs().end() - hasContext; inType++) {
auto gate =
gateFromMLIRType(BIG_KEY, precision, encryptionVariance, *inType);
auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, *inType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.inputs.push_back(gate.get());
}
for (auto outType : funcType.getResults()) {
auto gate =
gateFromMLIRType(BIG_KEY, precision, encryptionVariance, outType);
auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, outType);
if (auto err = gate.takeError()) {
return std::move(err);
}

View File

@@ -1,9 +1,9 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
// CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{3}>
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{3}>) -> !TFHE.glwe<{_,_,_}{3}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)