diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/FHEModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/FHEModule.cpp index a83e51866..b0640e276 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/FHEModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/FHEModule.cpp @@ -34,4 +34,17 @@ void mlir::concretelang::python::populateDialectFHESubmodule( } return cls(typeOrError.type); }); + + mlir_type_subclass(m, "EncryptedSignedIntegerType", + fheTypeIsAnEncryptedSignedIntegerType) + .def_classmethod( + "get", [](pybind11::object cls, MlirContext ctx, unsigned width) { + MlirTypeOrError typeOrError = + fheEncryptedSignedIntegerTypeGetChecked(ctx, width); + if (typeOrError.isError) { + throw std::invalid_argument( + "can't create esint with the given width"); + } + return cls(typeOrError.type); + }); } diff --git a/compilers/concrete-compiler/compiler/tests/python/test_fhe_dialect.py b/compilers/concrete-compiler/compiler/tests/python/test_fhe_dialect.py index a98f4843e..14bc63ba2 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_fhe_dialect.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_fhe_dialect.py @@ -12,6 +12,14 @@ def test_eint(width): assert eint.__str__() == f"!FHE.eint<{width}>" +@pytest.mark.parametrize("width", list(range(1, 8))) +def test_esint(width): + ctx = Context() + register_dialects(ctx) + eint = fhe.EncryptedSignedIntegerType.get(ctx, width) + assert eint.__str__() == f"!FHE.esint<{width}>" + + @pytest.mark.parametrize("shape", [(1,), (2,), (1, 1), (1, 2), (2, 1), (3, 3, 3)]) def test_eint_tensor(shape): with Context() as ctx, Location.unknown(context=ctx): @@ -21,9 +29,26 @@ def test_eint_tensor(shape): assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>" +@pytest.mark.parametrize("shape", [(1,), (2,), (1, 1), (1, 2), (2, 1), (3, 3, 3)]) +def test_esint_tensor(shape): + with Context() as ctx, Location.unknown(context=ctx): + register_dialects(ctx) + eint = fhe.EncryptedSignedIntegerType.get(ctx, 3) + tensor = RankedTensorType.get(shape, eint) + assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.esint<{3}>>" + + @pytest.mark.parametrize("width", [0]) def test_invalid_eint(width): ctx = Context() register_dialects(ctx) with pytest.raises(ValueError, match=r"can't create eint with the given width"): eint = fhe.EncryptedIntegerType.get(ctx, width) + + +@pytest.mark.parametrize("width", [0]) +def test_invalid_esint(width): + ctx = Context() + register_dialects(ctx) + with pytest.raises(ValueError, match=r"can't create esint with the given width"): + eint = fhe.EncryptedSignedIntegerType.get(ctx, width)