diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index b2ff3e6d6..e2f293794 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -37,15 +37,15 @@ typedef uint64_t GlweDimension; typedef std::string LweSecretKeyID; struct LweSecretKeyParam { - LweDimension size; + LweDimension dimension; void hash(size_t &seed); - inline uint64_t lweDimension() { return size; } - inline uint64_t lweSize() { return size + 1; } + inline uint64_t lweDimension() { return dimension; } + inline uint64_t lweSize() { return dimension + 1; } }; static bool operator==(const LweSecretKeyParam &lhs, const LweSecretKeyParam &rhs) { - return lhs.size == rhs.size; + return lhs.dimension == rhs.dimension; } typedef std::string BootstrapKeyID; diff --git a/compiler/lib/ClientLib/ClientParameters.cpp b/compiler/lib/ClientLib/ClientParameters.cpp index 8f675f94e..8caf9c4e6 100644 --- a/compiler/lib/ClientLib/ClientParameters.cpp +++ b/compiler/lib/ClientLib/ClientParameters.cpp @@ -25,7 +25,7 @@ static inline void hash_(std::size_t &seed, const T &v, Rest... rest) { hash_(seed, rest...); } -void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, size); } +void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, dimension); } void BootstrapKeyParam::hash(size_t &seed) { hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog, @@ -59,7 +59,7 @@ LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) { llvm::json::Value toJSON(const LweSecretKeyParam &v) { llvm::json::Object object{ - {"size", v.size}, + {"dimension", v.dimension}, }; return object; } @@ -71,12 +71,12 @@ bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v, p.report("should be an object"); return false; } - auto size = obj->getInteger("size"); - if (!size.hasValue()) { + auto dimension = obj->getInteger("dimension"); + if (!dimension.hasValue()) { p.report("missing size field"); return false; } - v.size = *size; + v.dimension = *dimension; return true; } diff --git a/compiler/lib/ClientLib/EncryptedArgs.cpp b/compiler/lib/ClientLib/EncryptedArgs.cpp index 1a21c9462..1c37ffdbc 100644 --- a/compiler/lib/ClientLib/EncryptedArgs.cpp +++ b/compiler/lib/ClientLib/EncryptedArgs.cpp @@ -43,7 +43,7 @@ EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr keySet) { } ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back(); - auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1; + auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize(); values_and_sizes.sizes.push_back(lweSize); values_and_sizes.values.resize(lweSize); @@ -106,7 +106,7 @@ EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef shape, } } if (input.encryption.hasValue()) { - auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1; + auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize(); values_and_sizes.sizes.push_back(lweSize); // Encrypted tensor: for now we support only 8 bits for encrypted tensor diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index c4b1aff0b..23d4c0674 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -139,7 +139,7 @@ outcome::checked KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, SecretRandomGenerator *generator) { LweSecretKey_u64 *sk; - sk = allocate_lwe_secret_key_u64({param.size}); + sk = allocate_lwe_secret_key_u64({param.dimension}); fill_lwe_secret_key_u64(sk, generator); @@ -163,7 +163,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param, // Allocate the bootstrap key LweBootstrapKey_u64 *bsk; - uint64_t total_dimension = outputSk->second.first.size; + uint64_t total_dimension = outputSk->second.first.dimension; assert(total_dimension % param.glweDimension == 0); @@ -171,7 +171,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param, bsk = allocate_lwe_bootstrap_key_u64( {param.level}, {param.baseLog}, {param.glweDimension}, - {inputSk->second.first.size}, {polynomialSize}); + {inputSk->second.first.dimension}, {polynomialSize}); // Store the bootstrap key bootstrapKeys[id] = {param, bsk}; @@ -208,8 +208,8 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param, LweKeyswitchKey_u64 *ksk; ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog}, - {inputSk->second.first.size}, - {outputSk->second.first.size}); + {inputSk->second.first.dimension}, + {outputSk->second.first.dimension}); // Store the keyswitch key keyswitchKeys[id] = {param, ksk}; @@ -228,7 +228,7 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) { } auto inputSk = inputs[argPos]; - size = std::get<1>(inputSk).size + 1; + size = std::get<1>(inputSk).lweSize(); *ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size); return outcome::success(); } diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 2d687e7b9..2293068a6 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -299,7 +299,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, // Allocate a buffer for ciphertexts, the size of the buffer is the number // of elements of the tensor * the size of the lwe ciphertext - auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1; + auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize(); uint64_t *ctBuffer = (uint64_t *)malloc(info.shape.size * lweSize * sizeof(uint64_t)); ciphertextBuffers.push_back(ctBuffer); @@ -337,7 +337,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, } // If encrypted +1 for the lwe size rank if (keySet.isInputEncrypted(pos)) { - inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).size + 1); + inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).lweSize()); rawArg[offset] = &inputs[offset]; offset++; } @@ -349,7 +349,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, if (keySet.isInputEncrypted(pos)) { inputs[offset + shape.size()] = (void *)stride; rawArg[offset + shape.size()] = &inputs[offset]; - stride *= keySet.getInputLweSecretKeyParam(pos).size + 1; + stride *= keySet.getInputLweSecretKeyParam(pos).lweSize(); } for (ssize_t i = shape.size() - 1; i >= 0; i--) { inputs[offset + i] = (void *)stride; @@ -493,7 +493,7 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res, } } else { // decrypt and fill the result buffer - auto lweSize = keySet.getOutputLweSecretKeyParam(pos).size + 1; + auto lweSize = keySet.getOutputLweSecretKeyParam(pos).lweSize(); for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) { uint64_t *ct = ((uint64_t *)alignedBytes) + o;