From 01757fa6d550637862d7da7687fdf21c5d10fa29 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 11 Oct 2021 10:37:00 +0100 Subject: [PATCH] fix: forward errors instead of creating new ones LLVM errors should be handled/consumed. Creating a new one and leaving the previous one alive will crash the compiler. Whenever we don't want a crash (e.g. logging the error is enough), but still wanna continue the execution, we can just consume it. --- compiler/lib/Support/CompilerEngine.cpp | 9 ++-- compiler/lib/Support/Jit.cpp | 4 ++ compiler/tests/python/test_compiler_engine.py | 45 +++++++++++++++++++ compiler/tests/python/test_hlfhe_dialect.py | 9 ++++ compiler/tests/python/test_round_trip.py | 4 +- 5 files changed, 63 insertions(+), 8 deletions(-) diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index a3c27430f..5a41c7e46 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -91,14 +91,12 @@ llvm::Error CompilerEngine::compile( auto clientParameter = mlir::zamalang::createClientParametersForV0( fheContext, "main", module_ref.get()); if (auto err = clientParameter.takeError()) { - return llvm::make_error( - "cannot generate client parameters", llvm::inconvertibleErrorCode()); + return std::move(err); } auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); if (auto err = maybeKeySet.takeError()) { - return llvm::make_error("cannot generate keyset", - llvm::inconvertibleErrorCode()); + return std::move(err); } keySet = std::move(maybeKeySet.get()); @@ -148,8 +146,7 @@ llvm::Expected CompilerEngine::run(std::vector args) { auto arguments = std::move(maybeArgument.get()); for (auto i = 0; i < args.size(); i++) { if (auto err = arguments->setArg(i, args[i])) { - return llvm::make_error( - "cannot push argument", llvm::inconvertibleErrorCode()); + return std::move(err); } } // Invoke the lambda diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 779e5bfdb..95be53411 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -32,6 +32,7 @@ runJit(mlir::ModuleOp module, llvm::StringRef func, if (auto err = maybeArguments.takeError()) { ::mlir::zamalang::log_error() << "Cannot create lambda arguments: " << err << "\n"; + llvm::consumeError(std::move(err)); return mlir::failure(); } @@ -41,17 +42,20 @@ runJit(mlir::ModuleOp module, llvm::StringRef func, if (auto err = arguments->setArg(i, funcArgs[i])) { ::mlir::zamalang::log_error() << "Cannot push argument " << i << ": " << err << "\n"; + llvm::consumeError(std::move(err)); return mlir::failure(); } } // Invoke the lambda if (auto err = lambda->invoke(*arguments)) { ::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n"; + llvm::consumeError(std::move(err)); return mlir::failure(); } uint64_t res = 0; if (auto err = arguments->getResult(0, res)) { ::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n"; + llvm::consumeError(std::move(err)); return mlir::failure(); } llvm::errs() << res << "\n"; diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 81e275f35..6e8bfb33a 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -37,6 +37,29 @@ def test_compile_and_run(mlir_input, args, expected_result): assert engine.run(*args) == expected_result + +@pytest.mark.parametrize( + "mlir_input, args", + [ + pytest.param( + """ + func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, + (5, 7, 8), + id="add_eint_int_invalid_arg_number" + ), + ], +) +def test_compile_and_run_invalid_arg_number(mlir_input, args): + engine = CompilerEngine() + engine.compile_fhe(mlir_input) + with pytest.raises(RuntimeError, match=r"failed pushing integer argument"): + engine.run(*args) + + @pytest.mark.parametrize( "mlir_input, args, expected_result, tab_size", [ @@ -59,3 +82,25 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): engine = CompilerEngine() engine.compile_fhe(mlir_input) assert abs(engine.run(*args) - expected_result) / tab_size < 0.1 + + +@pytest.mark.parametrize( + "mlir_input", + [ + pytest.param( + """ + func @test(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> + { + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> + return %ret : !HLFHE.eint<7> + } + """, + id="not @main" + ), + ], +) +def test_compile_invalid(mlir_input): + engine = CompilerEngine() + with pytest.raises(RuntimeError, match=r"failed compiling"): + engine.compile_fhe(mlir_input) diff --git a/compiler/tests/python/test_hlfhe_dialect.py b/compiler/tests/python/test_hlfhe_dialect.py index a95e6265b..3f359d21d 100644 --- a/compiler/tests/python/test_hlfhe_dialect.py +++ b/compiler/tests/python/test_hlfhe_dialect.py @@ -1,3 +1,4 @@ +import pytest from mlir.ir import Context from zamalang import register_dialects from zamalang.dialects import hlfhe @@ -8,3 +9,11 @@ def test_eint(): register_dialects(ctx) eint = hlfhe.EncryptedIntegerType.get(ctx, 6) assert eint.__str__() == "!HLFHE.eint<6>" + + +# FIXME: need to handle error on call to hlfhe.EncryptedIntegerType.get and throw an exception to python +# def test_invalid_eint(): +# ctx = Context() +# register_dialects(ctx) +# with pytest.raises(RuntimeError, match=r"mlir parsing failed"): +# eint = hlfhe.EncryptedIntegerType.get(ctx, 16) diff --git a/compiler/tests/python/test_round_trip.py b/compiler/tests/python/test_round_trip.py index e523ec17a..6aa1d645a 100644 --- a/compiler/tests/python/test_round_trip.py +++ b/compiler/tests/python/test_round_trip.py @@ -47,14 +47,14 @@ VALID_INPUTS = [ ] INVALID_INPUTS = [ - pytest.param("nothing really mlir", id="add_eint_int_cst"), + pytest.param("nothing really mlir", id="english sentence"), pytest.param( """ func @test(%arg0: !HLFHE.eint<0>) { return } """, - id="add_eint_int_cst", + id="eint<0>", ), ]