fix: call getChecked to better handle failure

This commit is contained in:
youben11
2021-10-12 15:45:13 +01:00
committed by Quentin Bourgerie
parent 33d75a92f4
commit 6204f93878
4 changed files with 30 additions and 17 deletions

View File

@@ -3,6 +3,8 @@
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#ifdef __cplusplus
extern "C" {
@@ -11,8 +13,9 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe);
/// Creates an encrypted integer type of `width` bits
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGet(MlirContext context,
unsigned width);
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGetChecked(
MlirContext context, unsigned width,
mlir::function_ref<mlir::InFlightDiagnostic()> emitError);
/// If the type is an EncryptedInteger
MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType);

View File

@@ -4,6 +4,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/raw_ostream.h"
@@ -21,8 +22,14 @@ void mlir::zamalang::python::populateDialectHLFHESubmodule(
mlir_type_subclass(m, "EncryptedIntegerType",
hlfheTypeIsAnEncryptedIntegerType)
.def_classmethod(
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
return cls(hlfheEncryptedIntegerTypeGet(ctx, width));
});
.def_classmethod("get", [](pybind11::object cls, MlirContext ctx,
unsigned width) {
// We want the user to receive a python exception for not being able to
// create the eint
auto emitException = []() -> mlir::InFlightDiagnostic {
throw std::invalid_argument("can't create eint with the given width");
};
return cls(
hlfheEncryptedIntegerTypeGetChecked(ctx, width, emitException));
});
}

View File

@@ -22,6 +22,8 @@ bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) {
return unwrap(type).isa<EncryptedIntegerType>();
}
MlirType hlfheEncryptedIntegerTypeGet(MlirContext ctx, unsigned width) {
return wrap(EncryptedIntegerType::get(unwrap(ctx), width));
MlirType hlfheEncryptedIntegerTypeGetChecked(
MlirContext ctx, unsigned width,
mlir::function_ref<mlir::InFlightDiagnostic()> emitError) {
return wrap(EncryptedIntegerType::getChecked(emitError, unwrap(ctx), width));
}

View File

@@ -4,16 +4,17 @@ from zamalang import register_dialects
from zamalang.dialects import hlfhe
def test_eint():
@pytest.mark.parametrize("width", list(range(1, 8)))
def test_eint(width):
ctx = Context()
register_dialects(ctx)
eint = hlfhe.EncryptedIntegerType.get(ctx, 6)
assert eint.__str__() == "!HLFHE.eint<6>"
eint = hlfhe.EncryptedIntegerType.get(ctx, width)
assert eint.__str__() == f"!HLFHE.eint<{width}>"
# FIXME: need to handle error on call to hlfhe.EncryptedIntegerType.get and throw an exception to python
# def test_invalid_eint():
# ctx = Context()
# register_dialects(ctx)
# with pytest.raises(RuntimeError, match=r"mlir parsing failed"):
# eint = hlfhe.EncryptedIntegerType.get(ctx, 16)
@pytest.mark.parametrize("width", [0, 8, 10, 12])
def test_invalid_eint(width):
ctx = Context()
register_dialects(ctx)
with pytest.raises(ValueError, match=r"can't create eint with the given width"):
eint = hlfhe.EncryptedIntegerType.get(ctx, width)