mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(compiler/bindings): create bindings for value management
This commit is contained in:
@@ -151,6 +151,12 @@ keySetUnserialize(const std::string &buffer);
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
keySetSerialize(concretelang::clientlib::KeySet &keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData
|
||||
valueUnserialize(const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value);
|
||||
|
||||
/// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
|
||||
@@ -112,17 +112,15 @@ private:
|
||||
class PublicArguments {
|
||||
public:
|
||||
PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&ciphertextBuffers);
|
||||
std::vector<clientlib::SharedScalarOrTensorData> &buffers);
|
||||
~PublicArguments();
|
||||
PublicArguments(PublicArguments &other) = delete;
|
||||
PublicArguments(PublicArguments &&other) = delete;
|
||||
|
||||
static outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
unserialize(ClientParameters &expectedParams, std::istream &istream);
|
||||
unserialize(const ClientParameters &expectedParams, std::istream &istream);
|
||||
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
std::vector<ScalarOrTensorData> &getArguments() { return arguments; }
|
||||
std::vector<SharedScalarOrTensorData> &getArguments() { return arguments; }
|
||||
ClientParameters &getClientParameters() { return clientParameters; }
|
||||
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
@@ -133,7 +131,7 @@ private:
|
||||
|
||||
ClientParameters clientParameters;
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<ScalarOrTensorData> arguments;
|
||||
std::vector<SharedScalarOrTensorData> arguments;
|
||||
};
|
||||
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
@@ -141,7 +139,7 @@ private:
|
||||
struct PublicResult {
|
||||
|
||||
PublicResult(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&buffers = {})
|
||||
std::vector<SharedScalarOrTensorData> &&buffers = {})
|
||||
: clientParameters(clientParameters), buffers(std::move(buffers)){};
|
||||
|
||||
PublicResult(PublicResult &) = delete;
|
||||
@@ -150,17 +148,18 @@ struct PublicResult {
|
||||
/// @param argPos The position of the value in the PublicResult
|
||||
/// @return Either the value or an error if there are no value at this
|
||||
/// position
|
||||
outcome::checked<ScalarOrTensorData, StringError> getValue(size_t argPos) {
|
||||
outcome::checked<SharedScalarOrTensorData, StringError>
|
||||
getValue(size_t argPos) {
|
||||
if (argPos >= buffers.size()) {
|
||||
return StringError("result #") << argPos << " does not exists";
|
||||
}
|
||||
return std::move(buffers[argPos]);
|
||||
return buffers[argPos];
|
||||
}
|
||||
|
||||
/// Create a public result from buffers.
|
||||
static std::unique_ptr<PublicResult>
|
||||
fromBuffers(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&buffers) {
|
||||
std::vector<SharedScalarOrTensorData> &&buffers) {
|
||||
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
|
||||
}
|
||||
|
||||
@@ -182,7 +181,7 @@ struct PublicResult {
|
||||
outcome::checked<T, StringError> asClearTextScalar(KeySet &keySet,
|
||||
size_t pos) {
|
||||
ValueDecrypter decrypter(keySet, clientParameters);
|
||||
auto &data = buffers[pos];
|
||||
auto &data = buffers[pos].get();
|
||||
return decrypter.template decrypt<T>(data, pos);
|
||||
}
|
||||
|
||||
@@ -192,7 +191,7 @@ struct PublicResult {
|
||||
outcome::checked<std::vector<T>, StringError>
|
||||
asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
ValueDecrypter decrypter(keySet, clientParameters);
|
||||
return decrypter.template decryptTensor<T>(buffers[pos], pos);
|
||||
return decrypter.template decryptTensor<T>(buffers[pos].get(), pos);
|
||||
}
|
||||
|
||||
/// Return the shape of the clear tensor of a result.
|
||||
@@ -205,7 +204,7 @@ struct PublicResult {
|
||||
// private: TODO tmp
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
ClientParameters clientParameters;
|
||||
std::vector<ScalarOrTensorData> buffers;
|
||||
std::vector<SharedScalarOrTensorData> buffers;
|
||||
};
|
||||
|
||||
/// Helper function to convert from MemRefDescriptor to
|
||||
|
||||
@@ -96,10 +96,9 @@ std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &
|
||||
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &sotd,
|
||||
std::ostream &ostream);
|
||||
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
|
||||
std::ostream &serializeVectorOfScalarOrTensorData(
|
||||
const std::vector<SharedScalarOrTensorData> &sotd, std::ostream &ostream);
|
||||
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
|
||||
unserializeVectorOfScalarOrTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
|
||||
|
||||
@@ -689,7 +689,7 @@ public:
|
||||
|
||||
// Returns a void pointer to the first element of a flat
|
||||
// representation of the tensor
|
||||
void *getValuesAsOpaquePointer() {
|
||||
void *getValuesAsOpaquePointer() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return static_cast<void *>(values.u64->data());
|
||||
@@ -879,6 +879,19 @@ public:
|
||||
return *tensor;
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedScalarOrTensorData {
|
||||
std::shared_ptr<ScalarOrTensorData> inner;
|
||||
|
||||
SharedScalarOrTensorData(std::shared_ptr<ScalarOrTensorData> inner)
|
||||
: inner{inner} {}
|
||||
|
||||
SharedScalarOrTensorData(ScalarOrTensorData &&inner)
|
||||
: inner{std::make_shared<ScalarOrTensorData>(std::move(inner))} {}
|
||||
|
||||
ScalarOrTensorData &get() const { return *this->inner; }
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
|
||||
@@ -75,17 +75,21 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
|
||||
}
|
||||
|
||||
// Store the result to the PublicResult
|
||||
std::vector<clientlib::ScalarOrTensorData> buffers;
|
||||
std::vector<clientlib::SharedScalarOrTensorData> buffers;
|
||||
{
|
||||
size_t outputOffset = 0;
|
||||
for (auto &output : clientParameters.outputs) {
|
||||
auto shape = clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar scalar
|
||||
buffers.push_back(concretelang::clientlib::ScalarOrTensorData(
|
||||
auto value = concretelang::clientlib::ScalarOrTensorData(
|
||||
concretelang::clientlib::ScalarData(outputs[outputOffset++],
|
||||
output.shape.sign,
|
||||
output.shape.width)));
|
||||
output.shape.width));
|
||||
auto sharedValue =
|
||||
clientlib::SharedScalarOrTensorData(std::move(value));
|
||||
|
||||
buffers.push_back(sharedValue);
|
||||
} else {
|
||||
// buffer gate
|
||||
auto rank = shape.size();
|
||||
@@ -102,14 +106,18 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
|
||||
: output.shape.width;
|
||||
|
||||
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
|
||||
concretelang::clientlib::TensorData td =
|
||||
|
||||
auto value = concretelang::clientlib::ScalarOrTensorData(
|
||||
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
|
||||
aligned, offset, sizes, strides);
|
||||
buffers.push_back(
|
||||
concretelang::clientlib::ScalarOrTensorData(std::move(td)));
|
||||
aligned, offset, sizes, strides));
|
||||
auto sharedValue =
|
||||
clientlib::SharedScalarOrTensorData(std::move(value));
|
||||
|
||||
buffers.push_back(sharedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
@@ -120,7 +128,8 @@ invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
// Prepare arguments with the right calling convention
|
||||
std::vector<void *> preparedArgs;
|
||||
for (auto &arg : arguments.getArguments()) {
|
||||
for (auto &sharedArg : arguments.getArguments()) {
|
||||
clientlib::ScalarOrTensorData &arg = sharedArg.get();
|
||||
if (arg.isScalar()) {
|
||||
auto scalar = arg.getScalar().getValueAsU64();
|
||||
preparedArgs.push_back((void *)scalar);
|
||||
|
||||
@@ -60,6 +60,9 @@ declare_mlir_python_sources(
|
||||
concrete/compiler/public_result.py
|
||||
concrete/compiler/evaluation_keys.py
|
||||
concrete/compiler/utils.py
|
||||
concrete/compiler/value.py
|
||||
concrete/compiler/value_decrypter.py
|
||||
concrete/compiler/value_exporter.py
|
||||
concrete/compiler/wrapper.py
|
||||
concrete/__init__.py
|
||||
concrete/lang/__init__.py
|
||||
|
||||
@@ -285,9 +285,98 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def("get_evaluation_keys",
|
||||
[](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); });
|
||||
|
||||
pybind11::class_<clientlib::SharedScalarOrTensorData>(m, "Value")
|
||||
.def_static("deserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
return valueUnserialize(buffer);
|
||||
})
|
||||
.def("serialize", [](const clientlib::SharedScalarOrTensorData &value) {
|
||||
return pybind11::bytes(valueSerialize(value));
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::ValueExporter>(m, "ValueExporter")
|
||||
.def_static("create",
|
||||
[](clientlib::KeySet &keySet,
|
||||
mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return clientlib::ValueExporter(keySet, clientParameters);
|
||||
})
|
||||
.def("export_scalar",
|
||||
[](clientlib::ValueExporter &exporter, size_t position,
|
||||
int64_t value) {
|
||||
outcome::checked<clientlib::ScalarOrTensorData, StringError>
|
||||
result = exporter.exportValue(value, position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return clientlib::SharedScalarOrTensorData(
|
||||
std::move(result.value()));
|
||||
})
|
||||
.def("export_tensor", [](clientlib::ValueExporter &exporter,
|
||||
size_t position, std::vector<int64_t> values,
|
||||
std::vector<int64_t> shape) {
|
||||
outcome::checked<clientlib::ScalarOrTensorData, StringError> result =
|
||||
exporter.exportValue(values.data(), shape, position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return clientlib::SharedScalarOrTensorData(std::move(result.value()));
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::ValueDecrypter>(m, "ValueDecrypter")
|
||||
.def_static("create",
|
||||
[](clientlib::KeySet &keySet,
|
||||
mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return clientlib::ValueDecrypter(keySet, clientParameters);
|
||||
})
|
||||
.def("get_shape",
|
||||
[](clientlib::ValueDecrypter &decrypter, size_t position) {
|
||||
outcome::checked<std::vector<int64_t>, StringError> result =
|
||||
decrypter.getShape(position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
})
|
||||
.def("decrypt_scalar",
|
||||
[](clientlib::ValueDecrypter &decrypter, size_t position,
|
||||
clientlib::SharedScalarOrTensorData &value) {
|
||||
outcome::checked<int64_t, StringError> result =
|
||||
decrypter.decrypt<int64_t>(value.get(), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
})
|
||||
.def("decrypt_tensor",
|
||||
[](clientlib::ValueDecrypter &decrypter, size_t position,
|
||||
clientlib::SharedScalarOrTensorData &value) {
|
||||
outcome::checked<std::vector<int64_t>, StringError> result =
|
||||
decrypter.decryptTensor<int64_t>(value.get(), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::PublicArguments,
|
||||
std::unique_ptr<clientlib::PublicArguments>>(
|
||||
m, "PublicArguments")
|
||||
.def_static(
|
||||
"create",
|
||||
[](const mlir::concretelang::ClientParameters &clientParameters,
|
||||
std::vector<clientlib::SharedScalarOrTensorData> &buffers) {
|
||||
return clientlib::PublicArguments(clientParameters, buffers);
|
||||
})
|
||||
.def_static("deserialize",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters,
|
||||
const pybind11::bytes &buffer) {
|
||||
@@ -302,9 +391,25 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
const pybind11::bytes &buffer) {
|
||||
return publicResultUnserialize(clientParameters, buffer);
|
||||
})
|
||||
.def("serialize", [](clientlib::PublicResult &publicResult) {
|
||||
return pybind11::bytes(publicResultSerialize(publicResult));
|
||||
});
|
||||
.def("serialize",
|
||||
[](clientlib::PublicResult &publicResult) {
|
||||
return pybind11::bytes(publicResultSerialize(publicResult));
|
||||
})
|
||||
.def("n_values",
|
||||
[](const clientlib::PublicResult &publicResult) {
|
||||
return publicResult.buffers.size();
|
||||
})
|
||||
.def("get_value",
|
||||
[](clientlib::PublicResult &publicResult, size_t position) {
|
||||
outcome::checked<clientlib::SharedScalarOrTensorData, StringError>
|
||||
result = publicResult.getValue(position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::EvaluationKeys>(m, "EvaluationKeys")
|
||||
.def_static("deserialize",
|
||||
|
||||
@@ -257,6 +257,26 @@ keySetSerialize(concretelang::clientlib::KeySet &keySet) {
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData
|
||||
valueUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
auto value = concretelang::clientlib::unserializeScalarOrTensorData(istream);
|
||||
if (istream.fail() || value.has_error()) {
|
||||
throw std::runtime_error("Cannot read data");
|
||||
}
|
||||
|
||||
return concretelang::clientlib::SharedScalarOrTensorData(
|
||||
std::move(value.value()));
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
serializeScalarOrTensorData(value.get(), buffer);
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
|
||||
@@ -29,6 +29,9 @@ from .client_support import ClientSupport
|
||||
from .jit_support import JITSupport
|
||||
from .library_support import LibrarySupport
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
from .value import Value
|
||||
from .value_decrypter import ValueDecrypter
|
||||
from .value_exporter import ValueExporter
|
||||
|
||||
|
||||
def init_dfr():
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
"""PublicArguments."""
|
||||
|
||||
from typing import List
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
PublicArguments as _PublicArguments,
|
||||
@@ -10,6 +12,7 @@ from mlir._mlir_libs._concretelang._compiler import (
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .client_parameters import ClientParameters
|
||||
from .value import Value
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
@@ -35,6 +38,20 @@ class PublicArguments(WrapperCpp):
|
||||
)
|
||||
super().__init__(public_arguments)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
client_parameters: ClientParameters, values: List[Value]
|
||||
) -> "PublicArguments":
|
||||
"""
|
||||
Create public arguments from individual values.
|
||||
"""
|
||||
return PublicArguments(
|
||||
_PublicArguments.create(
|
||||
client_parameters.cpp(),
|
||||
[value.cpp() for value in values],
|
||||
)
|
||||
)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the PublicArguments.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from mlir._mlir_libs._concretelang._compiler import (
|
||||
from .client_parameters import ClientParameters
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .value import Value
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
@@ -31,6 +32,18 @@ class PublicResult(WrapperCpp):
|
||||
)
|
||||
super().__init__(public_result)
|
||||
|
||||
def n_values(self) -> int:
|
||||
"""
|
||||
Get number of values in the result.
|
||||
"""
|
||||
return self.cpp().n_values()
|
||||
|
||||
def get_value(self, position: int) -> Value:
|
||||
"""
|
||||
Get a specific value in the result.
|
||||
"""
|
||||
return Value(self.cpp().get_value(position))
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the PublicResult.
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Value."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
Value as _Value,
|
||||
)
|
||||
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
class Value(WrapperCpp):
|
||||
"""An encrypted/clear value which can be scalar/tensor."""
|
||||
|
||||
def __init__(self, value: _Value):
|
||||
"""
|
||||
Wrap the native C++ object.
|
||||
|
||||
Args:
|
||||
value (_Value):
|
||||
object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if `value` is not of type `_Value`
|
||||
"""
|
||||
|
||||
if not isinstance(value, _Value):
|
||||
raise TypeError(f"value must be of type _Value, not {type(value)}")
|
||||
|
||||
super().__init__(value)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize value into bytes.
|
||||
|
||||
Returns:
|
||||
bytes: serialized value
|
||||
"""
|
||||
|
||||
return self.cpp().serialize()
|
||||
|
||||
@staticmethod
|
||||
def deserialize(serialized_value: bytes) -> "Value":
|
||||
"""
|
||||
Deserialize value from bytes.
|
||||
|
||||
Args:
|
||||
serialized_value (bytes):
|
||||
previously serialized value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
deserialized value
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if `serialized_value` is not of type `bytes`
|
||||
"""
|
||||
|
||||
if not isinstance(serialized_value, bytes):
|
||||
raise TypeError(
|
||||
f"serialized_value must be of type bytes, not {type(serialized_value)}"
|
||||
)
|
||||
|
||||
return Value.wrap(_Value.deserialize(serialized_value))
|
||||
@@ -0,0 +1,111 @@
|
||||
"""ValueDecrypter."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
ValueDecrypter as _ValueDecrypter,
|
||||
)
|
||||
|
||||
from .client_parameters import ClientParameters
|
||||
from .key_set import KeySet
|
||||
from .value import Value
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
class ValueDecrypter(WrapperCpp):
|
||||
"""A helper class to decrypt `Value`s."""
|
||||
|
||||
def __init__(self, value_decrypter: _ValueDecrypter):
|
||||
"""
|
||||
Wrap the native C++ object.
|
||||
|
||||
Args:
|
||||
value_decrypter (_ValueDecrypter):
|
||||
object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if `value_decrypter` is not of type `_ValueDecrypter`
|
||||
"""
|
||||
|
||||
if not isinstance(value_decrypter, _ValueDecrypter):
|
||||
raise TypeError(
|
||||
f"value_decrypter must be of type _ValueDecrypter, not {type(value_decrypter)}"
|
||||
)
|
||||
|
||||
super().__init__(value_decrypter)
|
||||
|
||||
@staticmethod
|
||||
def create(keyset: KeySet, client_parameters: ClientParameters):
|
||||
"""
|
||||
Create a value decrypter.
|
||||
"""
|
||||
return ValueDecrypter(
|
||||
_ValueDecrypter.create(keyset.cpp(), client_parameters.cpp())
|
||||
)
|
||||
|
||||
def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]:
|
||||
"""
|
||||
Decrypt value.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
value to decrypt
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray]:
|
||||
decrypted value
|
||||
"""
|
||||
|
||||
shape = tuple(self.cpp().get_shape(position))
|
||||
|
||||
if len(shape) == 0:
|
||||
return self.decrypt_scalar(position, value)
|
||||
|
||||
return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape(
|
||||
shape
|
||||
)
|
||||
|
||||
def decrypt_scalar(self, position: int, value: Value) -> int:
|
||||
"""
|
||||
Decrypt scalar.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
scalar value to decrypt
|
||||
|
||||
Returns:
|
||||
int:
|
||||
decrypted scalar
|
||||
"""
|
||||
|
||||
return self.cpp().decrypt_scalar(position, value.cpp())
|
||||
|
||||
def decrypt_tensor(self, position: int, value: Value) -> List[int]:
|
||||
"""
|
||||
Decrypt tensor.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
tensor value to decrypt
|
||||
|
||||
Returns:
|
||||
List[int]:
|
||||
decrypted tensor
|
||||
"""
|
||||
|
||||
return self.cpp().decrypt_tensor(position, value.cpp())
|
||||
@@ -0,0 +1,90 @@
|
||||
"""ValueExporter."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
|
||||
from typing import List
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
ValueExporter as _ValueExporter,
|
||||
)
|
||||
|
||||
from .client_parameters import ClientParameters
|
||||
from .key_set import KeySet
|
||||
from .value import Value
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
class ValueExporter(WrapperCpp):
|
||||
"""A helper class to create `Value`s."""
|
||||
|
||||
def __init__(self, value_exporter: _ValueExporter):
|
||||
"""
|
||||
Wrap the native C++ object.
|
||||
|
||||
Args:
|
||||
value_exporter (_ValueExporter):
|
||||
object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
if `value_exporter` is not of type `_ValueExporter`
|
||||
"""
|
||||
|
||||
if not isinstance(value_exporter, _ValueExporter):
|
||||
raise TypeError(
|
||||
f"value_exporter must be of type _ValueExporter, not {type(value_exporter)}"
|
||||
)
|
||||
|
||||
super().__init__(value_exporter)
|
||||
|
||||
@staticmethod
|
||||
def create(keyset: KeySet, client_parameters: ClientParameters) -> "ValueExporter":
|
||||
"""
|
||||
Create a value exporter.
|
||||
"""
|
||||
return ValueExporter(
|
||||
_ValueExporter.create(keyset.cpp(), client_parameters.cpp())
|
||||
)
|
||||
|
||||
def export_scalar(self, position: int, value: int) -> Value:
|
||||
"""
|
||||
Export scalar.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (int):
|
||||
scalar to export
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
exported scalar
|
||||
"""
|
||||
|
||||
return Value(self.cpp().export_scalar(position, value))
|
||||
|
||||
def export_tensor(
|
||||
self, position: int, values: List[int], shape: List[int]
|
||||
) -> Value:
|
||||
"""
|
||||
Export tensor.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
values (List[int]):
|
||||
tensor elements to export
|
||||
|
||||
shape (List[int]):
|
||||
tensor shape to export
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
exported tensor
|
||||
"""
|
||||
|
||||
return Value(self.cpp().export_tensor(position, values, shape))
|
||||
@@ -13,7 +13,14 @@ using StringError = concretelang::error::StringError;
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
|
||||
return std::make_unique<PublicArguments>(clientParameters, std::move(values));
|
||||
auto sharedValues = std::vector<SharedScalarOrTensorData>();
|
||||
sharedValues.reserve(this->values.size());
|
||||
|
||||
for (auto &&value : this->values) {
|
||||
sharedValues.push_back(SharedScalarOrTensorData(std::move(value)));
|
||||
}
|
||||
|
||||
return std::make_unique<PublicArguments>(clientParameters, sharedValues);
|
||||
}
|
||||
|
||||
/// Split the input integer into `size` chunks of `chunkWidth` bits each
|
||||
|
||||
@@ -15,10 +15,11 @@ namespace clientlib {
|
||||
using concretelang::error::StringError;
|
||||
|
||||
// TODO: optimize the move
|
||||
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&arguments_)
|
||||
PublicArguments::PublicArguments(
|
||||
const ClientParameters &clientParameters,
|
||||
std::vector<clientlib::SharedScalarOrTensorData> &buffers)
|
||||
: clientParameters(clientParameters) {
|
||||
arguments = std::move(arguments_);
|
||||
arguments = buffers;
|
||||
}
|
||||
|
||||
PublicArguments::~PublicArguments() {}
|
||||
@@ -44,11 +45,11 @@ PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
PublicArguments::unserialize(const ClientParameters &expectedParams,
|
||||
std::istream &istream) {
|
||||
std::vector<ScalarOrTensorData> emptyBuffers;
|
||||
auto sArguments = std::make_unique<PublicArguments>(clientParameters,
|
||||
std::move(emptyBuffers));
|
||||
std::vector<SharedScalarOrTensorData> emptyBuffers;
|
||||
auto sArguments =
|
||||
std::make_unique<PublicArguments>(expectedParams, emptyBuffers);
|
||||
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
|
||||
return std::move(sArguments);
|
||||
}
|
||||
|
||||
@@ -398,8 +398,11 @@ unserializeScalarData(std::istream &istream) {
|
||||
template <typename T>
|
||||
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
|
||||
std::istream &istream) {
|
||||
readWords(istream, values_and_sizes.getElementPointer<T>(0),
|
||||
values_and_sizes.getNumElements());
|
||||
// getElementPointer is not valid if the tensor contains no data
|
||||
if (values_and_sizes.getNumElements() > 0) {
|
||||
readWords(istream, values_and_sizes.getElementPointer<T>(0),
|
||||
values_and_sizes.getNumElements());
|
||||
}
|
||||
|
||||
return istream;
|
||||
}
|
||||
@@ -555,26 +558,25 @@ unserializeScalarOrTensorData(std::istream &istream) {
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream &
|
||||
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &v,
|
||||
std::ostream &ostream) {
|
||||
std::ostream &serializeVectorOfScalarOrTensorData(
|
||||
const std::vector<SharedScalarOrTensorData> &v, std::ostream &ostream) {
|
||||
writeSize(ostream, v.size());
|
||||
for (auto &sotd : v) {
|
||||
serializeScalarOrTensorData(sotd, ostream);
|
||||
serializeScalarOrTensorData(sotd.get(), ostream);
|
||||
if (!ostream.good()) {
|
||||
return ostream;
|
||||
}
|
||||
}
|
||||
return ostream;
|
||||
}
|
||||
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
|
||||
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
|
||||
unserializeVectorOfScalarOrTensorData(std::istream &istream) {
|
||||
uint64_t nbElt;
|
||||
readSize(istream, nbElt);
|
||||
std::vector<ScalarOrTensorData> v;
|
||||
std::vector<SharedScalarOrTensorData> v;
|
||||
for (uint64_t i = 0; i < nbElt; i++) {
|
||||
OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream));
|
||||
v.push_back(std::move(elt));
|
||||
v.push_back(SharedScalarOrTensorData(std::move(elt)));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
if (auto err = invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
}
|
||||
std::vector<clientlib::ScalarOrTensorData> buffers;
|
||||
std::vector<clientlib::SharedScalarOrTensorData> buffers;
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user