feat(compiler): Add verifiers to TFHE bootstrap operations

This commit is contained in:
Andi Drebes
2024-01-29 11:25:21 +01:00
parent a133407035
commit ea07239732
2 changed files with 75 additions and 0 deletions

View File

@@ -382,6 +382,8 @@ def TFHE_BatchedBootstrapGLWEOp : TFHE_Op<"batched_bootstrap_glwe", [Pure]> {
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
let hasVerifier = 1;
}
def TFHE_BatchedMappedBootstrapGLWEOp : TFHE_Op<"batched_mapped_bootstrap_glwe", [Pure]> {
@@ -394,6 +396,8 @@ def TFHE_BatchedMappedBootstrapGLWEOp : TFHE_Op<"batched_mapped_bootstrap_glwe",
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
let hasVerifier = 1;
}
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface]> {
@@ -408,6 +412,8 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface
let results = (outs TFHE_GLWECipherTextType : $result);
let hasVerifier = 1;
let extraClassDeclaration = [{
struct BatchingVariant {
static const unsigned CIPHERTEXT_BATCHING = 0;

View File

@@ -14,6 +14,8 @@ namespace mlir {
namespace concretelang {
namespace TFHE {
static const int64_t kUndefined = -1;
void emitOpErrorForKeyMismatch(mlir::OpState &op) {
op.emitOpError() << "should have the same GLWE Secret Key";
}
@@ -115,6 +117,73 @@ mlir::LogicalResult MulGLWEIntOp::verify() {
*this);
}
template <typename BootstrapOpT>
mlir::LogicalResult verifyBootstrapSingleLUTConstraints(BootstrapOpT &op) {
GLWEBootstrapKeyAttr keyAttr = op.getKeyAttr();
if (keyAttr) {
mlir::RankedTensorType rtt =
op.getLookupTable().getType().template cast<mlir::RankedTensorType>();
assert(rtt.getShape().size() == 1);
// Do not fail on unparametrized ops
if (keyAttr.getPolySize() == kUndefined)
return mlir::success();
if (rtt.getShape()[0] != keyAttr.getPolySize()) {
op.emitError("Size of the lookup table of ")
<< rtt.getShape()[0] << " does not match the size of the polynom of "
<< keyAttr.getPolySize();
return mlir::failure();
}
}
return mlir::success();
}
mlir::LogicalResult BootstrapGLWEOp::verify() {
return verifyBootstrapSingleLUTConstraints(*this);
}
mlir::LogicalResult BatchedBootstrapGLWEOp::verify() {
return verifyBootstrapSingleLUTConstraints(*this);
}
mlir::LogicalResult BatchedMappedBootstrapGLWEOp::verify() {
GLWEBootstrapKeyAttr keyAttr = this->getKeyAttr();
if (keyAttr) {
mlir::RankedTensorType lutRtt =
this->getLookupTable().getType().cast<mlir::RankedTensorType>();
mlir::RankedTensorType ciphertextsRtt =
this->getCiphertexts().getType().cast<mlir::RankedTensorType>();
assert(lutRtt.getShape().size() == 2);
if (lutRtt.getShape()[1] != keyAttr.getPolySize()) {
this->emitError("Size of the lookup table of ")
<< lutRtt.getShape()[1]
<< " does not match the size of the polynom of "
<< keyAttr.getPolySize();
return mlir::failure();
}
if (lutRtt.getShape()[0] != ciphertextsRtt.getShape()[0]) {
this->emitError("Number of lookup tables of ")
<< lutRtt.getShape()[0] << " does not match number of ciphertexts of "
<< ciphertextsRtt.getShape()[0];
return mlir::failure();
}
}
return mlir::success();
}
} // namespace TFHE
} // namespace concretelang
} // namespace mlir