From bd3d462384d730462330202379726dc19439e85a Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 10 Aug 2022 00:11:05 +0200 Subject: [PATCH] feat(multiprecision): enable real multiple precision computation --- .../lib/Conversion/FHEToTFHE/FHEToTFHE.cpp | 4 +- .../TFHEGlobalParametrization.cpp | 175 ++++++++---------- compiler/lib/Support/V0ClientParameters.cpp | 18 +- .../FHEToTFHE/apply_univariate.mlir | 4 +- 4 files changed, 85 insertions(+), 116 deletions(-) diff --git a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp index a77aee321..6f9695d84 100644 --- a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp +++ b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp @@ -99,8 +99,8 @@ struct ApplyLookupTableEintOpPattern .cast(); auto resultTy = converter.convertType(lutOp.getType()); // %glwe_lut = "TFHE.glwe_from_table"(%lut) - auto glweLut = rewriter.create(lutOp.getLoc(), - inputTy, lutOp.lut()); + auto glweLut = rewriter.create( + lutOp.getLoc(), resultTy, lutOp.lut()); // %glwe_ks = "TFHE.keyswitch_glwe"(%ct) auto glweKs = rewriter.create( lutOp.getLoc(), inputTy, lutOp.a(), -1, -1); diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 9813ed7fb..26706d19b 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -21,10 +21,11 @@ namespace TFHE = mlir::concretelang::TFHE; namespace { struct TFHEGlobalParametrizationPass : public TFHEGlobalParametrizationBase { - TFHEGlobalParametrizationPass(mlir::concretelang::V0FHEContext &fheContext) - : fheContext(fheContext){}; + TFHEGlobalParametrizationPass( + mlir::concretelang::V0Parameter &cryptoParameters) + : cryptoParameters(cryptoParameters){}; void runOnOperation() final; - mlir::concretelang::V0FHEContext &fheContext; + mlir::concretelang::V0Parameter &cryptoParameters; }; } // namespace @@ -37,114 +38,120 @@ class TFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter { public: TFHEGlobalParametrizationTypeConverter( - mlir::concretelang::V0FHEContext &fheContext) - : fheContext(fheContext) { - auto convertGLWECiphertextType = - [&](GLWECipherTextType type, - mlir::concretelang::V0FHEContext &fheContext) { - auto newTy = this->glweInterPBSType(type.getContext(), fheContext); - if (newTy.getDimension() == type.getDimension() && - newTy.getPolynomialSize() == type.getPolynomialSize() && - newTy.getP() == type.getP()) - return type; - return newTy; - }; + mlir::concretelang::V0Parameter &cryptoParameters) + : cryptoParameters(cryptoParameters) { addConversion([](mlir::Type type) { return type; }); - addConversion([&](GLWECipherTextType type) { - return convertGLWECiphertextType(type, fheContext); - }); + addConversion( + [&](GLWECipherTextType type) { return this->glweInterPBSType(type); }); addConversion([&](mlir::RankedTensorType type) { auto glwe = type.getElementType().dyn_cast_or_null(); if (glwe == nullptr) { return (mlir::Type)(type); } - mlir::Type r = mlir::RankedTensorType::get( - type.getShape(), convertGLWECiphertextType(glwe, fheContext)); + mlir::Type r = mlir::RankedTensorType::get(type.getShape(), + this->glweInterPBSType(glwe)); return r; }); } - TFHE::GLWECipherTextType - glweInterPBSType(mlir::MLIRContext *context, - mlir::concretelang::V0FHEContext fheContext) { - return TFHE::GLWECipherTextType::get( - context, fheContext.parameter.getNBigGlweDimension(), 1, 64, - fheContext.constraint.p); + TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) { + auto bits = 64; + auto dimension = cryptoParameters.getNBigGlweDimension(); + auto polynomialSize = 1; + auto precision = (signed)type.getP(); + if ((int)dimension == type.getDimension() && + (int)polynomialSize == type.getPolynomialSize()) { + return type; + } + return TFHE::GLWECipherTextType::get(type.getContext(), dimension, + polynomialSize, bits, precision); } - TFHE::GLWECipherTextType glweLookupTableType(mlir::MLIRContext *context) { - return TFHE::GLWECipherTextType::get( - context, fheContext.parameter.glweDimension, - fheContext.parameter.getPolynomialSize(), 64, fheContext.constraint.p); + TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) { + auto bits = 64; + auto dimension = cryptoParameters.glweDimension; + auto polynomialSize = cryptoParameters.getPolynomialSize(); + auto precision = (signed)type.getP(); + return TFHE::GLWECipherTextType::get(type.getContext(), dimension, + polynomialSize, bits, precision); } - TFHE::GLWECipherTextType glweIntraPBSType(mlir::MLIRContext *context) { - return TFHE::GLWECipherTextType::get(context, fheContext.parameter.nSmall, - 1, 64, fheContext.constraint.p); + TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) { + auto bits = 64; + auto dimension = cryptoParameters.nSmall; + auto polynomialSize = 1; + auto precision = (signed)type.getP(); + return TFHE::GLWECipherTextType::get(type.getContext(), dimension, + polynomialSize, bits, precision); } - mlir::concretelang::V0FHEContext fheContext; + mlir::concretelang::V0Parameter cryptoParameters; }; struct KeySwitchGLWEOpPattern : public mlir::OpRewritePattern { KeySwitchGLWEOpPattern(mlir::MLIRContext *context, TFHEGlobalParametrizationTypeConverter &converter, - mlir::concretelang::V0FHEContext &fheContext, + mlir::concretelang::V0Parameter &cryptoParameters, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), - converter(converter), fheContext(fheContext) {} + converter(converter), cryptoParameters(cryptoParameters) {} mlir::LogicalResult matchAndRewrite(TFHE::KeySwitchGLWEOp ksOp, mlir::PatternRewriter &rewriter) const override { - mlir::SmallVector newResultTypes; auto inputTy = ksOp.ciphertext().getType().cast(); - auto outputTy = converter.glweIntraPBSType(rewriter.getContext()); + auto newInputTy = converter.convertType(inputTy); + auto outputTy = ksOp.result().getType().cast(); + auto newOutputTy = converter.glweIntraPBSType(outputTy); auto newOp = rewriter.replaceOpWithNewOp( - ksOp, outputTy, ksOp.ciphertext(), fheContext.parameter.ksLevel, - fheContext.parameter.ksLogBase); + ksOp, newOutputTy, ksOp.ciphertext(), cryptoParameters.ksLevel, + cryptoParameters.ksLogBase); rewriter.startRootUpdate(newOp); - newOp.ciphertext().setType(converter.convertType(inputTy)); + newOp.ciphertext().setType(newInputTy); rewriter.finalizeRootUpdate(newOp); return mlir::success(); }; private: TFHEGlobalParametrizationTypeConverter &converter; - mlir::concretelang::V0FHEContext &fheContext; + mlir::concretelang::V0Parameter &cryptoParameters; }; struct BootstrapGLWEOpPattern : public mlir::OpRewritePattern { BootstrapGLWEOpPattern(mlir::MLIRContext *context, TFHEGlobalParametrizationTypeConverter &converter, - mlir::concretelang::V0FHEContext &fheContext, + mlir::concretelang::V0Parameter &cryptoParameters, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), - converter(converter), fheContext(fheContext) {} + converter(converter), cryptoParameters(cryptoParameters) {} mlir::LogicalResult matchAndRewrite(TFHE::BootstrapGLWEOp bsOp, mlir::PatternRewriter &rewriter) const override { + auto inputTy = bsOp.ciphertext().getType().cast(); + auto newInputTy = converter.glweIntraPBSType(inputTy); + auto outputTy = bsOp.result().getType().cast(); + auto newOutputTy = converter.convertType(outputTy); + auto tableTy = + bsOp.lookup_table().getType().cast(); + auto newTableTy = converter.glweLookupTableType(tableTy); auto newOp = rewriter.replaceOpWithNewOp( - bsOp, converter.convertType(bsOp.result().getType()), bsOp.ciphertext(), - bsOp.lookup_table(), fheContext.parameter.brLevel, - fheContext.parameter.brLogBase); + bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(), + cryptoParameters.brLevel, cryptoParameters.brLogBase); rewriter.startRootUpdate(newOp); - newOp.ciphertext().setType( - converter.glweIntraPBSType(rewriter.getContext())); - newOp.lookup_table().setType( - converter.glweLookupTableType(rewriter.getContext())); + newOp.ciphertext().setType(newInputTy); + newOp.lookup_table().setType(newTableTy); rewriter.finalizeRootUpdate(newOp); return mlir::success(); }; private: TFHEGlobalParametrizationTypeConverter &converter; - mlir::concretelang::V0FHEContext &fheContext; + mlir::concretelang::V0Parameter &cryptoParameters; }; /// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by @@ -170,58 +177,24 @@ struct GLWEFromTablePattern : public mlir::OpRewritePattern { GLWEFromTablePattern(mlir::MLIRContext *context, TFHEGlobalParametrizationTypeConverter &converter, - mlir::concretelang::V0FHEContext &fheContext, mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), - converter(converter), fheContext(fheContext) {} + converter(converter) {} mlir::LogicalResult matchAndRewrite(TFHE::GLWEFromTableOp glweOp, mlir::PatternRewriter &rewriter) const override { - auto newTy = converter.glweLookupTableType(glweOp.getContext()); - - auto lutOp = glweOp.table(); - auto tableTy = lutOp.getType().cast(); - - auto expectedSize = 1 << newTy.getP(); - if (tableTy.getShape()[0] < expectedSize) { - // Create a new padded lookup table - auto constantOp = mlir::dyn_cast_or_null( - lutOp.getDefiningOp()); - if (constantOp == nullptr) { - glweOp.emitError() << "padding for non-constant operator is NYI"; - return mlir::failure(); - } - mlir::DenseIntElementsAttr denseVals = - constantOp->getAttrOfType("value"); - if (denseVals == nullptr) { - constantOp.emitError() << "value should be dense"; - return mlir::failure(); - } - auto integerSize = 64; - llvm::SmallVector rawNewDenseVals( - expectedSize, llvm::APInt(integerSize, 0)); - auto denseValsAP = denseVals.getValues(); - for (auto i = 0; i < expectedSize; i++) { - rawNewDenseVals[i] = llvm::APInt( - integerSize, denseValsAP[i % denseVals.size()].getZExtValue()); - } - auto newDenseValsType = mlir::RankedTensorType::get( - {expectedSize}, rewriter.getIntegerType(integerSize)); - auto newDenseVals = - mlir::DenseIntElementsAttr::get(newDenseValsType, rawNewDenseVals); - // Replace the lutOp by the new padded lookup table - lutOp = rewriter.create(constantOp.getLoc(), - newDenseVals); - } - rewriter.replaceOpWithNewOp(glweOp, newTy, lutOp); + auto outputTy = glweOp.result().getType().cast(); + auto newOutputTy = converter.glweLookupTableType(outputTy); + auto tableOp = glweOp.table(); + rewriter.replaceOpWithNewOp(glweOp, newOutputTy, + tableOp); return mlir::success(); }; private: TFHEGlobalParametrizationTypeConverter &converter; - mlir::concretelang::V0FHEContext &fheContext; }; template @@ -239,8 +212,7 @@ void populateWithTFHEOpTypeConversionPattern( /// operators to the corresponding function call to the `Concrete C API`. void populateWithTFHEOpTypeConversionPatterns( mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, - mlir::TypeConverter &typeConverter, - mlir::concretelang::V0Parameter &v0Parameter) { + mlir::TypeConverter &typeConverter) { populateWithTFHEOpTypeConversionPattern( patterns, target, typeConverter); populateWithTFHEOpTypeConversionPattern< @@ -261,7 +233,7 @@ void populateWithTFHEOpTypeConversionPatterns( void TFHEGlobalParametrizationPass::runOnOperation() { auto op = this->getOperation(); - TFHEGlobalParametrizationTypeConverter converter(fheContext); + TFHEGlobalParametrizationTypeConverter converter(cryptoParameters); // Parametrize { @@ -278,7 +250,7 @@ void TFHEGlobalParametrizationPass::runOnOperation() { patterns, converter); // Parametrize keyswitch bootstrap - patterns.add(&getContext(), converter, fheContext); + patterns.add(&getContext(), converter); target.addDynamicallyLegalOp( [&](TFHE::GLWEFromTableOp op) { return !op.getType() @@ -286,20 +258,21 @@ void TFHEGlobalParametrizationPass::runOnOperation() { .hasUnparametrizedParameters(); }); target.addLegalOp(); - patterns.add(&getContext(), converter, fheContext); + patterns.add(&getContext(), converter, + cryptoParameters); target.addDynamicallyLegalOp( [&](TFHE::KeySwitchGLWEOp op) { return op.level() != (uint32_t)-1 && op.baseLog() != (uint32_t)-1; }); - patterns.add(&getContext(), converter, fheContext); + patterns.add(&getContext(), converter, + cryptoParameters); target.addDynamicallyLegalOp( [&](TFHE::BootstrapGLWEOp op) { return converter.isLegal(op->getResultTypes()); }); // Add all patterns to convert TFHE types - populateWithTFHEOpTypeConversionPatterns(patterns, target, converter, - fheContext.parameter); + populateWithTFHEOpTypeConversionPatterns(patterns, target, converter); patterns.add>( &getContext(), converter); @@ -348,7 +321,7 @@ namespace concretelang { std::unique_ptr> createConvertTFHEGlobalParametrizationPass( mlir::concretelang::V0FHEContext &fheContext) { - return std::make_unique(fheContext); + return std::make_unique(fheContext.parameter); } } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index 96c29e0eb..d8c0fe362 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -32,7 +32,6 @@ const auto v0Curve = getV0Curves(securityLevel, keyFormat); /// For the v0 the secretKeyID and precision are the same for all gates. llvm::Expected gateFromMLIRType(LweSecretKeyID secretKeyID, - Precision precision, Variance variance, mlir::Type type) { if (type.isIntOrIndex()) { @@ -53,9 +52,11 @@ llvm::Expected gateFromMLIRType(LweSecretKeyID secretKeyID, }, }; } - if (type.isa()) { + if (auto lweType = type.dyn_cast_or_null< + mlir::concretelang::Concrete::LweCiphertextType>()) { // TODO - Get the width from the LWECiphertextType instead of global // precision (could be possible after merge concrete-ciphertext-parameter) + size_t precision = (size_t)lweType.getP(); return CircuitGate{ /* .encryption = */ llvm::Optional({ /* .secretKeyID = */ secretKeyID, @@ -75,8 +76,8 @@ llvm::Expected gateFromMLIRType(LweSecretKeyID secretKeyID, } auto tensor = type.dyn_cast_or_null(); if (tensor != nullptr) { - auto gate = gateFromMLIRType(secretKeyID, precision, variance, - tensor.getElementType()); + auto gate = + gateFromMLIRType(secretKeyID, variance, tensor.getElementType()); if (auto err = gate.takeError()) { return std::move(err); } @@ -142,9 +143,6 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::inconvertibleErrorCode()); } - // For the v0 the precision is global - auto precision = fheContext.constraint.p; - // Create input and output circuit gate parameters auto funcType = (*funcOp).getFunctionType(); @@ -157,16 +155,14 @@ createClientParametersForV0(V0FHEContext fheContext, for (auto inType = funcType.getInputs().begin(); inType < funcType.getInputs().end() - hasContext; inType++) { - auto gate = - gateFromMLIRType(BIG_KEY, precision, encryptionVariance, *inType); + auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, *inType); if (auto err = gate.takeError()) { return std::move(err); } c.inputs.push_back(gate.get()); } for (auto outType : funcType.getResults()) { - auto gate = - gateFromMLIRType(BIG_KEY, precision, encryptionVariance, outType); + auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, outType); if (auto err = gate.takeError()) { return std::move(err); } diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir index ec870bfb5..3044dff73 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir @@ -1,9 +1,9 @@ // RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> { -// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> // CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> -// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{3}> +// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{3}>) -> !TFHE.glwe<{_,_,_}{3}> // CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}> func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> { %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)