mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): Add fallback implementations for batched keyswitch and bootstrap
Add default implementations for batched keyswitch and bootstrap, which simply call the scalar versions of these operations in a loop.
This commit is contained in:
@@ -65,6 +65,16 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_batched_keyswitch_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
|
||||
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
@@ -81,6 +91,17 @@ void memref_bootstrap_lwe_u64(
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_batched_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
|
||||
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
|
||||
@@ -73,7 +73,9 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] =
|
||||
"memref_mul_cleartext_lwe_ciphertext_u64";
|
||||
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
|
||||
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
|
||||
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
|
||||
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
|
||||
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
|
||||
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
|
||||
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
|
||||
char memref_await_future[] = "memref_await_future";
|
||||
@@ -117,6 +119,11 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
{memref1DType, memref1DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_bootstrap_lwe_u64 ||
|
||||
funcName == memref_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
@@ -124,6 +131,12 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_keyswitch_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType},
|
||||
@@ -431,6 +444,26 @@ void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op,
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BatchedKeySwitchLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// lwe_dim_in
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_inAttr()));
|
||||
// lwe_dim_out
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_outAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
@@ -457,6 +490,32 @@ void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BatchedBootstrapLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
@@ -556,9 +615,17 @@ void mlir::concretelang::BConcrete::
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(*ctx);
|
||||
BConcrete::BatchedKeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<
|
||||
BConcrete::BatchedKeySwitchLweBufferOp,
|
||||
memref_batched_keyswitch_lwe_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_u64>>(*ctx);
|
||||
BConcrete::BatchedBootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<
|
||||
BConcrete::BatchedBootstrapLweBufferOp,
|
||||
memref_batched_bootstrap_lwe_u64>>(*ctx);
|
||||
}
|
||||
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
|
||||
|
||||
@@ -336,6 +336,23 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset));
|
||||
}
|
||||
|
||||
void memref_batched_keyswitch_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
|
||||
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
for (size_t i = 0; i < ct0_size0; i++) {
|
||||
memref_keyswitch_lwe_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
|
||||
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1, level,
|
||||
base_log, input_lwe_dim, output_lwe_dim, context);
|
||||
}
|
||||
}
|
||||
|
||||
void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
@@ -367,6 +384,27 @@ void memref_bootstrap_lwe_u64(
|
||||
free(glwe_ct);
|
||||
}
|
||||
|
||||
void memref_batched_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
|
||||
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
|
||||
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
|
||||
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
for (size_t i = 0; i < out_size0; i++) {
|
||||
memref_bootstrap_lwe_u64(
|
||||
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
|
||||
out_size1, out_stride1, ct0_allocated, ct0_aligned + i * ct0_size1,
|
||||
ct0_offset, ct0_size1, ct0_stride1, tlu_allocated, tlu_aligned,
|
||||
tlu_offset, tlu_size, tlu_stride, input_lwe_dim, poly_size, level,
|
||||
base_log, glwe_dim, precision, context);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
|
||||
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user