mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
refactor(compiler): Refactor JITLambda::Argument::setArg
This commit is contained in:
@@ -31,48 +31,16 @@ public:
|
||||
// Set a scalar argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of int64.
|
||||
llvm::Error setArg(size_t pos, uint64_t *data, size_t dim1) {
|
||||
return setArg(pos, 64, (void *)data, 1, &dim1);
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, T *data, int64_t dim1) {
|
||||
return setArg<T>(pos, data, llvm::ArrayRef<int64_t>(&dim1, 1));
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint32_t *data, size_t dim1) {
|
||||
return setArg(pos, 32, (void *)data, 1, &dim1);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of int16.
|
||||
llvm::Error setArg(size_t pos, uint16_t *data, size_t dim1) {
|
||||
return setArg(pos, 16, (void *)data, 1, &dim1);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of int8.
|
||||
llvm::Error setArg(size_t pos, uint8_t *data, size_t dim1) {
|
||||
return setArg(pos, 8, (void *)data, 1, &dim1);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int64.
|
||||
llvm::Error setArg(size_t pos, uint64_t *data, size_t numDim,
|
||||
const size_t *dims) {
|
||||
return setArg(pos, 64, (void *)data, numDim, dims);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint32_t *data, size_t numDim,
|
||||
const size_t *dims) {
|
||||
return setArg(pos, 32, (void *)data, numDim, dims);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of int32.
|
||||
llvm::Error setArg(size_t pos, uint16_t *data, size_t numDim,
|
||||
const size_t *dims) {
|
||||
return setArg(pos, 16, (void *)data, numDim, dims);
|
||||
}
|
||||
|
||||
// Set a tensor argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint8_t *data, size_t numDim,
|
||||
const size_t *dims) {
|
||||
return setArg(pos, 8, (void *)data, numDim, dims);
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error setArg(size_t pos, T *data, llvm::ArrayRef<int64_t> shape) {
|
||||
return setArg(pos, 8 * sizeof(T), static_cast<void *>(data), shape);
|
||||
}
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
@@ -86,8 +54,8 @@ public:
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data, size_t numDim,
|
||||
const size_t *dims);
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
friend JITLambda;
|
||||
// Store the pointer on inputs values and outputs values
|
||||
|
||||
@@ -230,7 +230,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
|
||||
size_t numDim, const size_t *dims) {
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
auto gate = inputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
@@ -272,25 +272,25 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
|
||||
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (numDim != info.shape.dimensions.size()) {
|
||||
if (shape.size() != info.shape.dimensions.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("tensor argument #")
|
||||
.concat(llvm::Twine(pos))
|
||||
.concat(" has not the expected number of dimension, got ")
|
||||
.concat(llvm::Twine(numDim))
|
||||
.concat(llvm::Twine(shape.size()))
|
||||
.concat(" expected ")
|
||||
.concat(llvm::Twine(info.shape.dimensions.size())),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
for (size_t i = 0; i < numDim; i++) {
|
||||
if (dims[i] != info.shape.dimensions[i]) {
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != info.shape.dimensions[i]) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("tensor argument #")
|
||||
.concat(llvm::Twine(pos))
|
||||
.concat(" has not the expected dimension #")
|
||||
.concat(llvm::Twine(i))
|
||||
.concat(" , got ")
|
||||
.concat(llvm::Twine(dims[i]))
|
||||
.concat(llvm::Twine(shape[i]))
|
||||
.concat(" expected ")
|
||||
.concat(llvm::Twine(info.shape.dimensions[i])),
|
||||
llvm::inconvertibleErrorCode());
|
||||
@@ -341,13 +341,13 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
// sizes is an array of size equals to numDim
|
||||
for (size_t i = 0; i < numDim; i++) {
|
||||
inputs[offset] = (void *)dims[i];
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
inputs[offset] = (void *)shape[i];
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
}
|
||||
// strides is an array of size equals to numDim
|
||||
for (size_t i = 0; i < numDim; i++) {
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
inputs[offset] = (void *)0;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
@@ -411,12 +411,12 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
|
||||
if (!info.encryption.hasValue()) {
|
||||
// just copy values
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
res[i] = ((uint64_t *)(aligned))[i];
|
||||
res[i] = ((uint64_t *)aligned)[i];
|
||||
}
|
||||
} else {
|
||||
// decrypt and fill the result buffer
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
|
||||
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)aligned)[i];
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
@@ -222,14 +222,15 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const size_t numDim = 2;
|
||||
const size_t dim0 = 2;
|
||||
const size_t dim1 = 10;
|
||||
const size_t dims[numDim]{dim0, dim1};
|
||||
const int64_t dim0 = 2;
|
||||
const int64_t dim1 = 10;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
const uint64_t tensor2D[dim0][dim1]{
|
||||
{0xFFFFFFFFFFFFFFFF, 0, 8978, 2587490, 90, 197864, 698735, 72132, 87474,
|
||||
42},
|
||||
{986, 1873, 298493, 34939, 443, 59874, 43, 743, 8409, 9433},
|
||||
};
|
||||
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
|
||||
|
||||
TEST(End2EndJit_ClearTensor_2D, identity) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
@@ -244,7 +245,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<2x10xi64> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
@@ -272,7 +273,7 @@ func @main(%t: tensor<2x10xi64>, %i: index, %j: index) -> i64 {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
for (size_t i = 0; i < dims[0]; i++) {
|
||||
for (size_t j = 0; j < dims[1]; j++) {
|
||||
// Set %i, %j
|
||||
@@ -302,7 +303,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
@@ -331,7 +332,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
@@ -360,11 +361,12 @@ func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t0 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Set the %t1 argument
|
||||
uint64_t t1_dim[2] = {2, 2};
|
||||
int64_t t1_dim[2] = {2, 2};
|
||||
uint64_t t1[2][2]{{6, 9}, {4, 0}};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint64_t *)t1, 2, t1_dim));
|
||||
ASSERT_LLVM_ERROR(
|
||||
argument->setArg(1, (uint64_t *)t1, llvm::ArrayRef<int64_t>(t1_dim, 2)));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
|
||||
@@ -5,13 +5,14 @@
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const size_t numDim = 2;
|
||||
const size_t dim0 = 2;
|
||||
const size_t dim1 = 10;
|
||||
const size_t dims[numDim]{dim0, dim1};
|
||||
const int64_t dim0 = 2;
|
||||
const int64_t dim1 = 10;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
const uint8_t tensor2D[dim0][dim1]{
|
||||
{63, 12, 7, 43, 52, 9, 26, 34, 22, 0},
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
};
|
||||
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
|
||||
|
||||
TEST(End2EndJit_EncryptedTensor_2D, identity) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
@@ -26,7 +27,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<2x10x!HLFHE.eint<6>> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
@@ -54,7 +55,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>, %i: index, %j: index) -> !HLFHE.eint
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
for (size_t i = 0; i < dims[0]; i++) {
|
||||
for (size_t j = 0; j < dims[1]; j++) {
|
||||
// Set %i, %j
|
||||
@@ -84,7 +85,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
@@ -113,7 +114,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
@@ -142,11 +143,12 @@ func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) ->
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t0 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Set the %t1 argument
|
||||
uint64_t t1_dim[2] = {2, 2};
|
||||
int64_t t1_dim[2] = {2, 2};
|
||||
uint8_t t1[2][2]{{6, 9}, {4, 0}};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)t1, 2, t1_dim));
|
||||
ASSERT_LLVM_ERROR(
|
||||
argument->setArg(1, (uint8_t *)t1, llvm::ArrayRef<int64_t>(t1_dim, 2)));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
|
||||
Reference in New Issue
Block a user