diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index 4b4cc5057..1275cca0d 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -166,9 +166,21 @@ static inline bool operator==(const CircuitGateShape &lhs, lhs.size == rhs.size; } +struct ChunkInfo { + /// total number of bits used for the chunk including the carry. + /// size should be at least width + 1 + unsigned int size; + /// number of bits used for the chunk excluding the carry + unsigned int width; +}; +static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) { + return lhs.width == rhs.width && lhs.size == rhs.size; +} + struct CircuitGate { llvm::Optional encryption; CircuitGateShape shape; + llvm::Optional chunkInfo; bool isEncrypted() { return encryption.hasValue(); } @@ -186,7 +198,8 @@ struct CircuitGate { } }; static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) { - return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape; + return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape && + lhs.chunkInfo == rhs.chunkInfo; } struct ClientParameters { diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index c2caaf405..97afa2ca8 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -90,6 +90,18 @@ struct PublicResult { /// Serialize into an output stream. outcome::checked serialize(std::ostream &ostream); + /// Get the original integer that was decomposed into chunks of `chunkWidth` + /// bits each + uint64_t fromChunks(std::vector chunks, unsigned int chunkWidth) { + uint64_t value = 0; + uint64_t mask = (1 << chunkWidth) - 1; + for (size_t i = 0; i < chunks.size(); i++) { + auto chunk = chunks[i] & mask; + value += chunk << (chunkWidth * i); + } + return value; + } + /// Get the result at `pos` as a scalar. Decryption happens if the /// result is encrypted. template @@ -99,6 +111,16 @@ struct PublicResult { if (!gate.isEncrypted()) return buffers[pos].getScalar().getValue(); + // Chunked integers are represented as tensors at a lower level, so we need + // to deal with them as tensors, then build the resulting scalar out of the + // tensor values + if (gate.chunkInfo.hasValue()) { + OUTCOME_TRY(std::vector decryptedChunks, + this->asClearTextVector(keySet, pos)); + uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width); + return (T)decrypted; + } + auto &buffer = buffers[pos].getTensor(); auto ciphertext = buffer.getOpaqueElementPointer(0); diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 624c2cede..3af76dd52 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -74,6 +74,7 @@ struct CompilationOptions { /// When decomposing big integers into chunks, chunkSize is the total number /// of bits used for the message, including the carry, while chunkWidth is /// only the number of bits used during encoding and decoding of a big integer + bool chunkIntegers; unsigned int chunkSize; unsigned int chunkWidth; @@ -83,8 +84,8 @@ struct CompilationOptions { emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false), clientParametersFuncName(llvm::None), - optimizerConfig(optimizer::DEFAULT_CONFIG), chunkSize(4), - chunkWidth(2){}; + optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false), + chunkSize(4), chunkWidth(2){}; CompilationOptions(std::string funcname) : CompilationOptions() { clientParametersFuncName = funcname; diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index d3e68ced1..68dfb6d5d 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -206,6 +206,23 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return (sign) ? buildScalarLambdaResult(keySet, result) : buildScalarLambdaResult(keySet, result); } + } else if (gate.chunkInfo.hasValue()) { + // chunked scalar case + assert(gate.shape.dimensions.size() == 1); + width = gate.shape.size * gate.chunkInfo->width; + if (width > 32) { + return (sign) ? buildScalarLambdaResult(keySet, result) + : buildScalarLambdaResult(keySet, result); + } else if (width > 16) { + return (sign) ? buildScalarLambdaResult(keySet, result) + : buildScalarLambdaResult(keySet, result); + } else if (width > 8) { + return (sign) ? buildScalarLambdaResult(keySet, result) + : buildScalarLambdaResult(keySet, result); + } else if (width <= 8) { + return (sign) ? buildScalarLambdaResult(keySet, result) + : buildScalarLambdaResult(keySet, result); + } } else { // tensor case if (width > 32) { diff --git a/compiler/include/concretelang/Support/V0ClientParameters.h b/compiler/include/concretelang/Support/V0ClientParameters.h index ecba29be8..4ca97d44e 100644 --- a/compiler/include/concretelang/Support/V0ClientParameters.h +++ b/compiler/include/concretelang/Support/V0ClientParameters.h @@ -15,11 +15,13 @@ namespace mlir { namespace concretelang { +using ::concretelang::clientlib::ChunkInfo; using ::concretelang::clientlib::ClientParameters; llvm::Expected createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName, - mlir::ModuleOp module, int bitsOfSecurity); + mlir::ModuleOp module, int bitsOfSecurity, + llvm::Optional chunkInfo = llvm::None); } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/ClientLib/EncryptedArguments.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp index 6818c2309..999a9cadf 100644 --- a/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -18,11 +18,33 @@ EncryptedArguments::exportPublicArguments(ClientParameters clientParameters, clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers)); } +/// Split the input integer into `size` chunks of `chunkWidth` bits each +std::vector chunkInput(uint64_t value, size_t size, + unsigned int chunkWidth) { + std::vector chunks; + chunks.reserve(size); + uint64_t mask = (1 << chunkWidth) - 1; + for (size_t i = 0; i < size; i++) { + auto chunk = value & mask; + chunks.push_back((uint64_t)chunk); + value >>= chunkWidth; + } + return chunks; +} + outcome::checked EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); + OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos)); + // a chunked input is represented as a tensor in lower levels, and need to to + // splitted into chunks and encrypted as such + if (input.chunkInfo.hasValue()) { + std::vector chunks = + chunkInput(arg, input.shape.size, input.chunkInfo.getPointer()->width); + return this->pushArg(chunks.data(), input.shape.size, keySet); + } + // we only increment if we don't forward the call to another pushArg method auto pos = currentPos++; - OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(pos)); if (input.shape.size != 0) { return StringError("argument #") << pos << " is not a scalar"; } diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 88c56fa92..10a37291c 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -363,10 +363,15 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { emptyParams.functionName = funcName; res.clientParameters = emptyParams; } else { + llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo = + llvm::None; + if (options.chunkIntegers) { + chunkInfo = ::concretelang::clientlib::ChunkInfo{4, 2}; + } auto clientParametersOrErr = mlir::concretelang::createClientParametersForV0( *res.fheContext, funcName, module, - options.optimizerConfig.security); + options.optimizerConfig.security, chunkInfo); if (!clientParametersOrErr) return clientParametersOrErr.takeError(); diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index da15d8c2a..4042cac80 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -22,6 +22,7 @@ namespace mlir { namespace concretelang { namespace clientlib = ::concretelang::clientlib; +using ::concretelang::clientlib::ChunkInfo; using ::concretelang::clientlib::CircuitGate; using ::concretelang::clientlib::ClientParameters; using ::concretelang::clientlib::Encoding; @@ -33,10 +34,10 @@ using ::concretelang::clientlib::Variance; const auto keyFormat = concrete::BINARY; /// For the v0 the secretKeyID and precision are the same for all gates. -llvm::Expected gateFromMLIRType(V0FHEContext fheContext, - LweSecretKeyID secretKeyID, - Variance variance, - mlir::Type type) { +llvm::Expected +gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID, + Variance variance, llvm::Optional chunkInfo, + mlir::Type type) { if (type.isIntOrIndex()) { // TODO - The index type is dependant of the target architecture, so // actually we assume we target only 64 bits, we need to have some the size @@ -55,6 +56,7 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, /*.dimensions = */ std::vector(), /*.size = */ 0, /* .sign */ sign}, + /*.chunkInfo = */ llvm::None, }; } if (auto lweTy = type.dyn_cast_or_null< @@ -64,22 +66,76 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, if (fheContext.parameter.largeInteger.has_value()) { crt = fheContext.parameter.largeInteger.value().crtDecomposition; } + size_t width; + uint64_t size = 0; + std::vector dims; + if (chunkInfo.hasValue()) { + width = chunkInfo->size; + assert(lweTy.getWidth() % chunkInfo->width == 0); + size = lweTy.getWidth() / chunkInfo->width; + dims.push_back(size); + } else { + width = (size_t)lweTy.getWidth(); + } return CircuitGate{ /* .encryption = */ llvm::Optional({ /* .secretKeyID = */ secretKeyID, /* .variance = */ variance, /* .encoding = */ { - /* .precision = */ lweTy.getWidth(), + /* .precision = */ width, /* .crt = */ crt, /*.sign = */ sign, }, }), /*.shape = */ - {/*.width = */ (size_t)lweTy.getWidth(), - /*.dimensions = */ std::vector(), - /*.size = */ 0, - /*.sign = */ sign}, + { + /*.width = */ width, + /*.dimensions = */ dims, + /*.size = */ size, + /*.sign = */ sign, + }, + /*.chunkInfo = */ chunkInfo, + }; + } + // TODO: this a duplicate of the last if: should be removed when we remove + // chinked eint + if (auto lweTy = type.dyn_cast_or_null< + mlir::concretelang::FHE::ChunkedEncryptedIntegerType>()) { + bool sign = lweTy.isSignedInteger(); + std::vector crt; + if (fheContext.parameter.largeInteger.has_value()) { + crt = fheContext.parameter.largeInteger.value().crtDecomposition; + } + size_t width; + uint64_t size; + std::vector dims; + if (chunkInfo.hasValue()) { + width = chunkInfo->size; + assert(lweTy.getWidth() % chunkInfo->width == 0); + size = lweTy.getWidth() / chunkInfo->width; + dims.push_back(size); + } else { + width = (size_t)lweTy.getWidth(); + } + return CircuitGate{ + /* .encryption = */ llvm::Optional({ + /* .secretKeyID = */ secretKeyID, + /* .variance = */ variance, + /* .encoding = */ + { + /* .precision = */ width, + /* .crt = */ crt, + }, + }), + /*.shape = */ + { + /*.width = */ width, + /*.dimensions = */ dims, + /*.size = */ size, + /*.sign = */ sign, + }, + /*.chunkInfo = */ chunkInfo, }; } if (auto lweTy = type.dyn_cast_or_null< @@ -103,11 +159,12 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, /*.size = */ 0, /*.sign = */ false, }, + /*.chunkInfo = */ llvm::None, }; } auto tensor = type.dyn_cast_or_null(); if (tensor != nullptr) { - auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, + auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, chunkInfo, tensor.getElementType()); if (auto err = gate.takeError()) { return std::move(err); @@ -126,7 +183,8 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, llvm::Expected createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef functionName, mlir::ModuleOp module, - int bitsOfSecurity) { + int bitsOfSecurity, + llvm::Optional chunkInfo) { const auto v0Curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat); if (v0Curve == nullptr) { @@ -216,7 +274,8 @@ createClientParametersForV0(V0FHEContext fheContext, auto inputs = funcType.getInputs(); auto gateFromType = [&](mlir::Type ty) { - return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty); + return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, + chunkInfo, ty); }; for (auto inType : inputs) { auto gate = gateFromType(inType); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index cffb97950..5e1381108 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -206,6 +206,12 @@ llvm::cl::list llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); +llvm::cl::opt + chunkIntegers("chunk-integers", + llvm::cl::desc("Whether to decompose integer into chunks or " + "not, default is false (to not chunk)"), + llvm::cl::init(false)); + llvm::cl::opt chunkSize( "chunk-size", llvm::cl::desc( @@ -350,6 +356,7 @@ cmdlineCompilationOptions() { cmdline::unrollLoopsWithSDFGConvertibleOps; options.optimizeConcrete = cmdline::optimizeConcrete; options.emitGPUOps = cmdline::emitGPUOps; + options.chunkIntegers = cmdline::chunkIntegers; options.chunkSize = cmdline::chunkSize; options.chunkWidth = cmdline::chunkWidth; diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp index ceae4f269..bd170a5b6 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp @@ -45,11 +45,13 @@ TEST(Support, client_parameters_json_serde) { /*.encryption = */ { {clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}, false}}}, /*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4, false}, + /*.chunkInfo = */ llvm::None, }, { /*.encryption = */ { {clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}}, /*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false}, + /*.chunkInfo = */ llvm::None, }, }; params0.outputs = { @@ -57,6 +59,7 @@ TEST(Support, client_parameters_json_serde) { /*.encryption = */ { {clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}}, /*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false}, + /*.chunkInfo = */ llvm::None, }, }; auto json = clientlib::toJSON(params0);