diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index 8d8011403..8242b1307 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -65,6 +65,16 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, uint32_t level, uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim, mlir::concretelang::RuntimeContext *context); + +void memref_batched_keyswitch_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level, + uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim, + mlir::concretelang::RuntimeContext *context); + void *memref_keyswitch_async_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, @@ -81,6 +91,17 @@ void memref_bootstrap_lwe_u64( uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context); +void memref_batched_bootstrap_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1, + uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated, + uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size, + uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, + uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision, + mlir::concretelang::RuntimeContext *context); + void *memref_bootstrap_async_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index bf3c9a6c4..372bc6a7a 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -73,7 +73,9 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] = "memref_mul_cleartext_lwe_ciphertext_u64"; char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; +char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64"; char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; +char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64"; char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; char memref_await_future[] = "memref_await_future"; @@ -117,6 +119,11 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( {memref1DType, memref1DType, i32Type, i32Type, i32Type, i32Type, contextType}, {}); + } else if (funcName == memref_batched_keyswitch_lwe_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref2DType, memref2DType, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); } else if (funcName == memref_bootstrap_lwe_u64 || funcName == memref_bootstrap_lwe_cuda_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), @@ -124,6 +131,12 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( memref1DType, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, contextType}, {}); + } else if (funcName == memref_batched_bootstrap_lwe_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref2DType, memref2DType, + memref1DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); } else if (funcName == memref_keyswitch_async_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, @@ -431,6 +444,26 @@ void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op, operands.push_back(getContextArgument(op)); } +template <> +void pushAdditionalArgs(BConcrete::BatchedKeySwitchLweBufferOp op, + mlir::SmallVector &operands, + RewriterBase &rewriter) { + // level + operands.push_back( + rewriter.create(op.getLoc(), op.levelAttr())); + // base_log + operands.push_back( + rewriter.create(op.getLoc(), op.baseLogAttr())); + // lwe_dim_in + operands.push_back(rewriter.create( + op.getLoc(), op.lwe_dim_inAttr())); + // lwe_dim_out + operands.push_back(rewriter.create( + op.getLoc(), op.lwe_dim_outAttr())); + // context + operands.push_back(getContextArgument(op)); +} + template <> void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op, mlir::SmallVector &operands, @@ -457,6 +490,32 @@ void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op, operands.push_back(getContextArgument(op)); } +template <> +void pushAdditionalArgs(BConcrete::BatchedBootstrapLweBufferOp op, + mlir::SmallVector &operands, + RewriterBase &rewriter) { + // input_lwe_dim + operands.push_back(rewriter.create( + op.getLoc(), op.inputLweDimAttr())); + // poly_size + operands.push_back( + rewriter.create(op.getLoc(), op.polySizeAttr())); + // level + operands.push_back( + rewriter.create(op.getLoc(), op.levelAttr())); + // base_log + operands.push_back( + rewriter.create(op.getLoc(), op.baseLogAttr())); + // glwe_dim + operands.push_back(rewriter.create( + op.getLoc(), op.glweDimensionAttr())); + // out_precision + operands.push_back(rewriter.create( + op.getLoc(), op.outPrecisionAttr())); + // context + operands.push_back(getContextArgument(op)); +} + template <> void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op, mlir::SmallVector &operands, @@ -556,9 +615,17 @@ void mlir::concretelang::BConcrete:: BConcrete::KeySwitchLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); + BConcrete::BatchedKeySwitchLweBufferOp::attachInterface< + BufferizableWithCallOpInterface< + BConcrete::BatchedKeySwitchLweBufferOp, + memref_batched_keyswitch_lwe_u64>>(*ctx); BConcrete::BootstrapLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); + BConcrete::BatchedBootstrapLweBufferOp::attachInterface< + BufferizableWithCallOpInterface< + BConcrete::BatchedBootstrapLweBufferOp, + memref_batched_bootstrap_lwe_u64>>(*ctx); } BConcrete::WopPBSCRTLweBufferOp::attachInterface< BufferizableWithCallOpInterface