From 752f0feb75c62cd030ad99ee057c79c136e5d976 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Tue, 1 Nov 2022 17:32:11 +0000 Subject: [PATCH] feat(runtime): add a stream emulator. --- .../Runtime/stream_emulator_api.h | 57 +++ compiler/lib/Runtime/CMakeLists.txt | 2 +- compiler/lib/Runtime/StreamEmulator.cpp | 383 ++++++++++++++++++ 3 files changed, 441 insertions(+), 1 deletion(-) create mode 100644 compiler/include/concretelang/Runtime/stream_emulator_api.h create mode 100644 compiler/lib/Runtime/StreamEmulator.cpp diff --git a/compiler/include/concretelang/Runtime/stream_emulator_api.h b/compiler/include/concretelang/Runtime/stream_emulator_api.h new file mode 100644 index 000000000..da5d9cf48 --- /dev/null +++ b/compiler/include/concretelang/Runtime/stream_emulator_api.h @@ -0,0 +1,57 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +/// Define the API exposed to the compiler for streaming code generation. + +#ifndef CONCRETELANG_STREAM_EMULATOR_API_H +#define CONCRETELANG_STREAM_EMULATOR_API_H +#include +#include +#include + +typedef enum stream_type { + TS_STREAM_TYPE_X86_TO_TOPO_LSAP, + TS_STREAM_TYPE_TOPO_TO_TOPO_LSAP, + TS_STREAM_TYPE_TOPO_TO_X86_LSAP +} stream_type; + +extern "C" { +void *stream_emulator_init(); +void stream_emulator_run(void *dfg); +void stream_emulator_delete(void *dfg); + +void stream_emulator_make_memref_add_lwe_ciphertexts_u64_process(void *dfg, + void *sin1, + void *sin2, + void *sout); +void stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, + void *sin1, + void *sout); +void stream_emulator_make_memref_keyswitch_lwe_u64_process( + void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, + uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context); +void stream_emulator_make_memref_bootstrap_lwe_u64_process( + void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, + uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, + uint32_t precision, void *context); + +void *stream_emulator_make_uint64_stream(const char *name, stream_type stype); +void stream_emulator_put_uint64(void *stream, uint64_t e); +uint64_t stream_emulator_get_uint64(void *stream); + +void *stream_emulator_make_memref_stream(const char *name, stream_type stype); +void stream_emulator_put_memref(void *stream, uint64_t *allocated, + uint64_t *aligned, uint64_t offset, + uint64_t size, uint64_t stride); +void stream_emulator_get_memref(void *stream, uint64_t *out_allocated, + uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride); +} + +#endif diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index 0c14673b1..093a1cc08 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -1,4 +1,4 @@ -add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp seeder.cpp) +add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp seeder.cpp) if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED) target_link_libraries(ConcretelangRuntime PRIVATE HPX::hpx HPX::iostreams_component) diff --git a/compiler/lib/Runtime/StreamEmulator.cpp b/compiler/lib/Runtime/StreamEmulator.cpp new file mode 100644 index 000000000..17c612752 --- /dev/null +++ b/compiler/lib/Runtime/StreamEmulator.cpp @@ -0,0 +1,383 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +using concretelang::clientlib::MemRefDescriptor; + +namespace mlir { +namespace concretelang { +namespace stream_emulator { +namespace { + +template struct StreamBase { + void put(T e) { q.push(e); } + T get() { + while (q.empty()) + sched_yield(); + T ret = q.front(); + q.pop(); + return ret; + } + bool empty() { return q.empty(); } + +private: + std::queue q; +}; +union Stream { + StreamBase *uint64_stream; + StreamBase> *memref_stream; + + Stream(StreamBase *s) : uint64_stream(s) {} + Stream(StreamBase> *s) : memref_stream(s) {} +}; + +struct Void {}; +union Param { + Void _; + uint32_t val; +}; +union Context { + Void _; + mlir::concretelang::RuntimeContext *val; +}; +struct Process { + void terminate() { terminate_p = true; } + bool terminate_p = false; + std::vector input_streams; + std::vector output_streams; + Param level; + Param base_log; + Param input_lwe_dim; + Param output_lwe_dim; + Param poly_size; + Param glwe_dim; + Param precision; + Context ctx; + void (*fun)(Process *); +}; + +struct DFGraph { + ~DFGraph() { + for (auto p : dfg_processes) + p->terminate(); + } + void run() { + for (auto p : dfg_processes) { + std::thread process_thread(p->fun, p); + process_thread.detach(); + } + } + std::vector dfg_processes; +}; + +// Stream emulator processes +void memref_keyswitch_lwe_u64_process(Process *p) { + while (!p->terminate_p) { + MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); + MemRefDescriptor<1> out = ct0; + out.allocated = out.aligned = + (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + out.offset = 0; + memref_keyswitch_lwe_u64( + out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], + ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], + p->level.val, p->base_log.val, p->input_lwe_dim.val, + p->output_lwe_dim.val, p->ctx.val); + (p->output_streams[0]).memref_stream->put(out); + } + delete p; +} + +void memref_bootstrap_lwe_u64_process(Process *p) { + while (!p->terminate_p) { + MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); + MemRefDescriptor<1> tlu = (p->input_streams[1]).memref_stream->get(); + MemRefDescriptor<1> out = ct0; + out.allocated = out.aligned = + (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + out.offset = 0; + memref_bootstrap_lwe_u64( + out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], + ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], + tlu.allocated, tlu.aligned, tlu.offset, tlu.sizes[0], tlu.strides[0], + p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->base_log.val, + p->glwe_dim.val, p->precision.val, p->ctx.val); + (p->output_streams[0]).memref_stream->put(out); + } + delete p; +} + +void memref_add_lwe_ciphertexts_u64_process(Process *p) { + while (!p->terminate_p) { + MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); + MemRefDescriptor<1> ct1 = (p->input_streams[1]).memref_stream->get(); + MemRefDescriptor<1> out = ct0; + out.allocated = out.aligned = + (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + out.offset = 0; + memref_add_lwe_ciphertexts_u64( + out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], + ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], + ct1.allocated, ct1.aligned, ct1.offset, ct1.sizes[0], ct1.strides[0]); + (p->output_streams[0]).memref_stream->put(out); + } + delete p; +} + +void memref_add_plaintext_lwe_ciphertext_u64_process(Process *p) { + while (!p->terminate_p) { + MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); + uint64_t plaintext = (p->input_streams[1]).uint64_stream->get(); + MemRefDescriptor<1> out = ct0; + out.allocated = out.aligned = + (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + out.offset = 0; + memref_add_plaintext_lwe_ciphertext_u64( + out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], + ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], + plaintext); + (p->output_streams[0]).memref_stream->put(out); + } + delete p; +} + +void memref_mul_cleartext_lwe_ciphertext_u64_process(Process *p) { + while (!p->terminate_p) { + MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); + uint64_t cleartext = (p->input_streams[1]).uint64_stream->get(); + MemRefDescriptor<1> out = ct0; + out.allocated = out.aligned = + (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + out.offset = 0; + memref_mul_cleartext_lwe_ciphertext_u64( + out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], + ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], + cleartext); + (p->output_streams[0]).memref_stream->put(out); + } + delete p; +} + +void memref_negate_lwe_ciphertext_u64_process(Process *p) { + while (!p->terminate_p) { + MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); + MemRefDescriptor<1> out = ct0; + out.allocated = out.aligned = + (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + out.offset = 0; + memref_negate_lwe_ciphertext_u64( + out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], + ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0]); + (p->output_streams[0]).memref_stream->put(out); + } + delete p; +} + +} // namespace +} // namespace stream_emulator +} // namespace concretelang +} // namespace mlir + +// Code generation interface +void stream_emulator_make_memref_add_lwe_ciphertexts_u64_process(void *dfg, + void *sin1, + void *sin2, + void *sout) { + mlir::concretelang::stream_emulator::Process *p = + new mlir::concretelang::stream_emulator::Process; + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin1); + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin2); + p->output_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sout); + p->fun = mlir::concretelang::stream_emulator:: + memref_add_lwe_ciphertexts_u64_process; + ((mlir::concretelang::stream_emulator::DFGraph *)dfg) + ->dfg_processes.push_back(p); +} + +void stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout) { + mlir::concretelang::stream_emulator::Process *p = + new mlir::concretelang::stream_emulator::Process; + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin1); + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase *)sin2); + p->output_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sout); + p->fun = mlir::concretelang::stream_emulator:: + memref_add_plaintext_lwe_ciphertext_u64_process; + ((mlir::concretelang::stream_emulator::DFGraph *)dfg) + ->dfg_processes.push_back(p); +} + +void stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout) { + mlir::concretelang::stream_emulator::Process *p = + new mlir::concretelang::stream_emulator::Process; + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin1); + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase *)sin2); + p->output_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sout); + p->fun = mlir::concretelang::stream_emulator:: + memref_mul_cleartext_lwe_ciphertext_u64_process; + ((mlir::concretelang::stream_emulator::DFGraph *)dfg) + ->dfg_processes.push_back(p); +} + +void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, + void *sin1, + void *sout) { + mlir::concretelang::stream_emulator::Process *p = + new mlir::concretelang::stream_emulator::Process; + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin1); + p->output_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sout); + p->fun = mlir::concretelang::stream_emulator:: + memref_negate_lwe_ciphertext_u64_process; + ((mlir::concretelang::stream_emulator::DFGraph *)dfg) + ->dfg_processes.push_back(p); +} + +void stream_emulator_make_memref_keyswitch_lwe_u64_process( + void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, + uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context) { + mlir::concretelang::stream_emulator::Process *p = + new mlir::concretelang::stream_emulator::Process; + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin1); + p->output_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sout); + p->level.val = level; + p->base_log.val = base_log; + p->input_lwe_dim.val = input_lwe_dim; + p->output_lwe_dim.val = output_lwe_dim; + p->ctx.val = (mlir::concretelang::RuntimeContext *)context; + p->fun = + mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process; + ((mlir::concretelang::stream_emulator::DFGraph *)dfg) + ->dfg_processes.push_back(p); +} + +void stream_emulator_make_memref_bootstrap_lwe_u64_process( + void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, + uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, + uint32_t precision, void *context) { + mlir::concretelang::stream_emulator::Process *p = + new mlir::concretelang::stream_emulator::Process; + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin1); + p->input_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sin2); + p->output_streams.push_back( + (mlir::concretelang::stream_emulator::StreamBase> *) + sout); + p->input_lwe_dim.val = input_lwe_dim; + p->poly_size.val = poly_size; + p->level.val = level; + p->base_log.val = base_log; + p->glwe_dim.val = glwe_dim; + p->precision.val = precision; + p->ctx.val = (mlir::concretelang::RuntimeContext *)context; + p->fun = + mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process; + ((mlir::concretelang::stream_emulator::DFGraph *)dfg) + ->dfg_processes.push_back(p); +} + +void *stream_emulator_make_uint64_stream(const char *name, stream_type stype) { + return (void *)new mlir::concretelang::stream_emulator::StreamBase; +} +void stream_emulator_put_uint64(void *stream, uint64_t e) { + ((mlir::concretelang::stream_emulator::StreamBase *)stream)->put(e); +} +uint64_t stream_emulator_get_uint64(void *stream) { + return ((mlir::concretelang::stream_emulator::StreamBase *)stream) + ->get(); +} + +void *stream_emulator_make_memref_stream(const char *name, stream_type stype) { + return (void *)new mlir::concretelang::stream_emulator::StreamBase< + MemRefDescriptor<1>>; +} +void stream_emulator_put_memref(void *stream, uint64_t *allocated, + uint64_t *aligned, uint64_t offset, + uint64_t size, uint64_t stride) { + ((mlir::concretelang::stream_emulator::StreamBase> *) + stream) + ->put({allocated, aligned, offset, {size}, {stride}}); +} +void stream_emulator_get_memref(void *stream, uint64_t *out_allocated, + uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride) { + MemRefDescriptor<1> mref = + ((mlir::concretelang::stream_emulator::StreamBase> *) + stream) + ->get(); + memref_copy_one_rank(mref.allocated, mref.aligned, mref.offset, mref.sizes[0], + mref.strides[0], out_allocated, out_aligned, out_offset, + out_size, out_stride); + free(mref.allocated); +} + +void *stream_emulator_init() { +#ifdef CORNAMI_AVAILABLE + // TODO: check/update against new info on Cornami API + fhestream *pfhestream = new fhestream(); + pfhestream->initTopology(); + return (void *)pfhestream; +#else + return (void *)new mlir::concretelang::stream_emulator::DFGraph; +#endif +} + +void stream_emulator_run(void *dfg) { +#ifdef CORNAMI_AVAILABLE + ((fhestream *)dfg)->FinalizeAndRun(); +#else + ((mlir::concretelang::stream_emulator::DFGraph *)dfg)->run(); +#endif +} + +void stream_emulator_delete(void *dfg) { +#ifdef CORNAMI_AVAILABLE + delete ((fhestream *)dfg); +#else + delete ((mlir::concretelang::stream_emulator::DFGraph *)dfg); +#endif +}