mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
fix(SDFG): add output size attribute to KS/BS.
This commit is contained in:
committed by
Quentin Bourgerie
parent
c8c969773e
commit
e42d7bbe64
@@ -35,11 +35,12 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
|
||||
void *sout);
|
||||
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context);
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context);
|
||||
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t precision, void *context);
|
||||
uint32_t precision, uint32_t output_size, void *context);
|
||||
|
||||
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
|
||||
void stream_emulator_put_uint64(void *stream, uint64_t e);
|
||||
|
||||
@@ -196,6 +196,10 @@ struct LowerSDFGMakeProcess
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("lwe_dim_out")));
|
||||
// output_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
@@ -222,6 +226,10 @@ struct LowerSDFGMakeProcess
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("outPrecision")));
|
||||
// output_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("output_size")));
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
|
||||
@@ -37,8 +37,13 @@ struct ReplaceWithProcessSDFGConversionInterface
|
||||
*symbolizeProcessKind(processName), dfg, streams);
|
||||
|
||||
if (copyAttributes) {
|
||||
auto outType =
|
||||
op->getResult(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
auto outSize = outType.getDimSize(outType.getRank() - 1);
|
||||
auto attrList = mlir::NamedAttrList(op->getAttrs());
|
||||
attrList.append("output_size", builder.getI32IntegerAttr(outSize));
|
||||
llvm::SmallVector<mlir::NamedAttribute> combinedAttrs =
|
||||
llvm::to_vector(op->getAttrs());
|
||||
llvm::to_vector(attrList);
|
||||
|
||||
for (mlir::NamedAttribute attr : process->getAttrs()) {
|
||||
combinedAttrs.push_back(attr);
|
||||
|
||||
@@ -69,6 +69,7 @@ struct Process {
|
||||
Param poly_size;
|
||||
Param glwe_dim;
|
||||
Param precision;
|
||||
Param output_size;
|
||||
Context ctx;
|
||||
void (*fun)(Process *);
|
||||
};
|
||||
@@ -91,10 +92,12 @@ struct DFGraph {
|
||||
void memref_keyswitch_lwe_u64_process(Process *p) {
|
||||
while (!p->terminate_p) {
|
||||
MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get();
|
||||
MemRefDescriptor<1> out = ct0;
|
||||
out.allocated = out.aligned =
|
||||
(uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t));
|
||||
MemRefDescriptor<1> out;
|
||||
out.sizes[0] = p->output_size.val;
|
||||
out.strides[0] = 1;
|
||||
out.offset = 0;
|
||||
out.allocated = out.aligned =
|
||||
(uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t));
|
||||
memref_keyswitch_lwe_u64(
|
||||
out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0],
|
||||
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
|
||||
@@ -109,10 +112,12 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
|
||||
while (!p->terminate_p) {
|
||||
MemRefDescriptor<1> ct0 = (p->input_streams[0]).memref_stream->get();
|
||||
MemRefDescriptor<1> tlu = (p->input_streams[1]).memref_stream->get();
|
||||
MemRefDescriptor<1> out = ct0;
|
||||
out.allocated = out.aligned =
|
||||
(uint64_t *)malloc(ct0.sizes[0] * sizeof(uint64_t));
|
||||
MemRefDescriptor<1> out;
|
||||
out.sizes[0] = p->output_size.val;
|
||||
out.strides[0] = 1;
|
||||
out.offset = 0;
|
||||
out.allocated = out.aligned =
|
||||
(uint64_t *)malloc(out.sizes[0] * sizeof(uint64_t));
|
||||
memref_bootstrap_lwe_u64(
|
||||
out.allocated, out.aligned, out.offset, out.sizes[0], out.strides[0],
|
||||
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
|
||||
@@ -272,7 +277,8 @@ void stream_emulator_make_memref_negate_lwe_ciphertext_u64_process(void *dfg,
|
||||
|
||||
void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, void *context) {
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size,
|
||||
void *context) {
|
||||
mlir::concretelang::stream_emulator::Process *p =
|
||||
new mlir::concretelang::stream_emulator::Process;
|
||||
p->input_streams.push_back(
|
||||
@@ -285,6 +291,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
p->base_log.val = base_log;
|
||||
p->input_lwe_dim.val = input_lwe_dim;
|
||||
p->output_lwe_dim.val = output_lwe_dim;
|
||||
p->output_size.val = output_size;
|
||||
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
|
||||
p->fun =
|
||||
mlir::concretelang::stream_emulator::memref_keyswitch_lwe_u64_process;
|
||||
@@ -295,7 +302,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
|
||||
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
|
||||
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
|
||||
uint32_t precision, void *context) {
|
||||
uint32_t precision, uint32_t output_size, void *context) {
|
||||
mlir::concretelang::stream_emulator::Process *p =
|
||||
new mlir::concretelang::stream_emulator::Process;
|
||||
p->input_streams.push_back(
|
||||
@@ -313,6 +320,7 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
|
||||
p->base_log.val = base_log;
|
||||
p->glwe_dim.val = glwe_dim;
|
||||
p->precision.val = precision;
|
||||
p->output_size.val = output_size;
|
||||
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
|
||||
p->fun =
|
||||
mlir::concretelang::stream_emulator::memref_bootstrap_lwe_u64_process;
|
||||
|
||||
Reference in New Issue
Block a user