mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
feat(compiler): support woppbs in simulation
This commit is contained in:
@@ -76,8 +76,7 @@ void sim_wop_pbs_crt(
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size);
|
||||
uint32_t bsk_base_log, uint32_t polynomial_size, uint32_t glwe_dim);
|
||||
|
||||
void sim_encode_expand_lut_for_boostrap(
|
||||
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
|
||||
|
||||
@@ -205,6 +205,16 @@ struct EncodeLutForCrtWopPBSOpPattern
|
||||
encodeOp.getResult().getType().cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
auto dynamicResultType =
|
||||
toDynamicTensorType(encodeOp.getResult().getType());
|
||||
auto dynamicLutType =
|
||||
toDynamicTensorType(encodeOp.getInputLookupTable().getType());
|
||||
|
||||
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
|
||||
encodeOp.getLoc(), dynamicResultType, outputBuffer);
|
||||
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
|
||||
encodeOp.getLoc(), dynamicLutType, adaptor.getInputLookupTable());
|
||||
|
||||
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr());
|
||||
auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
@@ -213,10 +223,9 @@ struct EncodeLutForCrtWopPBSOpPattern
|
||||
if (insertForwardDeclaration(
|
||||
encodeOp, rewriter, funcName,
|
||||
rewriter.getFunctionType(
|
||||
{encodeOp.getResult().getType(),
|
||||
encodeOp.getInputLookupTable().getType(),
|
||||
crtDecompValue.getType(), crtBitsValue.getType(),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
|
||||
{dynamicResultType, dynamicLutType, crtDecompValue.getType(),
|
||||
crtBitsValue.getType(), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(1)},
|
||||
{}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
@@ -224,9 +233,8 @@ struct EncodeLutForCrtWopPBSOpPattern
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
encodeOp.getLoc(), funcName, mlir::TypeRange{},
|
||||
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
|
||||
crtDecompValue, crtBitsValue, modulusProductCst,
|
||||
isSignedCst}));
|
||||
mlir::ValueRange({castedOutputBuffer, castedLUT, crtDecompValue,
|
||||
crtBitsValue, modulusProductCst, isSignedCst}));
|
||||
|
||||
rewriter.replaceOp(encodeOp, outputBuffer);
|
||||
|
||||
@@ -259,13 +267,18 @@ struct EncodePlaintextWithCrtOpPattern
|
||||
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
auto dynamicResultType = toDynamicTensorType(epOp.getResult().getType());
|
||||
|
||||
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
|
||||
epOp.getLoc(), dynamicResultType, outputBuffer);
|
||||
|
||||
auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, epOp.getLoc(), epOp.getModsAttr());
|
||||
|
||||
if (insertForwardDeclaration(
|
||||
epOp, rewriter, funcName,
|
||||
rewriter.getFunctionType(
|
||||
{epOp.getResult().getType(), epOp.getInput().getType(),
|
||||
{dynamicResultType, epOp.getInput().getType(),
|
||||
ModsValue.getType(), rewriter.getI64Type()},
|
||||
{}))
|
||||
.failed()) {
|
||||
@@ -274,8 +287,8 @@ struct EncodePlaintextWithCrtOpPattern
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
epOp.getLoc(), funcName, mlir::TypeRange{},
|
||||
mlir::ValueRange(
|
||||
{outputBuffer, adaptor.getInput(), ModsValue, modsProductCst}));
|
||||
mlir::ValueRange({castedOutputBuffer, adaptor.getInput(), ModsValue,
|
||||
modsProductCst}));
|
||||
|
||||
rewriter.replaceOp(epOp, outputBuffer);
|
||||
|
||||
@@ -311,6 +324,22 @@ struct WopPBSGLWEOpPattern
|
||||
.cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
auto dynamicResultType = toDynamicTensorType(this->getTypeConverter()
|
||||
->convertType(resultType)
|
||||
.cast<mlir::TensorType>());
|
||||
auto dynamicInputType = toDynamicTensorType(this->getTypeConverter()
|
||||
->convertType(inputType)
|
||||
.cast<mlir::TensorType>());
|
||||
auto dynamicLutType =
|
||||
toDynamicTensorType(wopPbs.getLookupTable().getType());
|
||||
|
||||
mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
|
||||
wopPbs.getLoc(), dynamicResultType, outputBuffer);
|
||||
mlir::Value castedCiphertexts = rewriter.create<mlir::tensor::CastOp>(
|
||||
wopPbs.getLoc(), dynamicInputType, adaptor.getCiphertexts());
|
||||
mlir::Value castedLut = rewriter.create<mlir::tensor::CastOp>(
|
||||
wopPbs.getLoc(), dynamicLutType, adaptor.getLookupTable());
|
||||
|
||||
auto lweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32);
|
||||
auto cbsLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
@@ -325,12 +354,10 @@ struct WopPBSGLWEOpPattern
|
||||
wopPbs.getLoc(), adaptor.getBsk().getLevels(), 32);
|
||||
auto bskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getBsk().getBaseLog(), 32);
|
||||
auto fpkskLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getLevels(), 32);
|
||||
auto fpkskBaseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32);
|
||||
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32);
|
||||
auto glweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
wopPbs.getLoc(), adaptor.getBsk().getGlweDim(), 32);
|
||||
|
||||
auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
|
||||
rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr());
|
||||
@@ -338,10 +365,8 @@ struct WopPBSGLWEOpPattern
|
||||
if (insertForwardDeclaration(
|
||||
wopPbs, rewriter, funcName,
|
||||
rewriter.getFunctionType(
|
||||
{this->getTypeConverter()->convertType(resultType),
|
||||
this->getTypeConverter()->convertType(inputType),
|
||||
wopPbs.getLookupTable().getType(), crtDecompValue.getType(),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
{dynamicResultType, dynamicInputType, dynamicLutType,
|
||||
crtDecompValue.getType(), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
|
||||
@@ -353,11 +378,11 @@ struct WopPBSGLWEOpPattern
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
wopPbs.getLoc(), funcName, mlir::TypeRange{},
|
||||
mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(),
|
||||
adaptor.getLookupTable(), crtDecompValue, lweDimCst,
|
||||
cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst,
|
||||
kskBaseLogCst, bskLevelCountCst, bskBaseLogCst,
|
||||
fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst}));
|
||||
mlir::ValueRange({castedOutputBuffer, castedCiphertexts, castedLut,
|
||||
crtDecompValue, lweDimCst, cbsLevelCountCst,
|
||||
cbsBaseLogCst, kskLevelCountCst, kskBaseLogCst,
|
||||
bskLevelCountCst, bskBaseLogCst, polySizeCst,
|
||||
glweDimCst}));
|
||||
|
||||
rewriter.replaceOp(wopPbs, outputBuffer);
|
||||
|
||||
@@ -542,7 +567,8 @@ void SimulateTFHEPass::runOnOperation() {
|
||||
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
|
||||
mlir::bufferization::AllocTensorOp, mlir::tensor::CastOp>();
|
||||
mlir::memref::CastOp, mlir::bufferization::AllocTensorOp,
|
||||
mlir::tensor::CastOp>();
|
||||
// Make sure that no ops from `TFHE` remain after the lowering
|
||||
target.addIllegalDialect<TFHE::TFHEDialect>();
|
||||
|
||||
|
||||
@@ -111,9 +111,69 @@ void sim_wop_pbs_crt(
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size) {
|
||||
// TODO
|
||||
uint32_t bsk_base_log, uint32_t polynomial_size, uint32_t glwe_dim) {
|
||||
|
||||
// Check number of blocks
|
||||
assert(out_size == in_size && out_size == crt_decomp_size);
|
||||
|
||||
uint64_t log_poly_size =
|
||||
static_cast<uint64_t>(ceil(log2(static_cast<double>(polynomial_size))));
|
||||
|
||||
// Compute the numbers of bits to extract for each block and the total one.
|
||||
uint64_t total_number_of_bits_per_block = 0;
|
||||
auto number_of_bits_per_block = new uint64_t[crt_decomp_size]();
|
||||
for (uint64_t i = 0; i < crt_decomp_size; i++) {
|
||||
uint64_t modulus = crt_decomp_aligned[i + crt_decomp_offset];
|
||||
uint64_t nb_bit_to_extract =
|
||||
static_cast<uint64_t>(ceil(log2(static_cast<double>(modulus))));
|
||||
number_of_bits_per_block[i] = nb_bit_to_extract;
|
||||
|
||||
total_number_of_bits_per_block += nb_bit_to_extract;
|
||||
}
|
||||
|
||||
// Create the buffer of ciphertexts for storing the total number of bits to
|
||||
// extract.
|
||||
// The extracted bit should be in the following order:
|
||||
//
|
||||
// [msb(m%crt[n-1])..lsb(m%crt[n-1])...msb(m%crt[0])..lsb(m%crt[0])] where n
|
||||
// is the size of the crt decomposition
|
||||
auto extract_bits_output_buffer =
|
||||
new uint64_t[total_number_of_bits_per_block]{0};
|
||||
|
||||
// Extraction of each bit for each block
|
||||
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
|
||||
extract_bits_output_offset += number_of_bits_per_block[i--]) {
|
||||
auto nb_bits_to_extract = number_of_bits_per_block[i];
|
||||
|
||||
size_t delta_log = 64 - nb_bits_to_extract;
|
||||
|
||||
auto in_block = in_aligned[in_offset + i];
|
||||
|
||||
// trick ( ct - delta/2 + delta/2^4 )
|
||||
uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) -
|
||||
(uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5));
|
||||
in_block -= sub;
|
||||
|
||||
simulation_extract_bit_lwe_ciphertext_u64(
|
||||
&extract_bits_output_buffer[extract_bits_output_offset], in_block,
|
||||
delta_log, nb_bits_to_extract, log_poly_size, glwe_dim, lwe_small_dim,
|
||||
ksk_base_log, ksk_level_count, bsk_base_log, bsk_level_count, 64, 128);
|
||||
}
|
||||
|
||||
size_t ct_in_count = total_number_of_bits_per_block;
|
||||
size_t lut_size = 1 << ct_in_count;
|
||||
size_t ct_out_count = out_size;
|
||||
size_t lut_count = ct_out_count;
|
||||
|
||||
assert(lut_ct_size0 == lut_count);
|
||||
assert(lut_ct_size1 == lut_size);
|
||||
|
||||
// Vertical packing
|
||||
simulation_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
|
||||
extract_bits_output_buffer, out_aligned + out_offset, ct_in_count,
|
||||
ct_out_count, lut_size, lut_count, lut_ct_aligned + lut_ct_offset,
|
||||
glwe_dim, log_poly_size, lwe_small_dim, cbs_level_count, cbs_base_log, 64,
|
||||
128);
|
||||
}
|
||||
|
||||
uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }
|
||||
|
||||
@@ -170,6 +170,22 @@ end_to_end_fixture = [
|
||||
).reshape((4, 4)),
|
||||
id="matul_chain_with_crt",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func.func @main(%arg0: !FHE.eint<14>, %arg1: tensor<16384xi64>) -> !FHE.eint<14> {
|
||||
%cst = arith.constant 15 : i15
|
||||
%v = "FHE.add_eint_int"(%arg0, %cst): (!FHE.eint<14>, i15) -> (!FHE.eint<14>)
|
||||
%1 = "FHE.apply_lookup_table"(%v, %arg1): (!FHE.eint<14>, tensor<16384xi64>) -> (!FHE.eint<14>)
|
||||
return %1: !FHE.eint<14>
|
||||
}
|
||||
""",
|
||||
(
|
||||
81,
|
||||
np.array(range(16384), dtype=np.uint64),
|
||||
),
|
||||
96,
|
||||
id="add_lut_crt",
|
||||
),
|
||||
]
|
||||
|
||||
end_to_end_parallel_fixture = [
|
||||
|
||||
Reference in New Issue
Block a user