mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(rust): support execution with tensor args
This commit is contained in:
@@ -7,9 +7,11 @@
|
||||
#include "concretelang/CAPI/Wrappers.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/LambdaArgument.h"
|
||||
#include "concretelang/Support/LambdaSupport.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <numeric>
|
||||
|
||||
#define C_STRUCT_CLEANER(c_struct) \
|
||||
auto *cpp = unwrap(c_struct); \
|
||||
@@ -299,8 +301,58 @@ LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
|
||||
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
|
||||
}
|
||||
|
||||
// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
|
||||
// size_t rank);
|
||||
int64_t getSizeFromRankAndDims(size_t rank, int64_t *dims) {
|
||||
if (rank == 0) // not a tensor
|
||||
return 1;
|
||||
auto size = dims[0];
|
||||
for (size_t i = 1; i < rank; i++)
|
||||
size *= dims[i];
|
||||
return size;
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU8(uint8_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint8_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU16(uint16_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint16_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU32(uint32_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint32_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
|
||||
size_t rank) {
|
||||
|
||||
std::vector<uint64_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
|
||||
return unwrap(lambdaArg)
|
||||
@@ -330,9 +382,114 @@ bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
|
||||
// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
|
||||
template <typename T>
|
||||
bool copyTensorDataToBuffer(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor,
|
||||
uint64_t *buffer) {
|
||||
auto *data = tensor->getValue();
|
||||
auto sizeOrError = tensor->getNumElements();
|
||||
if (!sizeOrError) {
|
||||
llvm::errs() << llvm::toString(sizeOrError.takeError());
|
||||
return false;
|
||||
}
|
||||
auto size = sizeOrError.get();
|
||||
for (size_t i = 0; i < size; i++)
|
||||
buffer[i] = data[i];
|
||||
return true;
|
||||
}
|
||||
|
||||
bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, uint64_t *buffer) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
std::vector<int64_t> dims;
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
return std::accumulate(std::begin(dims), std::end(dims), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, int64_t *buffer) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
std::vector<int64_t> dims;
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
memcpy(buffer, dims.data(), sizeof(int64_t) * dims.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
|
||||
size_t argNumber, ClientParameters params,
|
||||
|
||||
Reference in New Issue
Block a user