feat(compiler): Add fallback implementations for batched keyswitch and bootstrap

Add default implementations for batched keyswitch and bootstrap, which
simply call the scalar versions of these operations in a loop.
This commit is contained in:
Andi Drebes
2022-11-15 16:53:06 +01:00
parent 9f153d2129
commit 46366eec41
3 changed files with 126 additions and 0 deletions

View File

@@ -65,6 +65,16 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context);
void memref_batched_keyswitch_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context);
void *memref_keyswitch_async_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
@@ -81,6 +91,17 @@ void memref_bootstrap_lwe_u64(
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
mlir::concretelang::RuntimeContext *context);
void memref_batched_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
mlir::concretelang::RuntimeContext *context);
void *memref_bootstrap_async_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,

View File

@@ -73,7 +73,9 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] =
"memref_mul_cleartext_lwe_ciphertext_u64";
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
char memref_await_future[] = "memref_await_future";
@@ -117,6 +119,11 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
{memref1DType, memref1DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_bootstrap_lwe_u64 ||
funcName == memref_bootstrap_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
@@ -124,6 +131,12 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_keyswitch_async_lwe_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref1DType, memref1DType, contextType},
@@ -431,6 +444,26 @@ void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op,
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BatchedKeySwitchLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// lwe_dim_in
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.lwe_dim_inAttr()));
// lwe_dim_out
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.lwe_dim_outAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
@@ -457,6 +490,32 @@ void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BatchedBootstrapLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.inputLweDimAttr()));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.glweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op,
mlir::SmallVector<mlir::Value> &operands,
@@ -556,9 +615,17 @@ void mlir::concretelang::BConcrete::
BConcrete::KeySwitchLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_u64>>(*ctx);
BConcrete::BatchedKeySwitchLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<
BConcrete::BatchedKeySwitchLweBufferOp,
memref_batched_keyswitch_lwe_u64>>(*ctx);
BConcrete::BootstrapLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_u64>>(*ctx);
BConcrete::BatchedBootstrapLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<
BConcrete::BatchedBootstrapLweBufferOp,
memref_batched_bootstrap_lwe_u64>>(*ctx);
}
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,

View File

@@ -336,6 +336,23 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
out_aligned + out_offset, ct0_aligned + ct0_offset));
}
void memref_batched_keyswitch_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint32_t level,
uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context) {
for (size_t i = 0; i < ct0_size0; i++) {
memref_keyswitch_lwe_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated + i * ct0_size1,
ct0_aligned + i * ct0_size1, ct0_offset, ct0_size1, ct0_stride1, level,
base_log, input_lwe_dim, output_lwe_dim, context);
}
}
void memref_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
@@ -367,6 +384,27 @@ void memref_bootstrap_lwe_u64(
free(glwe_ct);
}
void memref_batched_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
mlir::concretelang::RuntimeContext *context) {
for (size_t i = 0; i < out_size0; i++) {
memref_bootstrap_lwe_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated, ct0_aligned + i * ct0_size1,
ct0_offset, ct0_size1, ct0_stride1, tlu_allocated, tlu_aligned,
tlu_offset, tlu_size, tlu_stride, input_lwe_dim, poly_size, level,
base_log, glwe_dim, precision, context);
}
}
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
}