fix(SDFG): add output size attribute to KS/BS.

This commit is contained in:
Antoniu Pop
2023-01-18 19:16:51 +00:00
committed by Quentin Bourgerie
parent c8c969773e
commit e42d7bbe64
4 changed files with 33 additions and 11 deletions

View File

@@ -35,11 +35,12 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
void *sout);
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context);
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
void *context);
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t precision, void *context);
uint32_t precision, uint32_t output_size, void *context);
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
void stream_emulator_put_uint64(void *stream, uint64_t e);

View File

@@ -196,6 +196,10 @@ struct LowerSDFGMakeProcess
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("lwe_dim_out")));
// output_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
// context
operands.push_back(getContextArgument(mpOp));
break;
@@ -222,6 +226,10 @@ struct LowerSDFGMakeProcess
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("outPrecision")));
// output_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
// context
operands.push_back(getContextArgument(mpOp));
break;

View File

@@ -37,8 +37,13 @@ struct ReplaceWithProcessSDFGConversionInterface
*symbolizeProcessKind(processName), dfg, streams);
if (copyAttributes) {
auto outType =
op->getResult(0).getType().dyn_cast_or_null<mlir::TensorType>();
auto outSize = outType.getDimSize(outType.getRank() - 1);
auto attrList = mlir::NamedAttrList(op->getAttrs());
attrList.append("output_size", builder.getI32IntegerAttr(outSize));
llvm::SmallVector<mlir::NamedAttribute> combinedAttrs =
llvm::to_vector(op->getAttrs());
llvm::to_vector(attrList);
for (mlir::NamedAttribute attr : process->getAttrs()) {
combinedAttrs.push_back(attr);

View File

@@ -69,6 +69,7 @@ struct Process {
Param poly_size;
Param glwe_dim;
Param precision;
Param output_size;
Context ctx;
void (*fun)(Process *);
};
@@ -91,10 +92,12 @@ struct DFGraph {
void memref_keyswitch_lwe_u64_process(Process *p) {
while (!p->terminate_p) {
MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get();
MemRefDescriptor<1> out = ct0;
out.allocated = out.aligned =
(uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t));
MemRefDescriptor<1> out;
out.sizes[0] = p->output_size.val;
out.strides[0] = 1;
out.offset = 0;
out.allocated = out.aligned =
(uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t));
memref_keyswitch_lwe_u64(
out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0],
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
@@ -109,10 +112,12 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
while (!p->terminate_p) {
MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get();
MemRefDescriptor<1> tlu = (p->input_streams[1]).memref_stream->get();
MemRefDescriptor<1> out = ct0;
out.allocated = out.aligned =
(uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t));
MemRefDescriptor<1> out;
out.sizes[0] = p->output_size.val;
out.strides[0] = 1;
out.offset = 0;
out.allocated = out.aligned =
(uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t));
memref_bootstrap_lwe_u64(
out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0],
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
@@ -272,7 +277,8 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context) {
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
void *context) {
mlir::concretelang::stream_emulator::Process *p =
new mlir::concretelang::stream_emulator::Process;
p->input_streams.push_back(
@@ -285,6 +291,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
p->base_log.val = base_log;
p->input_lwe_dim.val = input_lwe_dim;
p->output_lwe_dim.val = output_lwe_dim;
p->output_size.val = output_size;
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
p->fun =
mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process;
@@ -295,7 +302,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t precision, void *context) {
uint32_t precision, uint32_t output_size, void *context) {
mlir::concretelang::stream_emulator::Process *p =
new mlir::concretelang::stream_emulator::Process;
p->input_streams.push_back(
@@ -313,6 +320,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
p->base_log.val = base_log;
p->glwe_dim.val = glwe_dim;
p->precision.val = precision;
p->output_size.val = output_size;
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
p->fun =
mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process;