From 20394368bfd90651c439cf70a9987723402c9a64 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Sun, 16 Apr 2023 20:38:44 +0100 Subject: [PATCH] feat(compiler): add lowering of batched mapped bootstrap operations to wrappers and SDFG, with support in the runtime. --- .../Dialect/Concrete/IR/ConcreteOps.td | 34 ++++++++ .../concretelang/Dialect/SDFG/IR/SDFGOps.td | 4 +- .../Runtime/stream_emulator_api.h | 16 ++-- .../ConcreteToCAPI/ConcreteToCAPI.cpp | 21 +++++ .../SDFGToStreamEmulator.cpp | 7 ++ .../TFHEToConcrete/TFHEToConcrete.cpp | 40 +++++++++- .../BufferizableOpInterfaceImpl.cpp | 5 ++ .../compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp | 2 + .../SDFGConvertibleOpInterfaceImpl.cpp | 5 ++ .../compiler/lib/Runtime/GPUDFG.cpp | 77 +++++++++++++------ .../compiler/lib/Runtime/wrappers.cpp | 6 +- .../concretelang/SDFG/SDFG_unit_tests.cpp | 30 ++++++++ 12 files changed, 212 insertions(+), 35 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 56e9f9201..c93d9dc7d 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -27,6 +27,7 @@ def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>; def Concrete_LweCRTTensor : 2DTensorOf<[I64]>; def Concrete_BatchLweTensor : 2DTensorOf<[I64]>; def Concrete_BatchPlaintextTensor : 1DTensorOf<[I64]>; +def Concrete_BatchLutTensor : 2DTensorOf<[I64]>; def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>; def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>; @@ -35,6 +36,7 @@ def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>; def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>; def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>; def Concrete_BatchPlaintextBuffer : MemRefRankOf<[I64], [1]>; +def Concrete_BatchLutBuffer : MemRefRankOf<[I64], [2]>; class Concrete_Op traits = []> : Op; @@ -359,6 +361,38 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu ); } +def Concrete_BatchedMappedBootstrapLweTensorOp : Concrete_Op<"batched_mapped_bootstrap_lwe_tensor", [Pure]> { + let summary = "Batched, mapped version of BootstrapLweOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweTensor:$input_ciphertext, + Concrete_BatchLutTensor:$lookup_table_vector, + I32Attr:$inputLweDim, + I32Attr:$polySize, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$glweDimension, + I32Attr:$bskIndex + ); + let results = (outs Concrete_BatchLweTensor:$result); +} + +def Concrete_BatchedMappedBootstrapLweBufferOp : Concrete_Op<"batched_mapped_bootstrap_lwe_buffer"> { + let summary = "Batched, mapped version of BootstrapLweOp, which performs the same operation on multiple elements"; + + let arguments = (ins + Concrete_BatchLweBuffer:$result, + Concrete_BatchLweBuffer:$input_ciphertext, + Concrete_BatchLutBuffer:$lookup_table_vector, + I32Attr:$inputLweDim, + I32Attr:$polySize, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$glweDimension, + I32Attr:$bskIndex + ); +} + def Concrete_KeySwitchLweTensorOp : Concrete_Op<"keyswitch_lwe_tensor", [Pure]> { let summary = "Keyswitches an LWE ciphertext"; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td index 44904e492..de7317c4b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td @@ -97,6 +97,7 @@ def ProcessKindBatchMulEintIntCst : I32EnumAttrCase<"batched_mul_eint_int_cst", def ProcessKindBatchNegEint : I32EnumAttrCase<"batched_neg_eint", 11>; def ProcessKindBatchKeyswitch : I32EnumAttrCase<"batched_keyswitch", 12>; def ProcessKindBatchBootstrap : I32EnumAttrCase<"batched_bootstrap", 13>; +def ProcessKindBatchMapBootstrap : I32EnumAttrCase<"batched_mapped_bootstrap", 14>; def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind", [ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt, @@ -104,7 +105,8 @@ def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind", ProcessKindBatchAddEint, ProcessKindBatchAddEintInt, ProcessKindBatchAddEintIntCst, ProcessKindBatchMulEintInt, ProcessKindBatchMulEintIntCst, ProcessKindBatchNegEint, - ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap]> { + ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap, + ProcessKindBatchMapBootstrap]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::concretelang::SDFG"; } diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h index ce9693e85..04ad646e8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h @@ -36,12 +36,12 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, void *sout); void stream_emulator_make_memref_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, - uint32_t ksk_index, void *context); + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index, + uint32_t output_size, void *context); void stream_emulator_make_memref_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, - uint32_t output_size, uint32_t bsk_index, void *context); + uint32_t bsk_index, uint32_t output_size, void *context); void stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process( void *dfg, void *sin1, void *sin2, void *sout); @@ -57,12 +57,16 @@ void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process( void *dfg, void *sin1, void *sout); void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, - void *context); + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index, + uint32_t output_size, void *context); void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, - uint32_t output_size, void *context); + uint32_t bsk_index, uint32_t output_size, void *context); +void stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process( + void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, + uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, + uint32_t bsk_index, uint32_t output_size, void *context); void *stream_emulator_make_uint64_stream(const char *name, stream_type stype); void stream_emulator_put_uint64(void *stream, uint64_t e); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp index f56ff4f3c..5d02485f2 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp @@ -41,6 +41,8 @@ char memref_batched_negate_lwe_ciphertext_u64[] = "memref_batched_negate_lwe_ciphertext_u64"; char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64"; char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64"; +char memref_batched_mapped_bootstrap_lwe_u64[] = + "memref_batched_mapped_bootstrap_lwe_u64"; char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; @@ -51,6 +53,8 @@ char memref_batched_keyswitch_lwe_cuda_u64[] = "memref_batched_keyswitch_lwe_cuda_u64"; char memref_batched_bootstrap_lwe_cuda_u64[] = "memref_batched_bootstrap_lwe_cuda_u64"; +char memref_batched_mapped_bootstrap_lwe_cuda_u64[] = + "memref_batched_mapped_bootstrap_lwe_cuda_u64"; char memref_expand_lut_in_trivial_glwe_ct_u64[] = "memref_expand_lut_in_trivial_glwe_ct_u64"; @@ -175,6 +179,13 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( memref1DType, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, contextType}, {}); + } else if (funcName == memref_batched_mapped_bootstrap_lwe_u64 || + funcName == memref_batched_mapped_bootstrap_lwe_cuda_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref2DType, memref2DType, + memref2DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); } else if (funcName == memref_await_future) { funcType = mlir::FunctionType::get( rewriter.getContext(), @@ -584,6 +595,11 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase { memref_batched_bootstrap_lwe_cuda_u64>>( &getContext(), bootstrapAddOperands); + patterns.add>( + &getContext(), + bootstrapAddOperands); } else { patterns.add>( @@ -601,6 +617,11 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase { memref_batched_bootstrap_lwe_u64>>( &getContext(), bootstrapAddOperands); + patterns.add< + ConcreteToCAPICallPattern>( + &getContext(), + bootstrapAddOperands); } patterns.add { + + BatchedMappedBootstrapGLWEOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter) + : mlir::OpConversionPattern( + typeConverter, context, + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(TFHE::BatchedMappedBootstrapGLWEOp bmbsOp, + TFHE::BatchedMappedBootstrapGLWEOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + TFHE::GLWECipherTextType inputElementType = + bmbsOp.getCiphertexts() + .getType() + .cast() + .getElementType() + .cast(); + + auto polySize = adaptor.getKey().getPolySize(); + auto glweDimension = adaptor.getKey().getGlweDim(); + auto levels = adaptor.getKey().getLevels(); + auto baseLog = adaptor.getKey().getBaseLog(); + auto inputLweDimension = + inputElementType.getKey().getNormalized().value().dimension; + auto bskIndex = bmbsOp.getKeyAttr().getIndex(); + + rewriter.replaceOpWithNewOp( + bmbsOp, this->getTypeConverter()->convertType(bmbsOp.getType()), + adaptor.getCiphertexts(), adaptor.getLookupTable(), inputLweDimension, + polySize, levels, baseLog, glweDimension, bskIndex); + + return mlir::success(); + } +}; + struct KeySwitchGLWEOpPattern : public mlir::OpConversionPattern { @@ -811,7 +848,8 @@ void TFHEToConcretePass::runOnOperation() { patterns.insert, ZeroOpPattern, SubIntGLWEOpPattern, BootstrapGLWEOpPattern, - BatchedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern, + BatchedBootstrapGLWEOpPattern, + BatchedMappedBootstrapGLWEOpPattern, KeySwitchGLWEOpPattern, BatchedKeySwitchGLWEOpPattern, WopPBSGLWEOpPattern>( &getContext(), converter); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp index 9e54bbdd4..e5c4a7446 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -159,6 +159,11 @@ void mlir::concretelang::Concrete:: Concrete::BatchedBootstrapLweTensorOp::attachInterface< TensorToMemrefOp>(*ctx); + // batched_mapped_bootstrap_lwe_tensor => + // batched_mapped_bootstrap_lwe_buffer + Concrete::BatchedMappedBootstrapLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); // wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer Concrete::WopPBSCRTLweTensorOp::attachInterface>(*ctx); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp index 2ec8b3164..1df90b477 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp @@ -98,6 +98,8 @@ mlir::LogicalResult MakeProcess::verify() { return checkStreams(1, 1); case ProcessKind::batched_bootstrap: return checkStreams(2, 1); + case ProcessKind::batched_mapped_bootstrap: + return checkStreams(2, 1); } return mlir::failure(); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp index dbbb25dbe..0b1e38da9 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp @@ -29,6 +29,7 @@ char batched_mul_eint_int_cst[] = "batched_mul_eint_int_cst"; char batched_neg_eint[] = "batched_neg_eint"; char batched_keyswitch[] = "batched_keyswitch"; char batched_bootstrap[] = "batched_bootstrap"; +char batched_mapped_bootstrap[] = "batched_mapped_bootstrap"; } // namespace template @@ -129,6 +130,10 @@ void registerSDFGConvertibleOpInterfaceExternalModels( ReplaceWithProcessSDFGConversionInterface< mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp, batched_bootstrap, true>>(*ctx); + mlir::concretelang::Concrete::BatchedMappedBootstrapLweTensorOp:: + attachInterface>(*ctx); }); } } // namespace SDFG diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/GPUDFG.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/GPUDFG.cpp index ee65a3fcc..835477b07 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/GPUDFG.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/GPUDFG.cpp @@ -184,13 +184,18 @@ struct Process { Param output_lwe_dim; Param poly_size; Param glwe_dim; + Param sk_index; Param output_size; Context ctx; void (*fun)(Process *); char name[80]; }; -static inline void schedule_kernel(Process *p) { p->fun(p); } +static inline void schedule_kernel(Process *p) { + std::cout << " Scheduling a " << p->name << " on GPU " << p->dfg->gpu_idx + << "\n"; + p->fun(p); +} struct Stream { stream_type type; @@ -368,7 +373,7 @@ sdfg_gpu_debug_compare_memref(MemRef2 &a, MemRef2 &b, char const *msg) { a.strides[0] != b.strides[0] || a.strides[1] != b.strides[1]) return false; size_t data_size = memref_get_data_size(a); - for (int i = 0; i < data_size / sizeof(uint64_t); ++i) + for (size_t i = 0; i < data_size / sizeof(uint64_t); ++i) if ((a.aligned + a.offset)[i] != (b.aligned + b.offset)[i]) { std::cout << msg << " - memrefs differ at position " << i << " " << (a.aligned + a.offset)[i] << " " << (b.aligned + b.offset)[i] @@ -380,6 +385,7 @@ sdfg_gpu_debug_compare_memref(MemRef2 &a, MemRef2 &b, char const *msg) { // Stream emulator processes void memref_keyswitch_lwe_u64_process(Process *p) { + assert(p->sk_index.val == 0 && "multiple ksk is not yet implemented on GPU"); Dependence *idep = p->input_streams[0]->get(p->dfg->gpu_idx); uint64_t num_samples = idep->host_data.sizes[0]; MemRef2 out = { @@ -402,6 +408,7 @@ void memref_keyswitch_lwe_u64_process(Process *p) { } void memref_bootstrap_lwe_u64_process(Process *p) { + assert(p->sk_index.val == 0 && "multiple bsk is not yet implemented on GPU"); assert(p->output_size.val == p->glwe_dim.val * p->poly_size.val + 1); void *fbsk_gpu = p->ctx.val->get_bsk_gpu( p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->glwe_dim.val, @@ -409,18 +416,23 @@ void memref_bootstrap_lwe_u64_process(Process *p) { Dependence *idep0 = p->input_streams[0]->get(p->dfg->gpu_idx); void *ct0_gpu = idep0->device_data; - uint64_t glwe_ct_len = p->poly_size.val * (p->glwe_dim.val + 1); - uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t); - uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size); Dependence *idep1 = p->input_streams[1]->get(host_location); MemRef2 &mtlu = idep1->host_data; + uint32_t num_lut_vectors = mtlu.sizes[0]; + uint64_t glwe_ct_len = + p->poly_size.val * (p->glwe_dim.val + 1) * num_lut_vectors; + uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t); + uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size); auto tlu = mtlu.aligned + mtlu.offset; // Glwe trivial encryption - for (size_t i = 0; i < p->poly_size.val * p->glwe_dim.val; i++) { - glwe_ct[i] = 0; - } - for (size_t i = 0; i < p->poly_size.val; i++) { - glwe_ct[p->poly_size.val * p->glwe_dim.val + i] = tlu[i]; + size_t pos = 0, postlu = 0; + for (size_t l = 0; l < num_lut_vectors; ++l) { + for (size_t i = 0; i < p->poly_size.val * p->glwe_dim.val; i++) { + glwe_ct[pos++] = 0; + } + for (size_t i = 0; i < p->poly_size.val; i++) { + glwe_ct[pos++] = tlu[postlu++]; + } } void *glwe_ct_gpu = cuda_malloc_async( glwe_ct_size, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx); @@ -434,15 +446,21 @@ void memref_bootstrap_lwe_u64_process(Process *p) { void *out_gpu = cuda_malloc_async( data_size, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx); cudaMemsetAsync(out_gpu, 0, data_size, *(cudaStream_t *)p->dfg->gpu_stream); + // Move test vector indexes to the GPU, the test vector indexes is set of 0 - uint32_t num_test_vectors = 1, lwe_idx = 0, - test_vector_idxes_size = num_samples * sizeof(uint64_t); - void *test_vector_idxes = malloc(test_vector_idxes_size); - memset(test_vector_idxes, 0, test_vector_idxes_size); + uint32_t lwe_idx = 0, test_vector_idxes_size = num_samples * sizeof(uint64_t); + uint64_t *test_vector_idxes = (uint64_t *)malloc(test_vector_idxes_size); + if (num_lut_vectors == 1) { + memset((void *)test_vector_idxes, 0, test_vector_idxes_size); + } else { + assert(num_lut_vectors == num_samples); + for (size_t i = 0; i < num_lut_vectors; ++i) + test_vector_idxes[i] = i; + } void *test_vector_idxes_gpu = cuda_malloc_async(test_vector_idxes_size, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx); - cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, test_vector_idxes, + cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, (void *)test_vector_idxes, test_vector_idxes_size, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx); // Schedule the bootstrap kernel on the GPU @@ -452,7 +470,7 @@ void memref_bootstrap_lwe_u64_process(Process *p) { (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx, out_gpu, glwe_ct_gpu, test_vector_idxes_gpu, ct0_gpu, fbsk_gpu, (int8_t *)pbs_buffer, p->input_lwe_dim.val, p->glwe_dim.val, p->poly_size.val, p->base_log.val, - p->level.val, num_samples, num_test_vectors, lwe_idx, + p->level.val, num_samples, num_lut_vectors, lwe_idx, cuda_get_max_shared_memory(p->dfg->gpu_idx)); cuda_drop_async(test_vector_idxes_gpu, (cudaStream_t *)p->dfg->gpu_stream, p->dfg->gpu_idx); @@ -573,14 +591,15 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, void stream_emulator_make_memref_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, - void *context) { + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index, + uint32_t output_size, void *context) { Process *p = make_process_1_1(dfg, sin1, sout, memref_keyswitch_lwe_u64_process); p->level.val = level; p->base_log.val = base_log; p->input_lwe_dim.val = input_lwe_dim; p->output_lwe_dim.val = output_lwe_dim; + p->sk_index.val = ksk_index; p->output_size.val = output_size; p->ctx.val = (RuntimeContext *)context; static int count = 0; @@ -590,7 +609,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process( void stream_emulator_make_memref_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, - uint32_t output_size, void *context) { + uint32_t bsk_index, uint32_t output_size, void *context) { // The TLU does not need to be sent to GPU ((Stream *)sin2)->type = TS_STREAM_TYPE_X86_TO_X86_LSAP; Process *p = @@ -600,6 +619,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process( p->level.val = level; p->base_log.val = base_log; p->glwe_dim.val = glwe_dim; + p->sk_index.val = bsk_index; p->output_size.val = output_size; p->ctx.val = (RuntimeContext *)context; static int count = 0; @@ -642,20 +662,29 @@ void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process( void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, - void *context) { + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t ksk_index, + uint32_t output_size, void *context) { stream_emulator_make_memref_keyswitch_lwe_u64_process( dfg, sin1, sout, level, base_log, input_lwe_dim, output_lwe_dim, - output_size, context); + ksk_index, output_size, context); } void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, - uint32_t output_size, void *context) { + uint32_t bsk_index, uint32_t output_size, void *context) { stream_emulator_make_memref_bootstrap_lwe_u64_process( dfg, sin1, sin2, sout, input_lwe_dim, poly_size, level, base_log, - glwe_dim, output_size, context); + glwe_dim, bsk_index, output_size, context); +} + +void stream_emulator_make_memref_batched_mapped_bootstrap_lwe_u64_process( + void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, + uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, + uint32_t bsk_index, uint32_t output_size, void *context) { + stream_emulator_make_memref_bootstrap_lwe_u64_process( + dfg, sin1, sin2, sout, input_lwe_dim, poly_size, level, base_log, + glwe_dim, bsk_index, output_size, context); } void *stream_emulator_make_uint64_stream(const char *name, stream_type stype) { diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp index 6058d7dc0..559f25043 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/wrappers.cpp @@ -247,7 +247,8 @@ void memref_batched_mapped_bootstrap_lwe_cuda_u64( assert(bsk_index == 0 && "multiple bsk is not yet implemented on GPU"); assert(out_size0 == ct0_size0); assert(out_size1 == glwe_dim * poly_size + 1); - assert((out_size0 == tlu_size0 || tlu_size0 == 1) && "Number of LUTs does not match batch size"); + assert((out_size0 == tlu_size0 || tlu_size0 == 1) && + "Number of LUTs does not match batch size"); // TODO: Multi GPU uint32_t gpu_idx = 0; uint32_t num_samples = out_size0; @@ -291,8 +292,7 @@ void memref_batched_mapped_bootstrap_lwe_cuda_u64( glwe_ct, 0, glwe_ct_size, gpu_idx, (cudaStream_t *)stream); // Move test vector indexes to the GPU, the test vector indexes is set of 0 - uint32_t lwe_idx = 0, - test_vector_idxes_size = num_samples * sizeof(uint64_t); + uint32_t lwe_idx = 0, test_vector_idxes_size = num_samples * sizeof(uint64_t); uint64_t *test_vector_idxes = (uint64_t *)malloc(test_vector_idxes_size); if (num_lut_vectors == 1) { memset((void *)test_vector_idxes, 0, test_vector_idxes_size); diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp index 542ec0abd..a8ac8eaaa 100644 --- a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/SDFG/SDFG_unit_tests.cpp @@ -262,3 +262,33 @@ TEST(SDFG_unit_tests, batched_tree) { ASSERT_TRUE(res); ASSERT_EQ_OUTCOME(res, expected); } + +TEST(SDFG_unit_tests, batched_tree_mapped_tlu) { + std::string source = R"( + func.func @main(%t: tensor<3x3x!FHE.eint<3>>, %a1: tensor<3x3xi4>, %a2: tensor<3x3xi4>) -> tensor<3x3x!FHE.eint<4>> { + %lut_vec = arith.constant dense<[[1,3,5,7,9,11,13,15], + [2,4,6,8,10,12,14,0], + [3,6,9,12,15,2,5,8], + [4,8,12,0,4,8,12,0]]> : tensor<4x8xi64> + %map = arith.constant dense<[[0, 1, 2], [3, 2, 1], [1, 2, 3]]> : tensor<3x3xindex> + %b1 = "FHELinalg.add_eint_int"(%t, %a1) : (tensor<3x3x!FHE.eint<3>>, tensor<3x3xi4>) -> tensor<3x3x!FHE.eint<3>> + %b2 = "FHELinalg.add_eint_int"(%t, %a2) : (tensor<3x3x!FHE.eint<3>>, tensor<3x3xi4>) -> tensor<3x3x!FHE.eint<3>> + %c = "FHELinalg.add_eint"(%b1, %b2) : (tensor<3x3x!FHE.eint<3>>, tensor<3x3x!FHE.eint<3>>) -> tensor<3x3x!FHE.eint<3>> + %res = "FHELinalg.apply_mapped_lookup_table"(%c, %lut_vec, %map) : (tensor<3x3x!FHE.eint<3>>, tensor<4x8xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)"; + using tensor2_in = std::array, 3>; + std::string outputLib = outputLibFromThis(this->test_info_); + auto compiled = compile(outputLib, source); + auto lambda = + load>( + outputLib); + tensor2_in t = {{{0, 1, 2}, {3, 0, 1}, {2, 3, 0}}}; + tensor2_in a1 = {{{0, 1, 0}, {0, 1, 0}, {0, 1, 0}}}; + tensor2_in a2 = {{{1, 0, 1}, {1, 0, 1}, {1, 0, 1}}}; + tensor2_out expected = {{{3, 8, 2}, {0, 6, 8}, {12, 8, 8}}}; + auto res = lambda.call(t, a1, a2); + ASSERT_TRUE(res); + ASSERT_EQ_OUTCOME(res, expected); +}