feat(compiler): support woppbs in simulation

This commit is contained in:
youben11
2023-09-25 10:42:01 +01:00
committed by Ayoub Benaissa
parent 0caa659244
commit e4835bd002
4 changed files with 130 additions and 29 deletions

View File

@@ -76,8 +76,7 @@ void sim_wop_pbs_crt(
// Additional crypto parameters
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size);
uint32_t bsk_base_log, uint32_t polynomial_size, uint32_t glwe_dim);
void sim_encode_expand_lut_for_boostrap(
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,

View File

@@ -205,6 +205,16 @@ struct EncodeLutForCrtWopPBSOpPattern
encodeOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
auto dynamicResultType =
toDynamicTensorType(encodeOp.getResult().getType());
auto dynamicLutType =
toDynamicTensorType(encodeOp.getInputLookupTable().getType());
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
encodeOp.getLoc(), dynamicResultType, outputBuffer);
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
encodeOp.getLoc(), dynamicLutType, adaptor.getInputLookupTable());
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr());
auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr(
@@ -213,10 +223,9 @@ struct EncodeLutForCrtWopPBSOpPattern
if (insertForwardDeclaration(
encodeOp, rewriter, funcName,
rewriter.getFunctionType(
{encodeOp.getResult().getType(),
encodeOp.getInputLookupTable().getType(),
crtDecompValue.getType(), crtBitsValue.getType(),
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
{dynamicResultType, dynamicLutType, crtDecompValue.getType(),
crtBitsValue.getType(), rewriter.getIntegerType(32),
rewriter.getIntegerType(1)},
{}))
.failed()) {
return mlir::failure();
@@ -224,9 +233,8 @@ struct EncodeLutForCrtWopPBSOpPattern
rewriter.create<mlir::func::CallOp>(
encodeOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
crtDecompValue, crtBitsValue, modulusProductCst,
isSignedCst}));
mlir::ValueRange({castedOutputBuffer, castedLUT, crtDecompValue,
crtBitsValue, modulusProductCst, isSignedCst}));
rewriter.replaceOp(encodeOp, outputBuffer);
@@ -259,13 +267,18 @@ struct EncodePlaintextWithCrtOpPattern
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
auto dynamicResultType = toDynamicTensorType(epOp.getResult().getType());
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
epOp.getLoc(), dynamicResultType, outputBuffer);
auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, epOp.getLoc(), epOp.getModsAttr());
if (insertForwardDeclaration(
epOp, rewriter, funcName,
rewriter.getFunctionType(
{epOp.getResult().getType(), epOp.getInput().getType(),
{dynamicResultType, epOp.getInput().getType(),
ModsValue.getType(), rewriter.getI64Type()},
{}))
.failed()) {
@@ -274,8 +287,8 @@ struct EncodePlaintextWithCrtOpPattern
rewriter.create<mlir::func::CallOp>(
epOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange(
{outputBuffer, adaptor.getInput(), ModsValue, modsProductCst}));
mlir::ValueRange({castedOutputBuffer, adaptor.getInput(), ModsValue,
modsProductCst}));
rewriter.replaceOp(epOp, outputBuffer);
@@ -311,6 +324,22 @@ struct WopPBSGLWEOpPattern
.cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
auto dynamicResultType = toDynamicTensorType(this->getTypeConverter()
->convertType(resultType)
.cast<mlir::TensorType>());
auto dynamicInputType = toDynamicTensorType(this->getTypeConverter()
->convertType(inputType)
.cast<mlir::TensorType>());
auto dynamicLutType =
toDynamicTensorType(wopPbs.getLookupTable().getType());
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
wopPbs.getLoc(), dynamicResultType, outputBuffer);
mlir::Value castedCiphertexts = rewriter.create<mlir::tensor::CastOp>(
wopPbs.getLoc(), dynamicInputType, adaptor.getCiphertexts());
mlir::Value castedLut = rewriter.create<mlir::tensor::CastOp>(
wopPbs.getLoc(), dynamicLutType, adaptor.getLookupTable());
auto lweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32);
auto cbsLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
@@ -325,12 +354,10 @@ struct WopPBSGLWEOpPattern
wopPbs.getLoc(), adaptor.getBsk().getLevels(), 32);
auto bskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getBsk().getBaseLog(), 32);
auto fpkskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getLevels(), 32);
auto fpkskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32);
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32);
auto glweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getBsk().getGlweDim(), 32);
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr());
@@ -338,10 +365,8 @@ struct WopPBSGLWEOpPattern
if (insertForwardDeclaration(
wopPbs, rewriter, funcName,
rewriter.getFunctionType(
{this->getTypeConverter()->convertType(resultType),
this->getTypeConverter()->convertType(inputType),
wopPbs.getLookupTable().getType(), crtDecompValue.getType(),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
{dynamicResultType, dynamicInputType, dynamicLutType,
crtDecompValue.getType(), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
@@ -353,11 +378,11 @@ struct WopPBSGLWEOpPattern
rewriter.create<mlir::func::CallOp>(
wopPbs.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(),
adaptor.getLookupTable(), crtDecompValue, lweDimCst,
cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst,
kskBaseLogCst, bskLevelCountCst, bskBaseLogCst,
fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst}));
mlir::ValueRange({castedOutputBuffer, castedCiphertexts, castedLut,
crtDecompValue, lweDimCst, cbsLevelCountCst,
cbsBaseLogCst, kskLevelCountCst, kskBaseLogCst,
bskLevelCountCst, bskBaseLogCst, polySizeCst,
glweDimCst}));
rewriter.replaceOp(wopPbs, outputBuffer);
@@ -542,7 +567,8 @@ void SimulateTFHEPass::runOnOperation() {
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
mlir::bufferization::AllocTensorOp, mlir::tensor::CastOp>();
mlir::memref::CastOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

View File

@@ -111,9 +111,69 @@ void sim_wop_pbs_crt(
// Additional crypto parameters
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size) {
// TODO
uint32_t bsk_base_log, uint32_t polynomial_size, uint32_t glwe_dim) {
// Check number of blocks
assert(out_size == in_size && out_size == crt_decomp_size);
uint64_t log_poly_size =
static_cast<uint64_t>(ceil(log2(static_cast<double>(polynomial_size))));
// Compute the numbers of bits to extract for each block and the total one.
uint64_t total_number_of_bits_per_block = 0;
auto number_of_bits_per_block = new uint64_t[crt_decomp_size]();
for (uint64_t i = 0; i < crt_decomp_size; i++) {
uint64_t modulus = crt_decomp_aligned[i + crt_decomp_offset];
uint64_t nb_bit_to_extract =
static_cast<uint64_t>(ceil(log2(static_cast<double>(modulus))));
number_of_bits_per_block[i] = nb_bit_to_extract;
total_number_of_bits_per_block += nb_bit_to_extract;
}
// Create the buffer of ciphertexts for storing the total number of bits to
// extract.
// The extracted bit should be in the following order:
//
// [msb(m%crt[n-1])..lsb(m%crt[n-1])...msb(m%crt[0])..lsb(m%crt[0])] where n
// is the size of the crt decomposition
auto extract_bits_output_buffer =
new uint64_t[total_number_of_bits_per_block]{0};
// Extraction of each bit for each block
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
extract_bits_output_offset += number_of_bits_per_block[i--]) {
auto nb_bits_to_extract = number_of_bits_per_block[i];
size_t delta_log = 64 - nb_bits_to_extract;
auto in_block = in_aligned[in_offset + i];
// trick ( ct - delta/2 + delta/2^4 )
uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) -
(uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5));
in_block -= sub;
simulation_extract_bit_lwe_ciphertext_u64(
&extract_bits_output_buffer[extract_bits_output_offset], in_block,
delta_log, nb_bits_to_extract, log_poly_size, glwe_dim, lwe_small_dim,
ksk_base_log, ksk_level_count, bsk_base_log, bsk_level_count, 64, 128);
}
size_t ct_in_count = total_number_of_bits_per_block;
size_t lut_size = 1 << ct_in_count;
size_t ct_out_count = out_size;
size_t lut_count = ct_out_count;
assert(lut_ct_size0 == lut_count);
assert(lut_ct_size1 == lut_size);
// Vertical packing
simulation_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
extract_bits_output_buffer, out_aligned + out_offset, ct_in_count,
ct_out_count, lut_size, lut_count, lut_ct_aligned + lut_ct_offset,
glwe_dim, log_poly_size, lwe_small_dim, cbs_level_count, cbs_base_log, 64,
128);
}
uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }

View File

@@ -170,6 +170,22 @@ end_to_end_fixture = [
).reshape((4, 4)),
id="matul_chain_with_crt",
),
pytest.param(
"""
func.func @main(%arg0: !FHE.eint<14>, %arg1: tensor<16384xi64>) -> !FHE.eint<14> {
%cst = arith.constant 15 : i15
%v = "FHE.add_eint_int"(%arg0, %cst): (!FHE.eint<14>, i15) -> (!FHE.eint<14>)
%1 = "FHE.apply_lookup_table"(%v, %arg1): (!FHE.eint<14>, tensor<16384xi64>) -> (!FHE.eint<14>)
return %1: !FHE.eint<14>
}
""",
(
81,
np.array(range(16384), dtype=np.uint64),
),
96,
id="add_lut_crt",
),
]
end_to_end_parallel_fixture = [