enhance(compiler): Add verifier of HLFHE operators

This commit is contained in:
Quentin Bourgerie
2021-07-15 15:56:24 +02:00
parent cb635f8a55
commit 143a2384fd
9 changed files with 138 additions and 8 deletions

View File

@@ -27,6 +27,10 @@ def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
build($_builder, $_state, a.getType(), a, b);
}]>
];
let verifier = [{
return ::mlir::zamalang::HLFHE::verifyAddEintIntOp(*this);
}];
}
def AddEintOp : HLFHE_Op<"add_eint"> {
@@ -38,6 +42,10 @@ def AddEintOp : HLFHE_Op<"add_eint"> {
build($_builder, $_state, a.getType(), a, b);
}]>
];
let verifier = [{
return ::mlir::zamalang::HLFHE::verifyAddEintOp(*this);
}];
}
def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
@@ -49,6 +57,10 @@ def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
build($_builder, $_state, a.getType(), a, b);
}]>
];
let verifier = [{
return ::mlir::zamalang::HLFHE::verifyMulEintIntOp(*this);
}];
}
def ApplyLookupTable : HLFHE_Op<"apply_lookup_table"> {

View File

@@ -8,9 +8,81 @@ namespace mlir {
namespace zamalang {
namespace HLFHE {
using mlir::zamalang::HLFHE::AddEintOp;
using mlir::zamalang::HLFHE::ApplyLookupTable;
using mlir::zamalang::HLFHE::EncryptedIntegerType;
bool verifyEncryptedIntegerInputAndResultConsistency(
::mlir::OpState &op, EncryptedIntegerType &input,
EncryptedIntegerType &result) {
if (input.getWidth() != result.getWidth()) {
op.emitOpError(
" should have the width of encrypted inputs and result equals");
return false;
}
return true;
}
bool verifyEncryptedIntegerAndIntegerInputsConsistency(::mlir::OpState &op,
EncryptedIntegerType &a,
IntegerType &b) {
if (a.getWidth() + 1 != b.getWidth()) {
op.emitOpError(" should have the width of plain input equals to width of "
"encrypted input + 1");
return false;
}
return true;
}
bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
EncryptedIntegerType &a,
EncryptedIntegerType &b) {
if (a.getWidth() != b.getWidth()) {
op.emitOpError(" should have the width of encrypted inputs equals");
return false;
}
return true;
}
::mlir::LogicalResult verifyAddEintIntOp(AddEintIntOp &op) {
auto a = op.a().getType().cast<EncryptedIntegerType>();
auto b = op.b().getType().cast<IntegerType>();
auto out = op.getResult().getType().cast<EncryptedIntegerType>();
if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult verifyAddEintOp(AddEintOp &op) {
auto a = op.a().getType().cast<EncryptedIntegerType>();
auto b = op.b().getType().cast<EncryptedIntegerType>();
auto out = op.getResult().getType().cast<EncryptedIntegerType>();
if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerInputsConsistency(op, a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult verifyMulEintIntOp(MulEintIntOp &op) {
auto a = op.a().getType().cast<EncryptedIntegerType>();
auto b = op.b().getType().cast<IntegerType>();
auto out = op.getResult().getType().cast<EncryptedIntegerType>();
if (!verifyEncryptedIntegerInputAndResultConsistency(op, a, out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(op, a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) {
auto ct = op.ct().getType().cast<EncryptedIntegerType>();
auto l_cst = op.l_cst().getType().cast<MemRefType>();

View File

@@ -0,0 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<3>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -0,0 +1,7 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}

View File

@@ -0,0 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i4
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -0,0 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%0 = constant 1 : i2
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}

View File

@@ -0,0 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i4
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}

View File

@@ -0,0 +1,8 @@
// RUN: not zamacompiler %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
%0 = constant 1 : i2
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>)
return %1: !HLFHE.eint<3>
}

View File

@@ -2,23 +2,23 @@
// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.add_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2>
%0 = constant 1 : i32
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>)
%0 = constant 1 : i3
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i32) -> !HLFHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "HLFHE.mul_eint_int"(%arg0, %[[V1]]) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK-NEXT: return %[[V2]] : !HLFHE.eint<2>
%0 = constant 1 : i32
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i32) -> (!HLFHE.eint<2>)
%0 = constant 1 : i3
%1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}