style: format tests

This commit is contained in:
Mayeul@Zama
2022-03-09 17:41:09 +01:00
committed by mayeul-zama
parent a94b6fcabe
commit 0d7c3570cb
7 changed files with 184 additions and 179 deletions

View File

@@ -9,62 +9,55 @@ namespace clientlib = concretelang::clientlib;
TEST(Support, client_parameters_json_serde) {
clientlib::ClientParameters params0;
params0.secretKeys = {
{clientlib::SMALL_KEY, {/*.size = */ 12}},
{clientlib::BIG_KEY, {/*.size = */ 14}},
{clientlib::SMALL_KEY, {/*.size = */ 12}},
{clientlib::BIG_KEY, {/*.size = */ 14}},
};
params0.bootstrapKeys = {
{
"bsk_v0", {
/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
{"bsk_v0",
{/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.glweDimension = */ 3,
/*.variance = */ 0.001
}
},{
"wtf_bsk_v0", {
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.glweDimension = */ 1,
/*.variance = */ 0.0001,
}
},
};
params0.keyswitchKeys = {
{
"ksk_v0", {
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
}
}
/*.variance = */ 0.001}},
{"wtf_bsk_v0",
{
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.glweDimension = */ 1,
/*.variance = */ 0.0001,
}},
};
params0.keyswitchKeys = {{"ksk_v0",
{
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
}}};
params0.inputs = {
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.01, {4}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
},
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.01, {4}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
},
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
params0.outputs = {
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
auto json = clientlib::toJSON(params0);
std::string jsonStr;
llvm::raw_string_ostream os(jsonStr);
os << json;
auto parseResult =
llvm::json::parse<clientlib::ClientParameters>(jsonStr);
auto parseResult = llvm::json::parse<clientlib::ClientParameters>(jsonStr);
ASSERT_EXPECTED_VALUE(parseResult, params0);
}

View File

@@ -20,30 +20,22 @@ using namespace concretelang::testlib;
using concretelang::clientlib::scalar_in;
using concretelang::clientlib::scalar_out;
using concretelang::clientlib::tensor1_in;
using concretelang::clientlib::tensor2_in;
using concretelang::clientlib::tensor1_out;
using concretelang::clientlib::tensor2_in;
using concretelang::clientlib::tensor2_out;
using concretelang::clientlib::tensor3_out;
std::vector<uint8_t>
values_3bits() {
return {0, 1, 2, 5, 7};
}
std::vector<uint8_t>
values_6bits() {
return {0, 1, 2, 13, 22, 59, 62, 63};
}
std::vector<uint8_t>
values_7bits() {
return {0, 1, 2, 63, 64, 65, 125, 126};
}
std::vector<uint8_t> values_3bits() { return {0, 1, 2, 5, 7}; }
std::vector<uint8_t> values_6bits() { return {0, 1, 2, 13, 22, 59, 62, 63}; }
std::vector<uint8_t> values_7bits() { return {0, 1, 2, 63, 64, 65, 125, 126}; }
mlir::concretelang::CompilerEngine::Library
compile(std::string outputLib, std::string source, std::string funcname = FUNCNAME) {
compile(std::string outputLib, std::string source,
std::string funcname = FUNCNAME) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::JitCompilerEngine ce {ccx};
mlir::concretelang::JitCompilerEngine ce{ccx};
ce.setClientParametersFuncName(funcname);
auto result = ce.compile(sources, outputLib);
assert(result);
@@ -53,13 +45,11 @@ compile(std::string outputLib, std::string source, std::string funcname = FUNCNA
static const std::string THIS_TEST_DIRECTORY = "tests/TestLib";
static const std::string OUT_DIRECTORY = THIS_TEST_DIRECTORY + "/out";
template<typename Info>
std::string outputLibFromThis(Info *info) {
template <typename Info> std::string outputLibFromThis(Info *info) {
return OUT_DIRECTORY + "/" + std::string(info->name());
}
template<typename Lambda>
Lambda load(std::string outputLib) {
template <typename Lambda> Lambda load(std::string outputLib) {
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
assert(l.has_value());
return l.value();
@@ -102,7 +92,7 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
for(auto a: values_7bits()) {
for (auto a : values_7bits()) {
auto res = lambda.call(a);
ASSERT_EQ_OUTCOME(res, a);
}
@@ -116,14 +106,16 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for(auto a: values_7bits()) for(auto b: values_7bits()) {
if (a > b) {
continue;
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a);
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a);
}
}
TEST(CompiledModule, call_2s_1s) {
@@ -135,14 +127,16 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for(auto a: values_7bits()) for(auto b: values_7bits()) {
if (a > b) {
continue;
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a + b);
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a + b);
}
}
TEST(CompiledModule, call_1s_1s_bad_call) {
@@ -169,7 +163,7 @@ func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<tensor1_out, scalar_in>>(outputLib);
for(auto a: values_7bits()) {
for (auto a : values_7bits()) {
auto res = lambda.call(a);
EXPECT_TRUE(res);
tensor1_out v = res.value();
@@ -186,9 +180,10 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<7>> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<tensor1_out, scalar_in, scalar_in>>(outputLib);
for(auto a : values_7bits()) {
auto res = lambda.call(a, a+1);
auto lambda =
load<TestTypedLambda<tensor1_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_7bits()) {
auto res = lambda.call(a, a + 1);
EXPECT_TRUE(res);
tensor1_out v = res.value();
EXPECT_EQ(v[0], (scalar_out)a);
@@ -207,7 +202,7 @@ func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, tensor1_in>>(outputLib);
for(uint8_t a : values_7bits()) {
for (uint8_t a : values_7bits()) {
tensor1_in ta = {a};
auto res = lambda.call(ta);
ASSERT_EQ_OUTCOME(res, a);
@@ -227,7 +222,7 @@ func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
auto res = lambda.call(ta);
ASSERT_TRUE(res);
tensor1_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for (size_t i = 0; i < v.size(); i++) {
EXPECT_EQ(v[i], ta[i]);
}
}
@@ -244,12 +239,14 @@ func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, tensor1_in, std::array<uint8_t, 3>>>(outputLib);
tensor1_in ta {1, 2, 3};
std::array<uint8_t, 3> tb {5, 7, 9};
auto lambda =
load<TestTypedLambda<scalar_out, tensor1_in, std::array<uint8_t, 3>>>(
outputLib);
tensor1_in ta{1, 2, 3};
std::array<uint8_t, 3> tb{5, 7, 9};
auto res = lambda.call(ta, tb);
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
std::accumulate(tb.begin(), tb.end(), 0u);
std::accumulate(tb.begin(), tb.end(), 0u);
ASSERT_EQ_OUTCOME(res, expected);
}
@@ -263,15 +260,12 @@ func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<tensor2_out, tensor2_in>>(outputLib);
tensor2_in ta = {{
{1, 2, 3},
{4, 5, 6}
}};
tensor2_in ta = {{{1, 2, 3}, {4, 5, 6}}};
auto res = lambda.call(ta);
ASSERT_TRUE(res);
tensor2_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v.size(); j++) {
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v.size(); j++) {
EXPECT_EQ(v[i][j], ta[i][j]);
}
}
@@ -287,16 +281,13 @@ func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<tensor3_out, tensor3_in>>(outputLib);
tensor3_in ta = {{
{{ {1}, {2}, {3} }},
{{ {4}, {5}, {6} }}
}};
tensor3_in ta = {{{{{1}, {2}, {3}}}, {{{4}, {5}, {6}}}}};
auto res = lambda.call(ta);
ASSERT_TRUE(res);
tensor3_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v[i].size(); j++) {
for(size_t k = 0; k < v[i][j].size(); k++) {
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
for (size_t k = 0; k < v[i][j].size(); k++) {
EXPECT_EQ(v[i][j][k], ta[i][j][k]);
}
}
@@ -313,17 +304,15 @@ func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>, %arg1: tensor<2x3x1x!FHE.eint<7>>)
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<tensor3_out, tensor3_in, tensor3_in>>(outputLib);
tensor3_in ta = {{
{{ {1}, {2}, {3} }},
{{ {4}, {5}, {6} }}
}};
auto lambda =
load<TestTypedLambda<tensor3_out, tensor3_in, tensor3_in>>(outputLib);
tensor3_in ta = {{{{{1}, {2}, {3}}}, {{{4}, {5}, {6}}}}};
auto res = lambda.call(ta, ta);
ASSERT_TRUE(res);
tensor3_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v[i].size(); j++) {
for(size_t k = 0; k < v[i][j].size(); k++) {
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
for (size_t k = 0; k < v[i][j].size(); k++) {
EXPECT_EQ(v[i][j][k], 2 * ta[i][j][k]);
}
}
@@ -353,8 +342,8 @@ func @extract(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
auto cLambda_ = extract::load(jsonPath);
ASSERT_TRUE(cLambda_);
tensor1_in ta {1, 2, 3};
tensor1_in tb {5, 7, 9};
tensor1_in ta{1, 2, 3};
tensor1_in tb{5, 7, 9};
auto sLambda_ = ServerLambda::load(extract::name, outputLib);
ASSERT_TRUE(sLambda_);
auto cLambda = cLambda_.value();
@@ -365,12 +354,12 @@ func @extract(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !
auto testLambda = TestTypedLambdaFrom(cLambda, sLambda, keySet);
auto res = testLambda.call(ta, tb);
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
std::accumulate(tb.begin(), tb.end(), 0u);
std::accumulate(tb.begin(), tb.end(), 0u);
ASSERT_EQ_OUTCOME(res, expected);
EXPECT_EQ(
fileContent(THIS_TEST_DIRECTORY + "/call_2t_1s_with_header-client.h.generated"),
fileContent(OUT_DIRECTORY + "/call_2t_1s_with_header-client.h"));
EXPECT_EQ(fileContent(THIS_TEST_DIRECTORY +
"/call_2t_1s_with_header-client.h.generated"),
fileContent(OUT_DIRECTORY + "/call_2t_1s_with_header-client.h"));
}
TEST(CompiledModule, call_2s_1s_lookup_table) {
@@ -386,9 +375,11 @@ func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<3>) -> !FHE.eint<6> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for(auto a: values_6bits()) for(auto b: values_3bits()) {
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a + b);
}
auto lambda =
load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for (auto a : values_6bits())
for (auto b : values_3bits()) {
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a + b);
}
}

View File

@@ -10,7 +10,8 @@
///////////////////////////////////////////////////////////////////////////////
TEST(ParallelizeAndRunFHE, add_eint_tree) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %arg3: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
%2 = "FHE.add_eint"(%arg0, %arg2): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
@@ -54,7 +55,8 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %arg3:
%35 = "FHE.add_eint"(%33, %34): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %35: !FHE.eint<7>
}
)XXX", "main", false, true);
)XXX",
"main", false, true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64, 4_u64), 150);
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64, 6_u64, 7_u64), 74);

View File

@@ -37,13 +37,14 @@ func @main(%t: tensor<10xi64>) -> tensor<10xi64> {
}
TEST(End2EndJit_ClearTensor_1D, extract_64) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<10xi64>, %i: index) -> i64{
%c = tensor.extract %t[%i] : tensor<10xi64>
return %c : i64
}
)XXX",
"main", true);
"main", true);
uint64_t arg[]{0xFFFFFFFFFFFFFFFF,
0,
@@ -62,13 +63,14 @@ func @main(%t: tensor<10xi64>, %i: index) -> i64{
}
TEST(End2EndJit_ClearTensor_1D, extract_32) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
%c = tensor.extract %t[%i] : tensor<10xi32>
return %c : i32
}
)XXX",
"main", true);
"main", true);
uint32_t arg[]{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
@@ -80,13 +82,14 @@ func @main(%t: tensor<10xi32>, %i: index) -> i32{
TEST(End2EndJit_ClearTensor_1D, extract_16) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<10xi16>, %i: index) -> i16{
%c = tensor.extract %t[%i] : tensor<10xi16>
return %c : i16
}
)XXX",
"main", true);
"main", true);
uint16_t arg[]{0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
@@ -98,13 +101,14 @@ func @main(%t: tensor<10xi16>, %i: index) -> i16{
TEST(End2EndJit_ClearTensor_1D, extract_8) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<10xi8>, %i: index) -> i8{
%c = tensor.extract %t[%i] : tensor<10xi8>
return %c : i8
}
)XXX",
"main", true);
"main", true);
uint8_t arg[]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
@@ -115,13 +119,14 @@ func @main(%t: tensor<10xi8>, %i: index) -> i8{
TEST(End2EndJit_ClearTensor_1D, extract_5) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<10xi5>, %i: index) -> i5{
%c = tensor.extract %t[%i] : tensor<10xi5>
return %c : i5
}
)XXX",
"main", true);
"main", true);
uint8_t arg[]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
@@ -132,13 +137,14 @@ func @main(%t: tensor<10xi5>, %i: index) -> i5{
TEST(End2EndJit_ClearTensor_1D, extract_1) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<10xi1>, %i: index) -> i1{
%c = tensor.extract %t[%i] : tensor<10xi1>
return %c : i1
}
)XXX",
"main", true);
"main", true);
uint8_t arg[]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
@@ -185,12 +191,13 @@ const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
TEST(End2EndJit_ClearTensor_2D, identity) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<2x10xi64>) -> tensor<2x10xi64> {
return %t : tensor<2x10xi64>
}
)XXX",
"main", true);
"main", true);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>
@@ -210,13 +217,14 @@ func @main(%t: tensor<2x10xi64>) -> tensor<2x10xi64> {
TEST(End2EndJit_ClearTensor_2D, extract) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<2x10xi64>, %i: index, %j: index) -> i64 {
%c = tensor.extract %t[%i, %j] : tensor<2x10xi64>
return %c : i64
}
)XXX",
"main", true);
"main", true);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>
@@ -233,13 +241,14 @@ func @main(%t: tensor<2x10xi64>, %i: index, %j: index) -> i64 {
TEST(End2EndJit_ClearTensor_2D, extract_slice) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10xi64> to
tensor<1x5xi64> return %r : tensor<1x5xi64>
}
)XXX",
"main", true);
"main", true);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>
@@ -261,13 +270,14 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
TEST(End2EndJit_ClearTensor_2D, extract_slice_stride) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10xi64> to
tensor<1x5xi64> return %r : tensor<1x5xi64>
}
)XXX",
"main", true);
"main", true);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>
@@ -289,13 +299,14 @@ func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
TEST(End2EndJit_ClearTensor_2D, insert_slice) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2xi64>
into tensor<2x10xi64> return %r : tensor<2x10xi64>
}
)XXX",
"main", true);
"main", true);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>
@@ -339,10 +350,11 @@ void checkResultTensor(
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>>>());
mlir::concretelang::TensorLambdaArgument<mlir::concretelang::IntLambdaArgument<T>>
&resp = (*res)
->cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>>>();
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>> &resp =
(*res)
->cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
ASSERT_EQ(resp.getDimensions().at(0), 5);
@@ -374,7 +386,8 @@ TEST_P(ReturnTensorWithPrecision, return_tensor) {
checkedJit(mlirProgram.str(), "main", true);
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::concretelang::LambdaArgument>>({});
lambda.operator()<std::unique_ptr<mlir::concretelang::LambdaArgument>>(
{});
ASSERT_EXPECTED_SUCCESS(res);
bool status;

View File

@@ -8,7 +8,8 @@
const mlir::concretelang::V0FHEConstraint defaultV0Constraints{10, 7};
TEST(CompileAndRunDFR, start_stop) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
func private @_dfr_stop()
func private @_dfr_start()
func @main() -> i64{
@@ -17,12 +18,14 @@ func @main() -> i64{
call @_dfr_stop() : () -> ()
return %1 : i64
}
)XXX", "main", true);
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(), 7);
}
TEST(CompileAndRunDFR, 0in1out_task) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @_dfr_stop()
@@ -52,12 +55,14 @@ TEST(CompileAndRunDFR, 0in1out_task) {
llvm.store %2, %arg0 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(), 7);
}
TEST(CompileAndRunDFR, 1in1out_task) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @malloc(i64) -> !llvm.ptr<i8>
@@ -96,13 +101,15 @@ TEST(CompileAndRunDFR, 1in1out_task) {
llvm.store %2, %arg1 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(5_u64), 7);
}
TEST(CompileAndRunDFR, 2in1out_task) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @malloc(i64) -> !llvm.ptr<i8>
@@ -150,15 +157,15 @@ TEST(CompileAndRunDFR, 2in1out_task) {
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 6_u64), 7);
}
TEST(CompileAndRunDFR, taskgraph) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(R"XXX(
llvm.func @_dfr_await_future(!llvm.ptr<i64>) -> !llvm.ptr<ptr<i64>> attributes {sym_visibility = "private"}
llvm.func @_dfr_create_async_task(...) attributes {sym_visibility = "private"}
llvm.func @malloc(i64) -> !llvm.ptr<i8>
@@ -340,7 +347,8 @@ TEST(CompileAndRunDFR, taskgraph) {
llvm.store %2, %arg2 : !llvm.ptr<i64>
llvm.return
}
)XXX", "main", true);
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64, 3_u64), 54);
ASSERT_EXPECTED_VALUE(lambda(2_u64, 5_u64, 1_u64), 72);

View File

@@ -4,7 +4,6 @@
using Lambda = mlir::concretelang::JitCompilerEngine::Lambda;
TEST(Lambda_check_param, int_to_void_missing_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !FHE.eint<1>) {
@@ -68,7 +67,7 @@ TEST(Lambda_check_param, scalar_tensor_to_scalar) {
return %arg0: !FHE.eint<1>
}
)XXX");
uint8_t arg[2] = {1 ,2};
uint8_t arg[2] = {1, 2};
ASSERT_EXPECTED_SUCCESS(lambda(1_u64, arg, ARRAY_SIZE(arg)));
}
@@ -80,8 +79,9 @@ TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) {
return %arg0: !FHE.eint<1>
}
)XXX");
uint8_t arg[2] = {1 ,2};
ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg)));
uint8_t arg[2] = {1, 2};
ASSERT_EXPECTED_FAILURE(
lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg)));
}
TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) {
@@ -92,10 +92,9 @@ TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) {
return %arg1: tensor<2x!FHE.eint<1>>
}
)XXX");
uint8_t arg[2] = {1 ,2};
uint8_t arg[2] = {1, 2};
ASSERT_EXPECTED_SUCCESS(
lambda.operator()<std::vector<uint8_t>>(1_u64, arg, ARRAY_SIZE(arg))
);
lambda.operator()<std::vector<uint8_t>>(1_u64, arg, ARRAY_SIZE(arg)));
}
TEST(Lambda_check_param, DISABLED_check_parameters_scalar_too_big) {

View File

@@ -94,13 +94,14 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
} while (0)
#define ASSERT_EQ_OUTCOME(val, exp) \
if(!val.has_value()) { \
if (!val.has_value()) { \
std::string msg = "ERROR: <" + val.error().mesg + "> \n"; \
GTEST_FATAL_FAILURE_(msg.c_str()); \
}; \
ASSERT_EQ(val.value(), exp);
static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache() {
static inline llvm::Optional<mlir::concretelang::KeySetCache>
getTestKeySetCache() {
llvm::SmallString<0> cachePath;
llvm::sys::path::system_temp_directory(true, cachePath);
@@ -111,20 +112,19 @@ static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache
mlir::concretelang::KeySetCache(cachePathStr));
}
static inline std::shared_ptr<mlir::concretelang::KeySetCache> getTestKeySetCachePtr() {
static inline std::shared_ptr<mlir::concretelang::KeySetCache>
getTestKeySetCachePtr() {
return std::make_shared<mlir::concretelang::KeySetCache>(
getTestKeySetCache().getValue());
getTestKeySetCache().getValue());
}
// Jit-compiles the function specified by `func` from `src` and
// returns the corresponding lambda. Any compilation errors are caught
// and reult in abnormal termination.
template <typename F>
mlir::concretelang::JitCompilerEngine::Lambda
internalCheckedJit(F checkFunc, llvm::StringRef src,
llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false,
bool autoParallelize = false) {
mlir::concretelang::JitCompilerEngine::Lambda internalCheckedJit(
F checkFunc, llvm::StringRef src, llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false, bool autoParallelize = false) {
mlir::concretelang::JitCompilerEngine engine;
@@ -162,9 +162,8 @@ static inline uint64_t operator"" _u64(unsigned long long int v) { return v; }
// caller instead of `internalCheckedJit`.
#define checkedJit(...) \
internalCheckedJit( \
[](llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> &lambda) { \
ASSERT_EXPECTED_SUCCESS(lambda); \
}, \
[](llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> \
&lambda) { ASSERT_EXPECTED_SUCCESS(lambda); }, \
__VA_ARGS__)
#endif