mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 08:01:20 -05:00
feat(compiler): Add verifiers to TFHE bootstrap operations
This commit is contained in:
@@ -382,6 +382,8 @@ def TFHE_BatchedBootstrapGLWEOp : TFHE_Op<"batched_bootstrap_glwe", [Pure]> {
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TFHE_BatchedMappedBootstrapGLWEOp : TFHE_Op<"batched_mapped_bootstrap_glwe", [Pure]> {
|
||||
@@ -394,6 +396,8 @@ def TFHE_BatchedMappedBootstrapGLWEOp : TFHE_Op<"batched_mapped_bootstrap_glwe",
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface]> {
|
||||
@@ -408,6 +412,8 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface
|
||||
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
struct BatchingVariant {
|
||||
static const unsigned CIPHERTEXT_BATCHING = 0;
|
||||
|
||||
@@ -14,6 +14,8 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
static const int64_t kUndefined = -1;
|
||||
|
||||
void emitOpErrorForKeyMismatch(mlir::OpState &op) {
|
||||
op.emitOpError() << "should have the same GLWE Secret Key";
|
||||
}
|
||||
@@ -115,6 +117,73 @@ mlir::LogicalResult MulGLWEIntOp::verify() {
|
||||
*this);
|
||||
}
|
||||
|
||||
template <typename BootstrapOpT>
|
||||
mlir::LogicalResult verifyBootstrapSingleLUTConstraints(BootstrapOpT &op) {
|
||||
GLWEBootstrapKeyAttr keyAttr = op.getKeyAttr();
|
||||
|
||||
if (keyAttr) {
|
||||
mlir::RankedTensorType rtt =
|
||||
op.getLookupTable().getType().template cast<mlir::RankedTensorType>();
|
||||
|
||||
assert(rtt.getShape().size() == 1);
|
||||
|
||||
// Do not fail on unparametrized ops
|
||||
if (keyAttr.getPolySize() == kUndefined)
|
||||
return mlir::success();
|
||||
|
||||
if (rtt.getShape()[0] != keyAttr.getPolySize()) {
|
||||
op.emitError("Size of the lookup table of ")
|
||||
<< rtt.getShape()[0] << " does not match the size of the polynom of "
|
||||
<< keyAttr.getPolySize();
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult BootstrapGLWEOp::verify() {
|
||||
return verifyBootstrapSingleLUTConstraints(*this);
|
||||
}
|
||||
|
||||
mlir::LogicalResult BatchedBootstrapGLWEOp::verify() {
|
||||
return verifyBootstrapSingleLUTConstraints(*this);
|
||||
}
|
||||
|
||||
mlir::LogicalResult BatchedMappedBootstrapGLWEOp::verify() {
|
||||
GLWEBootstrapKeyAttr keyAttr = this->getKeyAttr();
|
||||
|
||||
if (keyAttr) {
|
||||
mlir::RankedTensorType lutRtt =
|
||||
this->getLookupTable().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::RankedTensorType ciphertextsRtt =
|
||||
this->getCiphertexts().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
assert(lutRtt.getShape().size() == 2);
|
||||
|
||||
if (lutRtt.getShape()[1] != keyAttr.getPolySize()) {
|
||||
this->emitError("Size of the lookup table of ")
|
||||
<< lutRtt.getShape()[1]
|
||||
<< " does not match the size of the polynom of "
|
||||
<< keyAttr.getPolySize();
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (lutRtt.getShape()[0] != ciphertextsRtt.getShape()[0]) {
|
||||
this->emitError("Number of lookup tables of ")
|
||||
<< lutRtt.getShape()[0] << " does not match number of ciphertexts of "
|
||||
<< ciphertextsRtt.getShape()[0];
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace TFHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user