mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
feat(hpu): Adding a massively parallel multiplier operation
This commit is contained in:
committed by
Pierre Gardrat
parent
eeccace7b3
commit
827a6e912c
@@ -189,7 +189,7 @@ pub fn iop_mul(prog: &mut Program) {
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("MUL Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic mulx function
|
||||
|
||||
iop_mulx(prog, dst, src_a, src_b).add_to_prog(prog);
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ pub fn iop_muls(prog: &mut Program) {
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("MULS Operand::Dst Operand::Src Operand::Immediat".to_string());
|
||||
// Deferred implementation to generic mulx function
|
||||
|
||||
iop_mulx(prog, dst, src_a, src_b).add_to_prog(prog);
|
||||
}
|
||||
|
||||
@@ -313,11 +313,176 @@ fn iop_subx(
|
||||
.add_to_prog(prog);
|
||||
}
|
||||
|
||||
/// Generic mul operation for massively parallel HPUs
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_mulx_par(
|
||||
prog: &mut Program,
|
||||
dst: Vec<metavar::MetaVarCell>,
|
||||
src_a: Vec<metavar::MetaVarCell>,
|
||||
src_b: Vec<metavar::MetaVarCell>,
|
||||
) -> Rtl {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
let blk_w = props.blk_w();
|
||||
|
||||
// Transform metavars into RTL vars
|
||||
let mut dst = VarCell::from_vec(dst);
|
||||
let src_a = VarCell::from_vec(src_a);
|
||||
let src_b = VarCell::from_vec(src_b);
|
||||
let max_deg = VarDeg {
|
||||
deg: props.max_val(),
|
||||
nu: props.nu,
|
||||
};
|
||||
|
||||
let pbs_mul_lsb = pbs_by_name!("MultCarryMsgLsb");
|
||||
let pbs_mul_msb = pbs_by_name!("MultCarryMsgMsb");
|
||||
let max_carry = (props.max_msg() * props.max_msg()) >> props.msg_w;
|
||||
let max_msg = props.max_msg();
|
||||
|
||||
let mut mul_map: HashMap<usize, Vec<VarCellDeg>> = HashMap::new();
|
||||
itertools::iproduct!(0..blk_w, 0..blk_w).for_each(|(i, j)| {
|
||||
let pp = src_a[i].mac(tfhe_params.msg_range(), &src_b[j]);
|
||||
let lsb = pp.single_pbs(&pbs_mul_lsb);
|
||||
let msb = pp.single_pbs(&pbs_mul_msb);
|
||||
mul_map
|
||||
.entry(i + j)
|
||||
.or_default()
|
||||
.push(VarCellDeg::new(max_msg, lsb));
|
||||
mul_map
|
||||
.entry(i + j + 1)
|
||||
.or_default()
|
||||
.push(VarCellDeg::new(max_carry, msb));
|
||||
});
|
||||
|
||||
let mut pp: Vec<VecVarCellDeg> = (0..dst.len())
|
||||
.map(|i| mul_map.remove(&i).unwrap().into())
|
||||
.collect();
|
||||
|
||||
// Reduce dada tree like
|
||||
while pp.iter().any(|x| x.len() > 1) {
|
||||
trace!(
|
||||
target: "llt::mul",
|
||||
"pp length: {:?}",
|
||||
pp.iter().map(|x| x.len()).collect::<Vec<_>>()
|
||||
);
|
||||
for c in (0..dst.len()).rev() {
|
||||
let mut col_len = pp[c].len();
|
||||
let mut reduced = Vec::new();
|
||||
let mut chunks = pp[c].deg_chunks(&max_deg).peekable();
|
||||
let max_col = if c == (dst.len() - 1) {
|
||||
0
|
||||
} else {
|
||||
dst.len() - 1
|
||||
};
|
||||
|
||||
while chunks.peek().is_some() && col_len > pp[max_col].len() {
|
||||
let mut chunk = chunks.next().unwrap();
|
||||
let chunk_len = chunk.len();
|
||||
col_len -= chunk.len();
|
||||
|
||||
// sum the chunk
|
||||
while chunk.len() > 1 {
|
||||
chunk = chunk
|
||||
.chunks(2)
|
||||
.map(|chunk| match chunk.len() {
|
||||
1 => chunk[0].clone(),
|
||||
2 => &chunk[0] + &chunk[1],
|
||||
_ => panic!("Invalid chunk size"),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// And bootstrap if needed
|
||||
let element = chunk
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|sum| {
|
||||
assert!(sum.deg.nu <= props.nu);
|
||||
if sum.deg == max_deg || chunk_len == 1 {
|
||||
let (data, carry) = sum.bootstrap(&props);
|
||||
if let (Some(carry), Some(elm)) = (carry, pp.get_mut(c + 1)) {
|
||||
elm.push(carry);
|
||||
}
|
||||
data
|
||||
} else {
|
||||
sum
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
reduced.push(element);
|
||||
}
|
||||
|
||||
pp[c] = reduced
|
||||
.into_iter()
|
||||
.chain(chunks.flatten())
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
target: "llt::mul",
|
||||
"final pp: {:?}", pp
|
||||
);
|
||||
|
||||
// Extract carry and message and do carry propagation
|
||||
let mut a: Vec<Option<VarCell>> = (0..dst.len() + 1).map(|_| None).collect();
|
||||
let mut b: Vec<Option<VarCell>> = (0..dst.len() + 1).map(|_| None).collect();
|
||||
|
||||
pp.into_iter().enumerate().for_each(|(i, pp)| {
|
||||
assert!(pp.len() == 1);
|
||||
let vardeg = pp.first().unwrap();
|
||||
let (msg, carry) = vardeg.bootstrap(&props);
|
||||
a[i] = Some(msg.var);
|
||||
if let Some(carry) = carry {
|
||||
b[i + 1] = Some(carry.var);
|
||||
}
|
||||
});
|
||||
|
||||
let cs: Vec<_> = a
|
||||
.into_iter()
|
||||
.take(dst.len())
|
||||
.zip(b.into_iter())
|
||||
.map(|(a, b)| match (a, b) {
|
||||
(Some(a), Some(b)) => &a + &b,
|
||||
(Some(a), None) => a,
|
||||
(None, Some(b)) => b,
|
||||
_ => panic!("Fix your code"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Do fully parallel carry propagation
|
||||
kogge::propagate_carry(prog, dst.as_mut_slice(), cs.as_slice(), &None);
|
||||
|
||||
Rtl::from(dst)
|
||||
}
|
||||
|
||||
/// multiplier wrapper, to choose between parallel and serial implementations
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_mulx(
|
||||
prog: &mut Program,
|
||||
dst: Vec<metavar::MetaVarCell>,
|
||||
src_a: Vec<metavar::MetaVarCell>,
|
||||
src_b: Vec<metavar::MetaVarCell>,
|
||||
) -> Rtl {
|
||||
// When the batch size is enough to do a full stage in parallel, do parallel
|
||||
// mul.
|
||||
// Note: The break-even point might not be this one, but choosing the right
|
||||
// point is uninportant since we'll leap imensely the number of batches from
|
||||
// FPGA to ASIC.
|
||||
if prog.params().pbs_batch_w >= dst.len() {
|
||||
iop_mulx_par(prog, dst, src_a, src_b)
|
||||
} else {
|
||||
iop_mulx_ser(prog, dst, src_a, src_b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic mul operation
|
||||
/// One destination and two sources operation
|
||||
/// Source could be Operand or Immediat
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_mulx(
|
||||
pub fn iop_mulx_ser(
|
||||
prog: &mut Program,
|
||||
dst: Vec<metavar::MetaVarCell>,
|
||||
src_a: Vec<metavar::MetaVarCell>,
|
||||
@@ -367,7 +532,10 @@ pub fn iop_mulx(
|
||||
sum.var.single_pbs(&pbs_carry),
|
||||
));
|
||||
}
|
||||
VarCellDeg::new(props.max_msg(), sum.var.single_pbs(&pbs_msg))
|
||||
VarCellDeg::new(
|
||||
sum.deg.deg.min(props.max_msg()),
|
||||
sum.var.single_pbs(&pbs_msg),
|
||||
)
|
||||
};
|
||||
|
||||
while to_sum.len() > 1 {
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::rtl::VarCell;
|
||||
use super::*;
|
||||
use crate::pbs_by_name;
|
||||
use tracing::trace;
|
||||
|
||||
#[derive(Clone, Eq, Default, Debug)]
|
||||
@@ -48,6 +50,48 @@ pub struct VarCellDeg {
|
||||
pub deg: VarDeg,
|
||||
}
|
||||
|
||||
impl VarCellDeg {
|
||||
pub fn bootstrap(&self, props: &FwParameters) -> (VarCellDeg, Option<VarCellDeg>) {
|
||||
trace!(target: "vardeg::VarCellDeg::bootstrap", "bootstrap: {:?}", self);
|
||||
|
||||
let pbs_many_carry = pbs_by_name!("ManyCarryMsg");
|
||||
let pbs_carry = pbs_by_name!("CarryInMsg");
|
||||
let pbs_msg = pbs_by_name!("MsgOnly");
|
||||
|
||||
if self.deg.deg <= props.max_msg() {
|
||||
match self.deg.nu {
|
||||
1 => (self.clone(), None),
|
||||
_ => (
|
||||
VarCellDeg::new(self.deg.deg, self.var.single_pbs(&pbs_msg)),
|
||||
None,
|
||||
),
|
||||
}
|
||||
// If we still have a bit available to do manyLUT
|
||||
} else if self.deg.deg > props.max_msg() && self.deg.deg <= (props.max_val() >> 1) {
|
||||
let mut pbs = self.var.pbs(&pbs_many_carry).into_iter();
|
||||
(
|
||||
VarCellDeg::new(props.max_msg().min(self.deg.deg), pbs.next().unwrap()),
|
||||
Some(VarCellDeg::new(
|
||||
self.deg.deg >> props.carry_w,
|
||||
pbs.next().unwrap(),
|
||||
)),
|
||||
)
|
||||
//Otherwise, we'll have to use two independent PBSs
|
||||
} else {
|
||||
(
|
||||
VarCellDeg::new(
|
||||
self.deg.deg.min(props.max_msg()),
|
||||
self.var.single_pbs(&pbs_msg),
|
||||
),
|
||||
Some(VarCellDeg::new(
|
||||
self.deg.deg >> props.carry_w,
|
||||
self.var.single_pbs(&pbs_carry),
|
||||
)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for VarCellDeg {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
@@ -105,25 +149,23 @@ impl std::fmt::Debug for VarCellDeg {
|
||||
}
|
||||
|
||||
impl VecVarCellDeg {
|
||||
pub fn deg_chunks(
|
||||
mut self,
|
||||
max_deg: &VarDeg,
|
||||
) -> <Vec<Vec<VarCellDeg>> as IntoIterator>::IntoIter {
|
||||
pub fn deg_chunks(&self, max_deg: &VarDeg) -> <Vec<Vec<VarCellDeg>> as IntoIterator>::IntoIter {
|
||||
trace!(target: "llt:deg_chunks", "len: {:?}, {:?}", self.len(), self.0);
|
||||
|
||||
let mut res: Vec<Vec<VarCellDeg>> = Vec::new();
|
||||
let mut acc: VarDeg = VarDeg::default();
|
||||
let mut chunk: Vec<VarCellDeg> = Vec::new();
|
||||
let mut copy = self.0.clone();
|
||||
|
||||
// There are many ways to combine the whole vector in chunks up to
|
||||
// max_deg. We'll be greedy and sum up the elements by maximum degree
|
||||
// first.
|
||||
self.0.sort();
|
||||
copy.sort();
|
||||
|
||||
while !self.is_empty() {
|
||||
let sum = &acc + &self.0.last().unwrap().deg;
|
||||
while !copy.is_empty() {
|
||||
let sum = &acc + ©.last().unwrap().deg;
|
||||
if sum <= *max_deg {
|
||||
chunk.push(self.0.pop().unwrap());
|
||||
chunk.push(copy.pop().unwrap());
|
||||
acc = sum;
|
||||
} else {
|
||||
res.push(chunk);
|
||||
@@ -131,7 +173,7 @@ impl VecVarCellDeg {
|
||||
chunk = Vec::new();
|
||||
}
|
||||
trace!(target: "llt:deg_chunks:loop", "len: {:?}, {:?}, chunk: {:?}, acc: {:?}",
|
||||
self.len(), self.0, chunk, acc);
|
||||
self.len(), copy, chunk, acc);
|
||||
}
|
||||
|
||||
// Any remaining chunk is appended
|
||||
@@ -159,4 +201,8 @@ impl VecVarCellDeg {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.len() == 0
|
||||
}
|
||||
|
||||
pub fn push(&mut self, item: VarCellDeg) {
|
||||
self.0.push(item)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user