// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include using concretelang::clientlib::MemRefDescriptor; namespace mlir { namespace concretelang { namespace stream_emulator { namespace { template struct StreamBase { void put(T e) { q.push(e); } T get() { while (q.empty()) sched_yield(); T ret = q.front(); q.pop(); return ret; } bool empty() { return q.empty(); } private: std::queue q; }; union Stream { StreamBase *uint64_stream; StreamBase> *memref_stream; Stream(StreamBase *s) : uint64_stream(s) {} Stream(StreamBase> *s) : memref_stream(s) {} }; struct Void {}; union Param { Void _; uint32_t val; }; union Context { Void _; mlir::concretelang::RuntimeContext *val; }; struct Process { void terminate() { terminate_p = true; } bool terminate_p = false; std::vector input_streams; std::vector output_streams; Param level; Param base_log; Param input_lwe_dim; Param output_lwe_dim; Param poly_size; Param glwe_dim; Param precision; Param output_size; Context ctx; void (*fun)(Process *); }; struct DFGraph { ~DFGraph() { for (auto p : dfg_processes) p->terminate(); } void run() { for (auto p : dfg_processes) { std::thread process_thread(p->fun, p); process_thread.detach(); } } std::vector dfg_processes; }; // Stream emulator processes void memref_keyswitch_lwe_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); MemRefDescriptor<1> out; out.sizes[0] = p->output_size.val; out.strides[0] = 1; out.offset = 0; out.allocated = out.aligned = (uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t)); memref_keyswitch_lwe_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], p->level.val, p->base_log.val, p->input_lwe_dim.val, p->output_lwe_dim.val, p->ctx.val); (p->output_streams[0]).memref_stream->put(out); } delete p; } void memref_bootstrap_lwe_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); MemRefDescriptor<1> tlu = (p->input_streams[1]).memref_stream->get(); MemRefDescriptor<1> out; out.sizes[0] = p->output_size.val; out.strides[0] = 1; out.offset = 0; out.allocated = out.aligned = (uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t)); memref_bootstrap_lwe_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], tlu.allocated, tlu.aligned, tlu.offset, tlu.sizes[0], tlu.strides[0], p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->base_log.val, p->glwe_dim.val, p->precision.val, p->ctx.val); (p->output_streams[0]).memref_stream->put(out); } delete p; } void memref_add_lwe_ciphertexts_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); MemRefDescriptor<1> ct1 = (p->input_streams[1]).memref_stream->get(); MemRefDescriptor<1> out = ct0; out.allocated = out.aligned = (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); out.offset = 0; memref_add_lwe_ciphertexts_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], ct1.allocated, ct1.aligned, ct1.offset, ct1.sizes[0], ct1.strides[0]); (p->output_streams[0]).memref_stream->put(out); } delete p; } void memref_add_plaintext_lwe_ciphertext_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); uint64_t plaintext = (p->input_streams[1]).uint64_stream->get(); MemRefDescriptor<1> out = ct0; out.allocated = out.aligned = (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); out.offset = 0; memref_add_plaintext_lwe_ciphertext_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], plaintext); (p->output_streams[0]).memref_stream->put(out); } delete p; } void memref_mul_cleartext_lwe_ciphertext_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); uint64_t cleartext = (p->input_streams[1]).uint64_stream->get(); MemRefDescriptor<1> out = ct0; out.allocated = out.aligned = (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); out.offset = 0; memref_mul_cleartext_lwe_ciphertext_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], cleartext); (p->output_streams[0]).memref_stream->put(out); } delete p; } void memref_negate_lwe_ciphertext_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); MemRefDescriptor<1> out = ct0; out.allocated = out.aligned = (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); out.offset = 0; memref_negate_lwe_ciphertext_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0]); (p->output_streams[0]).memref_stream->put(out); } delete p; } } // namespace } // namespace stream_emulator } // namespace concretelang } // namespace mlir // Code generation interface void stream_emulator_make_memref_add_lwe_ciphertexts_u64_process(void *dfg, void *sin1, void *sin2, void *sout) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin1); p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin2); p->output_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sout); p->fun = mlir::concretelang::stream_emulator:: memref_add_lwe_ciphertexts_u64_process; ((mlir::concretelang::stream_emulator::DFGraph *)dfg) ->dfg_processes.push_back(p); } void stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process( void *dfg, void *sin1, void *sin2, void *sout) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin1); p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase *)sin2); p->output_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sout); p->fun = mlir::concretelang::stream_emulator:: memref_add_plaintext_lwe_ciphertext_u64_process; ((mlir::concretelang::stream_emulator::DFGraph *)dfg) ->dfg_processes.push_back(p); } void stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process( void *dfg, void *sin1, void *sin2, void *sout) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin1); p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase *)sin2); p->output_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sout); p->fun = mlir::concretelang::stream_emulator:: memref_mul_cleartext_lwe_ciphertext_u64_process; ((mlir::concretelang::stream_emulator::DFGraph *)dfg) ->dfg_processes.push_back(p); } void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, void *sin1, void *sout) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin1); p->output_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sout); p->fun = mlir::concretelang::stream_emulator:: memref_negate_lwe_ciphertext_u64_process; ((mlir::concretelang::stream_emulator::DFGraph *)dfg) ->dfg_processes.push_back(p); } void stream_emulator_make_memref_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, void *context) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin1); p->output_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sout); p->level.val = level; p->base_log.val = base_log; p->input_lwe_dim.val = input_lwe_dim; p->output_lwe_dim.val = output_lwe_dim; p->output_size.val = output_size; p->ctx.val = (mlir::concretelang::RuntimeContext *)context; p->fun = mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process; ((mlir::concretelang::stream_emulator::DFGraph *)dfg) ->dfg_processes.push_back(p); } void stream_emulator_make_memref_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision, uint32_t output_size, void *context) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin1); p->input_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sin2); p->output_streams.push_back( (mlir::concretelang::stream_emulator::StreamBase> *) sout); p->input_lwe_dim.val = input_lwe_dim; p->poly_size.val = poly_size; p->level.val = level; p->base_log.val = base_log; p->glwe_dim.val = glwe_dim; p->precision.val = precision; p->output_size.val = output_size; p->ctx.val = (mlir::concretelang::RuntimeContext *)context; p->fun = mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process; ((mlir::concretelang::stream_emulator::DFGraph *)dfg) ->dfg_processes.push_back(p); } void *stream_emulator_make_uint64_stream(const char *name, stream_type stype) { return (void *)new mlir::concretelang::stream_emulator::StreamBase; } void stream_emulator_put_uint64(void *stream, uint64_t e) { ((mlir::concretelang::stream_emulator::StreamBase *)stream)->put(e); } uint64_t stream_emulator_get_uint64(void *stream) { return ((mlir::concretelang::stream_emulator::StreamBase *)stream) ->get(); } void *stream_emulator_make_memref_stream(const char *name, stream_type stype) { return (void *)new mlir::concretelang::stream_emulator::StreamBase< MemRefDescriptor<1>>; } void stream_emulator_put_memref(void *stream, uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size, uint64_t stride) { ((mlir::concretelang::stream_emulator::StreamBase> *) stream) ->put({allocated, aligned, offset, {size}, {stride}}); } void stream_emulator_get_memref(void *stream, uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride) { MemRefDescriptor<1> mref = ((mlir::concretelang::stream_emulator::StreamBase> *) stream) ->get(); memref_copy_one_rank(mref.allocated, mref.aligned, mref.offset, mref.sizes[0], mref.strides[0], out_allocated, out_aligned, out_offset, out_size, out_stride); free(mref.allocated); } void *stream_emulator_init() { #ifdef CORNAMI_AVAILABLE // TODO: check/update against new info on Cornami API fhestream *pfhestream = new fhestream(); pfhestream->initTopology(); return (void *)pfhestream; #else return (void *)new mlir::concretelang::stream_emulator::DFGraph; #endif } void stream_emulator_run(void *dfg) { #ifdef CORNAMI_AVAILABLE ((fhestream *)dfg)->FinalizeAndRun(); #else ((mlir::concretelang::stream_emulator::DFGraph *)dfg)->run(); #endif } void stream_emulator_delete(void *dfg) { #ifdef CORNAMI_AVAILABLE delete ((fhestream *)dfg); #else delete ((mlir::concretelang::stream_emulator::DFGraph *)dfg); #endif }