feat: support chunked integer during enc/dec/exec

This commit is contained in:
youben11
2023-02-03 15:30:27 +01:00
committed by Ayoub Benaissa
parent d41d14dbb8
commit bb87d29934
10 changed files with 169 additions and 18 deletions

View File

@@ -166,9 +166,21 @@ static inline bool operator==(const CircuitGateShape &lhs,
lhs.size == rhs.size;
}
struct ChunkInfo {
/// total number of bits used for the chunk including the carry.
/// size should be at least width + 1
unsigned int size;
/// number of bits used for the chunk excluding the carry
unsigned int width;
};
static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
return lhs.width == rhs.width && lhs.size == rhs.size;
}
struct CircuitGate {
llvm::Optional<EncryptionGate> encryption;
CircuitGateShape shape;
llvm::Optional<ChunkInfo> chunkInfo;
bool isEncrypted() { return encryption.hasValue(); }
@@ -186,7 +198,8 @@ struct CircuitGate {
}
};
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape &&
lhs.chunkInfo == rhs.chunkInfo;
}
struct ClientParameters {

View File

@@ -90,6 +90,18 @@ struct PublicResult {
/// Serialize into an output stream.
outcome::checked<void, StringError> serialize(std::ostream &ostream);
/// Get the original integer that was decomposed into chunks of `chunkWidth`
/// bits each
uint64_t fromChunks(std::vector<uint64_t> chunks, unsigned int chunkWidth) {
uint64_t value = 0;
uint64_t mask = (1 << chunkWidth) - 1;
for (size_t i = 0; i < chunks.size(); i++) {
auto chunk = chunks[i] & mask;
value += chunk << (chunkWidth * i);
}
return value;
}
/// Get the result at `pos` as a scalar. Decryption happens if the
/// result is encrypted.
template <typename T>
@@ -99,6 +111,16 @@ struct PublicResult {
if (!gate.isEncrypted())
return buffers[pos].getScalar().getValue<T>();
// Chunked integers are represented as tensors at a lower level, so we need
// to deal with them as tensors, then build the resulting scalar out of the
// tensor values
if (gate.chunkInfo.hasValue()) {
OUTCOME_TRY(std::vector<uint64_t> decryptedChunks,
this->asClearTextVector<uint64_t>(keySet, pos));
uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width);
return (T)decrypted;
}
auto &buffer = buffers[pos].getTensor();
auto ciphertext = buffer.getOpaqueElementPointer(0);

View File

@@ -74,6 +74,7 @@ struct CompilationOptions {
/// When decomposing big integers into chunks, chunkSize is the total number
/// of bits used for the message, including the carry, while chunkWidth is
/// only the number of bits used during encoding and decoding of a big integer
bool chunkIntegers;
unsigned int chunkSize;
unsigned int chunkWidth;
@@ -83,8 +84,8 @@ struct CompilationOptions {
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
clientParametersFuncName(llvm::None),
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkSize(4),
chunkWidth(2){};
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
chunkSize(4), chunkWidth(2){};
CompilationOptions(std::string funcname) : CompilationOptions() {
clientParametersFuncName = funcname;

View File

@@ -206,6 +206,23 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
: buildScalarLambdaResult<uint8_t>(keySet, result);
}
} else if (gate.chunkInfo.hasValue()) {
// chunked scalar case
assert(gate.shape.dimensions.size() == 1);
width = gate.shape.size * gate.chunkInfo->width;
if (width > 32) {
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
: buildScalarLambdaResult<uint64_t>(keySet, result);
} else if (width > 16) {
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
: buildScalarLambdaResult<uint32_t>(keySet, result);
} else if (width > 8) {
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
: buildScalarLambdaResult<uint16_t>(keySet, result);
} else if (width <= 8) {
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
: buildScalarLambdaResult<uint8_t>(keySet, result);
}
} else {
// tensor case
if (width > 32) {

View File

@@ -15,11 +15,13 @@
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::ClientParameters;
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
mlir::ModuleOp module, int bitsOfSecurity);
mlir::ModuleOp module, int bitsOfSecurity,
llvm::Optional<ChunkInfo> chunkInfo = llvm::None);
} // namespace concretelang
} // namespace mlir

View File

@@ -18,11 +18,33 @@ EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
}
/// Split the input integer into `size` chunks of `chunkWidth` bits each
std::vector<uint64_t> chunkInput(uint64_t value, size_t size,
unsigned int chunkWidth) {
std::vector<uint64_t> chunks;
chunks.reserve(size);
uint64_t mask = (1 << chunkWidth) - 1;
for (size_t i = 0; i < size; i++) {
auto chunk = value & mask;
chunks.push_back((uint64_t)chunk);
value >>= chunkWidth;
}
return chunks;
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos));
// a chunked input is represented as a tensor in lower levels, and need to to
// splitted into chunks and encrypted as such
if (input.chunkInfo.hasValue()) {
std::vector<uint64_t> chunks =
chunkInput(arg, input.shape.size, input.chunkInfo.getPointer()->width);
return this->pushArg(chunks.data(), input.shape.size, keySet);
}
// we only increment if we don't forward the call to another pushArg method
auto pos = currentPos++;
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(pos));
if (input.shape.size != 0) {
return StringError("argument #") << pos << " is not a scalar";
}

View File

@@ -363,10 +363,15 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
emptyParams.functionName = funcName;
res.clientParameters = emptyParams;
} else {
llvm::Optional<::concretelang::clientlib::ChunkInfo> chunkInfo =
llvm::None;
if (options.chunkIntegers) {
chunkInfo = ::concretelang::clientlib::ChunkInfo{4, 2};
}
auto clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(
*res.fheContext, funcName, module,
options.optimizerConfig.security);
options.optimizerConfig.security, chunkInfo);
if (!clientParametersOrErr)
return clientParametersOrErr.takeError();

View File

@@ -22,6 +22,7 @@ namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::Encoding;
@@ -33,10 +34,10 @@ using ::concretelang::clientlib::Variance;
const auto keyFormat = concrete::BINARY;
/// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
LweSecretKeyID secretKeyID,
Variance variance,
mlir::Type type) {
llvm::Expected<CircuitGate>
gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
Variance variance, llvm::Optional<ChunkInfo> chunkInfo,
mlir::Type type) {
if (type.isIntOrIndex()) {
// TODO - The index type is dependant of the target architecture, so
// actually we assume we target only 64 bits, we need to have some the size
@@ -55,6 +56,7 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ llvm::None,
};
}
if (auto lweTy = type.dyn_cast_or_null<
@@ -64,22 +66,76 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
if (fheContext.parameter.largeInteger.has_value()) {
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
}
size_t width;
uint64_t size = 0;
std::vector<int64_t> dims;
if (chunkInfo.hasValue()) {
width = chunkInfo->size;
assert(lweTy.getWidth() % chunkInfo->width == 0);
size = lweTy.getWidth() / chunkInfo->width;
dims.push_back(size);
} else {
width = (size_t)lweTy.getWidth();
}
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ lweTy.getWidth(),
/* .precision = */ width,
/* .crt = */ crt,
/*.sign = */ sign,
},
}),
/*.shape = */
{/*.width = */ (size_t)lweTy.getWidth(),
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/*.sign = */ sign},
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ sign,
},
/*.chunkInfo = */ chunkInfo,
};
}
// TODO: this a duplicate of the last if: should be removed when we remove
// chinked eint
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::FHE::ChunkedEncryptedIntegerType>()) {
bool sign = lweTy.isSignedInteger();
std::vector<int64_t> crt;
if (fheContext.parameter.largeInteger.has_value()) {
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
}
size_t width;
uint64_t size;
std::vector<int64_t> dims;
if (chunkInfo.hasValue()) {
width = chunkInfo->size;
assert(lweTy.getWidth() % chunkInfo->width == 0);
size = lweTy.getWidth() / chunkInfo->width;
dims.push_back(size);
} else {
width = (size_t)lweTy.getWidth();
}
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ sign,
},
/*.chunkInfo = */ chunkInfo,
};
}
if (auto lweTy = type.dyn_cast_or_null<
@@ -103,11 +159,12 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
/*.size = */ 0,
/*.sign = */ false,
},
/*.chunkInfo = */ llvm::None,
};
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance,
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, chunkInfo,
tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
@@ -126,7 +183,8 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext fheContext,
llvm::StringRef functionName, mlir::ModuleOp module,
int bitsOfSecurity) {
int bitsOfSecurity,
llvm::Optional<ChunkInfo> chunkInfo) {
const auto v0Curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
if (v0Curve == nullptr) {
@@ -216,7 +274,8 @@ createClientParametersForV0(V0FHEContext fheContext,
auto inputs = funcType.getInputs();
auto gateFromType = [&](mlir::Type ty) {
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty);
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance,
chunkInfo, ty);
};
for (auto inType : inputs) {
auto gate = gateFromType(inType);

View File

@@ -206,6 +206,12 @@ llvm::cl::list<uint64_t>
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore,
llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::opt<bool>
chunkIntegers("chunk-integers",
llvm::cl::desc("Whether to decompose integer into chunks or "
"not, default is false (to not chunk)"),
llvm::cl::init<bool>(false));
llvm::cl::opt<unsigned int> chunkSize(
"chunk-size",
llvm::cl::desc(
@@ -350,6 +356,7 @@ cmdlineCompilationOptions() {
cmdline::unrollLoopsWithSDFGConvertibleOps;
options.optimizeConcrete = cmdline::optimizeConcrete;
options.emitGPUOps = cmdline::emitGPUOps;
options.chunkIntegers = cmdline::chunkIntegers;
options.chunkSize = cmdline::chunkSize;
options.chunkWidth = cmdline::chunkWidth;

View File

@@ -45,11 +45,13 @@ TEST(Support, client_parameters_json_serde) {
/*.encryption = */ {
{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}, false}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4, false},
/*.chunkInfo = */ llvm::None,
},
{
/*.encryption = */ {
{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false},
/*.chunkInfo = */ llvm::None,
},
};
params0.outputs = {
@@ -57,6 +59,7 @@ TEST(Support, client_parameters_json_serde) {
/*.encryption = */ {
{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false},
/*.chunkInfo = */ llvm::None,
},
};
auto json = clientlib::toJSON(params0);