refactor(compiler): Refactor JITLambda::Argument::setArg

This commit is contained in:
Quentin Bourgerie
2021-10-21 14:40:07 +02:00
parent b5f68c20c7
commit 247cc489c5
4 changed files with 45 additions and 73 deletions

View File

@@ -31,48 +31,16 @@ public:
// Set a scalar argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Set a argument at the given pos as a 1D tensor of int64.
llvm::Error setArg(size_t pos, uint64_t *data, size_t dim1) {
return setArg(pos, 64, (void *)data, 1, &dim1);
// Set a argument at the given pos as a 1D tensor of T.
template <typename T>
llvm::Error setArg(size_t pos, T *data, int64_t dim1) {
return setArg<T>(pos, data, llvm::ArrayRef<int64_t>(&dim1, 1));
}
// Set a argument at the given pos as a 1D tensor of int32.
llvm::Error setArg(size_t pos, uint32_t *data, size_t dim1) {
return setArg(pos, 32, (void *)data, 1, &dim1);
}
// Set a argument at the given pos as a 1D tensor of int16.
llvm::Error setArg(size_t pos, uint16_t *data, size_t dim1) {
return setArg(pos, 16, (void *)data, 1, &dim1);
}
// Set a argument at the given pos as a 1D tensor of int8.
llvm::Error setArg(size_t pos, uint8_t *data, size_t dim1) {
return setArg(pos, 8, (void *)data, 1, &dim1);
}
// Set a argument at the given pos as a tensor of int64.
llvm::Error setArg(size_t pos, uint64_t *data, size_t numDim,
const size_t *dims) {
return setArg(pos, 64, (void *)data, numDim, dims);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint32_t *data, size_t numDim,
const size_t *dims) {
return setArg(pos, 32, (void *)data, numDim, dims);
}
// Set a argument at the given pos as a tensor of int32.
llvm::Error setArg(size_t pos, uint16_t *data, size_t numDim,
const size_t *dims) {
return setArg(pos, 16, (void *)data, numDim, dims);
}
// Set a tensor argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint8_t *data, size_t numDim,
const size_t *dims) {
return setArg(pos, 8, (void *)data, numDim, dims);
// Set a argument at the given pos as a tensor of T.
template <typename T>
llvm::Error setArg(size_t pos, T *data, llvm::ArrayRef<int64_t> shape) {
return setArg(pos, 8 * sizeof(T), static_cast<void *>(data), shape);
}
// Get the result at the given pos as an uint64_t.
@@ -86,8 +54,8 @@ public:
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
private:
llvm::Error setArg(size_t pos, size_t width, void *data, size_t numDim,
const size_t *dims);
llvm::Error setArg(size_t pos, size_t width, void *data,
llvm::ArrayRef<int64_t> shape);
friend JITLambda;
// Store the pointer on inputs values and outputs values

View File

@@ -230,7 +230,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
}
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
size_t numDim, const size_t *dims) {
llvm::ArrayRef<int64_t> shape) {
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
@@ -272,25 +272,25 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
}
if (numDim != info.shape.dimensions.size()) {
if (shape.size() != info.shape.dimensions.size()) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("tensor argument #")
.concat(llvm::Twine(pos))
.concat(" has not the expected number of dimension, got ")
.concat(llvm::Twine(numDim))
.concat(llvm::Twine(shape.size()))
.concat(" expected ")
.concat(llvm::Twine(info.shape.dimensions.size())),
llvm::inconvertibleErrorCode());
}
for (size_t i = 0; i < numDim; i++) {
if (dims[i] != info.shape.dimensions[i]) {
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != info.shape.dimensions[i]) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("tensor argument #")
.concat(llvm::Twine(pos))
.concat(" has not the expected dimension #")
.concat(llvm::Twine(i))
.concat(" , got ")
.concat(llvm::Twine(dims[i]))
.concat(llvm::Twine(shape[i]))
.concat(" expected ")
.concat(llvm::Twine(info.shape.dimensions[i])),
llvm::inconvertibleErrorCode());
@@ -341,13 +341,13 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data,
rawArg[offset] = &inputs[offset];
offset++;
// sizes is an array of size equals to numDim
for (size_t i = 0; i < numDim; i++) {
inputs[offset] = (void *)dims[i];
for (size_t i = 0; i < shape.size(); i++) {
inputs[offset] = (void *)shape[i];
rawArg[offset] = &inputs[offset];
offset++;
}
// strides is an array of size equals to numDim
for (size_t i = 0; i < numDim; i++) {
for (size_t i = 0; i < shape.size(); i++) {
inputs[offset] = (void *)0;
rawArg[offset] = &inputs[offset];
offset++;
@@ -411,12 +411,12 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
if (!info.encryption.hasValue()) {
// just copy values
for (size_t i = 0; i < size; i++) {
res[i] = ((uint64_t *)(aligned))[i];
res[i] = ((uint64_t *)aligned)[i];
}
} else {
// decrypt and fill the result buffer
for (size_t i = 0; i < size; i++) {
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i];
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)aligned)[i];
if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) {
return std::move(err);
}

View File

@@ -222,14 +222,15 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{
///////////////////////////////////////////////////////////////////////////////
const size_t numDim = 2;
const size_t dim0 = 2;
const size_t dim1 = 10;
const size_t dims[numDim]{dim0, dim1};
const int64_t dim0 = 2;
const int64_t dim1 = 10;
const int64_t dims[numDim]{dim0, dim1};
const uint64_t tensor2D[dim0][dim1]{
{0xFFFFFFFFFFFFFFFF, 0, 8978, 2587490, 90, 197864, 698735, 72132, 87474,
42},
{986, 1873, 298493, 34939, 443, 59874, 43, 743, 8409, 9433},
};
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
TEST(End2EndJit_ClearTensor_2D, identity) {
mlir::zamalang::CompilerEngine engine;
@@ -244,7 +245,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<2x10xi64> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
@@ -272,7 +273,7 @@ func @main(%t: tensor<2x10xi64>, %i: index, %j: index) -> i64 {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
for (size_t i = 0; i < dims[0]; i++) {
for (size_t j = 0; j < dims[1]; j++) {
// Set %i, %j
@@ -302,7 +303,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
@@ -331,7 +332,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
@@ -360,11 +361,12 @@ func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t0 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
// Set the %t1 argument
uint64_t t1_dim[2] = {2, 2};
int64_t t1_dim[2] = {2, 2};
uint64_t t1[2][2]{{6, 9}, {4, 0}};
ASSERT_LLVM_ERROR(argument->setArg(1, (uint64_t *)t1, 2, t1_dim));
ASSERT_LLVM_ERROR(
argument->setArg(1, (uint64_t *)t1, llvm::ArrayRef<int64_t>(t1_dim, 2)));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result

View File

@@ -5,13 +5,14 @@
///////////////////////////////////////////////////////////////////////////////
const size_t numDim = 2;
const size_t dim0 = 2;
const size_t dim1 = 10;
const size_t dims[numDim]{dim0, dim1};
const int64_t dim0 = 2;
const int64_t dim1 = 10;
const int64_t dims[numDim]{dim0, dim1};
const uint8_t tensor2D[dim0][dim1]{
{63, 12, 7, 43, 52, 9, 26, 34, 22, 0},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
};
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
TEST(End2EndJit_EncryptedTensor_2D, identity) {
mlir::zamalang::CompilerEngine engine;
@@ -26,7 +27,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<2x10x!HLFHE.eint<6>> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
@@ -54,7 +55,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>, %i: index, %j: index) -> !HLFHE.eint
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
for (size_t i = 0; i < dims[0]; i++) {
for (size_t j = 0; j < dims[1]; j++) {
// Set %i, %j
@@ -84,7 +85,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
@@ -113,7 +114,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
@@ -142,11 +143,12 @@ func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) ->
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t0 argument
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims));
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
// Set the %t1 argument
uint64_t t1_dim[2] = {2, 2};
int64_t t1_dim[2] = {2, 2};
uint8_t t1[2][2]{{6, 9}, {4, 0}};
ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)t1, 2, t1_dim));
ASSERT_LLVM_ERROR(
argument->setArg(1, (uint8_t *)t1, llvm::ArrayRef<int64_t>(t1_dim, 2)));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result