From e42d7bbe64e2d0ce17061612f5a5924cb4679581 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Wed, 18 Jan 2023 19:16:51 +0000 Subject: [PATCH] fix(SDFG): add output size attribute to KS/BS. --- .../Runtime/stream_emulator_api.h | 5 ++-- .../SDFGToStreamEmulator.cpp | 8 +++++++ .../SDFGConvertibleOpInterfaceImpl.cpp | 7 +++++- .../compiler/lib/Runtime/StreamEmulator.cpp | 24 ++++++++++++------- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h index da5d9cf48..58b33f57c 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h @@ -35,11 +35,12 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, void *sout); void stream_emulator_make_memref_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context); + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, + void *context); void stream_emulator_make_memref_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, - uint32_t precision, void *context); + uint32_t precision, uint32_t output_size, void *context); void *stream_emulator_make_uint64_stream(const char *name, stream_type stype); void stream_emulator_put_uint64(void *stream, uint64_t e); diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp index fd3b86d23..f290f24b9 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp @@ -196,6 +196,10 @@ struct LowerSDFGMakeProcess operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("lwe_dim_out"))); + // output_size + operands.push_back(rewriter.create( + mpOp.getLoc(), + mpOp->getAttrOfType("output_size"))); // context operands.push_back(getContextArgument(mpOp)); break; @@ -222,6 +226,10 @@ struct LowerSDFGMakeProcess operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("outPrecision"))); + // output_size + operands.push_back(rewriter.create( + mpOp.getLoc(), + mpOp->getAttrOfType("output_size"))); // context operands.push_back(getContextArgument(mpOp)); break; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp index 848f2ed5d..6a7d5e418 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp @@ -37,8 +37,13 @@ struct ReplaceWithProcessSDFGConversionInterface *symbolizeProcessKind(processName), dfg, streams); if (copyAttributes) { + auto outType = + op->getResult(0).getType().dyn_cast_or_null(); + auto outSize = outType.getDimSize(outType.getRank() - 1); + auto attrList = mlir::NamedAttrList(op->getAttrs()); + attrList.append("output_size", builder.getI32IntegerAttr(outSize)); llvm::SmallVector combinedAttrs = - llvm::to_vector(op->getAttrs()); + llvm::to_vector(attrList); for (mlir::NamedAttribute attr : process->getAttrs()) { combinedAttrs.push_back(attr); diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/StreamEmulator.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/StreamEmulator.cpp index 17c612752..6be946a7b 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/StreamEmulator.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/StreamEmulator.cpp @@ -69,6 +69,7 @@ struct Process { Param poly_size; Param glwe_dim; Param precision; + Param output_size; Context ctx; void (*fun)(Process *); }; @@ -91,10 +92,12 @@ struct DFGraph { void memref_keyswitch_lwe_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); - MemRefDescriptor<1> out = ct0; - out.allocated = out.aligned = - (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + MemRefDescriptor<1> out; + out.sizes[0] = p->output_size.val; + out.strides[0] = 1; out.offset = 0; + out.allocated = out.aligned = + (uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t)); memref_keyswitch_lwe_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], @@ -109,10 +112,12 @@ void memref_bootstrap_lwe_u64_process(Process *p) { while (!p->terminate_p) { MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get(); MemRefDescriptor<1> tlu = (p->input_streams[1]).memref_stream->get(); - MemRefDescriptor<1> out = ct0; - out.allocated = out.aligned = - (uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t)); + MemRefDescriptor<1> out; + out.sizes[0] = p->output_size.val; + out.strides[0] = 1; out.offset = 0; + out.allocated = out.aligned = + (uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t)); memref_bootstrap_lwe_u64( out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0], ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0], @@ -272,7 +277,8 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg, void stream_emulator_make_memref_keyswitch_lwe_u64_process( void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context) { + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, + void *context) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( @@ -285,6 +291,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process( p->base_log.val = base_log; p->input_lwe_dim.val = input_lwe_dim; p->output_lwe_dim.val = output_lwe_dim; + p->output_size.val = output_size; p->ctx.val = (mlir::concretelang::RuntimeContext *)context; p->fun = mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process; @@ -295,7 +302,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process( void stream_emulator_make_memref_bootstrap_lwe_u64_process( void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, - uint32_t precision, void *context) { + uint32_t precision, uint32_t output_size, void *context) { mlir::concretelang::stream_emulator::Process *p = new mlir::concretelang::stream_emulator::Process; p->input_streams.push_back( @@ -313,6 +320,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process( p->base_log.val = base_log; p->glwe_dim.val = glwe_dim; p->precision.val = precision; + p->output_size.val = output_size; p->ctx.val = (mlir::concretelang::RuntimeContext *)context; p->fun = mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process;