mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
cleanup(compiler/jit): Removing dead code since the preparation of arguments has been factorized thanks the EncryptedArguments
This commit is contained in:
@@ -24,92 +24,6 @@ namespace clientlib = ::concretelang::clientlib;
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
class Argument {
|
||||
public:
|
||||
Argument(KeySet &keySet);
|
||||
~Argument();
|
||||
|
||||
// Create lambda Argument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, const T *data, int64_t dim1) {
|
||||
return setArg<T>(pos, data, llvm::ArrayRef<int64_t>(&dim1, 1));
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, const T *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
return setArg(pos, 8 * sizeof(T), static_cast<const void *>(data), shape);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Specifies the type of a result
|
||||
enum ResultType { SCALAR, TENSOR };
|
||||
|
||||
// Returns the result type at position `pos`. If pos is invalid,
|
||||
// an error is returned.
|
||||
llvm::Expected<enum ResultType> getResultType(size_t pos);
|
||||
|
||||
// Get a result for tensors, fill the `res` buffer with the value of the
|
||||
// tensor result.
|
||||
// Returns an error:
|
||||
// - if the result is a scalar
|
||||
// - or the size of the `res` buffser doesn't match the size of the tensor.
|
||||
template <typename T>
|
||||
llvm::Error getResult(size_t pos, T *res, size_t size) {
|
||||
return std::move(this->getResult(pos, res, sizeof(T), size));
|
||||
}
|
||||
|
||||
llvm::Error getResult(size_t pos, void *res, size_t elementSize,
|
||||
size_t numElements);
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
// Returns the width of the result scalar at position `pos` or the
|
||||
// width of the scalar values of a vector if the result at
|
||||
// position `pos` is a tensor.
|
||||
llvm::Expected<size_t> getResultWidth(size_t pos);
|
||||
|
||||
// Returns the dimensions of the result tensor at position `pos` or
|
||||
// an error if the result is a scalar value
|
||||
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);
|
||||
|
||||
private:
|
||||
// Verify if lambda can accept a n-th argument.
|
||||
llvm::Error emitErrorIfTooManyArgs(size_t n);
|
||||
llvm::Error setArg(size_t pos, size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
std::vector<void *> rawArg;
|
||||
// Store the values of inputs
|
||||
std::vector<const void *> inputs;
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<uint64_t *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<uint64_t *> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
|
||||
@@ -341,38 +341,6 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
template <int pos>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
|
||||
// base case -- nothing to do
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Recursive case for scalars: extract first scalar argument from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
|
||||
Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push scalar argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
// Recursive case for tensors: extract pointer and size from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
|
||||
size_t size, Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg, size)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push tensor argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
std::unique_ptr<JITLambda> innerLambda;
|
||||
std::unique_ptr<KeySet> keySet;
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
|
||||
@@ -6,7 +6,6 @@ add_subdirectory(Runtime)
|
||||
add_subdirectory(ClientLib)
|
||||
add_subdirectory(Bindings)
|
||||
add_subdirectory(ServerLib)
|
||||
add_subdirectory(Common)
|
||||
|
||||
# CAPI needed only for python bindings
|
||||
if (CONCRETELANG_BINDINGS_PYTHON_ENABLED)
|
||||
|
||||
@@ -23,7 +23,6 @@ add_mlir_library(
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
|
||||
|
||||
ConcretelangCommon
|
||||
ConcretelangRuntime
|
||||
ConcretelangSupportLib
|
||||
LINK_LIBS PUBLIC
|
||||
|
||||
@@ -11,6 +11,17 @@ namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
|
||||
size_t previousWidth = 0;
|
||||
for (auto currentWidth : sortedWordBitWidths) {
|
||||
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
|
||||
return currentWidth;
|
||||
}
|
||||
}
|
||||
return exactBitWidth;
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
@@ -72,7 +83,7 @@ EncryptedArguments::pushArg(size_t width, const void *data,
|
||||
return StringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
auto roundedSize = concretelang::common::bitWidthAsWord(input.shape.width);
|
||||
auto roundedSize = bitWidthAsWord(input.shape.width);
|
||||
if (width != roundedSize) {
|
||||
return StringError("argument #") << pos << "width mismatch, got " << width
|
||||
<< " expected " << roundedSize;
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace common {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
|
||||
size_t previousWidth = 0;
|
||||
for (auto currentWidth : sortedWordBitWidths) {
|
||||
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
|
||||
return currentWidth;
|
||||
}
|
||||
}
|
||||
return exactBitWidth;
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace concretelang
|
||||
@@ -1,19 +0,0 @@
|
||||
add_compile_options( -Werror )
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# using Clang
|
||||
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
# using GCC
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangCommon
|
||||
BitsSize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Common
|
||||
)
|
||||
@@ -30,7 +30,6 @@ add_mlir_library(ConcretelangSupport
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
ConcretelangCommon
|
||||
ConcretelangRuntime
|
||||
ConcretelangClientLib
|
||||
)
|
||||
|
||||
@@ -151,426 +151,5 @@ JITLambda::call(clientlib::PublicArguments &args) {
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
|
||||
}
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
// Setting the inputs
|
||||
auto numInputs = 0;
|
||||
{
|
||||
for (size_t i = 0; i < keySet.numInputs(); i++) {
|
||||
auto offset = numInputs;
|
||||
auto gate = keySet.inputGate(i);
|
||||
inputGates.push_back({gate, offset});
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
// scalar gate
|
||||
if (gate.encryption.hasValue()) {
|
||||
// encrypted is a memref<lweSizexi64>
|
||||
numInputs = numInputs + numArgOfRankedMemrefCallingConvention(1);
|
||||
} else {
|
||||
numInputs = numInputs + 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
auto rank = keySet.inputGate(i).shape.dimensions.size() +
|
||||
(keySet.isInputEncrypted(i) ? 1 /* for lwe size */ : 0);
|
||||
numInputs = numInputs + numArgOfRankedMemrefCallingConvention(rank);
|
||||
}
|
||||
// Reserve for the context argument
|
||||
numInputs = numInputs + 1;
|
||||
inputs = std::vector<const void *>(numInputs);
|
||||
}
|
||||
|
||||
// Setting the outputs
|
||||
{
|
||||
auto numOutputs = 0;
|
||||
for (size_t i = 0; i < keySet.numOutputs(); i++) {
|
||||
auto offset = numOutputs;
|
||||
auto gate = keySet.outputGate(i);
|
||||
outputGates.push_back({gate, offset});
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
// scalar gate
|
||||
if (gate.encryption.hasValue()) {
|
||||
// encrypted is a memref<lweSizexi64>
|
||||
numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(1);
|
||||
} else {
|
||||
numOutputs = numOutputs + 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
auto rank = keySet.outputGate(i).shape.dimensions.size() +
|
||||
(keySet.isOutputEncrypted(i) ? 1 /* for lwe size */ : 0);
|
||||
numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(rank);
|
||||
}
|
||||
outputs = std::vector<void *>(numOutputs);
|
||||
}
|
||||
// The raw argument contains pointers to inputs and pointers to store the
|
||||
// results
|
||||
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
|
||||
// Set the pointer on outputs on rawArg
|
||||
for (auto i = inputs.size(); i < rawArg.size(); i++) {
|
||||
rawArg[i] = &outputs[i - inputs.size()];
|
||||
}
|
||||
|
||||
// Set the context argument
|
||||
keySet.setRuntimeContext(context);
|
||||
inputs[numInputs - 1] = &context;
|
||||
rawArg[numInputs - 1] = &inputs[numInputs - 1];
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
free(ct);
|
||||
}
|
||||
for (auto buffer : ciphertextBuffers) {
|
||||
free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
|
||||
JITLambda::Argument::create(KeySet &keySet) {
|
||||
auto args = std::make_unique<JITLambda::Argument>(keySet);
|
||||
return std::move(args);
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::emitErrorIfTooManyArgs(size_t pos) {
|
||||
size_t arity = inputGates.size();
|
||||
if (pos < arity) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
return StreamStringError("The function has arity ")
|
||||
<< arity << " but is applied to too many arguments";
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
if (auto error = emitErrorIfTooManyArgs(pos)) {
|
||||
return error;
|
||||
}
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (!info.shape.dimensions.empty()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// If argument is not encrypted, just save.
|
||||
if (!info.encryption.hasValue()) {
|
||||
inputs[offset] = (void *)arg;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, allocate ciphertext and encrypt.
|
||||
uint64_t *ctArg;
|
||||
uint64_t ctSize;
|
||||
auto check = this->keySet.allocate_lwe(pos, &ctArg, ctSize);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
check = this->keySet.encrypt_lwe(pos, ctArg, arg);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
// memref calling convention
|
||||
// allocated
|
||||
inputs[offset] = nullptr;
|
||||
// aligned
|
||||
inputs[offset + 1] = ctArg;
|
||||
// offset
|
||||
inputs[offset + 2] = (void *)0;
|
||||
// size
|
||||
inputs[offset + 3] = (void *)ctSize;
|
||||
// stride
|
||||
inputs[offset + 4] = (void *)1;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
rawArg[offset + 1] = &inputs[offset + 1];
|
||||
rawArg[offset + 2] = &inputs[offset + 2];
|
||||
rawArg[offset + 3] = &inputs[offset + 3];
|
||||
rawArg[offset + 4] = &inputs[offset + 4];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
const void *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
if (auto error = emitErrorIfTooManyArgs(pos)) {
|
||||
return error;
|
||||
}
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
// Check if the width is compatible
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
if (info.shape.width > 64) {
|
||||
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : a width of " +
|
||||
llvm::Twine(info.shape.width) +
|
||||
"bits > 64 is not supported: pos=" + llvm::Twine(pos);
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto roundedSize = ::concretelang::common::bitWidthAsWord(info.shape.width);
|
||||
if (width != roundedSize) {
|
||||
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : expected " +
|
||||
llvm::Twine(roundedSize) + "bits" + " but received " +
|
||||
llvm::Twine(width) + "bits (rounded from " +
|
||||
llvm::Twine(info.shape.width) + ")";
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Check the size
|
||||
if (info.shape.dimensions.empty()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (shape.size() != info.shape.dimensions.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("tensor argument #")
|
||||
.concat(llvm::Twine(pos))
|
||||
.concat(" has not the expected number of dimension, got ")
|
||||
.concat(llvm::Twine(shape.size()))
|
||||
.concat(" expected ")
|
||||
.concat(llvm::Twine(info.shape.dimensions.size())),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != info.shape.dimensions[i]) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("tensor argument #")
|
||||
.concat(llvm::Twine(pos))
|
||||
.concat(" has not the expected dimension #")
|
||||
.concat(llvm::Twine(i))
|
||||
.concat(" , got ")
|
||||
.concat(llvm::Twine(shape[i]))
|
||||
.concat(" expected ")
|
||||
.concat(llvm::Twine(info.shape.dimensions[i])),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
}
|
||||
// If argument is not encrypted, just save with the right calling convention.
|
||||
if (info.encryption.hasValue()) {
|
||||
// Else if is encrypted
|
||||
// For moment we support only 8 bits inputs
|
||||
const uint8_t *data8 = (const uint8_t *)data;
|
||||
if (width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine(
|
||||
"argument width > 8 for encrypted gates are not supported: pos=")
|
||||
.concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// Allocate a buffer for ciphertexts, the size of the buffer is the number
|
||||
// of elements of the tensor * the size of the lwe ciphertext
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
|
||||
uint64_t *ctBuffer =
|
||||
(uint64_t *)malloc(info.shape.size * lweSize * sizeof(uint64_t));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Encrypt ciphertexts
|
||||
for (size_t i = 0, offset = 0; i < info.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
|
||||
auto check = this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i]);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
}
|
||||
// Replace the data by the buffer to ciphertext
|
||||
data = (void *)ctBuffer;
|
||||
}
|
||||
// Set the buffer as the memref calling convention expect.
|
||||
// allocated
|
||||
inputs[offset] =
|
||||
(void *)0; // Indicates that it's not allocated by the MLIR program
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
// aligned
|
||||
inputs[offset] = data;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
// offset
|
||||
inputs[offset] = (void *)0;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
// sizes is an array of size equals to numDim
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
inputs[offset] = (void *)shape[i];
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
}
|
||||
// If encrypted +1 for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).lweSize());
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = 1;
|
||||
// If encrypted +1 set the stride for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
inputs[offset + shape.size()] = (void *)stride;
|
||||
rawArg[offset + shape.size()] = &inputs[offset];
|
||||
stride *= keySet.getInputLweSecretKeyParam(pos).lweSize();
|
||||
}
|
||||
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
|
||||
inputs[offset + i] = (void *)stride;
|
||||
rawArg[offset + i] = &inputs[offset + i];
|
||||
stride *= shape[i];
|
||||
}
|
||||
offset += shape.size();
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size != 0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// If result is not encrypted, just set the result
|
||||
if (!info.encryption.hasValue()) {
|
||||
res = (uint64_t)(outputs[offset]);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, decrypt
|
||||
uint64_t *ct = (uint64_t *)(outputs[offset + 1]);
|
||||
auto check = this->keySet.decrypt_lwe(pos, ct, res);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Result at pos %zu is not a tensor", pos);
|
||||
}
|
||||
|
||||
return info.shape.size;
|
||||
}
|
||||
|
||||
// Returns the dimensions of the result tensor at position `pos` or
|
||||
// an error if the result is a scalar value
|
||||
llvm::Expected<std::vector<int64_t>>
|
||||
JITLambda::Argument::getResultDimensions(size_t pos) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Result at pos %zu is not a tensor", pos);
|
||||
}
|
||||
|
||||
return info.shape.dimensions;
|
||||
}
|
||||
|
||||
llvm::Expected<enum JITLambda::Argument::ResultType>
|
||||
JITLambda::Argument::getResultType(size_t pos) {
|
||||
if (pos >= outputGates.size()) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Requesting type for result at index %zu, "
|
||||
"but lambda only generates %zu results",
|
||||
pos, outputGates.size());
|
||||
}
|
||||
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return ResultType::SCALAR;
|
||||
} else {
|
||||
return ResultType::TENSOR;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<size_t> JITLambda::Argument::getResultWidth(size_t pos) {
|
||||
if (pos >= outputGates.size()) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Requesting width for result at index %zu, "
|
||||
"but lambda only generates %zu results",
|
||||
pos, outputGates.size());
|
||||
}
|
||||
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
// Encrypted values are always returned as 64-bit values for now
|
||||
if (info.encryption.hasValue())
|
||||
return 64;
|
||||
else
|
||||
return info.shape.width;
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
|
||||
size_t elementSize,
|
||||
size_t numElements) {
|
||||
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.dimensions.empty()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Check is the argument is a scalar
|
||||
if (info.shape.size != numElements) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("result #")
|
||||
.concat(llvm::Twine(pos))
|
||||
.concat(" has not the expected size, got ")
|
||||
.concat(llvm::Twine(numElements))
|
||||
.concat(" expect ")
|
||||
.concat(llvm::Twine(info.shape.size)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// Get the values as the memref calling convention expect.
|
||||
// aligned
|
||||
uint8_t *alignedBytes = static_cast<uint8_t *>(outputs[offset + 1]);
|
||||
uint8_t *resBytes = static_cast<uint8_t *>(res);
|
||||
if (!info.encryption.hasValue()) {
|
||||
// just copy values
|
||||
for (size_t i = 0; i < numElements; i++) {
|
||||
for (size_t j = 0; j < elementSize; j++) {
|
||||
*resBytes = *alignedBytes;
|
||||
resBytes++;
|
||||
alignedBytes++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// decrypt and fill the result buffer
|
||||
auto lweSize = keySet.getOutputLweSecretKeyParam(pos).lweSize();
|
||||
|
||||
for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) {
|
||||
uint64_t *ct = ((uint64_t *)alignedBytes) + o;
|
||||
auto check = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i]);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
}
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -15,7 +15,6 @@ set_source_files_properties(
|
||||
|
||||
target_link_libraries(
|
||||
testlib_unit_test
|
||||
ConcretelangCommon
|
||||
ConcretelangRuntime
|
||||
ConcretelangSupport
|
||||
ConcretelangClientLib
|
||||
|
||||
Reference in New Issue
Block a user