feat(hpu): Add ILOG2/COUNT0/COUNT1/LEAD0/LEAD1/TRAIL0/TRAIL1 IOp.

Those IOp are tested within new bitcnt category
This commit is contained in:
JJ-hw
2025-06-05 11:13:09 +02:00
committed by B. Roux
parent 71e86f0522
commit a20c90b090
10 changed files with 1101 additions and 108 deletions

View File

@@ -901,7 +901,6 @@ pbs!(
|_params: &DigitParameters, _deg| 3;
}
]],
["IfPos1FalseZeroed" => 55 [ // Ct must contain CondCt in Carry bit 1 and ValueCt in Msg. If condition it's *FALSE*, value ct is forced to 0
@0 =>{
|params: &DigitParameters, val | {
@@ -1000,6 +999,433 @@ pbs!(
}
]],
// NB: Lut IfPos1FalseZeroed already defined earlier
["ManyInv1CarryMsg" => 64 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 1;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 1;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyInv2CarryMsg" => 65 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 2;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 2;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyInv3CarryMsg" => 66 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 3;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 3;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyInv4CarryMsg" => 67 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 4;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 4;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyInv5CarryMsg" => 68 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 5;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 5;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyInv6CarryMsg" => 69 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 6;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 6;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyInv7CarryMsg" => 70 [ // Proceed Inv - ct
// Extract message and carry using many LUT.
@0 =>{
|params: &DigitParameters, val | {
let inv = 7;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value & params.msg_mask()
}
};
|params: &DigitParameters, _deg| params.msg_mask();
},
@1 =>{
|params: &DigitParameters, val | {
let inv = 7;
let mut value = val & params.data_mask();
if value > inv {
0
} else {
value = inv - value;
value >> params.msg_w
}
};
|_params: &DigitParameters, _deg| 1;
},
]],
["ManyMsgSplit" => 71 [ // Use manyLUT : split msg in halves
@0 =>{
|params: &DigitParameters, val| {
let lsb_size = params.msg_w.div_ceil(2);
val & ((1 << lsb_size)-1) // msg_lsb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
(1 << lsb_size)-1
};
},
@1 =>{
|params: &DigitParameters, val| {
let lsb_size = params.msg_w.div_ceil(2);
(val & params.msg_mask()) >> lsb_size // msg_msb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
let msb_size = params.msg_w - lsb_size;
(1 << msb_size)-1
};
}
]],
["Manym2lPropBit1MsgSplit" => 72 [ // Use ManyLut
// In carry part, contains the info if neighbor has a bit=1 (not null)
// or not (null).
// Propagate bits equal to 1 from msb to lsb.
// Split resulting message part into 2. Put both in lsb.
@0 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from msb to lsb
for idx in (0..params.msg_w).rev() {
let mut b = (m >> idx) & 1;
m &= (1 << idx)-1;
if c > 0 {b = 1;} // propagate to lsb
if b == 1 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
exp & ((1 << lsb_size)-1) // msg_lsb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
(1 << lsb_size)-1
};
},
@1 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from msb to lsb
for idx in (0..params.msg_w).rev() {
let mut b = (m >> idx) & 1;
m &= (1 << idx)-1;
if c > 0 {b = 1;} // propagate to lsb
if b == 1 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
(exp & params.msg_mask()) >> lsb_size // msg_msb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
let msb_size = params.msg_w - lsb_size;
(1 << msb_size)-1
};
}
]],
["Manym2lPropBit0MsgSplit" => 73 [ // Use ManyLut
// In carry part, contains the info if neighbor has a bit=0 (not null)
// or not (null).
// Propagate bits equal to 0 from msb to lsb.
// Split resulting message part into 2. Put both in lsb.
@0 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from msb to lsb
for idx in (0..(params.msg_w)).rev() {
let mut b = (m >> idx) & 1;
m &= (1 << idx)-1;
if c > 0 {b = 0;} // propagate to lsb
if b == 0 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
exp & ((1 << lsb_size)-1) // msg_lsb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
(1 << lsb_size)-1
};
},
@1 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from msb to lsb
for idx in (0..(params.msg_w)).rev() {
let mut b = (m >> idx) & 1;
m &= (1 << idx)-1;
if c > 0 {b = 0;} // propagate to lsb
if b == 0 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
(exp & params.msg_mask()) >> lsb_size // msg_msb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
let msb_size = params.msg_w - lsb_size;
(1 << msb_size)-1
};
}
]],
["Manyl2mPropBit1MsgSplit" => 74 [ // Use ManyLut
// In carry part, contains the info if neighbor has a bit=1 (not null)
// or not (null).
// Propagate bits equal to 1 from lsb to msb.
// Split resulting message part into 2. Put both in lsb.
@0 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from lsb to msb
for idx in 0..(params.msg_w) {
let mut b = m & 1;
m >>= 1;
if c > 0 {b = 1;} // propagate to msb
if b == 1 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
exp & ((1 << lsb_size)-1) // msg_lsb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
(1 << lsb_size)-1
};
},
@1 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from lsb to msb
for idx in 0..(params.msg_w) {
let mut b = m & 1;
m >>= 1;
if c > 0 {b = 1;} // propagate to msb
if b == 1 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
(exp & params.msg_mask()) >> lsb_size // msg_msb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
let msb_size = params.msg_w - lsb_size;
(1 << msb_size)-1
};
}
]],
["Manyl2mPropBit0MsgSplit" => 75 [ // Use ManyLut
// In carry part, contains the info if neighbor has a bit=0 (not null)
// or not (null).
// Propagate bits equal to 0 from lsb to msb.
// Split resulting message part into 2. Put both in lsb.
@0 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from lsb to msb
for idx in 0..(params.msg_w) {
let mut b = m & 1;
m >>= 1;
if c > 0 {b = 0;} // propagate to msb
if b == 0 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
exp & ((1 << lsb_size)-1) // msg_lsb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
(1 << lsb_size)-1
};
},
@1 =>{
|params: &DigitParameters, val| {
let mut c = val & params.carry_mask();
let mut m = val & params.msg_mask();
let mut exp = 0;
// Expand from lsb to msb
for idx in 0..(params.msg_w) {
let mut b = m & 1;
m >>= 1;
if c > 0 {b = 0;} // propagate to msb
if b == 0 {c = 1;}
exp += b << idx;
}
let lsb_size = params.msg_w.div_ceil(2);
(exp & params.msg_mask()) >> lsb_size // msg_msb
};
|params: &DigitParameters, _deg| {
let lsb_size = params.msg_w.div_ceil(2);
let msb_size = params.msg_w - lsb_size;
(1 << msb_size)-1
};
}
]],
);
pub(crate) fn ceil_ilog2(value: &u8) -> u8 {

View File

@@ -220,4 +220,11 @@ iop!(
[IOP_CT_F_2CT_BOOL -> "IF_THEN_ELSE", opcode::IF_THEN_ELSE],
[IOP_2CT_F_3CT -> "ERC_20", opcode::ERC_20],
[IOP_CT_F_CT -> "MEMCPY", opcode::MEMCPY],
[IOP_CT_F_CT -> "ILOG2", opcode::ILOG2],
[IOP_CT_F_CT -> "COUNT0", opcode::COUNT0],
[IOP_CT_F_CT -> "COUNT1", opcode::COUNT1],
[IOP_CT_F_CT -> "LEAD0", opcode::LEAD0],
[IOP_CT_F_CT -> "LEAD1", opcode::LEAD1],
[IOP_CT_F_CT -> "TRAIL0", opcode::TRAIL0],
[IOP_CT_F_CT -> "TRAIL1", opcode::TRAIL1],
);

View File

@@ -8,11 +8,6 @@
//! | ---------- | ------------------------- |
//! | 0x00.. 0x7f| User custom operations |
//! | 0x80.. 0xff| Fw generated operations |
//! | 0b1xyz_0000| x: Ct x Ct Operation |
//! | | !x: Ct x Imm Operation |
//! | | y!z: ARITH operations |
//! | | !yz: BW operations |
//! | | !y!z: CMP operations |
//! | ---------- | ------------------------- |
pub const USER_RANGE_LB: u8 = 0x0;
@@ -83,6 +78,15 @@ pub const IF_THEN_ELSE: u8 = 0xCB;
// 2Ct <- func(3Ct)
pub const ERC_20: u8 = 0x80;
// Count bits
pub const COUNT0: u8 = 0x81;
pub const COUNT1: u8 = 0x82;
pub const ILOG2: u8 = 0x83;
pub const LEAD0: u8 = 0x84;
pub const LEAD1: u8 = 0x85;
pub const TRAIL0: u8 = 0x86;
pub const TRAIL1: u8 = 0x87;
// Utility operations
// Used to handle real clone of ciphertext already uploaded in the Hpu memory
pub const MEMCPY: u8 = 0xFF;

View File

@@ -36,6 +36,15 @@ crate::impl_fw!("Demo" [
CMP_GT => cmp_gt;
CMP_GTE => cmp_gte;
CMP_LT => cmp_lt;
COUNT0 => fw_impl::ilp_log::iop_count0;
COUNT1 => fw_impl::ilp_log::iop_count1;
ILOG2 => fw_impl::ilp_log::iop_ilog2;
LEAD0 => fw_impl::ilp_log::iop_lead0;
LEAD1 => fw_impl::ilp_log::iop_lead1;
TRAIL0 => fw_impl::ilp_log::iop_trail0;
TRAIL1 => fw_impl::ilp_log::iop_trail1;
]);
// Recursive {{{1

View File

@@ -64,6 +64,14 @@ crate::impl_fw!("Ilp" [
ERC_20 => fw_impl::ilp::iop_erc_20;
MEMCPY => fw_impl::ilp::iop_memcpy;
COUNT0 => fw_impl::ilp_log::iop_count0;
COUNT1 => fw_impl::ilp_log::iop_count1;
ILOG2 => fw_impl::ilp_log::iop_ilog2;
LEAD0 => fw_impl::ilp_log::iop_lead0;
LEAD1 => fw_impl::ilp_log::iop_lead1;
TRAIL0 => fw_impl::ilp_log::iop_trail0;
TRAIL1 => fw_impl::ilp_log::iop_trail1;
]);
#[instrument(level = "trace", skip(prog))]

View File

@@ -1,5 +1,5 @@
//!
//! Implementation of Ilp firmware
//! Implementation of Ilp firmware for division, and modulo
//!
//! In this version of the Fw focus is done on Instruction Level Parallelism
use std::cmp::Ordering;
@@ -7,6 +7,7 @@ use std::collections::VecDeque;
use super::*;
use crate::asm::{self, OperandKind};
use crate::fw::fw_impl::ilp_log;
use crate::fw::program::Program;
use tracing::{instrument, warn};
@@ -44,12 +45,8 @@ pub fn iop_divs(prog: &mut Program) {
// Deferred implementation to generic divx function
// TODO: do computation on immediate directly for more efficiency.
// Workaround: transform immediate into ct.
let mut src_imm: Vec<metavar::MetaVarCell> = Vec::new();
let cst_0 = &src_a[src_a.len() - 1] - &src_a[src_a.len() - 1];
for cst in src_b.iter() {
let ct = cst + &cst_0;
src_imm.push(ct);
}
let cst_0 = prog.new_cst(0);
let src_imm: Vec<metavar::MetaVarCell> = src_b.iter().map(|imm| imm + &cst_0).collect();
iop_divx(prog, &mut dst_quotient, &mut dst_remain, &src_a, &src_imm);
}
@@ -106,12 +103,8 @@ pub fn iop_mods(prog: &mut Program) {
// Deferred implementation to generic modx function
// TODO: do computation on immediate directly for more efficiency.
// Workaround: transform immediate into ct.
let mut src_imm: Vec<metavar::MetaVarCell> = Vec::new();
let cst_0 = &src_a[src_a.len() - 1] - &src_a[src_a.len() - 1];
for cst in src_b.iter() {
let ct = cst + &cst_0;
src_imm.push(ct);
}
let cst_0 = prog.new_cst(0);
let src_imm: Vec<metavar::MetaVarCell> = src_b.iter().map(|imm| imm + &cst_0).collect();
iop_modx(prog, &mut dst_remain, &src_a, &src_imm);
}
@@ -342,88 +335,6 @@ pub fn iop_add_hillissteel_v(
res
}
#[instrument(level = "trace", skip(prog))]
/// Outputs a list of booleans,
/// each indicating the null status from the input
/// block <i> to the msb of the input.
/// 'true' means is not null.
/// 'false' means is null.
pub fn iop_is_not_null_vector_v(
prog: &mut Program,
src_a: &[metavar::MetaVarCell],
) -> Vec<metavar::MetaVarCell> {
let props = prog.params();
//let tfhe_params: asm::DigitParameters = props.clone().into();
let pbs_not_null = new_pbs!(prog, "NotNull");
// TODO: TOREVIEW
let op_nb = props.nu;
// clog2(op_nb)
let op_nb_bool = 1 << ((op_nb as f32).log2().ceil() as usize);
// First step
// Work within each group of op_nb blocks.
// For <i> get a boolean not null status of current block and the MSB ones.
// within this group.
let mut g_a: Vec<metavar::MetaVarCell> = Vec::new();
for (c_id, c) in src_a.chunks(op_nb).enumerate() {
c.iter().rev().fold(None, |acc, elt| {
let is_not_null;
let tmp;
if let Some(x) = acc {
tmp = &x + elt;
is_not_null = tmp.pbs(&pbs_not_null, false);
} else {
is_not_null = elt.pbs(&pbs_not_null, false);
//tmp = elt.clone();
tmp = elt.clone();
};
g_a.insert(c_id * op_nb, is_not_null); // Reverse insertion per chunk
Some(tmp)
});
}
// Second step
// Proparate the not null status from MSB to LSB, with stride of
// (op_nb_bool**k)*op_nb
//assert_eq!(g_a.len(),props.blk_w());
let grp_nb = g_a.len().div_ceil(op_nb);
let mut stride_size: usize = 1; // in group unit
while stride_size < grp_nb {
for chk in g_a.chunks_mut(op_nb_bool * stride_size * op_nb) {
chk.chunks_mut(stride_size * op_nb)
.rev()
.fold(None, |acc, sub_chk| {
if let Some(x) = acc {
let tmp = &x + &sub_chk[0];
sub_chk[0] = tmp.pbs(&pbs_not_null, false);
Some(tmp)
} else {
Some(sub_chk[0].clone())
}
});
}
stride_size *= op_nb_bool;
}
// Third step
// Apply
g_a.chunks_mut(op_nb).rev().fold(None, |acc, chk| {
if let Some(x) = acc {
for v in chk.iter_mut().skip(1) {
// [0] is already complete.
*v = &*v + x;
*v = v.pbs(&pbs_not_null, false);
}
}
Some(&chk[0])
});
g_a
}
#[instrument(level = "trace", skip(prog))]
/// Outputs a tuple corresponding to (src x2, src x3)
pub fn iop_x2_x3v(
@@ -518,9 +429,27 @@ pub fn iop_div_initv(prog: &mut Program, div_x1_a: &[metavar::MetaVarCell]) -> I
// Note that div_x2 and div_x3 has an additional ct in msb
let (div_x2_a, div_x3_a) = iop_x2_x3v(prog, div_x1_a);
let div_x1_is_not_null_a = iop_is_not_null_vector_v(prog, div_x1_a);
let div_x2_is_not_null_a = iop_is_not_null_vector_v(prog, &div_x2_a);
let div_x3_is_not_null_a = iop_is_not_null_vector_v(prog, &div_x3_a);
let div_x1_is_not_null_a = ilp_log::iop_propagate_msb_to_lsb_blockv(
prog,
div_x1_a,
&Some(ilp_log::BitType::One),
&Some(false),
&Some(false),
);
let div_x2_is_not_null_a = ilp_log::iop_propagate_msb_to_lsb_blockv(
prog,
&div_x2_a,
&Some(ilp_log::BitType::One),
&Some(false),
&Some(false),
);
let div_x3_is_not_null_a = ilp_log::iop_propagate_msb_to_lsb_blockv(
prog,
&div_x3_a,
&Some(ilp_log::BitType::One),
&Some(false),
&Some(false),
);
// If the divider is null set quotient to 0
let keep_div = div_x1_is_not_null_a[0].clone();
@@ -637,10 +566,7 @@ pub fn iop_div_corev(
// Loop
let mut quotient_a: Vec<metavar::MetaVarCell> = Vec::new();
let mut remain_a: Vec<metavar::MetaVarCell> = Vec::new();
// !! Workaround
let cst_sign_tmp = prog.new_imm(tfhe_params.msg_range() - 1);
let mut cst_sign = &src_a[num_a.len() - 1] - &src_a[num_a.len() - 1];
cst_sign = &cst_sign + &cst_sign_tmp;
let cst_sign = prog.new_cst(tfhe_params.msg_range() - 1);
for loop_idx in 0..num_a.len() {
let block_nb = loop_idx + 1;

View File

@@ -0,0 +1,533 @@
//!
//! Implementation of Ilp firmware for bit count (log2, trailing, leading bit)
//!
//! In this version of the Fw focus is done on Instruction Level Parallelism
use super::*;
use crate::asm::{self, OperandKind};
use crate::fw::program::Program;
use tracing::{instrument, warn};
use crate::new_pbs;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BitType {
One,
Zero,
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_count0(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("COUNT0 Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic countx function
iop_countx(prog, &mut dst, &src_a, &Some(BitType::Zero));
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_count1(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("COUNT1 Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic countx function
iop_countx(prog, &mut dst, &src_a, &Some(BitType::One));
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_ilog2(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("ILOG2 Operand::Dst Operand::Src Operand::Src".to_string());
let props = &prog.params();
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::One), &Some(false));
let count_a = iop_countv(prog, &prop_a[1..], &Some(BitType::One));
count_a.iter().enumerate().for_each(|(idx, c)| {
c.reg_alloc_mv();
dst[idx].mv_assign(c);
});
let cst_0 = prog.new_cst(0);
(count_a.len()..props.blk_w()).for_each(|blk| {
dst[blk].mv_assign(&cst_0);
});
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_lead0(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("LEAD0 Operand::Dst Operand::Src Operand::Src".to_string());
let props = &prog.params();
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::One), &Some(false));
let count_a = iop_countv(prog, &prop_a, &Some(BitType::Zero));
count_a.iter().enumerate().for_each(|(idx, c)| {
c.reg_alloc_mv();
dst[idx].mv_assign(c);
});
let cst_0 = prog.new_cst(0);
(count_a.len()..props.blk_w()).for_each(|blk| {
dst[blk].mv_assign(&cst_0);
});
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_lead1(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("LEAD1 Operand::Dst Operand::Src Operand::Src".to_string());
let props = &prog.params();
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::Zero), &Some(false));
let count_a = iop_countv(prog, &prop_a, &Some(BitType::One));
count_a.iter().enumerate().for_each(|(idx, c)| {
c.reg_alloc_mv();
dst[idx].mv_assign(c);
});
let cst_0 = prog.new_cst(0);
(count_a.len()..props.blk_w()).for_each(|blk| {
dst[blk].mv_assign(&cst_0);
});
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_trail0(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("TRAIL0 Operand::Dst Operand::Src Operand::Src".to_string());
let props = &prog.params();
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::One), &Some(true));
let count_a = iop_countv(prog, &prop_a, &Some(BitType::Zero));
count_a.iter().enumerate().for_each(|(idx, c)| {
c.reg_alloc_mv();
dst[idx].mv_assign(c);
});
let cst_0 = prog.new_cst(0);
(count_a.len()..props.blk_w()).for_each(|blk| {
dst[blk].mv_assign(&cst_0);
});
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_trail1(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src_a = prog.iop_template_var(OperandKind::Src, 0);
// Add Comment header
prog.push_comment("TRAIL1 Operand::Dst Operand::Src Operand::Src".to_string());
let props = &prog.params();
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::Zero), &Some(true));
let count_a = iop_countv(prog, &prop_a, &Some(BitType::One));
count_a.iter().enumerate().for_each(|(idx, c)| {
c.reg_alloc_mv();
dst[idx].mv_assign(c);
});
let cst_0 = prog.new_cst(0);
(count_a.len()..props.blk_w()).for_each(|blk| {
dst[blk].mv_assign(&cst_0);
});
}
/// Generic count bit operation
/// One destination and one source operation
/// Source is Operand
pub fn iop_countx(
prog: &mut Program,
dst: &mut [metavar::MetaVarCell],
src_a: &[metavar::MetaVarCell],
bit_type: &Option<BitType>,
) {
let props = prog.params();
//let tfhe_params: asm::DigitParameters = props.clone().into();
let pbs_many_msg_split = new_pbs!(prog, "ManyMsgSplit");
let mut bit_a: Vec<metavar::MetaVarCell> = Vec::new();
for (idx, ct) in src_a.iter().enumerate() {
let do_flush = idx == src_a.len() - 1;
let v = &ct.pbs_many(&pbs_many_msg_split, do_flush)[..];
bit_a.push(v[0].clone());
bit_a.push(v[1].clone());
}
let count_a = iop_countv(prog, &bit_a, bit_type);
count_a.iter().enumerate().for_each(|(idx, c)| {
c.reg_alloc_mv();
dst[idx].mv_assign(c);
});
let cst_0 = prog.new_cst(0);
(count_a.len()..props.blk_w()).for_each(|blk| {
dst[blk].mv_assign(&cst_0);
});
}
// Do an iteration only if there are columns that need
// to be reduced, i.e. with more than 1 element.
fn need_iter(v: &[Vec<metavar::MetaVarCell>]) -> bool {
v.iter()
.filter(|l| l.len() > 1)
.fold(false, |_acc, _l| true)
}
/// From a source composed blocks containing each
/// a single significant bit at position 0, count the number of
/// bits equal to 0 or 1, according to bit_type.
/// The source is assumed to be "clean".
pub fn iop_countv(
prog: &mut Program,
src_a: &[metavar::MetaVarCell],
bit_type: &Option<BitType>,
) -> Vec<metavar::MetaVarCell> {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let pbs_msg = new_pbs!(prog, "MsgOnly");
let pbs_carry = new_pbs!(prog, "CarryInMsg");
let pbs_many_carrymsg = new_pbs!(prog, "ManyCarryMsg");
let pbs_many_inv1_carrymsg = new_pbs!(prog, "ManyInv1CarryMsg");
let pbs_many_inv2_carrymsg = new_pbs!(prog, "ManyInv2CarryMsg");
let pbs_many_inv3_carrymsg = new_pbs!(prog, "ManyInv3CarryMsg");
let pbs_many_inv4_carrymsg = new_pbs!(prog, "ManyInv4CarryMsg");
let pbs_many_inv5_carrymsg = new_pbs!(prog, "ManyInv5CarryMsg");
let pbs_many_inv6_carrymsg = new_pbs!(prog, "ManyInv6CarryMsg");
let pbs_many_inv7_carrymsg = new_pbs!(prog, "ManyInv7CarryMsg");
let pbs_many_inv_carrymsg = [
pbs_many_carrymsg.clone(), // place holder
pbs_many_inv1_carrymsg,
pbs_many_inv2_carrymsg,
pbs_many_inv3_carrymsg,
pbs_many_inv4_carrymsg,
pbs_many_inv5_carrymsg,
pbs_many_inv6_carrymsg,
pbs_many_inv7_carrymsg,
];
// TODO: TOREVIEW
let op_nb = props.nu;
// clog2(op_nb)
let op_nb_bool = 1 << ((op_nb as f32).log2().ceil() as usize);
let op_nb_single = op_nb_bool - 1;
// Number of block to store the results.
let block_nb =
(((src_a.len() * tfhe_params.msg_w + 1) as f32).log2().ceil() as usize).div_ceil(2);
// During the process, the current MSB column will be composed of
// blocks of single bit.
// The others are composed of blocks of msg_w bits.
// Single bit column is summed op_nb_single blocks at a time. Therefore
// leaving a free bit for the manyLut extraction.
// Full msg column is summed op_nb blocks at a time. Therefore
// 2 PBS are used for the extraction.
let mut sum_v: Vec<Vec<metavar::MetaVarCell>> = vec![src_a.to_vec()];
let mut iter_idx = 0;
while need_iter(&sum_v) {
let empty_col_nb = sum_v
.iter()
.filter(|col| col.is_empty())
.fold(0, |acc, _col| acc + 1);
let mut next_v: Vec<Vec<metavar::MetaVarCell>> = Vec::new();
next_v.push(Vec::new()); // For the msg
for (c_idx, col) in sum_v.iter().enumerate() {
next_v.push(Vec::new()); // For the carry
let next_len = &next_v.len();
let is_last_nonempty_col = c_idx == (&sum_v.len() - 1 - empty_col_nb);
if col.len() == 1 {
// Single element, do not need to process
next_v[next_len - 2].push(col[0].clone());
} else if c_idx == sum_v.len() - 1 {
// Last column contains only bits
let chunk_nb = col.len().div_ceil(op_nb_single);
for (chk_idx, chk) in col.chunks(op_nb_single).enumerate() {
let do_flush = (chk_idx == (chunk_nb - 1)) && is_last_nonempty_col;
let cst_0 = prog.new_imm(0);
let (s, nb) = chk
.iter()
.fold((cst_0, 0), |(acc, elt_nb), ct| (ct + &acc, elt_nb + 1));
let m: metavar::MetaVarCell;
let c: metavar::MetaVarCell;
if bit_type.unwrap_or(BitType::One) == BitType::Zero && iter_idx == 0 {
let v = s.pbs_many(&pbs_many_inv_carrymsg[nb], do_flush);
m = v[0].clone();
c = v[1].clone();
} else {
let v = s.pbs_many(&pbs_many_carrymsg, do_flush);
m = v[0].clone();
c = v[1].clone();
}
if nb >= tfhe_params.msg_range() && c_idx < block_nb {
// Do not compute after
// the number of needed blocks.
next_v[next_len - 1].push(c);
}
next_v[next_len - 2].push(m);
}
} else {
// Regular column. Sum by op_nb elements
let chunk_nb = col.len().div_ceil(op_nb);
for (chk_idx, chk) in col.chunks(op_nb).enumerate() {
let do_flush = (chk_idx == (chunk_nb - 1)) && is_last_nonempty_col;
let cst_0 = prog.new_imm(0);
let (s, nb) = chk
.iter()
.fold((cst_0, 0), |(acc, elt_nb), ct| (ct + &acc, elt_nb + 1));
let m: metavar::MetaVarCell;
let c: metavar::MetaVarCell;
if nb > 2 {
m = s.pbs(&pbs_msg, false);
c = s.pbs(&pbs_carry, do_flush);
} else {
// Free bit to used manyLut
let v = s.pbs_many(&pbs_many_carrymsg, do_flush);
m = v[0].clone();
c = v[1].clone();
}
if c_idx < block_nb {
// Do not compute after
// the number of needed blocks.
next_v[next_len - 1].push(c);
}
next_v[next_len - 2].push(m);
}
}
} // For c_idx, col
iter_idx += 1;
sum_v = next_v;
} // while
// let mut res : Vec<metavar::MetaVarCell> = Vec::new();
sum_v
.iter()
.filter(|v| !v.is_empty())
.map(|v| v[0].clone())
.collect()
}
/// Propagate bit from msb to lsb.
pub fn iop_propagate_msb_to_lsbv(
prog: &mut Program,
src_a: &[metavar::MetaVarCell],
bit_type: &Option<BitType>,
inverse_propagation: &Option<bool>, // default false
) -> Vec<metavar::MetaVarCell> {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let pbs_many_m2l_prop_bit1_msg_split = new_pbs!(prog, "Manym2lPropBit1MsgSplit");
let pbs_many_m2l_prop_bit0_msg_split = new_pbs!(prog, "Manym2lPropBit0MsgSplit");
let pbs_many_l2m_prop_bit1_msg_split = new_pbs!(prog, "Manyl2mPropBit1MsgSplit");
let pbs_many_l2m_prop_bit0_msg_split = new_pbs!(prog, "Manyl2mPropBit0MsgSplit");
let propagate_block =
iop_propagate_msb_to_lsb_blockv(prog, src_a, bit_type, &Some(false), inverse_propagation);
let mut res_v = Vec::new();
for (idx, ct) in src_a.iter().enumerate() {
// propagation start point
let start_idx = if inverse_propagation.unwrap_or(false) {
0
} else {
src_a.len() - 1
};
let do_flush = idx == (src_a.len() - 1);
let m = if idx == start_idx {
ct.clone()
} else {
let neigh_idx = if inverse_propagation.unwrap_or(false) {
idx - 1
} else {
idx + 1
};
propagate_block[neigh_idx].mac(tfhe_params.msg_range() as u8, ct)
};
let v = if bit_type.unwrap_or(BitType::One) == BitType::One {
if inverse_propagation.unwrap_or(false) {
m.pbs_many(&pbs_many_l2m_prop_bit1_msg_split, do_flush)
} else {
m.pbs_many(&pbs_many_m2l_prop_bit1_msg_split, do_flush)
}
} else if inverse_propagation.unwrap_or(false) {
m.pbs_many(&pbs_many_l2m_prop_bit0_msg_split, do_flush)
} else {
m.pbs_many(&pbs_many_m2l_prop_bit0_msg_split, do_flush)
};
res_v.push(v[0].clone());
res_v.push(v[1].clone());
}
res_v
}
#[instrument(level = "trace", skip(prog))]
/// Propagate bit value given by bit_type from msb to lsb,
/// on block basis.
/// If inverse_output = false
/// From MSB to LSB:
/// * bit_type = 1 if block <i> contains the first bit equal to 1, from MSB then the Noutput block
/// <i> and below are set to 1, the output block <i+1> and up are set to 0.
/// * bit_type = 0 if block <i> contains the first bit equal to 0, from MSB then the output block
/// <i> and below are set to 1, the output block <i+1> and up are set to 0.
/// If inverse_output = true, the output bits described above are
/// negated.
pub fn iop_propagate_msb_to_lsb_blockv(
prog: &mut Program,
src_a: &[metavar::MetaVarCell],
bit_type: &Option<BitType>, // Default One
inverse_output: &Option<bool>, // Default false
inverse_direction: &Option<bool>, // Default false
) -> Vec<metavar::MetaVarCell> {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let pbs_not_null = new_pbs!(prog, "NotNull");
let pbs_is_null = new_pbs!(prog, "IsNull");
// TODO: TOREVIEW
let op_nb = props.nu;
// clog2(op_nb)
let op_nb_bool = 1 << ((op_nb as f32).log2().ceil() as usize);
let mut proc_nb = op_nb;
let mut src = if bit_type.unwrap_or(BitType::One) == BitType::One {
src_a.to_vec()
} else {
// Bitwise not
// Do not clean the ct, but reduce the nb of sequential operations
// in next step (reducing proc_nb).
proc_nb -= 1;
let cst_msg_max = prog.new_imm(tfhe_params.msg_mask());
src_a.iter().map(|ct| &cst_msg_max - ct).collect()
};
if inverse_direction.unwrap_or(false) {
src.reverse();
}
// First step
// Work within each group of proc_nb blocks.
// For <i> get a boolean not null status of current block and the MSB ones.
// within this group.
let mut g_a: Vec<metavar::MetaVarCell> = Vec::new();
for (c_id, c) in src.chunks(proc_nb).enumerate() {
c.iter().rev().fold(None, |acc, elt| {
let is_not_null;
let tmp;
if let Some(x) = acc {
tmp = &x + elt;
is_not_null = tmp.pbs(&pbs_not_null, false);
} else {
tmp = elt.clone();
is_not_null = elt.pbs(&pbs_not_null, false);
};
g_a.insert(c_id * proc_nb, is_not_null); // Reverse insertion per chunk
Some(tmp)
});
}
// Second step
// Proparate the not null status from MSB to LSB, with stride of
// (op_nb_bool**k)*proc_nb
//assert_eq!(g_a.len(),props.blk_w());
let grp_nb = g_a.len().div_ceil(proc_nb);
let mut stride_size: usize = 1; // in group unit
while stride_size < grp_nb {
for chk in g_a.chunks_mut(op_nb_bool * stride_size * proc_nb) {
chk.chunks_mut(stride_size * proc_nb)
.rev()
.fold(None, |acc, sub_chk| {
if let Some(x) = acc {
let tmp = &x + &sub_chk[0];
sub_chk[0] = tmp.pbs(&pbs_not_null, false);
Some(tmp)
} else {
Some(sub_chk[0].clone())
}
});
}
stride_size *= op_nb_bool;
}
// Third step
// Apply
g_a.chunks_mut(proc_nb).rev().fold(None, |acc, chk| {
if let Some(x) = acc {
for (idx, v) in chk.iter_mut().enumerate() {
if idx == 0 {
// [0] is already complete.
// Need to inverse it for 0 if needed
if inverse_output.unwrap_or(false) {
*v = v.pbs(&pbs_is_null, false);
}
} else {
*v = &*v + x;
if inverse_output.unwrap_or(false) {
*v = v.pbs(&pbs_is_null, false);
} else {
*v = v.pbs(&pbs_not_null, false);
}
}
}
}
Some(&chk[0])
});
if inverse_direction.unwrap_or(false) {
g_a.reverse();
}
g_a
}

View File

@@ -73,6 +73,14 @@ crate::impl_fw!("Llt" [
ERC_20 => fw_impl::llt::iop_erc_20;
MEMCPY => fw_impl::ilp::iop_memcpy;
COUNT0 => fw_impl::ilp_log::iop_count0;
COUNT1 => fw_impl::ilp_log::iop_count1;
ILOG2 => fw_impl::ilp_log::iop_ilog2;
LEAD0 => fw_impl::ilp_log::iop_lead0;
LEAD1 => fw_impl::ilp_log::iop_lead1;
TRAIL0 => fw_impl::ilp_log::iop_trail0;
TRAIL1 => fw_impl::ilp_log::iop_trail1;
]);
// ----------------------------------------------------------------------------

View File

@@ -4,6 +4,7 @@ use crate::asm::{AsmIOpcode, DOp, IOpcode};
pub mod demo;
pub mod ilp;
pub mod ilp_div;
pub mod ilp_log;
pub mod llt;
/// Utility macro to define new FW implementation

View File

@@ -368,6 +368,22 @@ mod hpu_test {
}
});
// Bit count IOp
hpu_testcase!("COUNT0" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].count_zeros()]);
hpu_testcase!("COUNT1" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].count_ones()]);
hpu_testcase!("ILOG2" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].ilog2()]);
hpu_testcase!("LEAD0" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].leading_zeros()]);
hpu_testcase!("LEAD1" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].leading_ones()]);
hpu_testcase!("TRAIL0" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].trailing_zeros()]);
hpu_testcase!("TRAIL1" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].trailing_ones()]);
// Define a set of test bundle for various size
// 8bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
@@ -453,6 +469,17 @@ mod hpu_test {
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::8 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 16bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::16 => [
@@ -537,6 +564,17 @@ mod hpu_test {
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::16 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 32bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::32 => [
@@ -621,6 +659,17 @@ mod hpu_test {
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::32 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 64bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::64 => [
@@ -705,6 +754,17 @@ mod hpu_test {
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::64 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 128bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::128 => [
@@ -789,6 +849,17 @@ mod hpu_test {
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::128 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
/// Simple test dedicated to check entities conversion from/to Cpu
#[cfg(feature = "hpu")]
#[test]