fix p-error in mul_eint_int test

This commit is contained in:
aPere3
2023-01-16 16:04:15 +01:00
committed by Quentin Bourgerie
parent e95c53f2ff
commit a200ce43bd
4 changed files with 9 additions and 5 deletions

View File

@@ -208,8 +208,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), adaptor.a(),
op.b().getType().cast<FHE::FheIntegerInterface>().getWidth(),
rewriter);
op.b().getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);
// Write the new op
rewriter.replaceOpWithNewOp<TFHE::SubGLWEIntOp>(
@@ -331,7 +330,8 @@ struct ApplyLookupTableEintOpPattern
// Insert keyswitch
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
op.getLoc(), getTypeConverter()->convertType(adaptor.a().getType()), input, -1, -1);
op.getLoc(), getTypeConverter()->convertType(adaptor.a().getType()),
input, -1, -1);
// Insert bootstrap
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(

View File

@@ -13,8 +13,9 @@ namespace utils {
/// Returns `true` if the given value is a scalar or tensor argument of
/// a function, for which a MANP of 1 can be assumed.
bool isEncryptedValue(mlir::Value value) {
return (value.getType().isa<mlir::concretelang::FHE::FheIntegerInterface>() ||
value.getType().isa<mlir::concretelang::FHE::EncryptedBooleanType>() ||
return (
value.getType().isa<mlir::concretelang::FHE::FheIntegerInterface>() ||
value.getType().isa<mlir::concretelang::FHE::EncryptedBooleanType>() ||
(value.getType().isa<mlir::TensorType>() &&
value.getType()
.cast<mlir::TensorType>()

View File

@@ -93,6 +93,7 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
{
/* .precision = */ width,
/* .crt = */ std::vector<int64_t>(),
/* .sign = */ false,
},
}),
/*.shape = */

View File

@@ -700,6 +700,8 @@ def main():
print(" %1 = \"FHE.mul_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
print(" return %1: !FHE.esint<{0}>".format(p))
print(" }")
if p <= 57:
print(f"p-error: {P_ERROR}")
print("tests:")
print(" - inputs:")
print(" - scalar: {0}".format(0))