mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
fix(compiler): Fixing MidLFHE.h_add verifier and adding tests
This commit is contained in:
@@ -101,6 +101,10 @@ bool verifyHAddResultPadding(::mlir::OpState &op, GLWECipherTextType &inA,
|
||||
if (inA.getPaddingBits() == -1 && inB.getPaddingBits() == -1) {
|
||||
return true;
|
||||
}
|
||||
if (inA.getPaddingBits() != inB.getPaddingBits()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "padding");
|
||||
return false;
|
||||
}
|
||||
return verifyAddResultPadding(op, inA, out);
|
||||
}
|
||||
|
||||
@@ -137,25 +141,25 @@ bool verifyHAddSameGLWEParameter(::mlir::OpState &op, GLWECipherTextType &inA,
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "dimension");
|
||||
return false;
|
||||
}
|
||||
if (inA.getPolynomialSize() != inB.getPolynomialSize() &&
|
||||
if (inA.getPolynomialSize() != inB.getPolynomialSize() ||
|
||||
inA.getPolynomialSize() != out.getPolynomialSize()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "polynomialSize");
|
||||
return false;
|
||||
}
|
||||
if (inA.getBits() != inB.getBits() && inA.getBits() != out.getBits()) {
|
||||
if (inA.getBits() != inB.getBits() || inA.getBits() != out.getBits()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "bits");
|
||||
return false;
|
||||
}
|
||||
if (inA.getP() != inB.getP() && inA.getP() != out.getP()) {
|
||||
if (inA.getP() != inB.getP() || inA.getP() != out.getP()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return false;
|
||||
}
|
||||
if (inA.getPhantomBits() != inB.getPhantomBits() &&
|
||||
if (inA.getPhantomBits() != inB.getPhantomBits() ||
|
||||
inA.getPhantomBits() != out.getPhantomBits()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "phantomBits");
|
||||
return false;
|
||||
}
|
||||
if (inA.getScalingFactor() && inB.getScalingFactor() &&
|
||||
if (inA.getScalingFactor() != inB.getScalingFactor() ||
|
||||
inA.getScalingFactor() != out.getScalingFactor()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "scalingFactor");
|
||||
return false;
|
||||
@@ -167,15 +171,15 @@ bool verifyHAddSameGLWEParameter(::mlir::OpState &op, GLWECipherTextType &inA,
|
||||
GLWECipherTextType inA = op.a().getType().cast<GLWECipherTextType>();
|
||||
GLWECipherTextType inB = op.b().getType().cast<GLWECipherTextType>();
|
||||
GLWECipherTextType out = op.getResult().getType().cast<GLWECipherTextType>();
|
||||
if (!verifyHAddSameGLWEParameter(op, inA, inB, out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyHAddResultPadding(op, inA, inB, out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyHAddResultLog2StdDev(op, inA, inB, out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyHAddSameGLWEParameter(op, inA, inB, out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: should have the same GLWE dimension parameter
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{2048,12,64}{0,7,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{2048,12,64}{0,7,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'MidLFHE.h_add' op has unexpected log2StdDev parameter of its GLWE result, expected:-22
|
||||
func @add_plain(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,57,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{0,7,0,57,-23}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,57,-29}> {
|
||||
%1 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,57,-25}>, !MidLFHE.glwe<{1024,12,64}{0,7,0,57,-23}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,57,-29}>)
|
||||
return %1: !MidLFHE.glwe<{1024,12,64}{0,7,0,57,-29}>
|
||||
}
|
||||
7
compiler/tests/Dialect/MidLFHE/op_h_add_err_p.mlir
Normal file
7
compiler/tests/Dialect/MidLFHE/op_h_add_err_p.mlir
Normal file
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: should have the same GLWE p parameter
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{0,8,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{1024,12,64}{0,8,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: should have the same GLWE padding parameter
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{1,7,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{1024,12,64}{1,7,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: the result should have one less padding bit than the input
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{2,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{2,7,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{2,7,0,50,-25}>, !MidLFHE.glwe<{1024,12,64}{2,7,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: the result shoud have 0 paddingBits has input has 0 paddingBits
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{1,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{1,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{1,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: should have the same GLWE phantomBits parameter
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{0,7,1,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{1024,12,64}{0,7,1,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: should have the same GLWE polynomialSize parameter
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,10,64}{0,7,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}> {
|
||||
%0 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{1024,10,64}{0,7,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>)
|
||||
return %0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-24}>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: not zamacompiler %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: should have the same GLWE scalingFactor parameter
|
||||
func @add(%arg0: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, %arg1: !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>) -> !MidLFHE.glwe<{1024,12,64}{0,7,0,49,-25}> {
|
||||
%1 = "MidLFHE.h_add"(%arg0, %arg1): (!MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>, !MidLFHE.glwe<{1024,12,64}{0,7,0,50,-25}>) -> (!MidLFHE.glwe<{1024,12,64}{0,7,0,49,-25}>)
|
||||
return %1: !MidLFHE.glwe<{1024,12,64}{0,7,0,49,-25}>
|
||||
}
|
||||
Reference in New Issue
Block a user