mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(hpu): LLT ROT/SHIFT IOPs
This commit is contained in:
committed by
Pierre Gardrat
parent
b4b6275ca5
commit
7b621e57b0
@@ -22,6 +22,7 @@
|
||||
]
|
||||
heap_size = 16384
|
||||
|
||||
|
||||
lut_mem = 256
|
||||
lut_pc = {Hbm={pc=34}}
|
||||
|
||||
@@ -70,6 +71,7 @@
|
||||
trace_depth = 32 # In MB
|
||||
|
||||
[firmware]
|
||||
#implementation = "Ilp"
|
||||
implementation = "Llt"
|
||||
integer_w=[2,4,6,8,10,12,14,16,32,64,128]
|
||||
min_batch_size = 12
|
||||
|
||||
@@ -750,14 +750,14 @@ pub fn iop_mulx(
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum ShiftKind {
|
||||
pub(super) enum ShiftKind {
|
||||
ShiftRight,
|
||||
ShiftLeft,
|
||||
RotRight,
|
||||
RotLeft,
|
||||
}
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum CondPos {
|
||||
pub(super) enum CondPos {
|
||||
Pos0,
|
||||
Pos1,
|
||||
}
|
||||
|
||||
@@ -12,7 +12,8 @@ use crate::asm::{self, OperandKind, Pbs};
|
||||
use crate::fw::metavar::MetaVarCell;
|
||||
use crate::fw::program::Program;
|
||||
use crate::pbs_by_name;
|
||||
use itertools::{EitherOrBoth, Itertools};
|
||||
use fw_impl::ilp::{CondPos, ShiftKind};
|
||||
use itertools::{EitherOrBoth, Itertools, Position};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{instrument, trace};
|
||||
|
||||
@@ -29,12 +30,10 @@ crate::impl_fw!("Llt" [
|
||||
OVF_SUB => fw_impl::ilp::iop_overflow_sub;
|
||||
OVF_MUL => fw_impl::ilp::iop_overflow_mul;
|
||||
|
||||
// NB: fallback to ilp
|
||||
// TODO: Add dedicated llt implementation
|
||||
ROT_R => fw_impl::ilp::iop_rotate_right;
|
||||
ROT_L => fw_impl::ilp::iop_rotate_left;
|
||||
SHIFT_R => fw_impl::ilp::iop_shift_right;
|
||||
SHIFT_L => fw_impl::ilp::iop_shift_left;
|
||||
ROT_R => fw_impl::llt::iop_rotate_right;
|
||||
ROT_L => fw_impl::llt::iop_rotate_left;
|
||||
SHIFT_R => fw_impl::llt::iop_shift_right;
|
||||
SHIFT_L => fw_impl::llt::iop_shift_left;
|
||||
|
||||
ADDS => fw_impl::llt::iop_adds;
|
||||
SUBS => fw_impl::llt::iop_subs;
|
||||
@@ -238,6 +237,70 @@ pub fn iop_erc_20_simd(prog: &mut Program) {
|
||||
simd(prog, crate::asm::iop::SIMD_N, fw_impl::llt::iop_erc_20_rtl);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_shift_right(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
|
||||
// Src -> Operand
|
||||
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
|
||||
// Amount -> Operand
|
||||
let amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("SHIFT_R Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic rotx function
|
||||
iop_shiftrotx(prog, ShiftKind::ShiftRight, dst, src, amount).add_to_prog(prog);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_shift_left(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
|
||||
// Src -> Operand
|
||||
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
|
||||
// ShiftAmount -> Operand
|
||||
let amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("SHIFT_L Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic rotx function
|
||||
iop_shiftrotx(prog, ShiftKind::ShiftLeft, dst, src, amount).add_to_prog(prog);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_rotate_right(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
|
||||
// Src -> Operand
|
||||
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
|
||||
// ShiftAmount -> Operand
|
||||
let rot_amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("ROT_R Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic rotx function
|
||||
iop_shiftrotx(prog, ShiftKind::RotRight, dst, src, rot_amount).add_to_prog(prog);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_rotate_left(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let dst = VarCell::from_vec(prog.iop_template_var(OperandKind::Dst, 0));
|
||||
// Src -> Operand
|
||||
let src = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 0));
|
||||
// ShiftAmount -> Operand
|
||||
let rot_amount = VarCell::from_vec(prog.iop_template_var(OperandKind::Src, 1));
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("ROT_L Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
// Deferred implementation to generic rotx function
|
||||
iop_shiftrotx(prog, ShiftKind::RotLeft, dst, src, rot_amount).add_to_prog(prog);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Helper Functions
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -782,3 +845,165 @@ where
|
||||
.sum::<Rtl>()
|
||||
.add_to_prog(prog);
|
||||
}
|
||||
|
||||
// Comupute inner-shift
|
||||
// input:
|
||||
// * src: clean ciphertext with only message
|
||||
// * amount: ciphertext encoding amount to Shift/Rotate. Only Lsb of msg will be considered
|
||||
// output:
|
||||
// Tuple of msg and msg_next.
|
||||
// msg_next is the contribution of next ct block in the shift direction
|
||||
fn inner_shift(
|
||||
prog: &Program,
|
||||
dir: ShiftKind,
|
||||
src: &VarCell,
|
||||
amount: &VarCell,
|
||||
) -> (VarCell, VarCell) {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let (pbs_msg, pbs_msg_next) = match dir {
|
||||
ShiftKind::ShiftRight | ShiftKind::RotRight => (
|
||||
pbs_by_name!("ShiftRightByCarryPos0Msg"),
|
||||
pbs_by_name!("ShiftRightByCarryPos0MsgNext"),
|
||||
),
|
||||
ShiftKind::ShiftLeft | ShiftKind::RotLeft => (
|
||||
pbs_by_name!("ShiftLeftByCarryPos0Msg"),
|
||||
pbs_by_name!("ShiftLeftByCarryPos0MsgNext"),
|
||||
),
|
||||
};
|
||||
|
||||
let pack = src.mac(tfhe_params.msg_range(), amount);
|
||||
let msg = pack.single_pbs(&pbs_msg);
|
||||
let msg_next = pack.single_pbs(&pbs_msg_next);
|
||||
(msg, msg_next)
|
||||
}
|
||||
|
||||
fn block_swap(
|
||||
prog: &Program,
|
||||
src_orig: &VarCell,
|
||||
src_swap: Option<&VarCell>,
|
||||
cond: &VarCell,
|
||||
cond_mask: CondPos,
|
||||
) -> VarCell {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
|
||||
let (pbs_orig, pbs_swap) = match cond_mask {
|
||||
CondPos::Pos0 => (
|
||||
pbs_by_name!("IfPos0TrueZeroed"),
|
||||
pbs_by_name!("IfPos0FalseZeroed"),
|
||||
),
|
||||
CondPos::Pos1 => (
|
||||
pbs_by_name!("IfPos1TrueZeroed"),
|
||||
pbs_by_name!("IfPos1FalseZeroed"),
|
||||
),
|
||||
};
|
||||
let pack_orig = src_orig.mac(tfhe_params.msg_range(), cond);
|
||||
if let Some(swap) = src_swap {
|
||||
let pack_swap = swap.mac(tfhe_params.msg_range(), cond);
|
||||
&pack_orig.single_pbs(&pbs_orig) + &pack_swap.single_pbs(&pbs_swap)
|
||||
} else {
|
||||
pack_orig.single_pbs(&pbs_orig)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic shift function operation
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
fn iop_shiftrotx(
|
||||
prog: &Program,
|
||||
kind: ShiftKind,
|
||||
mut dst: Vec<VarCell>,
|
||||
src: Vec<VarCell>,
|
||||
amount: Vec<VarCell>,
|
||||
) -> Rtl {
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
let blk_w = props.blk_w();
|
||||
|
||||
// First apply inner shift
|
||||
let (shiftrot, shiftrot_next): (Vec<_>, Vec<_>) = src
|
||||
.iter()
|
||||
.map(|ct| inner_shift(prog, kind, ct, &amount[0]))
|
||||
.unzip();
|
||||
|
||||
// Fuse msg and next msg based on direction/kind
|
||||
let mut merge_shiftrot = shiftrot
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.with_position()
|
||||
.map(|(pos, (i, ct))| match kind {
|
||||
ShiftKind::ShiftRight => {
|
||||
if !matches!(pos, Position::Last | Position::Only) {
|
||||
&ct + &shiftrot_next[i + 1]
|
||||
} else {
|
||||
ct
|
||||
}
|
||||
}
|
||||
ShiftKind::ShiftLeft => {
|
||||
if !matches!(pos, Position::First | Position::Only) {
|
||||
&ct + &shiftrot_next[i - 1]
|
||||
} else {
|
||||
ct
|
||||
}
|
||||
}
|
||||
ShiftKind::RotRight => {
|
||||
let rot_idx = (i + 1) % shiftrot_next.len();
|
||||
&ct + &shiftrot_next[rot_idx]
|
||||
}
|
||||
ShiftKind::RotLeft => {
|
||||
let rot_idx = ((i + shiftrot_next.len()) - 1) % shiftrot_next.len();
|
||||
&ct + &shiftrot_next[rot_idx]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Second apply block swap
|
||||
// Block swapping done with successive buterflies with log2 stages
|
||||
// NB: each block encode msg_w bits thus:
|
||||
// * First shiftrot is already done with inner_shiftrot
|
||||
// * Two swap is done for each amount blk
|
||||
for stg in 1..(2 * blk_w).ilog2() as usize {
|
||||
merge_shiftrot = (0..blk_w)
|
||||
.map(|i| {
|
||||
let stride = 1 << (stg - 1);
|
||||
let swap = match kind {
|
||||
ShiftKind::ShiftRight => merge_shiftrot.get(i + stride),
|
||||
ShiftKind::ShiftLeft => {
|
||||
if i >= stride {
|
||||
merge_shiftrot.get(i - stride)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ShiftKind::RotRight => {
|
||||
let swap_idx = (i + stride) % merge_shiftrot.len();
|
||||
merge_shiftrot.get(swap_idx)
|
||||
}
|
||||
ShiftKind::RotLeft => {
|
||||
let swap_idx = (i + merge_shiftrot.len() - stride) % merge_shiftrot.len();
|
||||
merge_shiftrot.get(swap_idx)
|
||||
}
|
||||
};
|
||||
// Based on stage index shiftrot condition is in amount msg at pos0 or pos1
|
||||
block_swap(
|
||||
prog,
|
||||
&merge_shiftrot[i],
|
||||
swap,
|
||||
&amount[stg / tfhe_params.msg_w],
|
||||
if stg % 2 == 1 {
|
||||
CondPos::Pos1
|
||||
} else {
|
||||
CondPos::Pos0
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
dst.iter_mut()
|
||||
.zip(merge_shiftrot.iter())
|
||||
.for_each(|(d, r)| {
|
||||
*d <<= r;
|
||||
});
|
||||
dst.into()
|
||||
}
|
||||
|
||||
@@ -1770,6 +1770,7 @@ impl<'a> dot2::Labeller<'a> for Graph {
|
||||
n.copy_name(),
|
||||
n.borrow()
|
||||
.load_stats()
|
||||
.clone()
|
||||
.and_then(|l| Some(format!("{:?}", l)))
|
||||
.unwrap_or(String::from("None")),
|
||||
n.copy_uid(),
|
||||
|
||||
Reference in New Issue
Block a user