mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: call getChecked to better handle failure
This commit is contained in:
committed by
Quentin Bourgerie
parent
33d75a92f4
commit
6204f93878
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@@ -11,8 +13,9 @@ extern "C" {
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe);
|
||||
|
||||
/// Creates an encrypted integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGet(MlirContext context,
|
||||
unsigned width);
|
||||
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGetChecked(
|
||||
MlirContext context, unsigned width,
|
||||
mlir::function_ref<mlir::InFlightDiagnostic()> emitError);
|
||||
|
||||
/// If the type is an EncryptedInteger
|
||||
MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
@@ -21,8 +22,14 @@ void mlir::zamalang::python::populateDialectHLFHESubmodule(
|
||||
|
||||
mlir_type_subclass(m, "EncryptedIntegerType",
|
||||
hlfheTypeIsAnEncryptedIntegerType)
|
||||
.def_classmethod(
|
||||
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
|
||||
return cls(hlfheEncryptedIntegerTypeGet(ctx, width));
|
||||
});
|
||||
.def_classmethod("get", [](pybind11::object cls, MlirContext ctx,
|
||||
unsigned width) {
|
||||
// We want the user to receive a python exception for not being able to
|
||||
// create the eint
|
||||
auto emitException = []() -> mlir::InFlightDiagnostic {
|
||||
throw std::invalid_argument("can't create eint with the given width");
|
||||
};
|
||||
return cls(
|
||||
hlfheEncryptedIntegerTypeGetChecked(ctx, width, emitException));
|
||||
});
|
||||
}
|
||||
@@ -22,6 +22,8 @@ bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) {
|
||||
return unwrap(type).isa<EncryptedIntegerType>();
|
||||
}
|
||||
|
||||
MlirType hlfheEncryptedIntegerTypeGet(MlirContext ctx, unsigned width) {
|
||||
return wrap(EncryptedIntegerType::get(unwrap(ctx), width));
|
||||
MlirType hlfheEncryptedIntegerTypeGetChecked(
|
||||
MlirContext ctx, unsigned width,
|
||||
mlir::function_ref<mlir::InFlightDiagnostic()> emitError) {
|
||||
return wrap(EncryptedIntegerType::getChecked(emitError, unwrap(ctx), width));
|
||||
}
|
||||
|
||||
@@ -4,16 +4,17 @@ from zamalang import register_dialects
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
|
||||
def test_eint():
|
||||
@pytest.mark.parametrize("width", list(range(1, 8)))
|
||||
def test_eint(width):
|
||||
ctx = Context()
|
||||
register_dialects(ctx)
|
||||
eint = hlfhe.EncryptedIntegerType.get(ctx, 6)
|
||||
assert eint.__str__() == "!HLFHE.eint<6>"
|
||||
eint = hlfhe.EncryptedIntegerType.get(ctx, width)
|
||||
assert eint.__str__() == f"!HLFHE.eint<{width}>"
|
||||
|
||||
|
||||
# FIXME: need to handle error on call to hlfhe.EncryptedIntegerType.get and throw an exception to python
|
||||
# def test_invalid_eint():
|
||||
# ctx = Context()
|
||||
# register_dialects(ctx)
|
||||
# with pytest.raises(RuntimeError, match=r"mlir parsing failed"):
|
||||
# eint = hlfhe.EncryptedIntegerType.get(ctx, 16)
|
||||
@pytest.mark.parametrize("width", [0, 8, 10, 12])
|
||||
def test_invalid_eint(width):
|
||||
ctx = Context()
|
||||
register_dialects(ctx)
|
||||
with pytest.raises(ValueError, match=r"can't create eint with the given width"):
|
||||
eint = hlfhe.EncryptedIntegerType.get(ctx, width)
|
||||
|
||||
Reference in New Issue
Block a user