From 6204f93878062c92b3ef0b3fb450bcd760585255 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 12 Oct 2021 15:45:13 +0100 Subject: [PATCH] fix: call `getChecked` to better handle failure --- compiler/include/zamalang-c/Dialect/HLFHE.h | 7 +++++-- compiler/lib/Bindings/Python/HLFHEModule.cpp | 15 +++++++++++---- compiler/lib/CAPI/Dialect/HLFHE.cpp | 6 ++++-- compiler/tests/python/test_hlfhe_dialect.py | 19 ++++++++++--------- 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/compiler/include/zamalang-c/Dialect/HLFHE.h b/compiler/include/zamalang-c/Dialect/HLFHE.h index 512a6f763..f0761a5ab 100644 --- a/compiler/include/zamalang-c/Dialect/HLFHE.h +++ b/compiler/include/zamalang-c/Dialect/HLFHE.h @@ -3,6 +3,8 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LLVM.h" #ifdef __cplusplus extern "C" { @@ -11,8 +13,9 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe); /// Creates an encrypted integer type of `width` bits -MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGet(MlirContext context, - unsigned width); +MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGetChecked( + MlirContext context, unsigned width, + mlir::function_ref emitError); /// If the type is an EncryptedInteger MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType); diff --git a/compiler/lib/Bindings/Python/HLFHEModule.cpp b/compiler/lib/Bindings/Python/HLFHEModule.cpp index 35d49bfde..0b65b3849 100644 --- a/compiler/lib/Bindings/Python/HLFHEModule.cpp +++ b/compiler/lib/Bindings/Python/HLFHEModule.cpp @@ -4,6 +4,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/IR/Diagnostics.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/raw_ostream.h" @@ -21,8 +22,14 @@ void mlir::zamalang::python::populateDialectHLFHESubmodule( mlir_type_subclass(m, "EncryptedIntegerType", hlfheTypeIsAnEncryptedIntegerType) - .def_classmethod( - "get", [](pybind11::object cls, MlirContext ctx, unsigned width) { - return cls(hlfheEncryptedIntegerTypeGet(ctx, width)); - }); + .def_classmethod("get", [](pybind11::object cls, MlirContext ctx, + unsigned width) { + // We want the user to receive a python exception for not being able to + // create the eint + auto emitException = []() -> mlir::InFlightDiagnostic { + throw std::invalid_argument("can't create eint with the given width"); + }; + return cls( + hlfheEncryptedIntegerTypeGetChecked(ctx, width, emitException)); + }); } \ No newline at end of file diff --git a/compiler/lib/CAPI/Dialect/HLFHE.cpp b/compiler/lib/CAPI/Dialect/HLFHE.cpp index 3557646ee..8f7a87239 100644 --- a/compiler/lib/CAPI/Dialect/HLFHE.cpp +++ b/compiler/lib/CAPI/Dialect/HLFHE.cpp @@ -22,6 +22,8 @@ bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) { return unwrap(type).isa(); } -MlirType hlfheEncryptedIntegerTypeGet(MlirContext ctx, unsigned width) { - return wrap(EncryptedIntegerType::get(unwrap(ctx), width)); +MlirType hlfheEncryptedIntegerTypeGetChecked( + MlirContext ctx, unsigned width, + mlir::function_ref emitError) { + return wrap(EncryptedIntegerType::getChecked(emitError, unwrap(ctx), width)); } diff --git a/compiler/tests/python/test_hlfhe_dialect.py b/compiler/tests/python/test_hlfhe_dialect.py index 3f359d21d..311486c28 100644 --- a/compiler/tests/python/test_hlfhe_dialect.py +++ b/compiler/tests/python/test_hlfhe_dialect.py @@ -4,16 +4,17 @@ from zamalang import register_dialects from zamalang.dialects import hlfhe -def test_eint(): +@pytest.mark.parametrize("width", list(range(1, 8))) +def test_eint(width): ctx = Context() register_dialects(ctx) - eint = hlfhe.EncryptedIntegerType.get(ctx, 6) - assert eint.__str__() == "!HLFHE.eint<6>" + eint = hlfhe.EncryptedIntegerType.get(ctx, width) + assert eint.__str__() == f"!HLFHE.eint<{width}>" -# FIXME: need to handle error on call to hlfhe.EncryptedIntegerType.get and throw an exception to python -# def test_invalid_eint(): -# ctx = Context() -# register_dialects(ctx) -# with pytest.raises(RuntimeError, match=r"mlir parsing failed"): -# eint = hlfhe.EncryptedIntegerType.get(ctx, 16) +@pytest.mark.parametrize("width", [0, 8, 10, 12]) +def test_invalid_eint(width): + ctx = Context() + register_dialects(ctx) + with pytest.raises(ValueError, match=r"can't create eint with the given width"): + eint = hlfhe.EncryptedIntegerType.get(ctx, width)