mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(compiler): add batched operations for all levelled ops.
This commit is contained in:
@@ -26,6 +26,7 @@ def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_BatchPlaintextTensor : 1DTensorOf<[I64]>;
|
||||
|
||||
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
|
||||
@@ -33,6 +34,7 @@ def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_BatchPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
|
||||
class Concrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Concrete_Dialect, mnemonic, traits>;
|
||||
@@ -58,6 +60,26 @@ def Concrete_AddLweBufferOp : Concrete_Op<"add_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedAddLweTensorOp : Concrete_Op<"batched_add_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of AddLweTensorOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweTensor:$lhs,
|
||||
Concrete_BatchLweTensor:$rhs
|
||||
);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedAddLweBufferOp : Concrete_Op<"batched_add_lwe_buffer"> {
|
||||
let summary = "Batched version of AddLweBufferOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$lhs,
|
||||
Concrete_BatchLweBuffer:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_AddPlaintextLweTensorOp : Concrete_Op<"add_plaintext_lwe_tensor", [Pure]> {
|
||||
let summary = "Returns the sum of a clear integer and an lwe ciphertext";
|
||||
|
||||
@@ -75,6 +97,40 @@ def Concrete_AddPlaintextLweBufferOp : Concrete_Op<"add_plaintext_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedAddPlaintextLweTensorOp : Concrete_Op<"batched_add_plaintext_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of AddPlaintextLweTensorOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins Concrete_BatchLweTensor:$lhs, Concrete_BatchPlaintextTensor:$rhs);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedAddPlaintextLweBufferOp : Concrete_Op<"batched_add_plaintext_lwe_buffer"> {
|
||||
let summary = "Batched version of AddPlaintextLweBufferOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$lhs,
|
||||
Concrete_BatchPlaintextBuffer:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedAddPlaintextCstLweTensorOp : Concrete_Op<"batched_add_plaintext_cst_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of AddPlaintextLweTensorOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins Concrete_BatchLweTensor:$lhs, I64:$rhs);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedAddPlaintextCstLweBufferOp : Concrete_Op<"batched_add_plaintext_cst_lwe_buffer"> {
|
||||
let summary = "Batched version of AddPlaintextLweBufferOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$lhs,
|
||||
I64:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_MulCleartextLweTensorOp : Concrete_Op<"mul_cleartext_lwe_tensor", [Pure]> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
@@ -92,6 +148,40 @@ def Concrete_MulCleartextLweBufferOp : Concrete_Op<"mul_cleartext_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedMulCleartextLweTensorOp : Concrete_Op<"batched_mul_cleartext_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of MulCleartextLweTensorOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins Concrete_BatchLweTensor:$lhs, Concrete_BatchPlaintextTensor:$rhs);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedMulCleartextLweBufferOp : Concrete_Op<"batched_mul_cleartext_lwe_buffer"> {
|
||||
let summary = "Batched version of MulCleartextLweBufferOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$lhs,
|
||||
Concrete_BatchPlaintextBuffer:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedMulCleartextCstLweTensorOp : Concrete_Op<"batched_mul_cleartext_cst_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of MulCleartextLweTensorOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins Concrete_BatchLweTensor:$lhs, I64:$rhs);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedMulCleartextCstLweBufferOp : Concrete_Op<"batched_mul_cleartext_cst_lwe_buffer"> {
|
||||
let summary = "Batched version of MulCleartextLweBufferOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$lhs,
|
||||
I64:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_NegateLweTensorOp : Concrete_Op<"negate_lwe_tensor", [Pure]> {
|
||||
let summary = "Negates a lwe ciphertext";
|
||||
|
||||
@@ -108,6 +198,22 @@ def Concrete_NegateLweBufferOp : Concrete_Op<"negate_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_BatchedNegateLweTensorOp : Concrete_Op<"batched_negate_lwe_tensor", [Pure]> {
|
||||
let summary = "Batched version of NegateLweTensorOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins Concrete_BatchLweTensor:$ciphertext);
|
||||
let results = (outs Concrete_BatchLweTensor:$result);
|
||||
}
|
||||
|
||||
def Concrete_BatchedNegateLweBufferOp : Concrete_Op<"batched_negate_lwe_buffer"> {
|
||||
let summary = "Batched version of NegateLweBufferOp, which performs the same operation on multiple elements";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_BatchLweBuffer:$result,
|
||||
Concrete_BatchLweBuffer:$ciphertext
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForBootstrapTensorOp : Concrete_Op<"encode_expand_lut_for_bootstrap_tensor", [Pure]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap";
|
||||
|
||||
@@ -77,22 +77,103 @@ def TFHE_ZeroTensorGLWEOp : TFHE_Op<"zero_tensor", [Pure]> {
|
||||
let results = (outs Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>:$tensor);
|
||||
}
|
||||
|
||||
def TFHE_AddGLWEIntOp : TFHE_Op<"add_glwe_int", [Pure]> {
|
||||
def TFHE_ABatchedAddGLWEIntOp : TFHE_Op<"batched_add_glwe_int", [Pure]> {
|
||||
let summary = "Batched version of AddGLWEIntOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
|
||||
1DTensorOf<[AnyInteger]> : $plaintexts
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_ABatchedAddGLWEIntCstOp : TFHE_Op<"batched_add_glwe_int_cst", [Pure]> {
|
||||
let summary = "Batched version of AddGLWEIntOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
|
||||
AnyInteger : $plaintext
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_AddGLWEIntOp : TFHE_Op<"add_glwe_int", [Pure, BatchableOpInterface]> {
|
||||
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins TFHE_GLWECipherTextType : $a, AnyInteger : $b);
|
||||
let results = (outs TFHE_GLWECipherTextType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
|
||||
return getOperation()->getOpOperands().take_front(2);
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::ValueRange batchedOperands,
|
||||
::mlir::ValueRange hoistedNonBatchableOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
::llvm::SmallVector<::mlir::Value> operands(batchedOperands);
|
||||
if (hoistedNonBatchableOperands.empty()) {
|
||||
return builder.create<ABatchedAddGLWEIntOp>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
} else {
|
||||
operands.append(hoistedNonBatchableOperands.begin(),
|
||||
hoistedNonBatchableOperands.end());
|
||||
return builder.create<ABatchedAddGLWEIntCstOp>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFHE_AddGLWEOp : TFHE_Op<"add_glwe", [Pure]> {
|
||||
def TFHE_ABatchedAddGLWEOp : TFHE_Op<"batched_add_glwe", [Pure]> {
|
||||
let summary = "Batched version of AddGLWEOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts_a,
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts_b
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_AddGLWEOp : TFHE_Op<"add_glwe", [Pure, BatchableOpInterface]> {
|
||||
let summary = "Returns the sum of 2 lwe ciphertexts";
|
||||
|
||||
let arguments = (ins TFHE_GLWECipherTextType : $a, TFHE_GLWECipherTextType : $b);
|
||||
let results = (outs TFHE_GLWECipherTextType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
|
||||
return getOperation()->getOpOperands().take_front(2);
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::ValueRange batchedOperands,
|
||||
::mlir::ValueRange hoistedNonBatchableOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<ABatchedAddGLWEOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFHE_SubGLWEIntOp : TFHE_Op<"sub_int_glwe", [Pure]> {
|
||||
@@ -104,22 +185,102 @@ def TFHE_SubGLWEIntOp : TFHE_Op<"sub_int_glwe", [Pure]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TFHE_NegGLWEOp : TFHE_Op<"neg_glwe", [Pure]> {
|
||||
def TFHE_BatchedNegGLWEOp : TFHE_Op<"batched_neg_glwe", [Pure]> {
|
||||
let summary = "Batched version of NegGLWEOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_NegGLWEOp : TFHE_Op<"neg_glwe", [Pure, BatchableOpInterface]> {
|
||||
let summary = "Negates a glwe ciphertext";
|
||||
|
||||
let arguments = (ins TFHE_GLWECipherTextType : $a);
|
||||
let results = (outs TFHE_GLWECipherTextType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
|
||||
return getOperation()->getOpOperands().take_front();
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::ValueRange batchedOperands,
|
||||
::mlir::ValueRange hoistedNonBatchableOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
return builder.create<BatchedNegGLWEOp>(
|
||||
mlir::TypeRange{resType},
|
||||
mlir::ValueRange{batchedOperands},
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure]> {
|
||||
def TFHE_BatchedMulGLWEIntOp : TFHE_Op<"batched_mul_glwe_int", [Pure]> {
|
||||
let summary = "Batched version of MulGLWEIntOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
|
||||
1DTensorOf<[AnyInteger]> : $cleartexts
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_BatchedMulGLWEIntCstOp : TFHE_Op<"batched_mul_glwe_int_cst", [Pure]> {
|
||||
let summary = "Batched version of MulGLWEIntOp";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[TFHE_GLWECipherTextType]> : $ciphertexts,
|
||||
AnyInteger: $cleartext
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[TFHE_GLWECipherTextType]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_MulGLWEIntOp : TFHE_Op<"mul_glwe_int", [Pure, BatchableOpInterface]> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins TFHE_GLWECipherTextType : $a, AnyInteger : $b);
|
||||
let results = (outs TFHE_GLWECipherTextType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands() {
|
||||
return getOperation()->getOpOperands().take_front(2);
|
||||
}
|
||||
|
||||
::mlir::Value createBatchedOperation(::mlir::ImplicitLocOpBuilder& builder,
|
||||
::mlir::ValueRange batchedOperands,
|
||||
::mlir::ValueRange hoistedNonBatchableOperands) {
|
||||
::mlir::RankedTensorType resType = ::mlir::RankedTensorType::get(
|
||||
batchedOperands[0].getType().cast<::mlir::RankedTensorType>().getShape(),
|
||||
getResult().getType());
|
||||
|
||||
::llvm::SmallVector<::mlir::Value> operands(batchedOperands);
|
||||
if (hoistedNonBatchableOperands.empty()) {
|
||||
return builder.create<BatchedMulGLWEIntOp>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
} else {
|
||||
operands.append(hoistedNonBatchableOperands.begin(),
|
||||
hoistedNonBatchableOperands.end());
|
||||
return builder.create<BatchedMulGLWEIntCstOp>(
|
||||
mlir::TypeRange{resType},
|
||||
operands,
|
||||
getOperation()->getAttrs());
|
||||
}
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TFHE_BatchedKeySwitchGLWEOp : TFHE_Op<"batched_keyswitch_glwe", [Pure]> {
|
||||
|
||||
@@ -85,6 +85,54 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint32_t ksk_index,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_batched_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
|
||||
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size0,
|
||||
uint64_t ct1_size1, uint64_t ct1_stride0, uint64_t ct1_stride1);
|
||||
|
||||
void memref_batched_add_plaintext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
|
||||
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
|
||||
uint64_t ct1_stride);
|
||||
|
||||
void memref_batched_add_plaintext_cst_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t plaintext);
|
||||
|
||||
void memref_batched_mul_cleartext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
|
||||
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
|
||||
uint64_t ct1_stride);
|
||||
|
||||
void memref_batched_mul_cleartext_cst_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t cleartext);
|
||||
|
||||
void memref_batched_negate_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1);
|
||||
|
||||
void memref_batched_keyswitch_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
|
||||
@@ -27,6 +27,18 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] =
|
||||
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
|
||||
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
|
||||
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
|
||||
char memref_batched_add_lwe_ciphertexts_u64[] =
|
||||
"memref_batched_add_lwe_ciphertexts_u64";
|
||||
char memref_batched_add_plaintext_lwe_ciphertext_u64[] =
|
||||
"memref_batched_add_plaintext_lwe_ciphertext_u64";
|
||||
char memref_batched_add_plaintext_cst_lwe_ciphertext_u64[] =
|
||||
"memref_batched_add_plaintext_cst_lwe_ciphertext_u64";
|
||||
char memref_batched_mul_cleartext_lwe_ciphertext_u64[] =
|
||||
"memref_batched_mul_cleartext_lwe_ciphertext_u64";
|
||||
char memref_batched_mul_cleartext_cst_lwe_ciphertext_u64[] =
|
||||
"memref_batched_mul_cleartext_cst_lwe_ciphertext_u64";
|
||||
char memref_batched_negate_lwe_ciphertext_u64[] =
|
||||
"memref_batched_negate_lwe_ciphertext_u64";
|
||||
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
|
||||
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
|
||||
|
||||
@@ -129,6 +141,26 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_batched_add_lwe_ciphertexts_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref2DType, memref2DType, memref2DType}, {});
|
||||
} else if (funcName == memref_batched_add_plaintext_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref2DType, memref2DType, memref1DType}, {});
|
||||
} else if (funcName == memref_batched_add_plaintext_cst_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref2DType, memref2DType, rewriter.getI64Type()}, {});
|
||||
} else if (funcName == memref_batched_mul_cleartext_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref2DType, memref2DType, memref1DType}, {});
|
||||
} else if (funcName == memref_batched_mul_cleartext_cst_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref2DType, memref2DType, rewriter.getI64Type()}, {});
|
||||
} else if (funcName == memref_batched_negate_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType}, {});
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
|
||||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
|
||||
funcType =
|
||||
@@ -515,6 +547,26 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
|
||||
.add<ConcreteToCAPICallPattern<Concrete::EncodeLutForCrtWopPBSBufferOp,
|
||||
memref_encode_lut_for_crt_woppbs>>(
|
||||
&getContext(), encodeLutForWopPBSAddOperands);
|
||||
patterns
|
||||
.add<ConcreteToCAPICallPattern<Concrete::BatchedAddLweBufferOp,
|
||||
memref_batched_add_lwe_ciphertexts_u64>>(
|
||||
&getContext());
|
||||
patterns.add<ConcreteToCAPICallPattern<
|
||||
Concrete::BatchedAddPlaintextLweBufferOp,
|
||||
memref_batched_add_plaintext_lwe_ciphertext_u64>>(&getContext());
|
||||
patterns.add<ConcreteToCAPICallPattern<
|
||||
Concrete::BatchedAddPlaintextCstLweBufferOp,
|
||||
memref_batched_add_plaintext_cst_lwe_ciphertext_u64>>(&getContext());
|
||||
patterns.add<ConcreteToCAPICallPattern<
|
||||
Concrete::BatchedMulCleartextLweBufferOp,
|
||||
memref_batched_mul_cleartext_lwe_ciphertext_u64>>(&getContext());
|
||||
patterns.add<ConcreteToCAPICallPattern<
|
||||
Concrete::BatchedMulCleartextCstLweBufferOp,
|
||||
memref_batched_mul_cleartext_cst_lwe_ciphertext_u64>>(&getContext());
|
||||
patterns.add<
|
||||
ConcreteToCAPICallPattern<Concrete::BatchedNegateLweBufferOp,
|
||||
memref_batched_negate_lwe_ciphertext_u64>>(
|
||||
&getContext());
|
||||
if (gpu) {
|
||||
patterns.add<ConcreteToCAPICallPattern<Concrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64>>(
|
||||
|
||||
@@ -784,8 +784,28 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
mlir::concretelang::Concrete::EncodeLutForCrtWopPBSTensorOp, true>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::EncodePlaintextWithCrtOp,
|
||||
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>>(
|
||||
&getContext(), converter);
|
||||
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>,
|
||||
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::ABatchedAddGLWEIntOp,
|
||||
mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::ABatchedAddGLWEIntCstOp,
|
||||
mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::ABatchedAddGLWEOp,
|
||||
mlir::concretelang::Concrete::BatchedAddLweTensorOp>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::BatchedMulGLWEIntOp,
|
||||
mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::BatchedMulGLWEIntCstOp,
|
||||
mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::BatchedNegGLWEOp,
|
||||
mlir::concretelang::Concrete::BatchedNegateLweTensorOp>
|
||||
|
||||
>(&getContext(), converter);
|
||||
// pattern of remaining TFHE ops
|
||||
|
||||
patterns.insert<ZeroOpPattern<mlir::concretelang::TFHE::ZeroGLWEOp>,
|
||||
|
||||
@@ -123,6 +123,34 @@ void mlir::concretelang::Concrete::
|
||||
// bootstrap_lwe_tensor => bootstrap_lwe_buffer
|
||||
Concrete::BootstrapLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
Concrete::BootstrapLweTensorOp, Concrete::BootstrapLweBufferOp>>(*ctx);
|
||||
|
||||
// batched_add_lwe_tensor => batched_add_lwe_buffer
|
||||
Concrete::BatchedAddLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
Concrete::BatchedAddLweTensorOp, Concrete::BatchedAddLweBufferOp>>(
|
||||
*ctx);
|
||||
// batched_add_plaintext_lwe_tensor => batched_add_plaintext_lwe_buffer
|
||||
Concrete::BatchedAddPlaintextLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedAddPlaintextLweTensorOp,
|
||||
Concrete::BatchedAddPlaintextLweBufferOp>>(*ctx);
|
||||
// batched_add_plaintext_cst_lwe_tensor =>
|
||||
// batched_add_plaintext_cst_lwe_buffer
|
||||
Concrete::BatchedAddPlaintextCstLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedAddPlaintextCstLweTensorOp,
|
||||
Concrete::BatchedAddPlaintextCstLweBufferOp>>(*ctx);
|
||||
// batched_mul_cleartext_lwe_tensor => batched_mul_cleartext_lwe_buffer
|
||||
Concrete::BatchedMulCleartextLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedMulCleartextLweTensorOp,
|
||||
Concrete::BatchedMulCleartextLweBufferOp>>(*ctx);
|
||||
// batched_mul_cleartext_cst_lwe_tensor =>
|
||||
// batched_mul_cleartext_cst_lwe_buffer
|
||||
Concrete::BatchedMulCleartextCstLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedMulCleartextCstLweTensorOp,
|
||||
Concrete::BatchedMulCleartextCstLweBufferOp>>(*ctx);
|
||||
// batched_negate_lwe_tensor => batched_negate_lwe_buffer
|
||||
Concrete::BatchedNegateLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedNegateLweTensorOp,
|
||||
Concrete::BatchedNegateLweBufferOp>>(*ctx);
|
||||
|
||||
// batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer
|
||||
Concrete::BatchedKeySwitchLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::BatchedKeySwitchLweTensorOp,
|
||||
|
||||
@@ -517,6 +517,102 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
output_dimension);
|
||||
}
|
||||
|
||||
void memref_batched_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
|
||||
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size0,
|
||||
uint64_t ct1_size1, uint64_t ct1_stride0, uint64_t ct1_stride1) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_add_lwe_ciphertexts_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
|
||||
ct1_allocated + i * ct1_size1, ct1_aligned + i * ct1_size1, ct1_offset,
|
||||
ct1_size1, ct1_stride1);
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_add_plaintext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
|
||||
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
|
||||
uint64_t ct1_stride) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_add_plaintext_lwe_ciphertext_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
|
||||
*(ct1_aligned + ct1_offset + i * ct1_stride));
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_add_plaintext_cst_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t plaintext) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_add_plaintext_lwe_ciphertext_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
|
||||
plaintext);
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_mul_cleartext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *ct1_allocated,
|
||||
uint64_t *ct1_aligned, uint64_t ct1_offset, uint64_t ct1_size,
|
||||
uint64_t ct1_stride) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
|
||||
*(ct1_aligned + ct1_offset + i * ct1_stride));
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_mul_cleartext_cst_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t cleartext) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1,
|
||||
cleartext);
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_negate_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_negate_lwe_ciphertext_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1);
|
||||
}
|
||||
}
|
||||
|
||||
void memref_batched_keyswitch_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
|
||||
Reference in New Issue
Block a user