feat(hpu): LLT ROT/SHIFT IOPs

This commit is contained in:
Helder Campos
2025-10-03 18:28:20 +01:00
committed by Pierre Gardrat
parent b4b6275ca5
commit 7b621e57b0
4 changed files with 237 additions and 9 deletions

View File

@@ -22,6 +22,7 @@
]
heap_size = 16384
lut_mem = 256
lut_pc = {Hbm={pc=34}}
@@ -70,6 +71,7 @@
trace_depth = 32 # In MB
[firmware]
#implementation = "Ilp"
implementation = "Llt"
integer_w=[2,4,6,8,10,12,14,16,32,64,128]
min_batch_size = 12

View File

@@ -750,14 +750,14 @@ pub fn iop_mulx(
}
#[derive(Debug, Clone, Copy)]
enum ShiftKind {
pub(super) enum ShiftKind {
ShiftRight,
ShiftLeft,
RotRight,
RotLeft,
}
#[derive(Debug, Clone, Copy)]
enum CondPos {
pub(super) enum CondPos {
Pos0,
Pos1,
}

View File

@@ -12,7 +12,8 @@ use crate::asm::{self, OperandKind, Pbs};
use crate::fw::metavar::MetaVarCell;
use crate::fw::program::Program;
use crate::pbs_by_name;
use itertools::{EitherOrBoth, Itertools};
use fw_impl::ilp::{CondPos, ShiftKind};
use itertools::{EitherOrBoth, Itertools, Position};
use std::collections::HashMap;
use tracing::{instrument, trace};
@@ -29,12 +30,10 @@ crate::impl_fw!("Llt" [
OVF_SUB => fw_impl::ilp::iop_overflow_sub;
OVF_MUL => fw_impl::ilp::iop_overflow_mul;
// NB: fallback to ilp
// TODO: Add dedicated llt implementation
ROT_R => fw_impl::ilp::iop_rotate_right;
ROT_L => fw_impl::ilp::iop_rotate_left;
SHIFT_R => fw_impl::ilp::iop_shift_right;
SHIFT_L => fw_impl::ilp::iop_shift_left;
ROT_R => fw_impl::llt::iop_rotate_right;
ROT_L => fw_impl::llt::iop_rotate_left;
SHIFT_R => fw_impl::llt::iop_shift_right;
SHIFT_L => fw_impl::llt::iop_shift_left;
ADDS => fw_impl::llt::iop_adds;
SUBS => fw_impl::llt::iop_subs;
@@ -238,6 +237,70 @@ pub fn iop_erc_20_simd(prog: &mut Program) {
simd(prog, crate::asm::iop::SIMD_N, fw_impl::llt::iop_erc_20_rtl);
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_shift_right(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
// Src -> Operand
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
// Amount -> Operand
let amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
// Add Comment header
prog.push_comment("SHIFT_R Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic rotx function
iop_shiftrotx(prog, ShiftKind::ShiftRight, dst, src, amount).add_to_prog(prog);
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_shift_left(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
// Src -> Operand
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
// ShiftAmount -> Operand
let amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
// Add Comment header
prog.push_comment("SHIFT_L Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic rotx function
iop_shiftrotx(prog, ShiftKind::ShiftLeft, dst, src, amount).add_to_prog(prog);
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_rotate_right(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
// Src -> Operand
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
// ShiftAmount -> Operand
let rot_amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
// Add Comment header
prog.push_comment("ROT_R Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic rotx function
iop_shiftrotx(prog, ShiftKind::RotRight, dst, src, rot_amount).add_to_prog(prog);
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_rotate_left(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
// Src -> Operand
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
// ShiftAmount -> Operand
let rot_amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
// Add Comment header
prog.push_comment("ROT_L Operand::Dst Operand::Src Operand::Src".to_string());
// Deferred implementation to generic rotx function
iop_shiftrotx(prog, ShiftKind::RotLeft, dst, src, rot_amount).add_to_prog(prog);
}
// ----------------------------------------------------------------------------
// Helper Functions
// ----------------------------------------------------------------------------
@@ -782,3 +845,165 @@ where
.sum::<Rtl>()
.add_to_prog(prog);
}
// Comupute inner-shift
// input:
// * src: clean ciphertext with only message
// * amount: ciphertext encoding amount to Shift/Rotate. Only Lsb of msg will be considered
// output:
// Tuple of msg and msg_next.
// msg_next is the contribution of next ct block in the shift direction
fn inner_shift(
prog: &Program,
dir: ShiftKind,
src: &VarCell,
amount: &VarCell,
) -> (VarCell, VarCell) {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let (pbs_msg, pbs_msg_next) = match dir {
ShiftKind::ShiftRight | ShiftKind::RotRight => (
pbs_by_name!("ShiftRightByCarryPos0Msg"),
pbs_by_name!("ShiftRightByCarryPos0MsgNext"),
),
ShiftKind::ShiftLeft | ShiftKind::RotLeft => (
pbs_by_name!("ShiftLeftByCarryPos0Msg"),
pbs_by_name!("ShiftLeftByCarryPos0MsgNext"),
),
};
let pack = src.mac(tfhe_params.msg_range(), amount);
let msg = pack.single_pbs(&pbs_msg);
let msg_next = pack.single_pbs(&pbs_msg_next);
(msg, msg_next)
}
fn block_swap(
prog: &Program,
src_orig: &VarCell,
src_swap: Option<&VarCell>,
cond: &VarCell,
cond_mask: CondPos,
) -> VarCell {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let (pbs_orig, pbs_swap) = match cond_mask {
CondPos::Pos0 => (
pbs_by_name!("IfPos0TrueZeroed"),
pbs_by_name!("IfPos0FalseZeroed"),
),
CondPos::Pos1 => (
pbs_by_name!("IfPos1TrueZeroed"),
pbs_by_name!("IfPos1FalseZeroed"),
),
};
let pack_orig = src_orig.mac(tfhe_params.msg_range(), cond);
if let Some(swap) = src_swap {
let pack_swap = swap.mac(tfhe_params.msg_range(), cond);
&pack_orig.single_pbs(&pbs_orig) + &pack_swap.single_pbs(&pbs_swap)
} else {
pack_orig.single_pbs(&pbs_orig)
}
}
/// Generic shift function operation
#[instrument(level = "trace", skip(prog))]
fn iop_shiftrotx(
prog: &Program,
kind: ShiftKind,
mut dst: Vec<VarCell>,
src: Vec<VarCell>,
amount: Vec<VarCell>,
) -> Rtl {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
let blk_w = props.blk_w();
// First apply inner shift
let (shiftrot, shiftrot_next): (Vec<_>, Vec<_>) = src
.iter()
.map(|ct| inner_shift(prog, kind, ct, &amount[0]))
.unzip();
// Fuse msg and next msg based on direction/kind
let mut merge_shiftrot = shiftrot
.into_iter()
.enumerate()
.with_position()
.map(|(pos, (i, ct))| match kind {
ShiftKind::ShiftRight => {
if !matches!(pos, Position::Last | Position::Only) {
&ct + &shiftrot_next[i + 1]
} else {
ct
}
}
ShiftKind::ShiftLeft => {
if !matches!(pos, Position::First | Position::Only) {
&ct + &shiftrot_next[i - 1]
} else {
ct
}
}
ShiftKind::RotRight => {
let rot_idx = (i + 1) % shiftrot_next.len();
&ct + &shiftrot_next[rot_idx]
}
ShiftKind::RotLeft => {
let rot_idx = ((i + shiftrot_next.len()) - 1) % shiftrot_next.len();
&ct + &shiftrot_next[rot_idx]
}
})
.collect::<Vec<_>>();
// Second apply block swap
// Block swapping done with successive buterflies with log2 stages
// NB: each block encode msg_w bits thus:
// * First shiftrot is already done with inner_shiftrot
// * Two swap is done for each amount blk
for stg in 1..(2 * blk_w).ilog2() as usize {
merge_shiftrot = (0..blk_w)
.map(|i| {
let stride = 1 << (stg - 1);
let swap = match kind {
ShiftKind::ShiftRight => merge_shiftrot.get(i + stride),
ShiftKind::ShiftLeft => {
if i >= stride {
merge_shiftrot.get(i - stride)
} else {
None
}
}
ShiftKind::RotRight => {
let swap_idx = (i + stride) % merge_shiftrot.len();
merge_shiftrot.get(swap_idx)
}
ShiftKind::RotLeft => {
let swap_idx = (i + merge_shiftrot.len() - stride) % merge_shiftrot.len();
merge_shiftrot.get(swap_idx)
}
};
// Based on stage index shiftrot condition is in amount msg at pos0 or pos1
block_swap(
prog,
&merge_shiftrot[i],
swap,
&amount[stg / tfhe_params.msg_w],
if stg % 2 == 1 {
CondPos::Pos1
} else {
CondPos::Pos0
},
)
})
.collect::<Vec<_>>();
}
dst.iter_mut()
.zip(merge_shiftrot.iter())
.for_each(|(d, r)| {
*d <<= r;
});
dst.into()
}

View File

@@ -1770,6 +1770,7 @@ impl<'a> dot2::Labeller<'a> for Graph {
n.copy_name(),
n.borrow()
.load_stats()
.clone()
.and_then(|l| Some(format!("{:?}", l)))
.unwrap_or(String::from("None")),
n.copy_uid(),