mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: support chunked integer during enc/dec/exec
This commit is contained in:
@@ -166,9 +166,21 @@ static inline bool operator==(const CircuitGateShape &lhs,
|
||||
lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct ChunkInfo {
|
||||
/// total number of bits used for the chunk including the carry.
|
||||
/// size should be at least width + 1
|
||||
unsigned int size;
|
||||
/// number of bits used for the chunk excluding the carry
|
||||
unsigned int width;
|
||||
};
|
||||
static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
|
||||
return lhs.width == rhs.width && lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
llvm::Optional<ChunkInfo> chunkInfo;
|
||||
|
||||
bool isEncrypted() { return encryption.hasValue(); }
|
||||
|
||||
@@ -186,7 +198,8 @@ struct CircuitGate {
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
|
||||
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;
|
||||
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape &&
|
||||
lhs.chunkInfo == rhs.chunkInfo;
|
||||
}
|
||||
|
||||
struct ClientParameters {
|
||||
|
||||
@@ -90,6 +90,18 @@ struct PublicResult {
|
||||
/// Serialize into an output stream.
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
/// Get the original integer that was decomposed into chunks of `chunkWidth`
|
||||
/// bits each
|
||||
uint64_t fromChunks(std::vector<uint64_t> chunks, unsigned int chunkWidth) {
|
||||
uint64_t value = 0;
|
||||
uint64_t mask = (1 << chunkWidth) - 1;
|
||||
for (size_t i = 0; i < chunks.size(); i++) {
|
||||
auto chunk = chunks[i] & mask;
|
||||
value += chunk << (chunkWidth * i);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Get the result at `pos` as a scalar. Decryption happens if the
|
||||
/// result is encrypted.
|
||||
template <typename T>
|
||||
@@ -99,6 +111,16 @@ struct PublicResult {
|
||||
if (!gate.isEncrypted())
|
||||
return buffers[pos].getScalar().getValue<T>();
|
||||
|
||||
// Chunked integers are represented as tensors at a lower level, so we need
|
||||
// to deal with them as tensors, then build the resulting scalar out of the
|
||||
// tensor values
|
||||
if (gate.chunkInfo.hasValue()) {
|
||||
OUTCOME_TRY(std::vector<uint64_t> decryptedChunks,
|
||||
this->asClearTextVector<uint64_t>(keySet, pos));
|
||||
uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width);
|
||||
return (T)decrypted;
|
||||
}
|
||||
|
||||
auto &buffer = buffers[pos].getTensor();
|
||||
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(0);
|
||||
|
||||
@@ -74,6 +74,7 @@ struct CompilationOptions {
|
||||
/// When decomposing big integers into chunks, chunkSize is the total number
|
||||
/// of bits used for the message, including the carry, while chunkWidth is
|
||||
/// only the number of bits used during encoding and decoding of a big integer
|
||||
bool chunkIntegers;
|
||||
unsigned int chunkSize;
|
||||
unsigned int chunkWidth;
|
||||
|
||||
@@ -83,8 +84,8 @@ struct CompilationOptions {
|
||||
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
|
||||
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
|
||||
clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkSize(4),
|
||||
chunkWidth(2){};
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
|
||||
chunkSize(4), chunkWidth(2){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
clientParametersFuncName = funcname;
|
||||
|
||||
@@ -206,6 +206,23 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint8_t>(keySet, result);
|
||||
}
|
||||
} else if (gate.chunkInfo.hasValue()) {
|
||||
// chunked scalar case
|
||||
assert(gate.shape.dimensions.size() == 1);
|
||||
width = gate.shape.size * gate.chunkInfo->width;
|
||||
if (width > 32) {
|
||||
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint64_t>(keySet, result);
|
||||
} else if (width > 16) {
|
||||
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint32_t>(keySet, result);
|
||||
} else if (width > 8) {
|
||||
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint16_t>(keySet, result);
|
||||
} else if (width <= 8) {
|
||||
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint8_t>(keySet, result);
|
||||
}
|
||||
} else {
|
||||
// tensor case
|
||||
if (width > 32) {
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::ChunkInfo;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
|
||||
mlir::ModuleOp module, int bitsOfSecurity);
|
||||
mlir::ModuleOp module, int bitsOfSecurity,
|
||||
llvm::Optional<ChunkInfo> chunkInfo = llvm::None);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -18,11 +18,33 @@ EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
/// Split the input integer into `size` chunks of `chunkWidth` bits each
|
||||
std::vector<uint64_t> chunkInput(uint64_t value, size_t size,
|
||||
unsigned int chunkWidth) {
|
||||
std::vector<uint64_t> chunks;
|
||||
chunks.reserve(size);
|
||||
uint64_t mask = (1 << chunkWidth) - 1;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto chunk = value & mask;
|
||||
chunks.push_back((uint64_t)chunk);
|
||||
value >>= chunkWidth;
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos));
|
||||
// a chunked input is represented as a tensor in lower levels, and need to to
|
||||
// splitted into chunks and encrypted as such
|
||||
if (input.chunkInfo.hasValue()) {
|
||||
std::vector<uint64_t> chunks =
|
||||
chunkInput(arg, input.shape.size, input.chunkInfo.getPointer()->width);
|
||||
return this->pushArg(chunks.data(), input.shape.size, keySet);
|
||||
}
|
||||
// we only increment if we don't forward the call to another pushArg method
|
||||
auto pos = currentPos++;
|
||||
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(pos));
|
||||
if (input.shape.size != 0) {
|
||||
return StringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
|
||||
@@ -363,10 +363,15 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
emptyParams.functionName = funcName;
|
||||
res.clientParameters = emptyParams;
|
||||
} else {
|
||||
llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo =
|
||||
llvm::None;
|
||||
if (options.chunkIntegers) {
|
||||
chunkInfo = ::concretelang::clientlib::ChunkInfo{4, 2};
|
||||
}
|
||||
auto clientParametersOrErr =
|
||||
mlir::concretelang::createClientParametersForV0(
|
||||
*res.fheContext, funcName, module,
|
||||
options.optimizerConfig.security);
|
||||
options.optimizerConfig.security, chunkInfo);
|
||||
if (!clientParametersOrErr)
|
||||
return clientParametersOrErr.takeError();
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
using ::concretelang::clientlib::ChunkInfo;
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
using ::concretelang::clientlib::Encoding;
|
||||
@@ -33,10 +34,10 @@ using ::concretelang::clientlib::Variance;
|
||||
const auto keyFormat = concrete::BINARY;
|
||||
|
||||
/// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
LweSecretKeyID secretKeyID,
|
||||
Variance variance,
|
||||
mlir::Type type) {
|
||||
llvm::Expected<CircuitGate>
|
||||
gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
Variance variance, llvm::Optional<ChunkInfo> chunkInfo,
|
||||
mlir::Type type) {
|
||||
if (type.isIntOrIndex()) {
|
||||
// TODO - The index type is dependant of the target architecture, so
|
||||
// actually we assume we target only 64 bits, we need to have some the size
|
||||
@@ -55,6 +56,7 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/* .sign */ sign},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
@@ -64,22 +66,76 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
if (fheContext.parameter.largeInteger.has_value()) {
|
||||
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
||||
}
|
||||
size_t width;
|
||||
uint64_t size = 0;
|
||||
std::vector<int64_t> dims;
|
||||
if (chunkInfo.hasValue()) {
|
||||
width = chunkInfo->size;
|
||||
assert(lweTy.getWidth() % chunkInfo->width == 0);
|
||||
size = lweTy.getWidth() / chunkInfo->width;
|
||||
dims.push_back(size);
|
||||
} else {
|
||||
width = (size_t)lweTy.getWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ lweTy.getWidth(),
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ crt,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{/*.width = */ (size_t)lweTy.getWidth(),
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ sign},
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ dims,
|
||||
/*.size = */ size,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
/*.chunkInfo = */ chunkInfo,
|
||||
};
|
||||
}
|
||||
// TODO: this a duplicate of the last if: should be removed when we remove
|
||||
// chinked eint
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::ChunkedEncryptedIntegerType>()) {
|
||||
bool sign = lweTy.isSignedInteger();
|
||||
std::vector<int64_t> crt;
|
||||
if (fheContext.parameter.largeInteger.has_value()) {
|
||||
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
||||
}
|
||||
size_t width;
|
||||
uint64_t size;
|
||||
std::vector<int64_t> dims;
|
||||
if (chunkInfo.hasValue()) {
|
||||
width = chunkInfo->size;
|
||||
assert(lweTy.getWidth() % chunkInfo->width == 0);
|
||||
size = lweTy.getWidth() / chunkInfo->width;
|
||||
dims.push_back(size);
|
||||
} else {
|
||||
width = (size_t)lweTy.getWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ crt,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ dims,
|
||||
/*.size = */ size,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
/*.chunkInfo = */ chunkInfo,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
@@ -103,11 +159,12 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ false,
|
||||
},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
};
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance,
|
||||
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, chunkInfo,
|
||||
tensor.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
@@ -126,7 +183,8 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext fheContext,
|
||||
llvm::StringRef functionName, mlir::ModuleOp module,
|
||||
int bitsOfSecurity) {
|
||||
int bitsOfSecurity,
|
||||
llvm::Optional<ChunkInfo> chunkInfo) {
|
||||
const auto v0Curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
|
||||
|
||||
if (v0Curve == nullptr) {
|
||||
@@ -216,7 +274,8 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
auto inputs = funcType.getInputs();
|
||||
|
||||
auto gateFromType = [&](mlir::Type ty) {
|
||||
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty);
|
||||
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance,
|
||||
chunkInfo, ty);
|
||||
};
|
||||
for (auto inType : inputs) {
|
||||
auto gate = gateFromType(inType);
|
||||
|
||||
@@ -206,6 +206,12 @@ llvm::cl::list<uint64_t>
|
||||
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore,
|
||||
llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
chunkIntegers("chunk-integers",
|
||||
llvm::cl::desc("Whether to decompose integer into chunks or "
|
||||
"not, default is false (to not chunk)"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<unsigned int> chunkSize(
|
||||
"chunk-size",
|
||||
llvm::cl::desc(
|
||||
@@ -350,6 +356,7 @@ cmdlineCompilationOptions() {
|
||||
cmdline::unrollLoopsWithSDFGConvertibleOps;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
options.chunkIntegers = cmdline::chunkIntegers;
|
||||
options.chunkSize = cmdline::chunkSize;
|
||||
options.chunkWidth = cmdline::chunkWidth;
|
||||
|
||||
|
||||
@@ -45,11 +45,13 @@ TEST(Support, client_parameters_json_serde) {
|
||||
/*.encryption = */ {
|
||||
{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}, false}}},
|
||||
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4, false},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
},
|
||||
{
|
||||
/*.encryption = */ {
|
||||
{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
},
|
||||
};
|
||||
params0.outputs = {
|
||||
@@ -57,6 +59,7 @@ TEST(Support, client_parameters_json_serde) {
|
||||
/*.encryption = */ {
|
||||
{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false},
|
||||
/*.chunkInfo = */ llvm::None,
|
||||
},
|
||||
};
|
||||
auto json = clientlib::toJSON(params0);
|
||||
|
||||
Reference in New Issue
Block a user