mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 20:55:02 -05:00
feat(compiler): Handle batched operators for gpu codegen
This commit is contained in:
@@ -34,6 +34,10 @@ char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
|
||||
char memref_await_future[] = "memref_await_future";
|
||||
char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64";
|
||||
char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64";
|
||||
char memref_batched_keyswitch_lwe_cuda_u64[] =
|
||||
"memref_batched_keyswitch_lwe_cuda_u64";
|
||||
char memref_batched_bootstrap_lwe_cuda_u64[] =
|
||||
"memref_batched_bootstrap_lwe_cuda_u64";
|
||||
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64";
|
||||
|
||||
@@ -116,12 +120,14 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
|
||||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
|
||||
} else if (funcName == memref_batched_bootstrap_lwe_u64 ||
|
||||
funcName == memref_batched_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
@@ -335,6 +341,16 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_cuda_u64>>(
|
||||
&getContext(), bootstrapAddOperands<BConcrete::BootstrapLweBufferOp>);
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::BatchedKeySwitchLweBufferOp,
|
||||
memref_batched_keyswitch_lwe_cuda_u64>>(
|
||||
&getContext(),
|
||||
keyswitchAddOperands<BConcrete::BatchedKeySwitchLweBufferOp>);
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::BatchedBootstrapLweBufferOp,
|
||||
memref_batched_bootstrap_lwe_cuda_u64>>(
|
||||
&getContext(),
|
||||
bootstrapAddOperands<BConcrete::BatchedBootstrapLweBufferOp>);
|
||||
} else {
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(
|
||||
|
||||
Reference in New Issue
Block a user