feat(compiler): Handle batched operators for gpu codegen

This commit is contained in:
Quentin Bourgerie
2022-11-29 14:36:24 +01:00
parent 312c9281eb
commit 3c616af622
3 changed files with 194 additions and 206 deletions

View File

@@ -34,6 +34,10 @@ char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
char memref_await_future[] = "memref_await_future";
char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64";
char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64";
char memref_batched_keyswitch_lwe_cuda_u64[] =
"memref_batched_keyswitch_lwe_cuda_u64";
char memref_batched_bootstrap_lwe_cuda_u64[] =
"memref_batched_bootstrap_lwe_cuda_u64";
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
"memref_expand_lut_in_trivial_glwe_ct_u64";
@@ -116,12 +120,14 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{futureType});
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
} else if (funcName == memref_batched_bootstrap_lwe_u64 ||
funcName == memref_batched_bootstrap_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType,
memref1DType, i32Type, i32Type, i32Type,
@@ -335,6 +341,16 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
patterns.add<BConcreteToCAPICallPattern<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_cuda_u64>>(
&getContext(), bootstrapAddOperands<BConcrete::BootstrapLweBufferOp>);
patterns.add<
BConcreteToCAPICallPattern<BConcrete::BatchedKeySwitchLweBufferOp,
memref_batched_keyswitch_lwe_cuda_u64>>(
&getContext(),
keyswitchAddOperands<BConcrete::BatchedKeySwitchLweBufferOp>);
patterns.add<
BConcreteToCAPICallPattern<BConcrete::BatchedBootstrapLweBufferOp,
memref_batched_bootstrap_lwe_cuda_u64>>(
&getContext(),
bootstrapAddOperands<BConcrete::BatchedBootstrapLweBufferOp>);
} else {
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_u64>>(