feat: expose !FHE.esint<p> in Python bindings

This commit is contained in:
Umut
2023-03-03 10:56:39 +01:00
committed by Quentin Bourgerie
parent e949e7e2a7
commit 72005be78d
2 changed files with 38 additions and 0 deletions

View File

@@ -34,4 +34,17 @@ void mlir::concretelang::python::populateDialectFHESubmodule(
}
return cls(typeOrError.type);
});
mlir_type_subclass(m, "EncryptedSignedIntegerType",
fheTypeIsAnEncryptedSignedIntegerType)
.def_classmethod(
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
MlirTypeOrError typeOrError =
fheEncryptedSignedIntegerTypeGetChecked(ctx, width);
if (typeOrError.isError) {
throw std::invalid_argument(
"can't create esint with the given width");
}
return cls(typeOrError.type);
});
}

View File

@@ -12,6 +12,14 @@ def test_eint(width):
assert eint.__str__() == f"!FHE.eint<{width}>"
@pytest.mark.parametrize("width", list(range(1, 8)))
def test_esint(width):
ctx = Context()
register_dialects(ctx)
eint = fhe.EncryptedSignedIntegerType.get(ctx, width)
assert eint.__str__() == f"!FHE.esint<{width}>"
@pytest.mark.parametrize("shape", [(1,), (2,), (1, 1), (1, 2), (2, 1), (3, 3, 3)])
def test_eint_tensor(shape):
with Context() as ctx, Location.unknown(context=ctx):
@@ -21,9 +29,26 @@ def test_eint_tensor(shape):
assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.eint<{3}>>"
@pytest.mark.parametrize("shape", [(1,), (2,), (1, 1), (1, 2), (2, 1), (3, 3, 3)])
def test_esint_tensor(shape):
with Context() as ctx, Location.unknown(context=ctx):
register_dialects(ctx)
eint = fhe.EncryptedSignedIntegerType.get(ctx, 3)
tensor = RankedTensorType.get(shape, eint)
assert tensor.__str__() == f"tensor<{'x'.join(map(str, shape))}x!FHE.esint<{3}>>"
@pytest.mark.parametrize("width", [0])
def test_invalid_eint(width):
ctx = Context()
register_dialects(ctx)
with pytest.raises(ValueError, match=r"can't create eint with the given width"):
eint = fhe.EncryptedIntegerType.get(ctx, width)
@pytest.mark.parametrize("width", [0])
def test_invalid_esint(width):
ctx = Context()
register_dialects(ctx)
with pytest.raises(ValueError, match=r"can't create esint with the given width"):
eint = fhe.EncryptedSignedIntegerType.get(ctx, width)