From eabd8b959dded91958c7ab077e77d17fdfc063a1 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 4 Nov 2022 11:54:56 +0100 Subject: [PATCH] fix(CAPI): remove Cpp code from CAPI this required to have a CAPI that when asked for types, returns a structure that can report if an error was faced during type creation. This is required since a failure at that stage in the compiler would lead to a segfault in the python bindings for example, and we want to be able to handle this scenario gracefully. --- compiler/include/concretelang-c/Dialect/FHE.h | 14 +++++++---- compiler/lib/Bindings/Python/FHEModule.cpp | 11 ++++----- compiler/lib/CAPI/Dialect/FHE/FHE.cpp | 24 +++++++++++++++---- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/compiler/include/concretelang-c/Dialect/FHE.h b/compiler/include/concretelang-c/Dialect/FHE.h index 3500ba8d8..5d2d79770 100644 --- a/compiler/include/concretelang-c/Dialect/FHE.h +++ b/compiler/include/concretelang-c/Dialect/FHE.h @@ -8,19 +8,23 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/Support/LLVM.h" #ifdef __cplusplus extern "C" { #endif +/// \brief structure to return an MlirType or report that there was an error +/// during type creation. +typedef struct { + MlirType type; + bool isError; +} MlirTypeOrError; + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe); /// Creates an encrypted integer type of `width` bits -MLIR_CAPI_EXPORTED MlirType fheEncryptedIntegerTypeGetChecked( - MlirContext context, unsigned width, - mlir::function_ref emitError); +MLIR_CAPI_EXPORTED MlirTypeOrError +fheEncryptedIntegerTypeGetChecked(MlirContext context, unsigned width); /// If the type is an EncryptedInteger MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedIntegerType(MlirType); diff --git a/compiler/lib/Bindings/Python/FHEModule.cpp b/compiler/lib/Bindings/Python/FHEModule.cpp index 11655b02f..6729d0987 100644 --- a/compiler/lib/Bindings/Python/FHEModule.cpp +++ b/compiler/lib/Bindings/Python/FHEModule.cpp @@ -28,12 +28,11 @@ void mlir::concretelang::python::populateDialectFHESubmodule( mlir_type_subclass(m, "EncryptedIntegerType", fheTypeIsAnEncryptedIntegerType) .def_classmethod("get", [](pybind11::object cls, MlirContext ctx, unsigned width) { - // We want the user to receive a python exception for not being able to - // create the eint - auto emitException = []() -> mlir::InFlightDiagnostic { + MlirTypeOrError typeOrError = + fheEncryptedIntegerTypeGetChecked(ctx, width); + if (typeOrError.isError) { throw std::invalid_argument("can't create eint with the given width"); - }; - return cls( - fheEncryptedIntegerTypeGetChecked(ctx, width, emitException)); + } + return cls(typeOrError.type); }); } diff --git a/compiler/lib/CAPI/Dialect/FHE/FHE.cpp b/compiler/lib/CAPI/Dialect/FHE/FHE.cpp index 70685ccd3..1ffe0fbfe 100644 --- a/compiler/lib/CAPI/Dialect/FHE/FHE.cpp +++ b/compiler/lib/CAPI/Dialect/FHE/FHE.cpp @@ -10,6 +10,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Support.h" +#include "mlir/IR/StorageUniquerSupport.h" using namespace mlir::concretelang::FHE; @@ -27,8 +28,23 @@ bool fheTypeIsAnEncryptedIntegerType(MlirType type) { return unwrap(type).isa(); } -MlirType fheEncryptedIntegerTypeGetChecked( - MlirContext ctx, unsigned width, - mlir::function_ref emitError) { - return wrap(EncryptedIntegerType::getChecked(emitError, unwrap(ctx), width)); +MlirTypeOrError fheEncryptedIntegerTypeGetChecked(MlirContext ctx, + unsigned width) { + MlirTypeOrError type = {{NULL}, false}; + auto catchError = [&]() -> mlir::InFlightDiagnostic { + type.isError = true; + mlir::DiagnosticEngine &engine = unwrap(ctx)->getDiagEngine(); + // The goal here is to make getChecked working, but we don't want the CAPI + // to stop execution due to an error, and leave the error handling logic to + // the user of the CAPI + return engine.emit(mlir::UnknownLoc::get(unwrap(ctx)), + mlir::DiagnosticSeverity::Warning); + }; + EncryptedIntegerType eint = + EncryptedIntegerType::getChecked(catchError, unwrap(ctx), width); + if (type.isError) { + return type; + } + type.type = wrap(eint); + return type; }