mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 08:31:31 -05:00
- added --compress-input compiler option which forces the use of seeded bootstrap keys and keyswitch keys - replaced the concrete-cpu FHE implementation with tfhe-rs Co-authored-by: Nikita Frolov <nf@mkmks.org>
956 lines
35 KiB
C++
956 lines
35 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "concretelang/Common/Transformers.h"
|
|
#include "capnp/any.h"
|
|
#include "concrete-cpu.h"
|
|
#include "concrete-protocol.capnp.h"
|
|
#include "concretelang/Common/CRT.h"
|
|
#include "concretelang/Common/Error.h"
|
|
#include "concretelang/Common/Keysets.h"
|
|
#include "concretelang/Common/Values.h"
|
|
#include "concretelang/Runtime/simulation.h"
|
|
#include <memory>
|
|
#include <stdlib.h>
|
|
#include <string>
|
|
|
|
using concretelang::error::Result;
|
|
using concretelang::keysets::ClientKeyset;
|
|
using concretelang::values::getCorrespondingPrecision;
|
|
using concretelang::values::Tensor;
|
|
using concretelang::values::TransportValue;
|
|
using concretelang::values::Value;
|
|
|
|
namespace concretelang {
|
|
namespace transformers {
|
|
|
|
/// A private type for value verifiers.
|
|
typedef std::function<Result<void>(const Value &)> ValueVerifier;
|
|
|
|
/// A private type for transport value verifiers.
|
|
typedef std::function<Result<void>(const TransportValue &)>
|
|
TransportValueVerifier;
|
|
|
|
/// A private type for transformers working purely on values.
|
|
typedef std::function<Value(Value)> Transformer;
|
|
|
|
Result<ValueVerifier> getIndexInputValueVerifier(
|
|
const Message<concreteprotocol::GateInfo> &gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasIndex()) {
|
|
return StringError("Tried to get index input value verifier for gate info "
|
|
"without proper type info.");
|
|
}
|
|
return [=](const Value &val) -> Result<void> {
|
|
auto type = gateInfo.asReader().getTypeInfo().getIndex();
|
|
if (!val.isCompatibleWithShape(type.getShape())) {
|
|
return StringError(
|
|
"Tried to transform index value with incompatible shape.");
|
|
}
|
|
if (val.getIntegerPrecision() != type.getIntegerPrecision()) {
|
|
return StringError(
|
|
"Tried to transform index value with incompatible integer "
|
|
"precision.");
|
|
}
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
Result<ValueVerifier> getObliviousValueVerifier() {
|
|
return [=](const Value &val) -> Result<void> { return outcome::success(); };
|
|
}
|
|
|
|
Result<ValueVerifier> getPlaintextInputValueVerifier(
|
|
const Message<concreteprotocol::GateInfo> &gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasPlaintext()) {
|
|
return StringError("Tried to get plaintext input value verifier for gate "
|
|
"info without proper type info.");
|
|
}
|
|
return [=](const Value &val) -> Result<void> {
|
|
auto type = gateInfo.asReader().getTypeInfo().getPlaintext();
|
|
if (!val.isCompatibleWithShape(type.getShape())) {
|
|
return StringError(
|
|
"Tried to transform plaintext value with incompatible shape.");
|
|
}
|
|
if (val.getIntegerPrecision() != type.getIntegerPrecision()) {
|
|
return StringError(
|
|
"Tried to transform plaintext value with incompatible integer "
|
|
"precision. Got " +
|
|
std::to_string(val.getIntegerPrecision()) + " expected " +
|
|
std::to_string(gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getPlaintext()
|
|
.getIntegerPrecision()));
|
|
}
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
Result<ValueVerifier> getLweCiphertextInputValueVerifier(
|
|
const Message<concreteprotocol::GateInfo> &gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError("Tried to get ciphertext input value verifier for gate "
|
|
"info without proper type info.");
|
|
}
|
|
|
|
if (gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.hasBoolean()) {
|
|
return [=](const Value &val) -> Result<void> {
|
|
auto type = gateInfo.asReader().getTypeInfo().getLweCiphertext();
|
|
if (!val.isCompatibleWithShape(type.getAbstractShape())) {
|
|
return StringError("Tried to transform ciphertext input value with "
|
|
"incompatible shape.");
|
|
}
|
|
if (val.getIntegerPrecision() != 64) {
|
|
return StringError("Tried to transform ciphertext input value "
|
|
"(boolean) with incompatible integer "
|
|
"precision. Got " +
|
|
std::to_string(val.getIntegerPrecision()) +
|
|
" expected 64");
|
|
}
|
|
if (val.isSigned()) {
|
|
return StringError("Tried to transform ciphertext input value "
|
|
"(boolean) with incompatible signedness.");
|
|
}
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
if (gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.hasInteger()) {
|
|
return [=](const Value &val) -> Result<void> {
|
|
auto type = gateInfo.asReader().getTypeInfo().getLweCiphertext();
|
|
if (!val.isCompatibleWithShape(type.getAbstractShape())) {
|
|
return StringError("Tried to transform ciphertext input value with "
|
|
"incompatible shape.");
|
|
}
|
|
if (val.getIntegerPrecision() != 64) {
|
|
return StringError("Tried to transform ciphertext input value with "
|
|
"incompatible integer "
|
|
"precision. Got " +
|
|
std::to_string(val.getIntegerPrecision()) +
|
|
" expected 64.");
|
|
}
|
|
if (val.isSigned() != type.getEncoding().getInteger().getIsSigned()) {
|
|
return StringError("Tried to transform ciphertext input value with "
|
|
"incompatible signedness.");
|
|
}
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
return StringError(
|
|
"Tried to get lwe ciphertext input verifier for wrongly defined gate.");
|
|
}
|
|
|
|
Result<ValueVerifier> getLweCiphertextOutputValueVerifier(
|
|
const Message<concreteprotocol::GateInfo> &gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError("Tried to get ciphertext output value verifier for gate "
|
|
"info without proper type info.");
|
|
}
|
|
|
|
return [=](const Value &val) -> Result<void> {
|
|
auto type = gateInfo.asReader().getTypeInfo().getLweCiphertext();
|
|
if (!val.isCompatibleWithShape(type.getConcreteShape())) {
|
|
return StringError("Tried to transform ciphertext output value with "
|
|
"incompatible shape.");
|
|
}
|
|
if (val.getIntegerPrecision() != 64) {
|
|
return StringError("Tried to transform ciphertext output value with "
|
|
"incompatible integer "
|
|
"precision. Got " +
|
|
std::to_string(val.getIntegerPrecision()) +
|
|
" expected 64");
|
|
}
|
|
if (val.isSigned()) {
|
|
return StringError("Tried to transform ciphertext output value with "
|
|
"incompatible signedness (signed).");
|
|
}
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
Result<TransportValueVerifier> getObliviousTransportValueVerifier() {
|
|
return [=](const TransportValue &val) -> Result<void> {
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
Result<TransportValueVerifier>
|
|
getTransportValueVerifier(const Message<concreteprotocol::GateInfo> &gateInfo) {
|
|
return [=](const TransportValue &transportVal) -> Result<void> {
|
|
if (!transportVal.asReader().hasPayload()) {
|
|
return StringError(
|
|
"Tried to transform a transport value without payload.");
|
|
}
|
|
if (!transportVal.asReader().hasRawInfo()) {
|
|
return StringError(
|
|
"Tried to transform a transport value without raw infos.");
|
|
}
|
|
if (!((capnp::AnyStruct::Reader)gateInfo.asReader().getRawInfo() ==
|
|
(capnp::AnyStruct::Reader)transportVal.asReader().getRawInfo())) {
|
|
std::string expected =
|
|
gateInfo.asReader().getRawInfo().toString().flatten().cStr();
|
|
std::string actual =
|
|
transportVal.asReader().getRawInfo().toString().flatten().cStr();
|
|
return StringError("Tried to transform transport value with incompatible "
|
|
"raw info.\nExpected: " +
|
|
expected + "\nActual: " + actual);
|
|
}
|
|
size_t expectedPayloadSize =
|
|
transportVal.asReader().getRawInfo().getIntegerPrecision() / 8;
|
|
for (auto dim :
|
|
transportVal.asReader().getRawInfo().getShape().getDimensions()) {
|
|
expectedPayloadSize *= dim;
|
|
}
|
|
size_t actualPayloadSize = 0;
|
|
for (auto blob : transportVal.asReader().getPayload().getData()) {
|
|
actualPayloadSize += blob.size();
|
|
}
|
|
if (actualPayloadSize != expectedPayloadSize) {
|
|
return StringError("Tried to transform a transport value with "
|
|
"incompatible payload size.");
|
|
}
|
|
if (!transportVal.asReader().getTypeInfo().hasIndex() &&
|
|
!transportVal.asReader().getTypeInfo().hasPlaintext() &&
|
|
!transportVal.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError(
|
|
"Tried to transform a transport value without type infos.");
|
|
}
|
|
if ((capnp::AnyStruct::Reader)gateInfo.asReader().getTypeInfo() !=
|
|
(capnp::AnyStruct::Reader)transportVal.asReader().getTypeInfo()) {
|
|
std::string expected =
|
|
gateInfo.asReader().getTypeInfo().toString().flatten().cStr();
|
|
std::string actual =
|
|
transportVal.asReader().getTypeInfo().toString().flatten().cStr();
|
|
return StringError("Tried to transform transport value with incompatible "
|
|
"type info.\nExpected: " +
|
|
expected + "\nActual: " + actual);
|
|
}
|
|
return outcome::success();
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getBooleanEncodingTransformer() {
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
outputTensor.values[i] = inputTensor.values[i] << 61;
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getNativeModeIntegerEncodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
auto width = info.asReader().getWidth();
|
|
auto isSigned = info.asReader().getIsSigned();
|
|
|
|
return [=](Value input) {
|
|
Tensor<uint64_t> inputTensor;
|
|
if (isSigned) {
|
|
inputTensor = (Tensor<uint64_t>)input.getTensor<int64_t>().value();
|
|
} else {
|
|
inputTensor = input.getTensor<uint64_t>().value();
|
|
}
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
outputTensor.values[i] = inputTensor.values[i] << (64 - (width + 1));
|
|
}
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getNativeModeIntegerDecodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
auto precision = info.asReader().getWidth();
|
|
auto isSigned = info.asReader().getIsSigned();
|
|
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
auto input = inputTensor.values[i];
|
|
|
|
// Decode unsigned integer
|
|
uint64_t output = input >> (64 - precision - 2);
|
|
auto carry = output % 2;
|
|
uint64_t mod = (((uint64_t)1) << (precision + 1));
|
|
output = ((output >> 1) + carry) % mod;
|
|
|
|
// Further decode signed integers.
|
|
if (isSigned) {
|
|
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
|
|
if (output >= maxPos) { // The output is actually negative.
|
|
// Set the preceding bits to zero
|
|
output |= UINT64_MAX << precision;
|
|
// This makes sure when the value is cast to int64, it has the
|
|
// correct value
|
|
};
|
|
}
|
|
|
|
outputTensor.values[i] = output;
|
|
}
|
|
|
|
Value output;
|
|
if (isSigned) {
|
|
auto signedOutputTensor = (Tensor<int64_t>)outputTensor;
|
|
output = Value{signedOutputTensor};
|
|
} else {
|
|
output = Value{outputTensor};
|
|
}
|
|
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getChunkedModeIntegerEncodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
auto size = info.asReader().getMode().getChunked().getSize();
|
|
auto chunkWidth = info.asReader().getMode().getChunked().getWidth();
|
|
auto isSigned = info.asReader().getIsSigned();
|
|
uint64_t mask = (1 << chunkWidth) - 1;
|
|
|
|
return [=](Value input) {
|
|
Tensor<uint64_t> inputTensor;
|
|
if (isSigned) {
|
|
inputTensor = (Tensor<uint64_t>)input.getTensor<int64_t>().value();
|
|
} else {
|
|
inputTensor = input.getTensor<uint64_t>().value();
|
|
}
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
outputTensor.dimensions.push_back(size);
|
|
outputTensor.values.resize(outputTensor.values.size() * size);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
auto value = inputTensor.values[i];
|
|
for (size_t j = 0; j < size; j++) {
|
|
auto chunk = value & mask;
|
|
outputTensor.values[i * size + j] = ((uint64_t)chunk)
|
|
<< (64 - (chunkWidth + 1));
|
|
value >>= chunkWidth;
|
|
}
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getChunkedModeIntegerDecodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
auto chunkSize = info.asReader().getMode().getChunked().getSize();
|
|
auto chunkWidth = info.asReader().getMode().getChunked().getWidth();
|
|
auto isSigned = info.asReader().getIsSigned();
|
|
uint64_t mask = (1 << chunkWidth) - 1;
|
|
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
outputTensor.dimensions.pop_back();
|
|
outputTensor.values.resize(outputTensor.values.size() / chunkSize);
|
|
|
|
for (size_t i = 0; i < outputTensor.values.size(); i++) {
|
|
uint64_t output = 0;
|
|
for (size_t j = 0; j < chunkSize; j++) {
|
|
auto input = inputTensor.values[i * chunkSize + j];
|
|
|
|
// Decode unsigned integer
|
|
uint64_t chunkOutput = input >> (64 - chunkWidth - 2);
|
|
auto carry = chunkOutput % 2;
|
|
uint64_t mod = (((uint64_t)1) << (chunkWidth + 1));
|
|
chunkOutput = ((chunkOutput >> 1) + carry) % mod;
|
|
|
|
// Further decode signed integers.
|
|
if (isSigned) {
|
|
uint64_t maxPos = (((uint64_t)1) << (chunkWidth - 1));
|
|
if (output >= maxPos) { // The output is actually negative.
|
|
// Set the preceding bits to zero
|
|
chunkOutput |= UINT64_MAX << chunkWidth;
|
|
// This makes sure when the value is cast to int64, it has the
|
|
// correct value
|
|
};
|
|
}
|
|
|
|
chunkOutput &= mask;
|
|
output += chunkOutput << (chunkWidth * j);
|
|
}
|
|
outputTensor.values[i] = output;
|
|
}
|
|
|
|
Value output;
|
|
if (isSigned) {
|
|
auto signedOutputTensor = (Tensor<int64_t>)outputTensor;
|
|
output = Value{signedOutputTensor};
|
|
} else {
|
|
output = Value{outputTensor};
|
|
}
|
|
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getCrtModeIntegerEncodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
std::vector<int64_t> moduli;
|
|
for (auto modulus : info.asReader().getMode().getCrt().getModuli()) {
|
|
moduli.push_back(modulus);
|
|
}
|
|
auto size = info.asReader().getMode().getCrt().getModuli().size();
|
|
auto productOfModuli = concretelang::crt::productOfModuli(moduli);
|
|
auto isSigned = info.asReader().getIsSigned();
|
|
|
|
return [=](Value input) {
|
|
Tensor<uint64_t> inputTensor;
|
|
if (isSigned) {
|
|
inputTensor = (Tensor<uint64_t>)input.getTensor<int64_t>().value();
|
|
} else {
|
|
inputTensor = input.getTensor<uint64_t>().value();
|
|
}
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
outputTensor.dimensions.push_back(size);
|
|
outputTensor.values.resize(outputTensor.values.size() * size);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
auto value = inputTensor.values[i];
|
|
for (size_t j = 0; j < (size_t)size; j++) {
|
|
outputTensor.values[i * size + j] =
|
|
concretelang::crt::encode(value, moduli[j], productOfModuli);
|
|
}
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getCrtModeIntegerDecodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> info) {
|
|
std::vector<int64_t> moduli;
|
|
for (auto modulus : info.asReader().getMode().getCrt().getModuli()) {
|
|
moduli.push_back(modulus);
|
|
}
|
|
std::vector<int64_t> remainders(
|
|
info.asReader().getMode().getCrt().getModuli().size());
|
|
auto size = info.asReader().getMode().getCrt().getModuli().size();
|
|
auto isSigned = info.asReader().getIsSigned();
|
|
|
|
return [=](Value input) mutable {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
outputTensor.dimensions.pop_back();
|
|
outputTensor.values.resize(outputTensor.values.size() / size);
|
|
|
|
for (size_t i = 0; i < outputTensor.values.size(); i++) {
|
|
for (size_t j = 0; j < (size_t)size; j++) {
|
|
remainders[j] =
|
|
crt::decode(inputTensor.values[i * size + j], moduli[j]);
|
|
}
|
|
|
|
// Compute the inverse crt
|
|
uint64_t output = crt::iCrt(moduli, remainders);
|
|
|
|
// Further decode signed integers
|
|
if (isSigned) {
|
|
uint64_t maxPos = 1;
|
|
for (auto prime : moduli) {
|
|
maxPos *= prime;
|
|
}
|
|
maxPos /= 2;
|
|
if (output >= maxPos) {
|
|
output -= maxPos * 2;
|
|
}
|
|
}
|
|
outputTensor.values[i] = output;
|
|
}
|
|
|
|
Value output;
|
|
if (isSigned) {
|
|
auto signedOutputTensor = (Tensor<int64_t>)outputTensor;
|
|
output = Value{signedOutputTensor};
|
|
} else {
|
|
output = Value{outputTensor};
|
|
}
|
|
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getEncryptionTransformer(
|
|
ClientKeyset keyset,
|
|
const Message<concreteprotocol::LweCiphertextEncryptionInfo> &info,
|
|
std::shared_ptr<csprng::EncryptionCSPRNG> csprng) {
|
|
|
|
auto key = keyset.lweSecretKeys[info.asReader().getKeyId()];
|
|
auto lweDimension = info.asReader().getLweDimension();
|
|
auto lweSize = lweDimension + 1;
|
|
auto variance = info.asReader().getVariance();
|
|
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
outputTensor.dimensions.push_back(lweSize);
|
|
outputTensor.values.resize(outputTensor.values.size() * lweSize);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
concrete_cpu_encrypt_lwe_ciphertext_u64(
|
|
key.getRawPtr(), &outputTensor.values[i * lweSize],
|
|
inputTensor.values[i], lweDimension, variance, csprng->ptr);
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getEncryptionSimulationTransformer(
|
|
const Message<concreteprotocol::LweCiphertextEncryptionInfo> &info,
|
|
std::shared_ptr<csprng::EncryptionCSPRNG> csprng) {
|
|
|
|
auto lweDimension = info.asReader().getLweDimension();
|
|
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
outputTensor.values[i] = sim_encrypt_lwe_u64(
|
|
inputTensor.values[i], lweDimension, (void *)(*csprng).ptr);
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getDecryptionTransformer(
|
|
ClientKeyset keyset,
|
|
const Message<concreteprotocol::LweCiphertextEncryptionInfo> &info) {
|
|
|
|
auto key = keyset.lweSecretKeys[info.asReader().getKeyId()];
|
|
auto lweDimension = info.asReader().getLweDimension();
|
|
auto lweSize = lweDimension + 1;
|
|
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
outputTensor.dimensions.pop_back();
|
|
outputTensor.values.resize(outputTensor.values.size() / lweSize);
|
|
|
|
for (size_t i = 0; i < outputTensor.values.size(); i++) {
|
|
concrete_cpu_decrypt_lwe_ciphertext_u64(
|
|
key.getRawPtr(), &inputTensor.values[i * lweSize], lweDimension,
|
|
&outputTensor.values[i]);
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getDecryptionSimulationTransformer() {
|
|
return [](auto input) { return input; };
|
|
}
|
|
|
|
Result<Transformer> getNoneCompressionTransformer() {
|
|
return [](auto input) { return input; };
|
|
}
|
|
|
|
Result<Transformer> getNoneDecompressionTransformer() {
|
|
return [](auto input) { return input; };
|
|
}
|
|
|
|
Result<Transformer> getBooleanDecodingTransformer() {
|
|
return [=](Value input) {
|
|
auto inputTensor = input.getTensor<uint64_t>().value();
|
|
auto outputTensor = Tensor<uint64_t>(inputTensor);
|
|
|
|
for (size_t i = 0; i < inputTensor.values.size(); i++) {
|
|
auto input = inputTensor.values[i];
|
|
uint64_t output = input >> 60;
|
|
uint64_t carry = output % 2;
|
|
uint64_t mod = 1 << 3;
|
|
output = ((output >> 1) + carry) % mod;
|
|
outputTensor.values[i] = output;
|
|
}
|
|
|
|
return Value{outputTensor};
|
|
};
|
|
}
|
|
|
|
Result<Transformer> getIntegerEncodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
if (info.asReader().getMode().hasNative()) {
|
|
return getNativeModeIntegerEncodingTransformer(info);
|
|
} else if (info.asReader().getMode().hasChunked()) {
|
|
return getChunkedModeIntegerEncodingTransformer(info);
|
|
} else if (info.asReader().getMode().hasCrt()) {
|
|
return getCrtModeIntegerEncodingTransformer(info);
|
|
} else {
|
|
return StringError(
|
|
"Tried to construct integer encoding transformer without mode.");
|
|
}
|
|
}
|
|
|
|
Result<Transformer> getIntegerDecodingTransformer(
|
|
const Message<concreteprotocol::IntegerCiphertextEncodingInfo> &info) {
|
|
if (info.asReader().getMode().hasNative()) {
|
|
return getNativeModeIntegerDecodingTransformer(info);
|
|
} else if (info.asReader().getMode().hasChunked()) {
|
|
return getChunkedModeIntegerDecodingTransformer(info);
|
|
} else if (info.asReader().getMode().hasCrt()) {
|
|
return getCrtModeIntegerDecodingTransformer(info);
|
|
} else {
|
|
return StringError(
|
|
"Tried to construct integer decoding transformer without mode.");
|
|
}
|
|
}
|
|
|
|
Result<InputTransformer> TransformerFactory::getIndexInputTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasIndex()) {
|
|
return StringError(
|
|
"Tried to get index input transformer from non-index gate info.");
|
|
}
|
|
OUTCOME_TRY(auto verify, getIndexInputValueVerifier(gateInfo));
|
|
return [=](Value val) -> Result<TransportValue> {
|
|
OUTCOME_TRYV(verify(val));
|
|
if (val.isSigned()) {
|
|
val = val.toUnsigned();
|
|
}
|
|
auto output = val.intoRawTransportValue();
|
|
output.asBuilder().initTypeInfo().setIndex(
|
|
gateInfo.asReader().getTypeInfo().getIndex());
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<OutputTransformer> TransformerFactory::getIndexOutputTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasIndex()) {
|
|
return StringError(
|
|
"Tried to get index output transformer from non-index gate info.");
|
|
}
|
|
OUTCOME_TRY(auto verify, getTransportValueVerifier(gateInfo));
|
|
return [=](TransportValue transportVal) -> Result<Value> {
|
|
OUTCOME_TRYV(verify(transportVal));
|
|
return Value::fromRawTransportValue(transportVal);
|
|
};
|
|
}
|
|
|
|
Result<ArgTransformer> TransformerFactory::getIndexArgTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasIndex()) {
|
|
return StringError(
|
|
"Tried to get index arg transformer from non-index gate info.");
|
|
}
|
|
// The arg transformer is the same as the output transformer here ...
|
|
return getIndexOutputTransformer(std::move(gateInfo));
|
|
}
|
|
|
|
Result<ReturnTransformer> TransformerFactory::getIndexReturnTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasIndex()) {
|
|
return StringError(
|
|
"Tried to get index return transformer from non-index gate info.");
|
|
}
|
|
// The return transformer is the same as the input transformer here ...
|
|
return getIndexInputTransformer(std::move(gateInfo));
|
|
}
|
|
|
|
Result<InputTransformer> TransformerFactory::getPlaintextInputTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasPlaintext()) {
|
|
return StringError("Tried to get plaintext input transformer from "
|
|
"non-plaintext gate info.");
|
|
}
|
|
OUTCOME_TRY(auto verify, getPlaintextInputValueVerifier(gateInfo));
|
|
return [=](Value val) -> Result<TransportValue> {
|
|
OUTCOME_TRYV(verify(val));
|
|
if (val.isSigned()) {
|
|
val = val.toUnsigned();
|
|
}
|
|
auto output = val.intoRawTransportValue();
|
|
output.asBuilder().initTypeInfo().setPlaintext(
|
|
gateInfo.asReader().getTypeInfo().getPlaintext());
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<OutputTransformer> TransformerFactory::getPlaintextOutputTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasPlaintext()) {
|
|
return StringError("Tried to get plaintext output transformer from "
|
|
"non-plaintext gate info.");
|
|
}
|
|
OUTCOME_TRY(auto verify, getTransportValueVerifier(gateInfo));
|
|
return [=](TransportValue transportVal) -> Result<Value> {
|
|
OUTCOME_TRYV(verify(transportVal));
|
|
return Value::fromRawTransportValue(transportVal);
|
|
};
|
|
}
|
|
|
|
Result<ArgTransformer> TransformerFactory::getPlaintextArgTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasPlaintext()) {
|
|
return StringError("Tried to get plaintext arg transformer from "
|
|
"non-plaintext gate info.");
|
|
}
|
|
// The arg transformer is the same as the output transformer here ...
|
|
return getPlaintextOutputTransformer(std::move(gateInfo));
|
|
}
|
|
|
|
Result<ReturnTransformer> TransformerFactory::getPlaintextReturnTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasPlaintext()) {
|
|
return StringError("Tried to get plaintext return transformer from "
|
|
"non-plaintext gate info.");
|
|
}
|
|
// The return transformer is the same as the input transformer here ...
|
|
return getPlaintextInputTransformer(std::move(gateInfo));
|
|
}
|
|
|
|
Result<InputTransformer> TransformerFactory::getLweCiphertextInputTransformer(
|
|
ClientKeyset keyset, Message<concreteprotocol::GateInfo> gateInfo,
|
|
std::shared_ptr<csprng::EncryptionCSPRNG> csprng, bool useSimulation) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError("Tried to get lwe ciphertext input transformer from "
|
|
"non-ciphertext gate info.");
|
|
}
|
|
if (!useSimulation) {
|
|
auto keyid = gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncryption()
|
|
.getKeyId();
|
|
if (keyid >= keyset.lweSecretKeys.size()) {
|
|
return StringError(
|
|
"Tried to generate lwe ciphertext input transformer with "
|
|
"key id unavailable");
|
|
}
|
|
}
|
|
|
|
/// Generating the encoding transformer.
|
|
Transformer encodingTransformer;
|
|
if (gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.hasBoolean()) {
|
|
OUTCOME_TRY(encodingTransformer, getBooleanEncodingTransformer());
|
|
} else if (gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.hasInteger()) {
|
|
OUTCOME_TRY(encodingTransformer,
|
|
getIntegerEncodingTransformer(gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.getInteger()));
|
|
} else {
|
|
return StringError("Malformed gate info");
|
|
}
|
|
|
|
/// Generating the encryption transformer.
|
|
Transformer encryptionTransformer;
|
|
if (useSimulation) {
|
|
OUTCOME_TRY(encryptionTransformer,
|
|
getEncryptionSimulationTransformer(gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncryption(),
|
|
csprng));
|
|
} else {
|
|
OUTCOME_TRY(encryptionTransformer,
|
|
getEncryptionTransformer(keyset,
|
|
gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncryption(),
|
|
csprng));
|
|
}
|
|
|
|
/// Generating the compression transformer.
|
|
Transformer compressionTransformer;
|
|
if (gateInfo.asReader().getTypeInfo().getLweCiphertext().getCompression() ==
|
|
concreteprotocol::Compression::NONE) {
|
|
OUTCOME_TRY(compressionTransformer, getNoneCompressionTransformer());
|
|
} else {
|
|
return StringError(
|
|
"Only none compression is currently supported for lwe ciphertext "
|
|
"currently.");
|
|
}
|
|
|
|
OUTCOME_TRY(auto verify, getLweCiphertextInputValueVerifier(gateInfo));
|
|
return [=](Value val) -> Result<TransportValue> {
|
|
OUTCOME_TRYV(verify(val));
|
|
auto output =
|
|
compressionTransformer(encryptionTransformer(encodingTransformer(val)))
|
|
.intoRawTransportValue();
|
|
output.asBuilder().initTypeInfo().setLweCiphertext(
|
|
gateInfo.asReader().getTypeInfo().getLweCiphertext());
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<ArgTransformer> TransformerFactory::getLweCiphertextArgTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo, bool useSimulation) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError("Tried to get lwe ciphertext arg transformer from "
|
|
"non-ciphertext gate info.");
|
|
}
|
|
|
|
/// Generating the decompression transformer.
|
|
Transformer decompressionTransformer;
|
|
if (gateInfo.asReader().getTypeInfo().getLweCiphertext().getCompression() ==
|
|
concreteprotocol::Compression::NONE) {
|
|
OUTCOME_TRY(decompressionTransformer, getNoneDecompressionTransformer());
|
|
} else {
|
|
return StringError(
|
|
"Only none compression is currently supported for lwe ciphertext "
|
|
"currently.");
|
|
}
|
|
|
|
// Generating the verifier.
|
|
TransportValueVerifier verify;
|
|
if (useSimulation) {
|
|
OUTCOME_TRY(verify, getObliviousTransportValueVerifier());
|
|
} else {
|
|
OUTCOME_TRY(verify, getTransportValueVerifier(gateInfo));
|
|
}
|
|
|
|
return [=](TransportValue transportVal) -> Result<Value> {
|
|
OUTCOME_TRYV(verify(transportVal));
|
|
return decompressionTransformer(Value::fromRawTransportValue(transportVal));
|
|
};
|
|
}
|
|
|
|
Result<ReturnTransformer> TransformerFactory::getLweCiphertextReturnTransformer(
|
|
Message<concreteprotocol::GateInfo> gateInfo, bool useSimulation) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError("Tried to get lwe ciphertext return transformer from "
|
|
"non-ciphertext gate info.");
|
|
}
|
|
|
|
/// Generating the compression transformer.
|
|
Transformer compressionTransformer;
|
|
if (gateInfo.asReader().getTypeInfo().getLweCiphertext().getCompression() ==
|
|
concreteprotocol::Compression::NONE) {
|
|
OUTCOME_TRY(compressionTransformer, getNoneCompressionTransformer());
|
|
} else {
|
|
return StringError(
|
|
"Only none compression is currently supported for lwe ciphertext "
|
|
"currently.");
|
|
}
|
|
|
|
// Generating the verifier.
|
|
ValueVerifier verify;
|
|
if (useSimulation) {
|
|
OUTCOME_TRY(verify, getObliviousValueVerifier());
|
|
} else {
|
|
OUTCOME_TRY(verify, getLweCiphertextOutputValueVerifier(gateInfo));
|
|
}
|
|
|
|
return [=](Value val) -> Result<TransportValue> {
|
|
OUTCOME_TRYV(verify(val));
|
|
auto output = compressionTransformer(val).intoRawTransportValue();
|
|
output.asBuilder().initTypeInfo().setLweCiphertext(
|
|
gateInfo.asReader().getTypeInfo().getLweCiphertext());
|
|
return output;
|
|
};
|
|
}
|
|
|
|
Result<OutputTransformer> TransformerFactory::getLweCiphertextOutputTransformer(
|
|
ClientKeyset keyset, Message<concreteprotocol::GateInfo> gateInfo,
|
|
bool useSimulation) {
|
|
if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) {
|
|
return StringError("Tried to get lwe ciphertext output transformer from "
|
|
"non-ciphertext gate info.");
|
|
}
|
|
if (!useSimulation) {
|
|
auto keyid = gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncryption()
|
|
.getKeyId();
|
|
if (keyid >= keyset.lweSecretKeys.size()) {
|
|
return StringError(
|
|
"Tried to generate lwe ciphertext output transformer with "
|
|
"key id unavailable");
|
|
}
|
|
}
|
|
|
|
/// Generating the decompression transformer.
|
|
Transformer decompressionTransformer;
|
|
if (gateInfo.asReader().getTypeInfo().getLweCiphertext().getCompression() ==
|
|
concreteprotocol::Compression::NONE) {
|
|
OUTCOME_TRY(decompressionTransformer, getNoneDecompressionTransformer());
|
|
} else {
|
|
return StringError(
|
|
"Only none compression is currently supported for lwe ciphertext "
|
|
"currently.");
|
|
}
|
|
|
|
/// Generating the decryption transformer.
|
|
Transformer decryptionTransformer;
|
|
if (useSimulation) {
|
|
OUTCOME_TRY(decryptionTransformer, getDecryptionSimulationTransformer());
|
|
} else {
|
|
OUTCOME_TRY(decryptionTransformer,
|
|
getDecryptionTransformer(keyset, gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncryption()));
|
|
}
|
|
|
|
/// Generating the decoding transformer.
|
|
Transformer decodingTransformer;
|
|
if (gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.hasBoolean()) {
|
|
OUTCOME_TRY(decodingTransformer, getBooleanDecodingTransformer());
|
|
} else if (gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.hasInteger()) {
|
|
OUTCOME_TRY(decodingTransformer,
|
|
getIntegerDecodingTransformer(gateInfo.asReader()
|
|
.getTypeInfo()
|
|
.getLweCiphertext()
|
|
.getEncoding()
|
|
.getInteger()));
|
|
} else {
|
|
return StringError("Malformed gate info");
|
|
}
|
|
|
|
// Generating the verifier.
|
|
TransportValueVerifier verify;
|
|
if (useSimulation) {
|
|
OUTCOME_TRY(verify, getObliviousTransportValueVerifier());
|
|
} else {
|
|
OUTCOME_TRY(verify, getTransportValueVerifier(gateInfo));
|
|
}
|
|
|
|
return [=](TransportValue transportVal) -> Result<Value> {
|
|
OUTCOME_TRYV(verify(transportVal));
|
|
return decodingTransformer(decryptionTransformer(
|
|
decompressionTransformer(Value::fromRawTransportValue(transportVal))));
|
|
};
|
|
}
|
|
|
|
} // namespace transformers
|
|
} // namespace concretelang
|