fix: rename LweSecretKeyParam.size dimension

This commit is contained in:
Mayeul@Zama
2022-02-24 17:22:35 +01:00
committed by mayeul-zama
parent 0d7c3570cb
commit cee07d2440
5 changed files with 21 additions and 21 deletions

View File

@@ -37,15 +37,15 @@ typedef uint64_t GlweDimension;
typedef std::string LweSecretKeyID;
struct LweSecretKeyParam {
LweDimension size;
LweDimension dimension;
void hash(size_t &seed);
inline uint64_t lweDimension() { return size; }
inline uint64_t lweSize() { return size + 1; }
inline uint64_t lweDimension() { return dimension; }
inline uint64_t lweSize() { return dimension + 1; }
};
static bool operator==(const LweSecretKeyParam &lhs,
const LweSecretKeyParam &rhs) {
return lhs.size == rhs.size;
return lhs.dimension == rhs.dimension;
}
typedef std::string BootstrapKeyID;

View File

@@ -25,7 +25,7 @@ static inline void hash_(std::size_t &seed, const T &v, Rest... rest) {
hash_(seed, rest...);
}
void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, size); }
void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, dimension); }
void BootstrapKeyParam::hash(size_t &seed) {
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
@@ -59,7 +59,7 @@ LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) {
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
llvm::json::Object object{
{"size", v.size},
{"dimension", v.dimension},
};
return object;
}
@@ -71,12 +71,12 @@ bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
p.report("should be an object");
return false;
}
auto size = obj->getInteger("size");
if (!size.hasValue()) {
auto dimension = obj->getInteger("dimension");
if (!dimension.hasValue()) {
p.report("missing size field");
return false;
}
v.size = *size;
v.dimension = *dimension;
return true;
}

View File

@@ -43,7 +43,7 @@ EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1;
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
values_and_sizes.values.resize(lweSize);
@@ -106,7 +106,7 @@ EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
}
}
if (input.encryption.hasValue()) {
auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1;
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
// Encrypted tensor: for now we support only 8 bits for encrypted tensor

View File

@@ -139,7 +139,7 @@ outcome::checked<void, StringError>
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator) {
LweSecretKey_u64 *sk;
sk = allocate_lwe_secret_key_u64({param.size});
sk = allocate_lwe_secret_key_u64({param.dimension});
fill_lwe_secret_key_u64(sk, generator);
@@ -163,7 +163,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
// Allocate the bootstrap key
LweBootstrapKey_u64 *bsk;
uint64_t total_dimension = outputSk->second.first.size;
uint64_t total_dimension = outputSk->second.first.dimension;
assert(total_dimension % param.glweDimension == 0);
@@ -171,7 +171,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
bsk = allocate_lwe_bootstrap_key_u64(
{param.level}, {param.baseLog}, {param.glweDimension},
{inputSk->second.first.size}, {polynomialSize});
{inputSk->second.first.dimension}, {polynomialSize});
// Store the bootstrap key
bootstrapKeys[id] = {param, bsk};
@@ -208,8 +208,8 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
LweKeyswitchKey_u64 *ksk;
ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog},
{inputSk->second.first.size},
{outputSk->second.first.size});
{inputSk->second.first.dimension},
{outputSk->second.first.dimension});
// Store the keyswitch key
keyswitchKeys[id] = {param, ksk};
@@ -228,7 +228,7 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
}
auto inputSk = inputs[argPos];
size = std::get<1>(inputSk).size + 1;
size = std::get<1>(inputSk).lweSize();
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
return outcome::success();
}

View File

@@ -299,7 +299,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
// Allocate a buffer for ciphertexts, the size of the buffer is the number
// of elements of the tensor * the size of the lwe ciphertext
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
uint64_t *ctBuffer =
(uint64_t *)malloc(info.shape.size * lweSize * sizeof(uint64_t));
ciphertextBuffers.push_back(ctBuffer);
@@ -337,7 +337,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
}
// If encrypted +1 for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).size + 1);
inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).lweSize());
rawArg[offset] = &inputs[offset];
offset++;
}
@@ -349,7 +349,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
if (keySet.isInputEncrypted(pos)) {
inputs[offset + shape.size()] = (void *)stride;
rawArg[offset + shape.size()] = &inputs[offset];
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
stride *= keySet.getInputLweSecretKeyParam(pos).lweSize();
}
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
inputs[offset + i] = (void *)stride;
@@ -493,7 +493,7 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
}
} else {
// decrypt and fill the result buffer
auto lweSize = keySet.getOutputLweSecretKeyParam(pos).size + 1;
auto lweSize = keySet.getOutputLweSecretKeyParam(pos).lweSize();
for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) {
uint64_t *ct = ((uint64_t *)alignedBytes) + o;