mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(hpu): Add ILOG2/COUNT0/COUNT1/LEAD0/LEAD1/TRAIL0/TRAIL1 IOp.
Those IOp are tested within new bitcnt category
This commit is contained in:
@@ -901,7 +901,6 @@ pbs!(
|
||||
|_params: &DigitParameters, _deg| 3;
|
||||
}
|
||||
]],
|
||||
|
||||
["IfPos1FalseZeroed" => 55 [ // Ct must contain CondCt in Carry bit 1 and ValueCt in Msg. If condition it's *FALSE*, value ct is forced to 0
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
@@ -1000,6 +999,433 @@ pbs!(
|
||||
}
|
||||
]],
|
||||
// NB: Lut IfPos1FalseZeroed already defined earlier
|
||||
["ManyInv1CarryMsg" => 64 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 1;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 1;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
["ManyInv2CarryMsg" => 65 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 2;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 2;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
|
||||
["ManyInv3CarryMsg" => 66 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 3;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 3;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
|
||||
["ManyInv4CarryMsg" => 67 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 4;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 4;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
|
||||
["ManyInv5CarryMsg" => 68 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 5;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 5;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
|
||||
["ManyInv6CarryMsg" => 69 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 6;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 6;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
|
||||
["ManyInv7CarryMsg" => 70 [ // Proceed Inv - ct
|
||||
// Extract message and carry using many LUT.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 7;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value & params.msg_mask()
|
||||
}
|
||||
};
|
||||
|params: &DigitParameters, _deg| params.msg_mask();
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val | {
|
||||
let inv = 7;
|
||||
let mut value = val & params.data_mask();
|
||||
if value > inv {
|
||||
0
|
||||
} else {
|
||||
value = inv - value;
|
||||
value >> params.msg_w
|
||||
}
|
||||
};
|
||||
|_params: &DigitParameters, _deg| 1;
|
||||
},
|
||||
]],
|
||||
["ManyMsgSplit" => 71 [ // Use manyLUT : split msg in halves
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
val & ((1 << lsb_size)-1) // msg_lsb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(1 << lsb_size)-1
|
||||
};
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(val & params.msg_mask()) >> lsb_size // msg_msb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
let msb_size = params.msg_w - lsb_size;
|
||||
(1 << msb_size)-1
|
||||
};
|
||||
}
|
||||
]],
|
||||
["Manym2lPropBit1MsgSplit" => 72 [ // Use ManyLut
|
||||
// In carry part, contains the info if neighbor has a bit=1 (not null)
|
||||
// or not (null).
|
||||
// Propagate bits equal to 1 from msb to lsb.
|
||||
// Split resulting message part into 2. Put both in lsb.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from msb to lsb
|
||||
for idx in (0..params.msg_w).rev() {
|
||||
let mut b = (m >> idx) & 1;
|
||||
m &= (1 << idx)-1;
|
||||
if c > 0 {b = 1;} // propagate to lsb
|
||||
if b == 1 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
exp & ((1 << lsb_size)-1) // msg_lsb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(1 << lsb_size)-1
|
||||
};
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from msb to lsb
|
||||
for idx in (0..params.msg_w).rev() {
|
||||
let mut b = (m >> idx) & 1;
|
||||
m &= (1 << idx)-1;
|
||||
if c > 0 {b = 1;} // propagate to lsb
|
||||
if b == 1 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(exp & params.msg_mask()) >> lsb_size // msg_msb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
let msb_size = params.msg_w - lsb_size;
|
||||
(1 << msb_size)-1
|
||||
};
|
||||
}
|
||||
]],
|
||||
["Manym2lPropBit0MsgSplit" => 73 [ // Use ManyLut
|
||||
// In carry part, contains the info if neighbor has a bit=0 (not null)
|
||||
// or not (null).
|
||||
// Propagate bits equal to 0 from msb to lsb.
|
||||
// Split resulting message part into 2. Put both in lsb.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from msb to lsb
|
||||
for idx in (0..(params.msg_w)).rev() {
|
||||
let mut b = (m >> idx) & 1;
|
||||
m &= (1 << idx)-1;
|
||||
if c > 0 {b = 0;} // propagate to lsb
|
||||
if b == 0 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
exp & ((1 << lsb_size)-1) // msg_lsb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(1 << lsb_size)-1
|
||||
};
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from msb to lsb
|
||||
for idx in (0..(params.msg_w)).rev() {
|
||||
let mut b = (m >> idx) & 1;
|
||||
m &= (1 << idx)-1;
|
||||
if c > 0 {b = 0;} // propagate to lsb
|
||||
if b == 0 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(exp & params.msg_mask()) >> lsb_size // msg_msb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
let msb_size = params.msg_w - lsb_size;
|
||||
(1 << msb_size)-1
|
||||
};
|
||||
}
|
||||
]],
|
||||
["Manyl2mPropBit1MsgSplit" => 74 [ // Use ManyLut
|
||||
// In carry part, contains the info if neighbor has a bit=1 (not null)
|
||||
// or not (null).
|
||||
// Propagate bits equal to 1 from lsb to msb.
|
||||
// Split resulting message part into 2. Put both in lsb.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from lsb to msb
|
||||
for idx in 0..(params.msg_w) {
|
||||
let mut b = m & 1;
|
||||
m >>= 1;
|
||||
if c > 0 {b = 1;} // propagate to msb
|
||||
if b == 1 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
exp & ((1 << lsb_size)-1) // msg_lsb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(1 << lsb_size)-1
|
||||
};
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from lsb to msb
|
||||
for idx in 0..(params.msg_w) {
|
||||
let mut b = m & 1;
|
||||
m >>= 1;
|
||||
if c > 0 {b = 1;} // propagate to msb
|
||||
if b == 1 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(exp & params.msg_mask()) >> lsb_size // msg_msb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
let msb_size = params.msg_w - lsb_size;
|
||||
(1 << msb_size)-1
|
||||
};
|
||||
}
|
||||
]],
|
||||
["Manyl2mPropBit0MsgSplit" => 75 [ // Use ManyLut
|
||||
// In carry part, contains the info if neighbor has a bit=0 (not null)
|
||||
// or not (null).
|
||||
// Propagate bits equal to 0 from lsb to msb.
|
||||
// Split resulting message part into 2. Put both in lsb.
|
||||
@0 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from lsb to msb
|
||||
for idx in 0..(params.msg_w) {
|
||||
let mut b = m & 1;
|
||||
m >>= 1;
|
||||
if c > 0 {b = 0;} // propagate to msb
|
||||
if b == 0 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
exp & ((1 << lsb_size)-1) // msg_lsb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(1 << lsb_size)-1
|
||||
};
|
||||
},
|
||||
@1 =>{
|
||||
|params: &DigitParameters, val| {
|
||||
let mut c = val & params.carry_mask();
|
||||
let mut m = val & params.msg_mask();
|
||||
let mut exp = 0;
|
||||
// Expand from lsb to msb
|
||||
for idx in 0..(params.msg_w) {
|
||||
let mut b = m & 1;
|
||||
m >>= 1;
|
||||
if c > 0 {b = 0;} // propagate to msb
|
||||
if b == 0 {c = 1;}
|
||||
exp += b << idx;
|
||||
}
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
(exp & params.msg_mask()) >> lsb_size // msg_msb
|
||||
};
|
||||
|params: &DigitParameters, _deg| {
|
||||
let lsb_size = params.msg_w.div_ceil(2);
|
||||
let msb_size = params.msg_w - lsb_size;
|
||||
(1 << msb_size)-1
|
||||
};
|
||||
}
|
||||
]],
|
||||
);
|
||||
|
||||
pub(crate) fn ceil_ilog2(value: &u8) -> u8 {
|
||||
|
||||
@@ -220,4 +220,11 @@ iop!(
|
||||
[IOP_CT_F_2CT_BOOL -> "IF_THEN_ELSE", opcode::IF_THEN_ELSE],
|
||||
[IOP_2CT_F_3CT -> "ERC_20", opcode::ERC_20],
|
||||
[IOP_CT_F_CT -> "MEMCPY", opcode::MEMCPY],
|
||||
[IOP_CT_F_CT -> "ILOG2", opcode::ILOG2],
|
||||
[IOP_CT_F_CT -> "COUNT0", opcode::COUNT0],
|
||||
[IOP_CT_F_CT -> "COUNT1", opcode::COUNT1],
|
||||
[IOP_CT_F_CT -> "LEAD0", opcode::LEAD0],
|
||||
[IOP_CT_F_CT -> "LEAD1", opcode::LEAD1],
|
||||
[IOP_CT_F_CT -> "TRAIL0", opcode::TRAIL0],
|
||||
[IOP_CT_F_CT -> "TRAIL1", opcode::TRAIL1],
|
||||
);
|
||||
|
||||
@@ -8,11 +8,6 @@
|
||||
//! | ---------- | ------------------------- |
|
||||
//! | 0x00.. 0x7f| User custom operations |
|
||||
//! | 0x80.. 0xff| Fw generated operations |
|
||||
//! | 0b1xyz_0000| x: Ct x Ct Operation |
|
||||
//! | | !x: Ct x Imm Operation |
|
||||
//! | | y!z: ARITH operations |
|
||||
//! | | !yz: BW operations |
|
||||
//! | | !y!z: CMP operations |
|
||||
//! | ---------- | ------------------------- |
|
||||
|
||||
pub const USER_RANGE_LB: u8 = 0x0;
|
||||
@@ -83,6 +78,15 @@ pub const IF_THEN_ELSE: u8 = 0xCB;
|
||||
// 2Ct <- func(3Ct)
|
||||
pub const ERC_20: u8 = 0x80;
|
||||
|
||||
// Count bits
|
||||
pub const COUNT0: u8 = 0x81;
|
||||
pub const COUNT1: u8 = 0x82;
|
||||
pub const ILOG2: u8 = 0x83;
|
||||
pub const LEAD0: u8 = 0x84;
|
||||
pub const LEAD1: u8 = 0x85;
|
||||
pub const TRAIL0: u8 = 0x86;
|
||||
pub const TRAIL1: u8 = 0x87;
|
||||
|
||||
// Utility operations
|
||||
// Used to handle real clone of ciphertext already uploaded in the Hpu memory
|
||||
pub const MEMCPY: u8 = 0xFF;
|
||||
|
||||
@@ -36,6 +36,15 @@ crate::impl_fw!("Demo" [
|
||||
CMP_GT => cmp_gt;
|
||||
CMP_GTE => cmp_gte;
|
||||
CMP_LT => cmp_lt;
|
||||
|
||||
COUNT0 => fw_impl::ilp_log::iop_count0;
|
||||
COUNT1 => fw_impl::ilp_log::iop_count1;
|
||||
ILOG2 => fw_impl::ilp_log::iop_ilog2;
|
||||
LEAD0 => fw_impl::ilp_log::iop_lead0;
|
||||
LEAD1 => fw_impl::ilp_log::iop_lead1;
|
||||
TRAIL0 => fw_impl::ilp_log::iop_trail0;
|
||||
TRAIL1 => fw_impl::ilp_log::iop_trail1;
|
||||
|
||||
]);
|
||||
|
||||
// Recursive {{{1
|
||||
|
||||
@@ -64,6 +64,14 @@ crate::impl_fw!("Ilp" [
|
||||
ERC_20 => fw_impl::ilp::iop_erc_20;
|
||||
|
||||
MEMCPY => fw_impl::ilp::iop_memcpy;
|
||||
|
||||
COUNT0 => fw_impl::ilp_log::iop_count0;
|
||||
COUNT1 => fw_impl::ilp_log::iop_count1;
|
||||
ILOG2 => fw_impl::ilp_log::iop_ilog2;
|
||||
LEAD0 => fw_impl::ilp_log::iop_lead0;
|
||||
LEAD1 => fw_impl::ilp_log::iop_lead1;
|
||||
TRAIL0 => fw_impl::ilp_log::iop_trail0;
|
||||
TRAIL1 => fw_impl::ilp_log::iop_trail1;
|
||||
]);
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//!
|
||||
//! Implementation of Ilp firmware
|
||||
//! Implementation of Ilp firmware for division, and modulo
|
||||
//!
|
||||
//! In this version of the Fw focus is done on Instruction Level Parallelism
|
||||
use std::cmp::Ordering;
|
||||
@@ -7,6 +7,7 @@ use std::collections::VecDeque;
|
||||
|
||||
use super::*;
|
||||
use crate::asm::{self, OperandKind};
|
||||
use crate::fw::fw_impl::ilp_log;
|
||||
use crate::fw::program::Program;
|
||||
use tracing::{instrument, warn};
|
||||
|
||||
@@ -44,12 +45,8 @@ pub fn iop_divs(prog: &mut Program) {
|
||||
// Deferred implementation to generic divx function
|
||||
// TODO: do computation on immediate directly for more efficiency.
|
||||
// Workaround: transform immediate into ct.
|
||||
let mut src_imm: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
let cst_0 = &src_a[src_a.len() - 1] - &src_a[src_a.len() - 1];
|
||||
for cst in src_b.iter() {
|
||||
let ct = cst + &cst_0;
|
||||
src_imm.push(ct);
|
||||
}
|
||||
let cst_0 = prog.new_cst(0);
|
||||
let src_imm: Vec<metavar::MetaVarCell> = src_b.iter().map(|imm| imm + &cst_0).collect();
|
||||
iop_divx(prog, &mut dst_quotient, &mut dst_remain, &src_a, &src_imm);
|
||||
}
|
||||
|
||||
@@ -106,12 +103,8 @@ pub fn iop_mods(prog: &mut Program) {
|
||||
// Deferred implementation to generic modx function
|
||||
// TODO: do computation on immediate directly for more efficiency.
|
||||
// Workaround: transform immediate into ct.
|
||||
let mut src_imm: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
let cst_0 = &src_a[src_a.len() - 1] - &src_a[src_a.len() - 1];
|
||||
for cst in src_b.iter() {
|
||||
let ct = cst + &cst_0;
|
||||
src_imm.push(ct);
|
||||
}
|
||||
let cst_0 = prog.new_cst(0);
|
||||
let src_imm: Vec<metavar::MetaVarCell> = src_b.iter().map(|imm| imm + &cst_0).collect();
|
||||
iop_modx(prog, &mut dst_remain, &src_a, &src_imm);
|
||||
}
|
||||
|
||||
@@ -342,88 +335,6 @@ pub fn iop_add_hillissteel_v(
|
||||
res
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
/// Outputs a list of booleans,
|
||||
/// each indicating the null status from the input
|
||||
/// block <i> to the msb of the input.
|
||||
/// 'true' means is not null.
|
||||
/// 'false' means is null.
|
||||
pub fn iop_is_not_null_vector_v(
|
||||
prog: &mut Program,
|
||||
src_a: &[metavar::MetaVarCell],
|
||||
) -> Vec<metavar::MetaVarCell> {
|
||||
let props = prog.params();
|
||||
//let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let pbs_not_null = new_pbs!(prog, "NotNull");
|
||||
|
||||
// TODO: TOREVIEW
|
||||
let op_nb = props.nu;
|
||||
// clog2(op_nb)
|
||||
let op_nb_bool = 1 << ((op_nb as f32).log2().ceil() as usize);
|
||||
|
||||
// First step
|
||||
// Work within each group of op_nb blocks.
|
||||
// For <i> get a boolean not null status of current block and the MSB ones.
|
||||
// within this group.
|
||||
let mut g_a: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
for (c_id, c) in src_a.chunks(op_nb).enumerate() {
|
||||
c.iter().rev().fold(None, |acc, elt| {
|
||||
let is_not_null;
|
||||
let tmp;
|
||||
if let Some(x) = acc {
|
||||
tmp = &x + elt;
|
||||
is_not_null = tmp.pbs(&pbs_not_null, false);
|
||||
} else {
|
||||
is_not_null = elt.pbs(&pbs_not_null, false);
|
||||
//tmp = elt.clone();
|
||||
tmp = elt.clone();
|
||||
};
|
||||
g_a.insert(c_id * op_nb, is_not_null); // Reverse insertion per chunk
|
||||
Some(tmp)
|
||||
});
|
||||
}
|
||||
|
||||
// Second step
|
||||
// Proparate the not null status from MSB to LSB, with stride of
|
||||
// (op_nb_bool**k)*op_nb
|
||||
//assert_eq!(g_a.len(),props.blk_w());
|
||||
let grp_nb = g_a.len().div_ceil(op_nb);
|
||||
let mut stride_size: usize = 1; // in group unit
|
||||
while stride_size < grp_nb {
|
||||
for chk in g_a.chunks_mut(op_nb_bool * stride_size * op_nb) {
|
||||
chk.chunks_mut(stride_size * op_nb)
|
||||
.rev()
|
||||
.fold(None, |acc, sub_chk| {
|
||||
if let Some(x) = acc {
|
||||
let tmp = &x + &sub_chk[0];
|
||||
sub_chk[0] = tmp.pbs(&pbs_not_null, false);
|
||||
Some(tmp)
|
||||
} else {
|
||||
Some(sub_chk[0].clone())
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
stride_size *= op_nb_bool;
|
||||
}
|
||||
|
||||
// Third step
|
||||
// Apply
|
||||
g_a.chunks_mut(op_nb).rev().fold(None, |acc, chk| {
|
||||
if let Some(x) = acc {
|
||||
for v in chk.iter_mut().skip(1) {
|
||||
// [0] is already complete.
|
||||
*v = &*v + x;
|
||||
*v = v.pbs(&pbs_not_null, false);
|
||||
}
|
||||
}
|
||||
Some(&chk[0])
|
||||
});
|
||||
|
||||
g_a
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
/// Outputs a tuple corresponding to (src x2, src x3)
|
||||
pub fn iop_x2_x3v(
|
||||
@@ -518,9 +429,27 @@ pub fn iop_div_initv(prog: &mut Program, div_x1_a: &[metavar::MetaVarCell]) -> I
|
||||
// Note that div_x2 and div_x3 has an additional ct in msb
|
||||
let (div_x2_a, div_x3_a) = iop_x2_x3v(prog, div_x1_a);
|
||||
|
||||
let div_x1_is_not_null_a = iop_is_not_null_vector_v(prog, div_x1_a);
|
||||
let div_x2_is_not_null_a = iop_is_not_null_vector_v(prog, &div_x2_a);
|
||||
let div_x3_is_not_null_a = iop_is_not_null_vector_v(prog, &div_x3_a);
|
||||
let div_x1_is_not_null_a = ilp_log::iop_propagate_msb_to_lsb_blockv(
|
||||
prog,
|
||||
div_x1_a,
|
||||
&Some(ilp_log::BitType::One),
|
||||
&Some(false),
|
||||
&Some(false),
|
||||
);
|
||||
let div_x2_is_not_null_a = ilp_log::iop_propagate_msb_to_lsb_blockv(
|
||||
prog,
|
||||
&div_x2_a,
|
||||
&Some(ilp_log::BitType::One),
|
||||
&Some(false),
|
||||
&Some(false),
|
||||
);
|
||||
let div_x3_is_not_null_a = ilp_log::iop_propagate_msb_to_lsb_blockv(
|
||||
prog,
|
||||
&div_x3_a,
|
||||
&Some(ilp_log::BitType::One),
|
||||
&Some(false),
|
||||
&Some(false),
|
||||
);
|
||||
|
||||
// If the divider is null set quotient to 0
|
||||
let keep_div = div_x1_is_not_null_a[0].clone();
|
||||
@@ -637,10 +566,7 @@ pub fn iop_div_corev(
|
||||
// Loop
|
||||
let mut quotient_a: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
let mut remain_a: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
// !! Workaround
|
||||
let cst_sign_tmp = prog.new_imm(tfhe_params.msg_range() - 1);
|
||||
let mut cst_sign = &src_a[num_a.len() - 1] - &src_a[num_a.len() - 1];
|
||||
cst_sign = &cst_sign + &cst_sign_tmp;
|
||||
let cst_sign = prog.new_cst(tfhe_params.msg_range() - 1);
|
||||
|
||||
for loop_idx in 0..num_a.len() {
|
||||
let block_nb = loop_idx + 1;
|
||||
|
||||
533
backends/tfhe-hpu-backend/src/fw/fw_impl/ilp_log.rs
Normal file
533
backends/tfhe-hpu-backend/src/fw/fw_impl/ilp_log.rs
Normal file
@@ -0,0 +1,533 @@
|
||||
//!
|
||||
//! Implementation of Ilp firmware for bit count (log2, trailing, leading bit)
|
||||
//!
|
||||
//! In this version of the Fw focus is done on Instruction Level Parallelism
|
||||
use super::*;
|
||||
use crate::asm::{self, OperandKind};
|
||||
use crate::fw::program::Program;
|
||||
use tracing::{instrument, warn};
|
||||
|
||||
use crate::new_pbs;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum BitType {
|
||||
One,
|
||||
Zero,
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_count0(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("COUNT0 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic countx function
|
||||
iop_countx(prog, &mut dst, &src_a, &Some(BitType::Zero));
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_count1(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("COUNT1 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic countx function
|
||||
iop_countx(prog, &mut dst, &src_a, &Some(BitType::One));
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_ilog2(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("ILOG2 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
|
||||
let props = &prog.params();
|
||||
|
||||
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::One), &Some(false));
|
||||
let count_a = iop_countv(prog, &prop_a[1..], &Some(BitType::One));
|
||||
|
||||
count_a.iter().enumerate().for_each(|(idx, c)| {
|
||||
c.reg_alloc_mv();
|
||||
dst[idx].mv_assign(c);
|
||||
});
|
||||
let cst_0 = prog.new_cst(0);
|
||||
(count_a.len()..props.blk_w()).for_each(|blk| {
|
||||
dst[blk].mv_assign(&cst_0);
|
||||
});
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_lead0(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("LEAD0 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
|
||||
let props = &prog.params();
|
||||
|
||||
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::One), &Some(false));
|
||||
let count_a = iop_countv(prog, &prop_a, &Some(BitType::Zero));
|
||||
|
||||
count_a.iter().enumerate().for_each(|(idx, c)| {
|
||||
c.reg_alloc_mv();
|
||||
dst[idx].mv_assign(c);
|
||||
});
|
||||
let cst_0 = prog.new_cst(0);
|
||||
(count_a.len()..props.blk_w()).for_each(|blk| {
|
||||
dst[blk].mv_assign(&cst_0);
|
||||
});
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_lead1(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("LEAD1 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
|
||||
let props = &prog.params();
|
||||
|
||||
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::Zero), &Some(false));
|
||||
let count_a = iop_countv(prog, &prop_a, &Some(BitType::One));
|
||||
|
||||
count_a.iter().enumerate().for_each(|(idx, c)| {
|
||||
c.reg_alloc_mv();
|
||||
dst[idx].mv_assign(c);
|
||||
});
|
||||
let cst_0 = prog.new_cst(0);
|
||||
(count_a.len()..props.blk_w()).for_each(|blk| {
|
||||
dst[blk].mv_assign(&cst_0);
|
||||
});
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_trail0(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("TRAIL0 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
|
||||
let props = &prog.params();
|
||||
|
||||
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::One), &Some(true));
|
||||
let count_a = iop_countv(prog, &prop_a, &Some(BitType::Zero));
|
||||
|
||||
count_a.iter().enumerate().for_each(|(idx, c)| {
|
||||
c.reg_alloc_mv();
|
||||
dst[idx].mv_assign(c);
|
||||
});
|
||||
let cst_0 = prog.new_cst(0);
|
||||
(count_a.len()..props.blk_w()).for_each(|blk| {
|
||||
dst[blk].mv_assign(&cst_0);
|
||||
});
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_trail1(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let mut dst = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
// SrcA -> Operand
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 0);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("TRAIL1 Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
|
||||
let props = &prog.params();
|
||||
|
||||
let prop_a = iop_propagate_msb_to_lsbv(prog, &src_a, &Some(BitType::Zero), &Some(true));
|
||||
let count_a = iop_countv(prog, &prop_a, &Some(BitType::One));
|
||||
|
||||
count_a.iter().enumerate().for_each(|(idx, c)| {
|
||||
c.reg_alloc_mv();
|
||||
dst[idx].mv_assign(c);
|
||||
});
|
||||
let cst_0 = prog.new_cst(0);
|
||||
(count_a.len()..props.blk_w()).for_each(|blk| {
|
||||
dst[blk].mv_assign(&cst_0);
|
||||
});
|
||||
}
|
||||
|
||||
/// Generic count bit operation
|
||||
/// One destination and one source operation
|
||||
/// Source is Operand
|
||||
pub fn iop_countx(
|
||||
prog: &mut Program,
|
||||
dst: &mut [metavar::MetaVarCell],
|
||||
src_a: &[metavar::MetaVarCell],
|
||||
bit_type: &Option<BitType>,
|
||||
) {
|
||||
let props = prog.params();
|
||||
//let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let pbs_many_msg_split = new_pbs!(prog, "ManyMsgSplit");
|
||||
|
||||
let mut bit_a: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
for (idx, ct) in src_a.iter().enumerate() {
|
||||
let do_flush = idx == src_a.len() - 1;
|
||||
let v = &ct.pbs_many(&pbs_many_msg_split, do_flush)[..];
|
||||
bit_a.push(v[0].clone());
|
||||
bit_a.push(v[1].clone());
|
||||
}
|
||||
|
||||
let count_a = iop_countv(prog, &bit_a, bit_type);
|
||||
|
||||
count_a.iter().enumerate().for_each(|(idx, c)| {
|
||||
c.reg_alloc_mv();
|
||||
dst[idx].mv_assign(c);
|
||||
});
|
||||
let cst_0 = prog.new_cst(0);
|
||||
(count_a.len()..props.blk_w()).for_each(|blk| {
|
||||
dst[blk].mv_assign(&cst_0);
|
||||
});
|
||||
}
|
||||
|
||||
// Do an iteration only if there are columns that need
|
||||
// to be reduced, i.e. with more than 1 element.
|
||||
fn need_iter(v: &[Vec<metavar::MetaVarCell>]) -> bool {
|
||||
v.iter()
|
||||
.filter(|l| l.len() > 1)
|
||||
.fold(false, |_acc, _l| true)
|
||||
}
|
||||
|
||||
/// From a source composed blocks containing each
|
||||
/// a single significant bit at position 0, count the number of
|
||||
/// bits equal to 0 or 1, according to bit_type.
|
||||
/// The source is assumed to be "clean".
|
||||
pub fn iop_countv(
|
||||
prog: &mut Program,
|
||||
src_a: &[metavar::MetaVarCell],
|
||||
bit_type: &Option<BitType>,
|
||||
) -> Vec<metavar::MetaVarCell> {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let pbs_msg = new_pbs!(prog, "MsgOnly");
|
||||
let pbs_carry = new_pbs!(prog, "CarryInMsg");
|
||||
let pbs_many_carrymsg = new_pbs!(prog, "ManyCarryMsg");
|
||||
let pbs_many_inv1_carrymsg = new_pbs!(prog, "ManyInv1CarryMsg");
|
||||
let pbs_many_inv2_carrymsg = new_pbs!(prog, "ManyInv2CarryMsg");
|
||||
let pbs_many_inv3_carrymsg = new_pbs!(prog, "ManyInv3CarryMsg");
|
||||
let pbs_many_inv4_carrymsg = new_pbs!(prog, "ManyInv4CarryMsg");
|
||||
let pbs_many_inv5_carrymsg = new_pbs!(prog, "ManyInv5CarryMsg");
|
||||
let pbs_many_inv6_carrymsg = new_pbs!(prog, "ManyInv6CarryMsg");
|
||||
let pbs_many_inv7_carrymsg = new_pbs!(prog, "ManyInv7CarryMsg");
|
||||
|
||||
let pbs_many_inv_carrymsg = [
|
||||
pbs_many_carrymsg.clone(), // place holder
|
||||
pbs_many_inv1_carrymsg,
|
||||
pbs_many_inv2_carrymsg,
|
||||
pbs_many_inv3_carrymsg,
|
||||
pbs_many_inv4_carrymsg,
|
||||
pbs_many_inv5_carrymsg,
|
||||
pbs_many_inv6_carrymsg,
|
||||
pbs_many_inv7_carrymsg,
|
||||
];
|
||||
|
||||
// TODO: TOREVIEW
|
||||
let op_nb = props.nu;
|
||||
// clog2(op_nb)
|
||||
let op_nb_bool = 1 << ((op_nb as f32).log2().ceil() as usize);
|
||||
let op_nb_single = op_nb_bool - 1;
|
||||
|
||||
// Number of block to store the results.
|
||||
let block_nb =
|
||||
(((src_a.len() * tfhe_params.msg_w + 1) as f32).log2().ceil() as usize).div_ceil(2);
|
||||
|
||||
// During the process, the current MSB column will be composed of
|
||||
// blocks of single bit.
|
||||
// The others are composed of blocks of msg_w bits.
|
||||
// Single bit column is summed op_nb_single blocks at a time. Therefore
|
||||
// leaving a free bit for the manyLut extraction.
|
||||
// Full msg column is summed op_nb blocks at a time. Therefore
|
||||
// 2 PBS are used for the extraction.
|
||||
|
||||
let mut sum_v: Vec<Vec<metavar::MetaVarCell>> = vec![src_a.to_vec()];
|
||||
|
||||
let mut iter_idx = 0;
|
||||
while need_iter(&sum_v) {
|
||||
let empty_col_nb = sum_v
|
||||
.iter()
|
||||
.filter(|col| col.is_empty())
|
||||
.fold(0, |acc, _col| acc + 1);
|
||||
|
||||
let mut next_v: Vec<Vec<metavar::MetaVarCell>> = Vec::new();
|
||||
next_v.push(Vec::new()); // For the msg
|
||||
for (c_idx, col) in sum_v.iter().enumerate() {
|
||||
next_v.push(Vec::new()); // For the carry
|
||||
let next_len = &next_v.len();
|
||||
let is_last_nonempty_col = c_idx == (&sum_v.len() - 1 - empty_col_nb);
|
||||
|
||||
if col.len() == 1 {
|
||||
// Single element, do not need to process
|
||||
next_v[next_len - 2].push(col[0].clone());
|
||||
} else if c_idx == sum_v.len() - 1 {
|
||||
// Last column contains only bits
|
||||
let chunk_nb = col.len().div_ceil(op_nb_single);
|
||||
for (chk_idx, chk) in col.chunks(op_nb_single).enumerate() {
|
||||
let do_flush = (chk_idx == (chunk_nb - 1)) && is_last_nonempty_col;
|
||||
|
||||
let cst_0 = prog.new_imm(0);
|
||||
let (s, nb) = chk
|
||||
.iter()
|
||||
.fold((cst_0, 0), |(acc, elt_nb), ct| (ct + &acc, elt_nb + 1));
|
||||
let m: metavar::MetaVarCell;
|
||||
let c: metavar::MetaVarCell;
|
||||
if bit_type.unwrap_or(BitType::One) == BitType::Zero && iter_idx == 0 {
|
||||
let v = s.pbs_many(&pbs_many_inv_carrymsg[nb], do_flush);
|
||||
m = v[0].clone();
|
||||
c = v[1].clone();
|
||||
} else {
|
||||
let v = s.pbs_many(&pbs_many_carrymsg, do_flush);
|
||||
m = v[0].clone();
|
||||
c = v[1].clone();
|
||||
}
|
||||
if nb >= tfhe_params.msg_range() && c_idx < block_nb {
|
||||
// Do not compute after
|
||||
// the number of needed blocks.
|
||||
next_v[next_len - 1].push(c);
|
||||
}
|
||||
next_v[next_len - 2].push(m);
|
||||
}
|
||||
} else {
|
||||
// Regular column. Sum by op_nb elements
|
||||
let chunk_nb = col.len().div_ceil(op_nb);
|
||||
for (chk_idx, chk) in col.chunks(op_nb).enumerate() {
|
||||
let do_flush = (chk_idx == (chunk_nb - 1)) && is_last_nonempty_col;
|
||||
|
||||
let cst_0 = prog.new_imm(0);
|
||||
let (s, nb) = chk
|
||||
.iter()
|
||||
.fold((cst_0, 0), |(acc, elt_nb), ct| (ct + &acc, elt_nb + 1));
|
||||
let m: metavar::MetaVarCell;
|
||||
let c: metavar::MetaVarCell;
|
||||
if nb > 2 {
|
||||
m = s.pbs(&pbs_msg, false);
|
||||
c = s.pbs(&pbs_carry, do_flush);
|
||||
} else {
|
||||
// Free bit to used manyLut
|
||||
let v = s.pbs_many(&pbs_many_carrymsg, do_flush);
|
||||
m = v[0].clone();
|
||||
c = v[1].clone();
|
||||
}
|
||||
if c_idx < block_nb {
|
||||
// Do not compute after
|
||||
// the number of needed blocks.
|
||||
next_v[next_len - 1].push(c);
|
||||
}
|
||||
next_v[next_len - 2].push(m);
|
||||
}
|
||||
}
|
||||
} // For c_idx, col
|
||||
iter_idx += 1;
|
||||
sum_v = next_v;
|
||||
} // while
|
||||
|
||||
// let mut res : Vec<metavar::MetaVarCell> = Vec::new();
|
||||
sum_v
|
||||
.iter()
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(|v| v[0].clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Propagate bit from msb to lsb.
|
||||
pub fn iop_propagate_msb_to_lsbv(
|
||||
prog: &mut Program,
|
||||
src_a: &[metavar::MetaVarCell],
|
||||
bit_type: &Option<BitType>,
|
||||
inverse_propagation: &Option<bool>, // default false
|
||||
) -> Vec<metavar::MetaVarCell> {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let pbs_many_m2l_prop_bit1_msg_split = new_pbs!(prog, "Manym2lPropBit1MsgSplit");
|
||||
let pbs_many_m2l_prop_bit0_msg_split = new_pbs!(prog, "Manym2lPropBit0MsgSplit");
|
||||
let pbs_many_l2m_prop_bit1_msg_split = new_pbs!(prog, "Manyl2mPropBit1MsgSplit");
|
||||
let pbs_many_l2m_prop_bit0_msg_split = new_pbs!(prog, "Manyl2mPropBit0MsgSplit");
|
||||
|
||||
let propagate_block =
|
||||
iop_propagate_msb_to_lsb_blockv(prog, src_a, bit_type, &Some(false), inverse_propagation);
|
||||
|
||||
let mut res_v = Vec::new();
|
||||
for (idx, ct) in src_a.iter().enumerate() {
|
||||
// propagation start point
|
||||
let start_idx = if inverse_propagation.unwrap_or(false) {
|
||||
0
|
||||
} else {
|
||||
src_a.len() - 1
|
||||
};
|
||||
let do_flush = idx == (src_a.len() - 1);
|
||||
let m = if idx == start_idx {
|
||||
ct.clone()
|
||||
} else {
|
||||
let neigh_idx = if inverse_propagation.unwrap_or(false) {
|
||||
idx - 1
|
||||
} else {
|
||||
idx + 1
|
||||
};
|
||||
propagate_block[neigh_idx].mac(tfhe_params.msg_range() as u8, ct)
|
||||
};
|
||||
let v = if bit_type.unwrap_or(BitType::One) == BitType::One {
|
||||
if inverse_propagation.unwrap_or(false) {
|
||||
m.pbs_many(&pbs_many_l2m_prop_bit1_msg_split, do_flush)
|
||||
} else {
|
||||
m.pbs_many(&pbs_many_m2l_prop_bit1_msg_split, do_flush)
|
||||
}
|
||||
} else if inverse_propagation.unwrap_or(false) {
|
||||
m.pbs_many(&pbs_many_l2m_prop_bit0_msg_split, do_flush)
|
||||
} else {
|
||||
m.pbs_many(&pbs_many_m2l_prop_bit0_msg_split, do_flush)
|
||||
};
|
||||
res_v.push(v[0].clone());
|
||||
res_v.push(v[1].clone());
|
||||
}
|
||||
|
||||
res_v
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
/// Propagate bit value given by bit_type from msb to lsb,
|
||||
/// on block basis.
|
||||
/// If inverse_output = false
|
||||
/// From MSB to LSB:
|
||||
/// * bit_type = 1 if block <i> contains the first bit equal to 1, from MSB then the Noutput block
|
||||
/// <i> and below are set to 1, the output block <i+1> and up are set to 0.
|
||||
/// * bit_type = 0 if block <i> contains the first bit equal to 0, from MSB then the output block
|
||||
/// <i> and below are set to 1, the output block <i+1> and up are set to 0.
|
||||
/// If inverse_output = true, the output bits described above are
|
||||
/// negated.
|
||||
pub fn iop_propagate_msb_to_lsb_blockv(
|
||||
prog: &mut Program,
|
||||
src_a: &[metavar::MetaVarCell],
|
||||
bit_type: &Option<BitType>, // Default One
|
||||
inverse_output: &Option<bool>, // Default false
|
||||
inverse_direction: &Option<bool>, // Default false
|
||||
) -> Vec<metavar::MetaVarCell> {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let pbs_not_null = new_pbs!(prog, "NotNull");
|
||||
let pbs_is_null = new_pbs!(prog, "IsNull");
|
||||
|
||||
// TODO: TOREVIEW
|
||||
let op_nb = props.nu;
|
||||
// clog2(op_nb)
|
||||
let op_nb_bool = 1 << ((op_nb as f32).log2().ceil() as usize);
|
||||
|
||||
let mut proc_nb = op_nb;
|
||||
let mut src = if bit_type.unwrap_or(BitType::One) == BitType::One {
|
||||
src_a.to_vec()
|
||||
} else {
|
||||
// Bitwise not
|
||||
// Do not clean the ct, but reduce the nb of sequential operations
|
||||
// in next step (reducing proc_nb).
|
||||
proc_nb -= 1;
|
||||
let cst_msg_max = prog.new_imm(tfhe_params.msg_mask());
|
||||
src_a.iter().map(|ct| &cst_msg_max - ct).collect()
|
||||
};
|
||||
|
||||
if inverse_direction.unwrap_or(false) {
|
||||
src.reverse();
|
||||
}
|
||||
|
||||
// First step
|
||||
// Work within each group of proc_nb blocks.
|
||||
// For <i> get a boolean not null status of current block and the MSB ones.
|
||||
// within this group.
|
||||
let mut g_a: Vec<metavar::MetaVarCell> = Vec::new();
|
||||
for (c_id, c) in src.chunks(proc_nb).enumerate() {
|
||||
c.iter().rev().fold(None, |acc, elt| {
|
||||
let is_not_null;
|
||||
let tmp;
|
||||
if let Some(x) = acc {
|
||||
tmp = &x + elt;
|
||||
is_not_null = tmp.pbs(&pbs_not_null, false);
|
||||
} else {
|
||||
tmp = elt.clone();
|
||||
is_not_null = elt.pbs(&pbs_not_null, false);
|
||||
};
|
||||
g_a.insert(c_id * proc_nb, is_not_null); // Reverse insertion per chunk
|
||||
Some(tmp)
|
||||
});
|
||||
}
|
||||
|
||||
// Second step
|
||||
// Proparate the not null status from MSB to LSB, with stride of
|
||||
// (op_nb_bool**k)*proc_nb
|
||||
//assert_eq!(g_a.len(),props.blk_w());
|
||||
let grp_nb = g_a.len().div_ceil(proc_nb);
|
||||
let mut stride_size: usize = 1; // in group unit
|
||||
while stride_size < grp_nb {
|
||||
for chk in g_a.chunks_mut(op_nb_bool * stride_size * proc_nb) {
|
||||
chk.chunks_mut(stride_size * proc_nb)
|
||||
.rev()
|
||||
.fold(None, |acc, sub_chk| {
|
||||
if let Some(x) = acc {
|
||||
let tmp = &x + &sub_chk[0];
|
||||
sub_chk[0] = tmp.pbs(&pbs_not_null, false);
|
||||
Some(tmp)
|
||||
} else {
|
||||
Some(sub_chk[0].clone())
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
stride_size *= op_nb_bool;
|
||||
}
|
||||
|
||||
// Third step
|
||||
// Apply
|
||||
g_a.chunks_mut(proc_nb).rev().fold(None, |acc, chk| {
|
||||
if let Some(x) = acc {
|
||||
for (idx, v) in chk.iter_mut().enumerate() {
|
||||
if idx == 0 {
|
||||
// [0] is already complete.
|
||||
// Need to inverse it for 0 if needed
|
||||
if inverse_output.unwrap_or(false) {
|
||||
*v = v.pbs(&pbs_is_null, false);
|
||||
}
|
||||
} else {
|
||||
*v = &*v + x;
|
||||
if inverse_output.unwrap_or(false) {
|
||||
*v = v.pbs(&pbs_is_null, false);
|
||||
} else {
|
||||
*v = v.pbs(&pbs_not_null, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(&chk[0])
|
||||
});
|
||||
|
||||
if inverse_direction.unwrap_or(false) {
|
||||
g_a.reverse();
|
||||
}
|
||||
|
||||
g_a
|
||||
}
|
||||
@@ -73,6 +73,14 @@ crate::impl_fw!("Llt" [
|
||||
|
||||
ERC_20 => fw_impl::llt::iop_erc_20;
|
||||
MEMCPY => fw_impl::ilp::iop_memcpy;
|
||||
|
||||
COUNT0 => fw_impl::ilp_log::iop_count0;
|
||||
COUNT1 => fw_impl::ilp_log::iop_count1;
|
||||
ILOG2 => fw_impl::ilp_log::iop_ilog2;
|
||||
LEAD0 => fw_impl::ilp_log::iop_lead0;
|
||||
LEAD1 => fw_impl::ilp_log::iop_lead1;
|
||||
TRAIL0 => fw_impl::ilp_log::iop_trail0;
|
||||
TRAIL1 => fw_impl::ilp_log::iop_trail1;
|
||||
]);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::asm::{AsmIOpcode, DOp, IOpcode};
|
||||
pub mod demo;
|
||||
pub mod ilp;
|
||||
pub mod ilp_div;
|
||||
pub mod ilp_log;
|
||||
pub mod llt;
|
||||
|
||||
/// Utility macro to define new FW implementation
|
||||
|
||||
@@ -368,6 +368,22 @@ mod hpu_test {
|
||||
}
|
||||
});
|
||||
|
||||
// Bit count IOp
|
||||
hpu_testcase!("COUNT0" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].count_zeros()]);
|
||||
hpu_testcase!("COUNT1" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].count_ones()]);
|
||||
hpu_testcase!("ILOG2" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].ilog2()]);
|
||||
hpu_testcase!("LEAD0" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].leading_zeros()]);
|
||||
hpu_testcase!("LEAD1" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].leading_ones()]);
|
||||
hpu_testcase!("TRAIL0" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].trailing_zeros()]);
|
||||
hpu_testcase!("TRAIL1" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| [ct[0].trailing_ones()]);
|
||||
|
||||
// Define a set of test bundle for various size
|
||||
// 8bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
@@ -453,6 +469,17 @@ mod hpu_test {
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cntbit"::8 => [
|
||||
"count0",
|
||||
"count1",
|
||||
"ilog2",
|
||||
"lead0",
|
||||
"lead1",
|
||||
"trail0",
|
||||
"trail1"
|
||||
]);
|
||||
|
||||
// 16bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::16 => [
|
||||
@@ -537,6 +564,17 @@ mod hpu_test {
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cntbit"::16 => [
|
||||
"count0",
|
||||
"count1",
|
||||
"ilog2",
|
||||
"lead0",
|
||||
"lead1",
|
||||
"trail0",
|
||||
"trail1"
|
||||
]);
|
||||
|
||||
// 32bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::32 => [
|
||||
@@ -621,6 +659,17 @@ mod hpu_test {
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cntbit"::32 => [
|
||||
"count0",
|
||||
"count1",
|
||||
"ilog2",
|
||||
"lead0",
|
||||
"lead1",
|
||||
"trail0",
|
||||
"trail1"
|
||||
]);
|
||||
|
||||
// 64bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::64 => [
|
||||
@@ -705,6 +754,17 @@ mod hpu_test {
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cntbit"::64 => [
|
||||
"count0",
|
||||
"count1",
|
||||
"ilog2",
|
||||
"lead0",
|
||||
"lead1",
|
||||
"trail0",
|
||||
"trail1"
|
||||
]);
|
||||
|
||||
// 128bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::128 => [
|
||||
@@ -789,6 +849,17 @@ mod hpu_test {
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cntbit"::128 => [
|
||||
"count0",
|
||||
"count1",
|
||||
"ilog2",
|
||||
"lead0",
|
||||
"lead1",
|
||||
"trail0",
|
||||
"trail1"
|
||||
]);
|
||||
|
||||
/// Simple test dedicated to check entities conversion from/to Cpu
|
||||
#[cfg(feature = "hpu")]
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user