From 70fb5fcd8ee5692dcb82c0dbd9c5b8edcb79cf09 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 16 Aug 2021 17:08:18 +0200 Subject: [PATCH] fix(compiler/midlfhe): Change constraint on operators with integers (just too large integers are forbidden) --- compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp | 13 +++++++------ .../Dialect/MidLFHE/op_add_glwe_int.invalid.mlir | 6 +++--- .../MidLFHE/op_apply_lookup_table.invalid.mlir | 4 ++-- .../Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir | 6 +++--- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp index 4682f4142..52e1e3e04 100644 --- a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp +++ b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp @@ -36,8 +36,10 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op, } // verify consistency of width of inputs - if (a.getP() + 1 != b.getWidth()) { - op.emitOpError() << "should have the width of `b` equals to 'p'+1"; + if (b.getWidth() > a.getP() + 1) { + op.emitOpError() + << "should have the width of `b` equals or less than 'p'+1: " + << b.getWidth() << " <= " << a.getP() << "+ 1"; return mlir::failure(); } return mlir::success(); @@ -123,10 +125,9 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) { } // Check the witdh of the encrypted integer and the integer of the tabulated // lambda are equals - if (result.getP() != l_cst.getElementType().cast().getWidth()) { - op.emitOpError() - << "should have equals width beetwen the encrypted integer result and " - "integers of the `tabulated_lambda` argument"; + if (result.getP() < l_cst.getElementType().cast().getWidth()) { + op.emitOpError() << "should have the width of the constants less or equals " + "than the precision of the encrypted integer"; return mlir::failure(); } return mlir::success(); diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir index 1f073f22f..4ce221611 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir @@ -32,8 +32,8 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 // integer width doesn't match GLWE parameter func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { - %0 = constant 1 : i6 - // expected-error @+1 {{'MidLFHE.add_glwe_int' op should have the width of `b` equals to 'p'+1}} - %1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i6) -> (!MidLFHE.glwe<{1024,12,64}{7}>) + %0 = constant 1 : i9 + // expected-error @+1 {{'MidLFHE.add_glwe_int' op should have the width of `b` equals or less than 'p'+1}} + %1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i9) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> } diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir index 05b664ac6..c1f0759b3 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir @@ -11,7 +11,7 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4x // Bad dimension of integer in the lookup table func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi3>) -> !MidLFHE.glwe<{512,10,64}{2}> { - // expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have equals width beetwen the encrypted integer result and integers of the `tabulated_lambda` argument}} + // expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have the width of the constants less or equals than the precision of the encrypted integer}} %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi3>) -> (!MidLFHE.glwe<{512,10,64}{2}>) return %1: !MidLFHE.glwe<{512,10,64}{2}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir index 0948db10f..f3fe65d0e 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir @@ -32,8 +32,8 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 // integer width doesn't match GLWE parameter func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { - %0 = constant 1 : i6 - // expected-error @+1 {{'MidLFHE.mul_glwe_int' op should have the width of `b` equals to 'p'+1}} - %1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i6) -> (!MidLFHE.glwe<{1024,12,64}{7}>) + %0 = constant 1 : i9 + // expected-error @+1 {{'MidLFHE.mul_glwe_int' op should have the width of `b` equals or less than 'p'+1}} + %1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i9) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> }