feat(hpu): Adding a massively parallel multiplier operation

This commit is contained in:
Helder Campos
2025-06-17 15:48:27 +01:00
committed by Pierre Gardrat
parent eeccace7b3
commit 827a6e912c
2 changed files with 227 additions and 13 deletions

View File

@@ -189,7 +189,7 @@ pub fn iop_mul(prog: &mut Program) {
// Add Comment header
prog.push_comment("MUL Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic mulx function
iop_mulx(prog, dst, src_a, src_b).add_to_prog(prog);
}
@@ -205,7 +205,7 @@ pub fn iop_muls(prog: &mut Program) {
// Add Comment header
prog.push_comment("MULS Operand::Dst Operand::Src Operand::Immediat".to_string());
// Deferred implementation to generic mulx function
iop_mulx(prog, dst, src_a, src_b).add_to_prog(prog);
}
@@ -313,11 +313,176 @@ fn iop_subx(
.add_to_prog(prog);
}
/// Generic mul operation for massively parallel HPUs
#[instrument(level = "trace", skip(prog))]
pub fn iop_mulx_par(
prog: &mut Program,
dst: Vec<metavar::MetaVarCell>,
src_a: Vec<metavar::MetaVarCell>,
src_b: Vec<metavar::MetaVarCell>,
) -> Rtl {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let blk_w = props.blk_w();
// Transform metavars into RTL vars
let mut dst = VarCell::from_vec(dst);
let src_a = VarCell::from_vec(src_a);
let src_b = VarCell::from_vec(src_b);
let max_deg = VarDeg {
deg: props.max_val(),
nu: props.nu,
};
let pbs_mul_lsb = pbs_by_name!("MultCarryMsgLsb");
let pbs_mul_msb = pbs_by_name!("MultCarryMsgMsb");
let max_carry = (props.max_msg() * props.max_msg()) >> props.msg_w;
let max_msg = props.max_msg();
let mut mul_map: HashMap<usize, Vec<VarCellDeg>> = HashMap::new();
itertools::iproduct!(0..blk_w, 0..blk_w).for_each(|(i, j)| {
let pp = src_a[i].mac(tfhe_params.msg_range(), &src_b[j]);
let lsb = pp.single_pbs(&pbs_mul_lsb);
let msb = pp.single_pbs(&pbs_mul_msb);
mul_map
.entry(i + j)
.or_default()
.push(VarCellDeg::new(max_msg, lsb));
mul_map
.entry(i + j + 1)
.or_default()
.push(VarCellDeg::new(max_carry, msb));
});
let mut pp: Vec<VecVarCellDeg> = (0..dst.len())
.map(|i| mul_map.remove(&i).unwrap().into())
.collect();
// Reduce dada tree like
while pp.iter().any(|x| x.len() > 1) {
trace!(
target: "llt::mul",
"pp length: {:?}",
pp.iter().map(|x| x.len()).collect::<Vec<_>>()
);
for c in (0..dst.len()).rev() {
let mut col_len = pp[c].len();
let mut reduced = Vec::new();
let mut chunks = pp[c].deg_chunks(&max_deg).peekable();
let max_col = if c == (dst.len() - 1) {
0
} else {
dst.len() - 1
};
while chunks.peek().is_some() && col_len > pp[max_col].len() {
let mut chunk = chunks.next().unwrap();
let chunk_len = chunk.len();
col_len -= chunk.len();
// sum the chunk
while chunk.len() > 1 {
chunk = chunk
.chunks(2)
.map(|chunk| match chunk.len() {
1 => chunk[0].clone(),
2 => &chunk[0] + &chunk[1],
_ => panic!("Invalid chunk size"),
})
.collect()
}
// And bootstrap if needed
let element = chunk
.into_iter()
.next()
.map(|sum| {
assert!(sum.deg.nu <= props.nu);
if sum.deg == max_deg || chunk_len == 1 {
let (data, carry) = sum.bootstrap(&props);
if let (Some(carry), Some(elm)) = (carry, pp.get_mut(c + 1)) {
elm.push(carry);
}
data
} else {
sum
}
})
.unwrap();
reduced.push(element);
}
pp[c] = reduced
.into_iter()
.chain(chunks.flatten())
.collect::<Vec<_>>()
.into();
}
}
trace!(
target: "llt::mul",
"final pp: {:?}", pp
);
// Extract carry and message and do carry propagation
let mut a: Vec<Option<VarCell>> = (0..dst.len() + 1).map(|_| None).collect();
let mut b: Vec<Option<VarCell>> = (0..dst.len() + 1).map(|_| None).collect();
pp.into_iter().enumerate().for_each(|(i, pp)| {
assert!(pp.len() == 1);
let vardeg = pp.first().unwrap();
let (msg, carry) = vardeg.bootstrap(&props);
a[i] = Some(msg.var);
if let Some(carry) = carry {
b[i + 1] = Some(carry.var);
}
});
let cs: Vec<_> = a
.into_iter()
.take(dst.len())
.zip(b.into_iter())
.map(|(a, b)| match (a, b) {
(Some(a), Some(b)) => &a + &b,
(Some(a), None) => a,
(None, Some(b)) => b,
_ => panic!("Fix your code"),
})
.collect();
// Do fully parallel carry propagation
kogge::propagate_carry(prog, dst.as_mut_slice(), cs.as_slice(), &None);
Rtl::from(dst)
}
/// multiplier wrapper, to choose between parallel and serial implementations
#[instrument(level = "trace", skip(prog))]
pub fn iop_mulx(
prog: &mut Program,
dst: Vec<metavar::MetaVarCell>,
src_a: Vec<metavar::MetaVarCell>,
src_b: Vec<metavar::MetaVarCell>,
) -> Rtl {
// When the batch size is enough to do a full stage in parallel, do parallel
// mul.
// Note: The break-even point might not be this one, but choosing the right
// point is uninportant since we'll leap imensely the number of batches from
// FPGA to ASIC.
if prog.params().pbs_batch_w >= dst.len() {
iop_mulx_par(prog, dst, src_a, src_b)
} else {
iop_mulx_ser(prog, dst, src_a, src_b)
}
}
/// Generic mul operation
/// One destination and two sources operation
/// Source could be Operand or Immediat
#[instrument(level = "trace", skip(prog))]
pub fn iop_mulx(
pub fn iop_mulx_ser(
prog: &mut Program,
dst: Vec<metavar::MetaVarCell>,
src_a: Vec<metavar::MetaVarCell>,
@@ -367,7 +532,10 @@ pub fn iop_mulx(
sum.var.single_pbs(&pbs_carry),
));
}
VarCellDeg::new(props.max_msg(), sum.var.single_pbs(&pbs_msg))
VarCellDeg::new(
sum.deg.deg.min(props.max_msg()),
sum.var.single_pbs(&pbs_msg),
)
};
while to_sum.len() > 1 {

View File

@@ -1,4 +1,6 @@
use super::rtl::VarCell;
use super::*;
use crate::pbs_by_name;
use tracing::trace;
#[derive(Clone, Eq, Default, Debug)]
@@ -48,6 +50,48 @@ pub struct VarCellDeg {
pub deg: VarDeg,
}
impl VarCellDeg {
pub fn bootstrap(&self, props: &FwParameters) -> (VarCellDeg, Option<VarCellDeg>) {
trace!(target: "vardeg::VarCellDeg::bootstrap", "bootstrap: {:?}", self);
let pbs_many_carry = pbs_by_name!("ManyCarryMsg");
let pbs_carry = pbs_by_name!("CarryInMsg");
let pbs_msg = pbs_by_name!("MsgOnly");
if self.deg.deg <= props.max_msg() {
match self.deg.nu {
1 => (self.clone(), None),
_ => (
VarCellDeg::new(self.deg.deg, self.var.single_pbs(&pbs_msg)),
None,
),
}
// If we still have a bit available to do manyLUT
} else if self.deg.deg > props.max_msg() && self.deg.deg <= (props.max_val() >> 1) {
let mut pbs = self.var.pbs(&pbs_many_carry).into_iter();
(
VarCellDeg::new(props.max_msg().min(self.deg.deg), pbs.next().unwrap()),
Some(VarCellDeg::new(
self.deg.deg >> props.carry_w,
pbs.next().unwrap(),
)),
)
//Otherwise, we'll have to use two independent PBSs
} else {
(
VarCellDeg::new(
self.deg.deg.min(props.max_msg()),
self.var.single_pbs(&pbs_msg),
),
Some(VarCellDeg::new(
self.deg.deg >> props.carry_w,
self.var.single_pbs(&pbs_carry),
)),
)
}
}
}
impl PartialOrd for VarCellDeg {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
@@ -105,25 +149,23 @@ impl std::fmt::Debug for VarCellDeg {
}
impl VecVarCellDeg {
pub fn deg_chunks(
mut self,
max_deg: &VarDeg,
) -> <Vec<Vec<VarCellDeg>> as IntoIterator>::IntoIter {
pub fn deg_chunks(&self, max_deg: &VarDeg) -> <Vec<Vec<VarCellDeg>> as IntoIterator>::IntoIter {
trace!(target: "llt:deg_chunks", "len: {:?}, {:?}", self.len(), self.0);
let mut res: Vec<Vec<VarCellDeg>> = Vec::new();
let mut acc: VarDeg = VarDeg::default();
let mut chunk: Vec<VarCellDeg> = Vec::new();
let mut copy = self.0.clone();
// There are many ways to combine the whole vector in chunks up to
// max_deg. We'll be greedy and sum up the elements by maximum degree
// first.
self.0.sort();
copy.sort();
while !self.is_empty() {
let sum = &acc + &self.0.last().unwrap().deg;
while !copy.is_empty() {
let sum = &acc + &copy.last().unwrap().deg;
if sum <= *max_deg {
chunk.push(self.0.pop().unwrap());
chunk.push(copy.pop().unwrap());
acc = sum;
} else {
res.push(chunk);
@@ -131,7 +173,7 @@ impl VecVarCellDeg {
chunk = Vec::new();
}
trace!(target: "llt:deg_chunks:loop", "len: {:?}, {:?}, chunk: {:?}, acc: {:?}",
self.len(), self.0, chunk, acc);
self.len(), copy, chunk, acc);
}
// Any remaining chunk is appended
@@ -159,4 +201,8 @@ impl VecVarCellDeg {
pub fn is_empty(&self) -> bool {
self.0.len() == 0
}
pub fn push(&mut self, item: VarCellDeg) {
self.0.push(item)
}
}