feat(compiler): add batched operations for all levelled ops.

This commit is contained in:
Antoniu Pop
2023-03-24 10:44:38 +00:00
committed by Antoniu Pop
parent 799e64e8ab
commit 3f230957cb
7 changed files with 517 additions and 6 deletions

View File

@@ -26,6 +26,7 @@ def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>;
def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
def Concrete_BatchPlaintextTensor : 1DTensorOf<[I64]>;
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
@@ -33,6 +34,7 @@ def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_BatchPlaintextBuffer : MemRefRankOf<[I64], [1]>;
class Concrete_Op<string mnemonic, list<Trait> traits = []> :
Op<Concrete_Dialect, mnemonic, traits>;
@@ -58,6 +60,26 @@ def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> {
);
}
def Concrete_BatchedAddLweTensorOp : Concrete_Op<"batched_add_lwe_tensor", [Pure]> {
let summary = "Batched version of AddLweTensorOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweTensor:$lhs,
Concrete_BatchLweTensor:$rhs
);
let results = (outs Concrete_BatchLweTensor:$result);
}
def Concrete_BatchedAddLweBufferOp : Concrete_Op<"batched_add_lwe_buffer"> {
let summary = "Batched version of AddLweBufferOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweBuffer:$result,
Concrete_BatchLweBuffer:$lhs,
Concrete_BatchLweBuffer:$rhs
);
}
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [Pure]> {
let summary = "Returns the sum of a clear integer and an lwe ciphertext";
@@ -75,6 +97,40 @@ def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> {
);
}
def Concrete_BatchedAddPlaintextLweTensorOp : Concrete_Op<"batched_add_plaintext_lwe_tensor", [Pure]> {
let summary = "Batched version of AddPlaintextLweTensorOp, which performs the same operation on multiple elements";
let arguments = (ins Concrete_BatchLweTensor:$lhs, Concrete_BatchPlaintextTensor:$rhs);
let results = (outs Concrete_BatchLweTensor:$result);
}
def Concrete_BatchedAddPlaintextLweBufferOp : Concrete_Op<"batched_add_plaintext_lwe_buffer"> {
let summary = "Batched version of AddPlaintextLweBufferOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweBuffer:$result,
Concrete_BatchLweBuffer:$lhs,
Concrete_BatchPlaintextBuffer:$rhs
);
}
def Concrete_BatchedAddPlaintextCstLweTensorOp : Concrete_Op<"batched_add_plaintext_cst_lwe_tensor", [Pure]> {
let summary = "Batched version of AddPlaintextLweTensorOp, which performs the same operation on multiple elements";
let arguments = (ins Concrete_BatchLweTensor:$lhs, I64:$rhs);
let results = (outs Concrete_BatchLweTensor:$result);
}
def Concrete_BatchedAddPlaintextCstLweBufferOp : Concrete_Op<"batched_add_plaintext_cst_lwe_buffer"> {
let summary = "Batched version of AddPlaintextLweBufferOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweBuffer:$result,
Concrete_BatchLweBuffer:$lhs,
I64:$rhs
);
}
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [Pure]> {
let summary = "Returns the product of a clear integer and a lwe ciphertext";
@@ -92,6 +148,40 @@ def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> {
);
}
def Concrete_BatchedMulCleartextLweTensorOp : Concrete_Op<"batched_mul_cleartext_lwe_tensor", [Pure]> {
let summary = "Batched version of MulCleartextLweTensorOp, which performs the same operation on multiple elements";
let arguments = (ins Concrete_BatchLweTensor:$lhs, Concrete_BatchPlaintextTensor:$rhs);
let results = (outs Concrete_BatchLweTensor:$result);
}
def Concrete_BatchedMulCleartextLweBufferOp : Concrete_Op<"batched_mul_cleartext_lwe_buffer"> {
let summary = "Batched version of MulCleartextLweBufferOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweBuffer:$result,
Concrete_BatchLweBuffer:$lhs,
Concrete_BatchPlaintextBuffer:$rhs
);
}
def Concrete_BatchedMulCleartextCstLweTensorOp : Concrete_Op<"batched_mul_cleartext_cst_lwe_tensor", [Pure]> {
let summary = "Batched version of MulCleartextLweTensorOp, which performs the same operation on multiple elements";
let arguments = (ins Concrete_BatchLweTensor:$lhs, I64:$rhs);
let results = (outs Concrete_BatchLweTensor:$result);
}
def Concrete_BatchedMulCleartextCstLweBufferOp : Concrete_Op<"batched_mul_cleartext_cst_lwe_buffer"> {
let summary = "Batched version of MulCleartextLweBufferOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweBuffer:$result,
Concrete_BatchLweBuffer:$lhs,
I64:$rhs
);
}
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [Pure]> {
let summary = "Negates a lwe ciphertext";
@@ -108,6 +198,22 @@ def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> {
);
}
def Concrete_BatchedNegateLweTensorOp : Concrete_Op<"batched_negate_lwe_tensor", [Pure]> {
let summary = "Batched version of NegateLweTensorOp, which performs the same operation on multiple elements";
let arguments = (ins Concrete_BatchLweTensor:$ciphertext);
let results = (outs Concrete_BatchLweTensor:$result);
}
def Concrete_BatchedNegateLweBufferOp : Concrete_Op<"batched_negate_lwe_buffer"> {
let summary = "Batched version of NegateLweBufferOp, which performs the same operation on multiple elements";
let arguments = (ins
Concrete_BatchLweBuffer:$result,
Concrete_BatchLweBuffer:$ciphertext
);
}
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [Pure]> {
let summary =
"Encode and expand a lookup table so that it can be used for a bootstrap";

View File

@@ -77,22 +77,103 @@ def TFHE_ZeroTensorGLWEOp : TFHE_Op<"zero_tensor", [Pure]> {
let results = (outs Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>:$tensor);
}
def TFHE_AddGLWEIntOp : TFHE_Op<"add_glwe_int", [Pure]> {
def TFHE_ABatchedAddGLWEIntOp : TFHE_Op<"batched_add_glwe_int", [Pure]> {
let summary = "Batched version of AddGLWEIntOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
1DTensorOf<[AnyInteger]> : $plaintexts
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_ABatchedAddGLWEIntCstOp : TFHE_Op<"batched_add_glwe_int_cst", [Pure]> {
let summary = "Batched version of AddGLWEIntOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
AnyInteger : $plaintext
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_AddGLWEIntOp : TFHE_Op<"add_glwe_int", [Pure, BatchableOpInterface]> {
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
let arguments = (ins TFHE_GLWECipherTextType : $a, AnyInteger : $b);
let results = (outs TFHE_GLWECipherTextType);
let hasVerifier = 1;
let extraClassDeclaration = [{
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
return getOperation()->getOpOperands().take_front(2);
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::ValueRange batchedOperands,
::mlir::ValueRange hoistedNonBatchableOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
::llvm::SmallVector<::mlir::Value> operands(batchedOperands);
if (hoistedNonBatchableOperands.empty()) {
return builder.create<ABatchedAddGLWEIntOp>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
} else {
operands.append(hoistedNonBatchableOperands.begin(),
hoistedNonBatchableOperands.end());
return builder.create<ABatchedAddGLWEIntCstOp>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
}
}
}];
}
def TFHE_AddGLWEOp : TFHE_Op<"add_glwe", [Pure]> {
def TFHE_ABatchedAddGLWEOp : TFHE_Op<"batched_add_glwe", [Pure]> {
let summary = "Batched version of AddGLWEOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts_a,
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts_b
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_AddGLWEOp : TFHE_Op<"add_glwe", [Pure, BatchableOpInterface]> {
let summary = "Returns the sum of 2 lwe ciphertexts";
let arguments = (ins TFHE_GLWECipherTextType : $a, TFHE_GLWECipherTextType : $b);
let results = (outs TFHE_GLWECipherTextType);
let hasVerifier = 1;
let extraClassDeclaration = [{
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
return getOperation()->getOpOperands().take_front(2);
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::ValueRange batchedOperands,
::mlir::ValueRange hoistedNonBatchableOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
return builder.create<ABatchedAddGLWEOp>(
mlir::TypeRange{resType},
mlir::ValueRange{batchedOperands},
getOperation()->getAttrs());
}
}];
}
def TFHE_SubGLWEIntOp : TFHE_Op<"sub_int_glwe", [Pure]> {
@@ -104,22 +185,102 @@ def TFHE_SubGLWEIntOp : TFHE_Op<"sub_int_glwe", [Pure]> {
let hasVerifier = 1;
}
def TFHE_NegGLWEOp : TFHE_Op<"neg_glwe", [Pure]> {
def TFHE_BatchedNegGLWEOp : TFHE_Op<"batched_neg_glwe", [Pure]> {
let summary = "Batched version of NegGLWEOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_NegGLWEOp : TFHE_Op<"neg_glwe", [Pure, BatchableOpInterface]> {
let summary = "Negates a glwe ciphertext";
let arguments = (ins TFHE_GLWECipherTextType : $a);
let results = (outs TFHE_GLWECipherTextType);
let hasVerifier = 1;
let extraClassDeclaration = [{
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
return getOperation()->getOpOperands().take_front();
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::ValueRange batchedOperands,
::mlir::ValueRange hoistedNonBatchableOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
return builder.create<BatchedNegGLWEOp>(
mlir::TypeRange{resType},
mlir::ValueRange{batchedOperands},
getOperation()->getAttrs());
}
}];
}
def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure]> {
def TFHE_BatchedMulGLWEIntOp : TFHE_Op<"batched_mul_glwe_int", [Pure]> {
let summary = "Batched version of MulGLWEIntOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
1DTensorOf<[AnyInteger]> : $cleartexts
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_BatchedMulGLWEIntCstOp : TFHE_Op<"batched_mul_glwe_int_cst", [Pure]> {
let summary = "Batched version of MulGLWEIntOp";
let arguments = (ins
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
AnyInteger: $cleartext
);
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
}
def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure, BatchableOpInterface]> {
let summary = "Returns the product of a clear integer and a lwe ciphertext";
let arguments = (ins TFHE_GLWECipherTextType : $a, AnyInteger : $b);
let results = (outs TFHE_GLWECipherTextType);
let hasVerifier = 1;
let extraClassDeclaration = [{
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
return getOperation()->getOpOperands().take_front(2);
}
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
::mlir::ValueRange batchedOperands,
::mlir::ValueRange hoistedNonBatchableOperands) {
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
getResult().getType());
::llvm::SmallVector<::mlir::Value> operands(batchedOperands);
if (hoistedNonBatchableOperands.empty()) {
return builder.create<BatchedMulGLWEIntOp>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
} else {
operands.append(hoistedNonBatchableOperands.begin(),
hoistedNonBatchableOperands.end());
return builder.create<BatchedMulGLWEIntCstOp>(
mlir::TypeRange{resType},
operands,
getOperation()->getAttrs());
}
}
}];
}
def TFHE_BatchedKeySwitchGLWEOp : TFHE_Op<"batched_keyswitch_glwe", [Pure]> {

View File

@@ -85,6 +85,54 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint32_t ksk_index,
mlir::concretelang::RuntimeContext *context);
void memref_batched_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size0,
uint64_t ct1_size1, uint64_t ct1_stride0, uint64_t ct1_stride1);
void memref_batched_add_plaintext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
uint64_t ct1_stride);
void memref_batched_add_plaintext_cst_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t plaintext);
void memref_batched_mul_cleartext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
uint64_t ct1_stride);
void memref_batched_mul_cleartext_cst_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t cleartext);
void memref_batched_negate_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1);
void memref_batched_keyswitch_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,

View File

@@ -27,6 +27,18 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] =
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
char memref_batched_add_lwe_ciphertexts_u64[] =
"memref_batched_add_lwe_ciphertexts_u64";
char memref_batched_add_plaintext_lwe_ciphertext_u64[] =
"memref_batched_add_plaintext_lwe_ciphertext_u64";
char memref_batched_add_plaintext_cst_lwe_ciphertext_u64[] =
"memref_batched_add_plaintext_cst_lwe_ciphertext_u64";
char memref_batched_mul_cleartext_lwe_ciphertext_u64[] =
"memref_batched_mul_cleartext_lwe_ciphertext_u64";
char memref_batched_mul_cleartext_cst_lwe_ciphertext_u64[] =
"memref_batched_mul_cleartext_cst_lwe_ciphertext_u64";
char memref_batched_negate_lwe_ciphertext_u64[] =
"memref_batched_negate_lwe_ciphertext_u64";
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
@@ -129,6 +141,26 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{futureType});
} else if (funcName == memref_batched_add_lwe_ciphertexts_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref2DType, memref2DType, memref2DType}, {});
} else if (funcName == memref_batched_add_plaintext_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref2DType, memref2DType, memref1DType}, {});
} else if (funcName == memref_batched_add_plaintext_cst_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref2DType, memref2DType, rewriter.getI64Type()}, {});
} else if (funcName == memref_batched_mul_cleartext_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref2DType, memref2DType, memref1DType}, {});
} else if (funcName == memref_batched_mul_cleartext_cst_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref2DType, memref2DType, rewriter.getI64Type()}, {});
} else if (funcName == memref_batched_negate_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType}, {});
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
funcType =
@@ -515,6 +547,26 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
.add<ConcreteToCAPICallPattern<Concrete::EncodeLutForCrtWopPBSBufferOp,
memref_encode_lut_for_crt_woppbs>>(
&getContext(), encodeLutForWopPBSAddOperands);
patterns
.add<ConcreteToCAPICallPattern<Concrete::BatchedAddLweBufferOp,
memref_batched_add_lwe_ciphertexts_u64>>(
&getContext());
patterns.add<ConcreteToCAPICallPattern<
Concrete::BatchedAddPlaintextLweBufferOp,
memref_batched_add_plaintext_lwe_ciphertext_u64>>(&getContext());
patterns.add<ConcreteToCAPICallPattern<
Concrete::BatchedAddPlaintextCstLweBufferOp,
memref_batched_add_plaintext_cst_lwe_ciphertext_u64>>(&getContext());
patterns.add<ConcreteToCAPICallPattern<
Concrete::BatchedMulCleartextLweBufferOp,
memref_batched_mul_cleartext_lwe_ciphertext_u64>>(&getContext());
patterns.add<ConcreteToCAPICallPattern<
Concrete::BatchedMulCleartextCstLweBufferOp,
memref_batched_mul_cleartext_cst_lwe_ciphertext_u64>>(&getContext());
patterns.add<
ConcreteToCAPICallPattern<Concrete::BatchedNegateLweBufferOp,
memref_batched_negate_lwe_ciphertext_u64>>(
&getContext());
if (gpu) {
patterns.add<ConcreteToCAPICallPattern<Concrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_cuda_u64>>(

View File

@@ -784,8 +784,28 @@ void TFHEToConcretePass::runOnOperation() {
mlir::concretelang::Concrete::EncodeLutForCrtWopPBSTensorOp, true>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::EncodePlaintextWithCrtOp,
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>(
&getContext(), converter);
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::ABatchedAddGLWEIntOp,
mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::ABatchedAddGLWEIntCstOp,
mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::ABatchedAddGLWEOp,
mlir::concretelang::Concrete::BatchedAddLweTensorOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::BatchedMulGLWEIntOp,
mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::BatchedMulGLWEIntCstOp,
mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::BatchedNegGLWEOp,
mlir::concretelang::Concrete::BatchedNegateLweTensorOp>
>(&getContext(), converter);
// pattern of remaining TFHE ops
patterns.insert<ZeroOpPattern<mlir::concretelang::TFHE::ZeroGLWEOp>,

View File

@@ -123,6 +123,34 @@ void mlir::concretelang::Concrete::
// bootstrap_lwe_tensor => bootstrap_lwe_buffer
Concrete::BootstrapLweTensorOp::attachInterface<TensorToMemrefOp<
Concrete::BootstrapLweTensorOp, Concrete::BootstrapLweBufferOp>>(*ctx);
// batched_add_lwe_tensor => batched_add_lwe_buffer
Concrete::BatchedAddLweTensorOp::attachInterface<TensorToMemrefOp<
Concrete::BatchedAddLweTensorOp, Concrete::BatchedAddLweBufferOp>>(
*ctx);
// batched_add_plaintext_lwe_tensor => batched_add_plaintext_lwe_buffer
Concrete::BatchedAddPlaintextLweTensorOp::attachInterface<
TensorToMemrefOp<Concrete::BatchedAddPlaintextLweTensorOp,
Concrete::BatchedAddPlaintextLweBufferOp>>(*ctx);
// batched_add_plaintext_cst_lwe_tensor =>
// batched_add_plaintext_cst_lwe_buffer
Concrete::BatchedAddPlaintextCstLweTensorOp::attachInterface<
TensorToMemrefOp<Concrete::BatchedAddPlaintextCstLweTensorOp,
Concrete::BatchedAddPlaintextCstLweBufferOp>>(*ctx);
// batched_mul_cleartext_lwe_tensor => batched_mul_cleartext_lwe_buffer
Concrete::BatchedMulCleartextLweTensorOp::attachInterface<
TensorToMemrefOp<Concrete::BatchedMulCleartextLweTensorOp,
Concrete::BatchedMulCleartextLweBufferOp>>(*ctx);
// batched_mul_cleartext_cst_lwe_tensor =>
// batched_mul_cleartext_cst_lwe_buffer
Concrete::BatchedMulCleartextCstLweTensorOp::attachInterface<
TensorToMemrefOp<Concrete::BatchedMulCleartextCstLweTensorOp,
Concrete::BatchedMulCleartextCstLweBufferOp>>(*ctx);
// batched_negate_lwe_tensor => batched_negate_lwe_buffer
Concrete::BatchedNegateLweTensorOp::attachInterface<
TensorToMemrefOp<Concrete::BatchedNegateLweTensorOp,
Concrete::BatchedNegateLweBufferOp>>(*ctx);
// batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer
Concrete::BatchedKeySwitchLweTensorOp::attachInterface<
TensorToMemrefOp<Concrete::BatchedKeySwitchLweTensorOp,

View File

@@ -517,6 +517,102 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
output_dimension);
}
void memref_batched_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size0,
uint64_t ct1_size1, uint64_t ct1_stride0, uint64_t ct1_stride1) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_add_lwe_ciphertexts_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
ct1_allocated + i * ct1_size1, ct1_aligned + i * ct1_size1, ct1_offset,
ct1_size1, ct1_stride1);
}
}
void memref_batched_add_plaintext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
uint64_t ct1_stride) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_add_plaintext_lwe_ciphertext_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
*(ct1_aligned + ct1_offset + i * ct1_stride));
}
}
void memref_batched_add_plaintext_cst_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t plaintext) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_add_plaintext_lwe_ciphertext_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
plaintext);
}
}
void memref_batched_mul_cleartext_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
uint64_t ct1_stride) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_mul_cleartext_lwe_ciphertext_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
*(ct1_aligned + ct1_offset + i * ct1_stride));
}
}
void memref_batched_mul_cleartext_cst_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t cleartext) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_mul_cleartext_lwe_ciphertext_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
cleartext);
}
}
void memref_batched_negate_lwe_ciphertext_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_negate_lwe_ciphertext_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1);
}
}
void memref_batched_keyswitch_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,