diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 6349279ad..ca97c32cb 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -31,48 +31,16 @@ public: // Set a scalar argument at the given pos as a uint64_t. llvm::Error setArg(size_t pos, uint64_t arg); - // Set a argument at the given pos as a 1D tensor of int64. - llvm::Error setArg(size_t pos, uint64_t *data, size_t dim1) { - return setArg(pos, 64, (void *)data, 1, &dim1); + // Set a argument at the given pos as a 1D tensor of T. + template + llvm::Error setArg(size_t pos, T *data, int64_t dim1) { + return setArg(pos, data, llvm::ArrayRef(&dim1, 1)); } - // Set a argument at the given pos as a 1D tensor of int32. - llvm::Error setArg(size_t pos, uint32_t *data, size_t dim1) { - return setArg(pos, 32, (void *)data, 1, &dim1); - } - - // Set a argument at the given pos as a 1D tensor of int16. - llvm::Error setArg(size_t pos, uint16_t *data, size_t dim1) { - return setArg(pos, 16, (void *)data, 1, &dim1); - } - - // Set a argument at the given pos as a 1D tensor of int8. - llvm::Error setArg(size_t pos, uint8_t *data, size_t dim1) { - return setArg(pos, 8, (void *)data, 1, &dim1); - } - - // Set a argument at the given pos as a tensor of int64. - llvm::Error setArg(size_t pos, uint64_t *data, size_t numDim, - const size_t *dims) { - return setArg(pos, 64, (void *)data, numDim, dims); - } - - // Set a argument at the given pos as a tensor of int32. - llvm::Error setArg(size_t pos, uint32_t *data, size_t numDim, - const size_t *dims) { - return setArg(pos, 32, (void *)data, numDim, dims); - } - - // Set a argument at the given pos as a tensor of int32. - llvm::Error setArg(size_t pos, uint16_t *data, size_t numDim, - const size_t *dims) { - return setArg(pos, 16, (void *)data, numDim, dims); - } - - // Set a tensor argument at the given pos as a uint64_t. - llvm::Error setArg(size_t pos, uint8_t *data, size_t numDim, - const size_t *dims) { - return setArg(pos, 8, (void *)data, numDim, dims); + // Set a argument at the given pos as a tensor of T. + template + llvm::Error setArg(size_t pos, T *data, llvm::ArrayRef shape) { + return setArg(pos, 8 * sizeof(T), static_cast(data), shape); } // Get the result at the given pos as an uint64_t. @@ -86,8 +54,8 @@ public: llvm::Error getResult(size_t pos, uint64_t *res, size_t size); private: - llvm::Error setArg(size_t pos, size_t width, void *data, size_t numDim, - const size_t *dims); + llvm::Error setArg(size_t pos, size_t width, void *data, + llvm::ArrayRef shape); friend JITLambda; // Store the pointer on inputs values and outputs values diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 50a679946..614bd752c 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -230,7 +230,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { } llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, - size_t numDim, const size_t *dims) { + llvm::ArrayRef shape) { auto gate = inputGates[pos]; auto info = std::get<0>(gate); auto offset = std::get<1>(gate); @@ -272,25 +272,25 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)), llvm::inconvertibleErrorCode()); } - if (numDim != info.shape.dimensions.size()) { + if (shape.size() != info.shape.dimensions.size()) { return llvm::make_error( llvm::Twine("tensor argument #") .concat(llvm::Twine(pos)) .concat(" has not the expected number of dimension, got ") - .concat(llvm::Twine(numDim)) + .concat(llvm::Twine(shape.size())) .concat(" expected ") .concat(llvm::Twine(info.shape.dimensions.size())), llvm::inconvertibleErrorCode()); } - for (size_t i = 0; i < numDim; i++) { - if (dims[i] != info.shape.dimensions[i]) { + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i] != info.shape.dimensions[i]) { return llvm::make_error( llvm::Twine("tensor argument #") .concat(llvm::Twine(pos)) .concat(" has not the expected dimension #") .concat(llvm::Twine(i)) .concat(" , got ") - .concat(llvm::Twine(dims[i])) + .concat(llvm::Twine(shape[i])) .concat(" expected ") .concat(llvm::Twine(info.shape.dimensions[i])), llvm::inconvertibleErrorCode()); @@ -341,13 +341,13 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, rawArg[offset] = &inputs[offset]; offset++; // sizes is an array of size equals to numDim - for (size_t i = 0; i < numDim; i++) { - inputs[offset] = (void *)dims[i]; + for (size_t i = 0; i < shape.size(); i++) { + inputs[offset] = (void *)shape[i]; rawArg[offset] = &inputs[offset]; offset++; } // strides is an array of size equals to numDim - for (size_t i = 0; i < numDim; i++) { + for (size_t i = 0; i < shape.size(); i++) { inputs[offset] = (void *)0; rawArg[offset] = &inputs[offset]; offset++; @@ -411,12 +411,12 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, if (!info.encryption.hasValue()) { // just copy values for (size_t i = 0; i < size; i++) { - res[i] = ((uint64_t *)(aligned))[i]; + res[i] = ((uint64_t *)aligned)[i]; } } else { // decrypt and fill the result buffer for (size_t i = 0; i < size; i++) { - LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i]; + LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)aligned)[i]; if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) { return std::move(err); } diff --git a/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc b/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc index aabecf557..749948964 100644 --- a/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc +++ b/compiler/tests/unittest/end_to_end_jit_clear_tensor.cc @@ -222,14 +222,15 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{ /////////////////////////////////////////////////////////////////////////////// const size_t numDim = 2; -const size_t dim0 = 2; -const size_t dim1 = 10; -const size_t dims[numDim]{dim0, dim1}; +const int64_t dim0 = 2; +const int64_t dim1 = 10; +const int64_t dims[numDim]{dim0, dim1}; const uint64_t tensor2D[dim0][dim1]{ {0xFFFFFFFFFFFFFFFF, 0, 8978, 2587490, 90, 197864, 698735, 72132, 87474, 42}, {986, 1873, 298493, 34939, 443, 59874, 43, 743, 8409, 9433}, }; +const llvm::ArrayRef shape2D(dims, numDim); TEST(End2EndJit_ClearTensor_2D, identity) { mlir::zamalang::CompilerEngine engine; @@ -244,7 +245,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<2x10xi64> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D)); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result @@ -272,7 +273,7 @@ func @main(%t: tensor<2x10xi64>, %i: index, %j: index) -> i64 { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D)); for (size_t i = 0; i < dims[0]; i++) { for (size_t j = 0; j < dims[1]; j++) { // Set %i, %j @@ -302,7 +303,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D)); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result @@ -331,7 +332,7 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D)); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result @@ -360,11 +361,12 @@ func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t0 argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D)); // Set the %t1 argument - uint64_t t1_dim[2] = {2, 2}; + int64_t t1_dim[2] = {2, 2}; uint64_t t1[2][2]{{6, 9}, {4, 0}}; - ASSERT_LLVM_ERROR(argument->setArg(1, (uint64_t *)t1, 2, t1_dim)); + ASSERT_LLVM_ERROR( + argument->setArg(1, (uint64_t *)t1, llvm::ArrayRef(t1_dim, 2))); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result diff --git a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc index 20ea56ef1..c10774ba8 100644 --- a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc +++ b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc @@ -5,13 +5,14 @@ /////////////////////////////////////////////////////////////////////////////// const size_t numDim = 2; -const size_t dim0 = 2; -const size_t dim1 = 10; -const size_t dims[numDim]{dim0, dim1}; +const int64_t dim0 = 2; +const int64_t dim1 = 10; +const int64_t dims[numDim]{dim0, dim1}; const uint8_t tensor2D[dim0][dim1]{ {63, 12, 7, 43, 52, 9, 26, 34, 22, 0}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, }; +const llvm::ArrayRef shape2D(dims, numDim); TEST(End2EndJit_EncryptedTensor_2D, identity) { mlir::zamalang::CompilerEngine engine; @@ -26,7 +27,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<2x10x!HLFHE.eint<6>> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D)); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result @@ -54,7 +55,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>, %i: index, %j: index) -> !HLFHE.eint ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D)); for (size_t i = 0; i < dims[0]; i++) { for (size_t j = 0; j < dims[1]; j++) { // Set %i, %j @@ -84,7 +85,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D)); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result @@ -113,7 +114,7 @@ func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> { ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D)); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result @@ -142,11 +143,12 @@ func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) -> ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); // Set the %t0 argument - ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, numDim, dims)); + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D)); // Set the %t1 argument - uint64_t t1_dim[2] = {2, 2}; + int64_t t1_dim[2] = {2, 2}; uint8_t t1[2][2]{{6, 9}, {4, 0}}; - ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)t1, 2, t1_dim)); + ASSERT_LLVM_ERROR( + argument->setArg(1, (uint8_t *)t1, llvm::ArrayRef(t1_dim, 2))); // Invoke the function ASSERT_LLVM_ERROR(engine.invoke(*argument)); // Get and assert the result