From ea0723973298c40e00ac755073b91d8485259fc0 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 29 Jan 2024 11:25:21 +0100 Subject: [PATCH] feat(compiler): Add verifiers to TFHE bootstrap operations --- .../concretelang/Dialect/TFHE/IR/TFHEOps.td | 6 ++ .../compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp | 69 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index 1172b0292..3901069c8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -382,6 +382,8 @@ def TFHE_BatchedBootstrapGLWEOp : TFHE_Op<"batched_bootstrap_glwe", [Pure]> { ); let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); + + let hasVerifier = 1; } def TFHE_BatchedMappedBootstrapGLWEOp : TFHE_Op<"batched_mapped_bootstrap_glwe", [Pure]> { @@ -394,6 +396,8 @@ def TFHE_BatchedMappedBootstrapGLWEOp : TFHE_Op<"batched_mapped_bootstrap_glwe", ); let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result); + + let hasVerifier = 1; } def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface]> { @@ -408,6 +412,8 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe", [Pure, BatchableOpInterface let results = (outs TFHE_GLWECipherTextType : $result); + let hasVerifier = 1; + let extraClassDeclaration = [{ struct BatchingVariant { static const unsigned CIPHERTEXT_BATCHING = 0; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp index 8d5e010ce..d1d25ed18 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp @@ -14,6 +14,8 @@ namespace mlir { namespace concretelang { namespace TFHE { +static const int64_t kUndefined = -1; + void emitOpErrorForKeyMismatch(mlir::OpState &op) { op.emitOpError() << "should have the same GLWE Secret Key"; } @@ -115,6 +117,73 @@ mlir::LogicalResult MulGLWEIntOp::verify() { *this); } +template +mlir::LogicalResult verifyBootstrapSingleLUTConstraints(BootstrapOpT &op) { + GLWEBootstrapKeyAttr keyAttr = op.getKeyAttr(); + + if (keyAttr) { + mlir::RankedTensorType rtt = + op.getLookupTable().getType().template cast(); + + assert(rtt.getShape().size() == 1); + + // Do not fail on unparametrized ops + if (keyAttr.getPolySize() == kUndefined) + return mlir::success(); + + if (rtt.getShape()[0] != keyAttr.getPolySize()) { + op.emitError("Size of the lookup table of ") + << rtt.getShape()[0] << " does not match the size of the polynom of " + << keyAttr.getPolySize(); + + return mlir::failure(); + } + } + + return mlir::success(); +} + +mlir::LogicalResult BootstrapGLWEOp::verify() { + return verifyBootstrapSingleLUTConstraints(*this); +} + +mlir::LogicalResult BatchedBootstrapGLWEOp::verify() { + return verifyBootstrapSingleLUTConstraints(*this); +} + +mlir::LogicalResult BatchedMappedBootstrapGLWEOp::verify() { + GLWEBootstrapKeyAttr keyAttr = this->getKeyAttr(); + + if (keyAttr) { + mlir::RankedTensorType lutRtt = + this->getLookupTable().getType().cast(); + + mlir::RankedTensorType ciphertextsRtt = + this->getCiphertexts().getType().cast(); + + assert(lutRtt.getShape().size() == 2); + + if (lutRtt.getShape()[1] != keyAttr.getPolySize()) { + this->emitError("Size of the lookup table of ") + << lutRtt.getShape()[1] + << " does not match the size of the polynom of " + << keyAttr.getPolySize(); + + return mlir::failure(); + } + + if (lutRtt.getShape()[0] != ciphertextsRtt.getShape()[0]) { + this->emitError("Number of lookup tables of ") + << lutRtt.getShape()[0] << " does not match number of ciphertexts of " + << ciphertextsRtt.getShape()[0]; + + return mlir::failure(); + } + } + + return mlir::success(); +} + } // namespace TFHE } // namespace concretelang } // namespace mlir