chore(concrete-compiler): removes unnecessary precision parameter in bootstrap signature

This commit is contained in:
aPere3
2023-02-22 09:28:51 +01:00
committed by Quentin Bourgerie
parent d2bfa03104
commit 5c1a15c514
9 changed files with 33 additions and 44 deletions

View File

@@ -192,8 +192,7 @@ def Concrete_BootstrapLweTensorOp : Concrete_Op<"bootstrap_lwe_tensor", [Pure]>
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension,
I32Attr:$outPrecision
I32Attr:$glweDimension
);
let results = (outs Concrete_LweTensor:$result);
@@ -231,8 +230,7 @@ def Concrete_BootstrapLweBufferOp : Concrete_Op<"bootstrap_lwe_buffer"> {
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension,
I32Attr:$outPrecision
I32Attr:$glweDimension
);
}
@@ -246,8 +244,7 @@ def Concrete_BatchedBootstrapLweTensorOp : Concrete_Op<"batched_bootstrap_lwe_te
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension,
I32Attr:$outPrecision
I32Attr:$glweDimension
);
let results = (outs Concrete_BatchLweTensor:$result);
}
@@ -263,8 +260,7 @@ def Concrete_BatchedBootstrapLweBufferOp : Concrete_Op<"batched_bootstrap_lwe_bu
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension,
I32Attr:$outPrecision
I32Attr:$glweDimension
);
}

View File

@@ -40,7 +40,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t precision, uint32_t output_size, void *context);
uint32_t output_size, void *context);
void *stream_emulator_make_uint64_stream(const char *name, stream_type stype);
void stream_emulator_put_uint64(void *stream, uint64_t e);

View File

@@ -99,15 +99,17 @@ void *memref_keyswitch_async_lwe_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context);
void memref_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
mlir::concretelang::RuntimeContext *context);
void memref_bootstrap_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context);
void memref_batched_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
@@ -117,7 +119,7 @@ void memref_batched_bootstrap_lwe_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context);
void *memref_bootstrap_async_lwe_u64(
@@ -127,7 +129,7 @@ void *memref_bootstrap_async_lwe_u64(
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context);
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
@@ -195,7 +197,7 @@ void memref_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context);
// Batched CUDA function //////////////////////////////////////////////////////
@@ -217,7 +219,7 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context);
// Tracing ////////////////////////////////////////////////////////////////////

View File

@@ -115,7 +115,7 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
i32Type, i32Type, contextType},
{});
} else if (funcName == memref_keyswitch_async_lwe_u64) {
funcType = mlir::FunctionType::get(
@@ -125,7 +125,7 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
i32Type, i32Type, contextType},
{futureType});
} else if (funcName == memref_batched_keyswitch_lwe_u64 ||
funcName == memref_batched_keyswitch_lwe_cuda_u64) {
@@ -138,7 +138,7 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
i32Type, i32Type, contextType},
{});
} else if (funcName == memref_await_future) {
funcType = mlir::FunctionType::get(
@@ -296,9 +296,6 @@ void bootstrapAddOperands(BootstrapOp op,
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getGlweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.getOutPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}

View File

@@ -222,10 +222,6 @@ struct LowerSDFGMakeProcess
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("glweDimension")));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("outPrecision")));
// output_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),

View File

@@ -128,7 +128,7 @@ struct BootstrapGLWEOpPattern
bsOp, this->getTypeConverter()->convertType(resultType),
adaptor.getCiphertext(), adaptor.getLookupTable(),
inputType.getDimension(), adaptor.getPolySize(), adaptor.getLevel(),
adaptor.getBaseLog(), adaptor.getGlweDimension(), resultType.getP());
adaptor.getBaseLog(), adaptor.getGlweDimension());
return mlir::success();
}

View File

@@ -123,7 +123,7 @@ void memref_bootstrap_lwe_u64_process(Process *p) {
ct0.allocated, ct0.aligned, ct0.offset, ct0.sizes[0], ct0.strides[0],
tlu.allocated, tlu.aligned, tlu.offset, tlu.sizes[0], tlu.strides[0],
p->input_lwe_dim.val, p->poly_size.val, p->level.val, p->base_log.val,
p->glwe_dim.val, p->precision.val, p->ctx.val);
p->glwe_dim.val, p->ctx.val);
(p->output_streams[0]).memref_stream->put(out);
}
delete p;
@@ -302,7 +302,7 @@ void stream_emulator_make_memref_keyswitch_lwe_u64_process(
void stream_emulator_make_memref_bootstrap_lwe_u64_process(
void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim,
uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim,
uint32_t precision, uint32_t output_size, void *context) {
uint32_t output_size, void *context) {
mlir::concretelang::stream_emulator::Process *p =
new mlir::concretelang::stream_emulator::Process;
p->input_streams.push_back(
@@ -319,7 +319,6 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process(
p->level.val = level;
p->base_log.val = base_log;
p->glwe_dim.val = glwe_dim;
p->precision.val = precision;
p->output_size.val = output_size;
p->ctx.val = (mlir::concretelang::RuntimeContext *)context;
p->fun =

View File

@@ -88,7 +88,7 @@ void memref_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context) {
memref_batched_bootstrap_lwe_cuda_u64(
// Output 1D memref as 2D memref
@@ -98,7 +98,7 @@ void memref_bootstrap_lwe_cuda_u64(
// Table lookup memref
tlu_allocated, tlu_aligned, tlu_offset, tlu_size, tlu_stride,
// Bootstrap additional arguments
input_lwe_dim, poly_size, level, base_log, glwe_dim, precision, context);
input_lwe_dim, poly_size, level, base_log, glwe_dim, context);
}
// Batched CUDA function //////////////////////////////////////////////////////
@@ -154,7 +154,7 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context) {
assert(out_size0 == ct0_size0);
assert(out_size1 == glwe_dim * poly_size + 1);
@@ -537,8 +537,7 @@ void memref_bootstrap_lwe_u64(
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dimension, uint32_t polynomial_size,
uint32_t decomposition_level_count, uint32_t decomposition_base_log,
uint32_t glwe_dimension, uint32_t precision,
mlir::concretelang::RuntimeContext *context) {
uint32_t glwe_dimension, mlir::concretelang::RuntimeContext *context) {
uint64_t glwe_ct_size = polynomial_size * (glwe_dimension + 1);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
@@ -583,7 +582,7 @@ void memref_batched_bootstrap_lwe_u64(
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size,
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t level, uint32_t base_log, uint32_t glwe_dim,
mlir::concretelang::RuntimeContext *context) {
for (size_t i = 0; i < out_size0; i++) {
@@ -592,7 +591,7 @@ void memref_batched_bootstrap_lwe_u64(
out_size1, out_stride1, ct0_allocated, ct0_aligned + i * ct0_size1,
ct0_offset, ct0_size1, ct0_stride1, tlu_allocated, tlu_aligned,
tlu_offset, tlu_size, tlu_stride, input_lwe_dim, poly_size, level,
base_log, glwe_dim, precision, context);
base_log, glwe_dim, context);
}
}

View File

@@ -2,7 +2,7 @@
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<601xi64>) -> tensor<1025xi64> {
//CHECK: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64>
//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe_tensor"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<128xi64>) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<{600,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}> {