mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
enable_testing()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
|
||||
add_executable(
|
||||
hello_test
|
||||
hello_test.cc
|
||||
@@ -7,6 +10,7 @@ add_executable(
|
||||
target_link_libraries(
|
||||
hello_test
|
||||
gtest_main
|
||||
ZamalangSupport
|
||||
)
|
||||
|
||||
include(GoogleTest)
|
||||
|
||||
@@ -1,9 +1,347 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
// Demonstrate some basic assertions.
|
||||
TEST(HelloTest, BasicAssertions) {
|
||||
// Expect two strings not to be equal.
|
||||
EXPECT_STRNE("hello", "world");
|
||||
// Expect equality.
|
||||
EXPECT_EQ(7 * 6, 42);
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
|
||||
#define ASSERT_LLVM_ERROR(err) \
|
||||
if (err) { \
|
||||
llvm::errs() << "error: " << std::move(err) << "\n"; \
|
||||
ASSERT_TRUE(false); \
|
||||
}
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_eint) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
auto maybeResult = engine.run({1, 2});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 3);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_64) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi64>, %i: index) -> i64{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi64>
|
||||
return %c : i64
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_32) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_16) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi16>, %i: index) -> i16{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi16>
|
||||
return %c : i16
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_8) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi8>, %i: index) -> i8{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi8>
|
||||
return %c : i8
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi5>, %i: index) -> i5{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi5>
|
||||
return %c : i5
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_1) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi1>
|
||||
return %c : i1
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
|
||||
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
|
||||
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
|
||||
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
|
||||
return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Set the %j argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, j));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dim_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
|
||||
%c0 = constant 0 : index
|
||||
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : index
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, size);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
|
||||
return %t: tensor<1x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
size_t size_res = 1;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], 10);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
|
||||
%c_0 = constant 0 : index
|
||||
%c_1 = constant 1 : index
|
||||
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
|
||||
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
|
||||
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
|
||||
return %out: tensor<3x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the argument
|
||||
const size_t in_size = 2;
|
||||
uint8_t in[in_size] = {2, 16};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
const size_t size_res = 3;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], in[0] + in[0]);
|
||||
ASSERT_EQ(t_res[1], in[0] + in[1]);
|
||||
ASSERT_EQ(t_res[2], in[1] + in[1]);
|
||||
}
|
||||
Reference in New Issue
Block a user