mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(hpu): Made two SIMD IOPs, ADD and ERC20.
This commit is contained in:
committed by
Pierre Gardrat
parent
827a6e912c
commit
3b48ef301e
@@ -31,14 +31,22 @@ class LD(BaseInstruction):
|
||||
self.__dict__ = d
|
||||
|
||||
def args(self):
|
||||
return f'R{self.rid} @{hex(self.slot["Addr"])}'
|
||||
try:
|
||||
return f'R{self.rid} @{hex(self.slot["Addr"])}'
|
||||
except:
|
||||
# It can happen that an IOP is not translated by the FW
|
||||
return f'R{self.rid} @{self.slot}'
|
||||
|
||||
class ST(BaseInstruction):
|
||||
def __init__(self, d):
|
||||
self.__dict__ = d
|
||||
|
||||
def args(self):
|
||||
return f'@{hex(self.slot["Addr"])} R{self.rid}'
|
||||
try:
|
||||
return f'@{hex(self.slot["Addr"])} R{self.rid}'
|
||||
except:
|
||||
# It can happen that an IOP is not translated by the FW
|
||||
return f'@{self.slot} R{self.rid}'
|
||||
|
||||
class MAC(BaseInstruction):
|
||||
def __init__(self, d):
|
||||
|
||||
@@ -176,6 +176,18 @@ pub const IOP_2CT_F_CT_SCALAR: ConstIOpProto<2, 1> = ConstIOpProto {
|
||||
imm: 1,
|
||||
};
|
||||
|
||||
pub const SIMD_N: usize = 12; //TODO: We need to come up with a way to have this dynamic
|
||||
pub const IOP_NCT_F_2NCT: ConstIOpProto<{ 1 * SIMD_N }, { 2 * SIMD_N }> = ConstIOpProto {
|
||||
dst: [VarMode::Native; 1 * SIMD_N],
|
||||
src: [VarMode::Native; 2 * SIMD_N],
|
||||
imm: 0,
|
||||
};
|
||||
pub const IOP_2NCT_F_3NCT: ConstIOpProto<{ 2 * SIMD_N }, { 3 * SIMD_N }> = ConstIOpProto {
|
||||
dst: [VarMode::Native; 2 * SIMD_N],
|
||||
src: [VarMode::Native; 3 * SIMD_N],
|
||||
imm: 0,
|
||||
};
|
||||
|
||||
use crate::iop;
|
||||
use arg::IOpFormat;
|
||||
use lazy_static::lazy_static;
|
||||
@@ -227,4 +239,6 @@ iop!(
|
||||
[IOP_CT_F_CT -> "LEAD1", opcode::LEAD1],
|
||||
[IOP_CT_F_CT -> "TRAIL0", opcode::TRAIL0],
|
||||
[IOP_CT_F_CT -> "TRAIL1", opcode::TRAIL1],
|
||||
[IOP_NCT_F_2NCT -> "ADD_SIMD", opcode::ADD_SIMD],
|
||||
[IOP_2NCT_F_3NCT -> "ERC_20_SIMD", opcode::ERC_20_SIMD],
|
||||
);
|
||||
|
||||
@@ -87,6 +87,10 @@ pub const LEAD1: u8 = 0x85;
|
||||
pub const TRAIL0: u8 = 0x86;
|
||||
pub const TRAIL1: u8 = 0x87;
|
||||
|
||||
// SIMD for maximum throughput
|
||||
pub const ADD_SIMD: u8 = 0xF0;
|
||||
pub const ERC_20_SIMD: u8 = 0xF1;
|
||||
//
|
||||
// Utility operations
|
||||
// Used to handle real clone of ciphertext already uploaded in the Hpu memory
|
||||
pub const MEMCPY: u8 = 0xFF;
|
||||
|
||||
@@ -72,6 +72,9 @@ crate::impl_fw!("Ilp" [
|
||||
LEAD1 => fw_impl::ilp_log::iop_lead1;
|
||||
TRAIL0 => fw_impl::ilp_log::iop_trail0;
|
||||
TRAIL1 => fw_impl::ilp_log::iop_trail1;
|
||||
// SIMD Implementations
|
||||
ADD_SIMD => fw_impl::llt::iop_add_simd;
|
||||
ERC_20_SIMD => fw_impl::llt::iop_erc_20_simd;
|
||||
]);
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
|
||||
@@ -57,16 +57,16 @@ crate::impl_fw!("Llt" [
|
||||
OVF_SSUB => fw_impl::ilp::iop_overflow_ssub;
|
||||
OVF_MULS => fw_impl::ilp::iop_overflow_muls;
|
||||
|
||||
BW_AND => (|prog| {fw_impl::ilp::iop_bw(prog, asm::dop::PbsBwAnd::default().into())});
|
||||
BW_OR => (|prog| {fw_impl::ilp::iop_bw(prog, asm::dop::PbsBwOr::default().into())});
|
||||
BW_XOR => (|prog| {fw_impl::ilp::iop_bw(prog, asm::dop::PbsBwXor::default().into())});
|
||||
BW_AND => (|prog| {fw_impl::ilp::iop_bw(prog, asm::dop::PbsBwAnd::default().into())});
|
||||
BW_OR => (|prog| {fw_impl::ilp::iop_bw(prog, asm::dop::PbsBwOr::default().into())});
|
||||
BW_XOR => (|prog| {fw_impl::ilp::iop_bw(prog, asm::dop::PbsBwXor::default().into())});
|
||||
|
||||
CMP_GT => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpGtMrg"), pbs_by_name!("CmpGt"))});
|
||||
CMP_GTE => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpGteMrg"), pbs_by_name!("CmpGte"))});
|
||||
CMP_LT => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpLtMrg"), pbs_by_name!("CmpLt"))});
|
||||
CMP_LTE => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpLteMrg"), pbs_by_name!("CmpLte"))});
|
||||
CMP_EQ => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpEqMrg"), pbs_by_name!("CmpEq"))});
|
||||
CMP_NEQ => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpNeqMrg"), pbs_by_name!("CmpNeq"))});
|
||||
CMP_GT => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpGtMrg"), pbs_by_name!("CmpGt"))});
|
||||
CMP_GTE => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpGteMrg"), pbs_by_name!("CmpGte"))});
|
||||
CMP_LT => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpLtMrg"), pbs_by_name!("CmpLt"))});
|
||||
CMP_LTE => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpLteMrg"), pbs_by_name!("CmpLte"))});
|
||||
CMP_EQ => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpEqMrg"), pbs_by_name!("CmpEq"))});
|
||||
CMP_NEQ => (|prog| {fw_impl::llt::iop_cmp(prog, pbs_by_name!("CmpNeqMrg"), pbs_by_name!("CmpNeq"))});
|
||||
|
||||
IF_THEN_ZERO => fw_impl::ilp::iop_if_then_zero;
|
||||
IF_THEN_ELSE => fw_impl::ilp::iop_if_then_else;
|
||||
@@ -81,6 +81,10 @@ crate::impl_fw!("Llt" [
|
||||
LEAD1 => fw_impl::ilp_log::iop_lead1;
|
||||
TRAIL0 => fw_impl::ilp_log::iop_trail0;
|
||||
TRAIL1 => fw_impl::ilp_log::iop_trail1;
|
||||
|
||||
// SIMD Implementations
|
||||
ADD_SIMD => fw_impl::llt::iop_add_simd;
|
||||
ERC_20_SIMD => fw_impl::llt::iop_erc_20_simd;
|
||||
]);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -102,6 +106,17 @@ pub fn iop_add(prog: &mut Program) {
|
||||
iop_addx(prog, dst, src_a, src_b);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_add_simd(prog: &mut Program) {
|
||||
// Add Comment header
|
||||
prog.push_comment("ADD_SIMD Operand::Dst Operand::Src Operand::Src".to_string());
|
||||
simd(
|
||||
prog,
|
||||
crate::asm::iop::SIMD_N,
|
||||
fw_impl::llt::iop_add_ripple_rtl,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn iop_adds(prog: &mut Program) {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
@@ -209,25 +224,46 @@ pub fn iop_muls(prog: &mut Program) {
|
||||
iop_mulx(prog, dst, src_a, src_b).add_to_prog(prog);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_erc_20(prog: &mut Program) {
|
||||
// Add Comment header
|
||||
prog.push_comment("ERC_20 (new_from, new_to) <- (from, to, amount)".to_string());
|
||||
iop_erc_20_rtl(prog, 0).add_to_prog(prog);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_erc_20_simd(prog: &mut Program) {
|
||||
// Add Comment header
|
||||
prog.push_comment("ERC_20_SIMD (new_from, new_to) <- (from, to, amount)".to_string());
|
||||
simd(prog, crate::asm::iop::SIMD_N, fw_impl::llt::iop_erc_20_rtl);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Helper Functions
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
/// Implement erc_20 fund xfer
|
||||
/// Targeted algorithm is as follow:
|
||||
/// 1. Check that from has enough funds
|
||||
/// 2. Compute real_amount to xfer (i.e. amount or 0)
|
||||
/// 3. Compute new amount (from - new_amount, to + new_amount)
|
||||
///
|
||||
/// The input operands are:
|
||||
/// (from[0], to[0], amount[0], ..., from[N-1], to[N-1], amount[N-1])
|
||||
/// The output operands are:
|
||||
/// (dst_from[0], dst_to[0], ..., dst_from[N-1], dst_to[N-1])
|
||||
/// Where N is the batch size
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_erc_20(prog: &mut Program) {
|
||||
pub fn iop_erc_20_rtl(prog: &mut Program, batch_index: u8) -> Rtl {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let dst_from = prog.iop_template_var(OperandKind::Dst, 0);
|
||||
let dst_to = prog.iop_template_var(OperandKind::Dst, 1);
|
||||
let dst_from = prog.iop_template_var(OperandKind::Dst, 2 * batch_index);
|
||||
let dst_to = prog.iop_template_var(OperandKind::Dst, 2 * batch_index + 1);
|
||||
// Src -> Operand
|
||||
let src_from = prog.iop_template_var(OperandKind::Src, 0);
|
||||
let src_to = prog.iop_template_var(OperandKind::Src, 1);
|
||||
let src_from = prog.iop_template_var(OperandKind::Src, 3 * batch_index);
|
||||
let src_to = prog.iop_template_var(OperandKind::Src, 3 * batch_index + 1);
|
||||
// Src Amount -> Operand
|
||||
let src_amount = prog.iop_template_var(OperandKind::Src, 2);
|
||||
|
||||
// Add Comment header
|
||||
prog.push_comment("ERC_20 (new_from, new_to) <- (from, to, amount)".to_string());
|
||||
let src_amount = prog.iop_template_var(OperandKind::Src, 3 * batch_index + 2);
|
||||
|
||||
// TODO: Make this a parameter or sweep this
|
||||
// All these little parameters would be very handy to write an
|
||||
@@ -236,7 +272,7 @@ pub fn iop_erc_20(prog: &mut Program) {
|
||||
let kogge_blk_w = 10;
|
||||
let ripple = true;
|
||||
|
||||
let tree = {
|
||||
{
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
let lut = pbs_by_name!("IfFalseZeroed");
|
||||
@@ -273,13 +309,26 @@ pub fn iop_erc_20(prog: &mut Program) {
|
||||
kogge::add(prog, dst_to, src_to, src_amount.clone(), None, kogge_blk_w)
|
||||
+ kogge::sub(prog, dst_from, src_from, src_amount, kogge_blk_w)
|
||||
}
|
||||
};
|
||||
tree.add_to_prog(prog);
|
||||
}
|
||||
}
|
||||
|
||||
/// A SIMD implementation of add for maximum throughput
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_add_ripple_rtl(prog: &mut Program, i: u8) -> Rtl {
|
||||
// Allocate metavariables:
|
||||
let dst = prog.iop_template_var(OperandKind::Dst, i);
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 2 * i);
|
||||
let src_b = prog.iop_template_var(OperandKind::Src, 2 * i + 1);
|
||||
|
||||
// Convert MetaVarCell in VarCell for Rtl analysis
|
||||
let a = VarCell::from_vec(src_a);
|
||||
let b = VarCell::from_vec(src_b);
|
||||
let d = VarCell::from_vec(dst);
|
||||
|
||||
// Do a + b with the ripple carry adder
|
||||
kogge::ripple_add(d, a, b, None)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Helper Functions
|
||||
// ----------------------------------------------------------------------------
|
||||
fn iop_addx(
|
||||
prog: &mut Program,
|
||||
dst: Vec<MetaVarCell>,
|
||||
@@ -471,7 +520,12 @@ pub fn iop_mulx(
|
||||
// Note: The break-even point might not be this one, but choosing the right
|
||||
// point is uninportant since we'll leap imensely the number of batches from
|
||||
// FPGA to ASIC.
|
||||
if prog.params().pbs_batch_w >= dst.len() {
|
||||
let parallel = prog
|
||||
.op_cfg()
|
||||
.parallel
|
||||
.unwrap_or_else(|| prog.params().pbs_batch_w >= dst.len());
|
||||
|
||||
if parallel {
|
||||
iop_mulx_par(prog, dst, src_a, src_b)
|
||||
} else {
|
||||
iop_mulx_ser(prog, dst, src_a, src_b)
|
||||
@@ -708,3 +762,24 @@ fn bw_inv(prog: &mut Program, b: Vec<VarCell>) -> Vec<VarCell> {
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
/// Creates a SIMD version of the closure
|
||||
/// Make sure that the closure is a PBS optimized version of the operation
|
||||
/// The closure receives as inputs the program and the batch index.
|
||||
/// How the ASM operands are actually organized is defined by the closure
|
||||
/// itself.
|
||||
///
|
||||
/// Maybe this should go into a SIMD firmware implementation... At some point we
|
||||
/// would need a mechanism to choose between implementations on the fly to make
|
||||
/// real good use of all of this.
|
||||
|
||||
fn simd<F>(prog: &mut Program, batch_size: usize, rtl_closure: F)
|
||||
where
|
||||
F: Fn(&mut Program, u8) -> Rtl,
|
||||
{
|
||||
(0..batch_size)
|
||||
.map(|i| i as u8)
|
||||
.map(|i| rtl_closure(prog, i))
|
||||
.sum::<Rtl>()
|
||||
.add_to_prog(prog);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ pub struct OpCfg {
|
||||
pub flush: bool,
|
||||
/// Whether to use latency tiers when scheduling
|
||||
pub use_tiers: bool,
|
||||
/// Whether to use a massively parallel implementation
|
||||
pub parallel: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
|
||||
|
||||
@@ -1616,6 +1616,12 @@ impl Rtl {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Rtl {
|
||||
fn default() -> Self {
|
||||
Rtl(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<Rtl> for Rtl {
|
||||
type Output = Rtl;
|
||||
fn add(self, rhs: Rtl) -> Self::Output {
|
||||
@@ -1623,6 +1629,12 @@ impl std::ops::Add<Rtl> for Rtl {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::iter::Sum<Rtl> for Rtl {
|
||||
fn sum<I: Iterator<Item = Rtl>>(iter: I) -> Self {
|
||||
iter.fold(Rtl::default(), |acc, x| acc + x)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Rtl {
|
||||
fn drop(&mut self) {
|
||||
self.unload();
|
||||
|
||||
@@ -64,13 +64,13 @@ pub struct IscPoolState {
|
||||
pub(super) vld: bool,
|
||||
pub(super) wr_lock: u32,
|
||||
pub(super) rd_lock: u32,
|
||||
//pub(super) issue_lock: u32,
|
||||
pub(super) issue_lock: u32,
|
||||
pub(super) sync_id: u32,
|
||||
}
|
||||
|
||||
impl Len for IscPoolState {
|
||||
fn len() -> usize {
|
||||
21
|
||||
28
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,8 +85,8 @@ where
|
||||
vld: *(slice.get(2).ok_or(NoMoreBits)?),
|
||||
wr_lock: slice.get(3..10).ok_or(NoMoreBits)?.load::<u32>(),
|
||||
rd_lock: slice.get(10..17).ok_or(NoMoreBits)?.load::<u32>(),
|
||||
//issue_lock: slice.get(17..24).ok_or(NoMoreBits)?.load::<u32>(),
|
||||
sync_id: slice.get(17..21).ok_or(NoMoreBits)?.load::<u32>(),
|
||||
issue_lock: slice.get(17..24).ok_or(NoMoreBits)?.load::<u32>(),
|
||||
sync_id: slice.get(24..28).ok_or(NoMoreBits)?.load::<u32>(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user