feat(hpu): Add IOP_IF_THEN_ZERO and IOP_ERC_20

IF_THEN_ZERO is an altered version of IF_THEN_ELSE than take 0 as default value.
ERC_20 is a custom iop dedicated to erc_20 computation. Its a first attempt and mainly a placeholder for future work.
It will be use to test various way to call custom iop from HighLevelApi.

Change test macro to support multi-output IOp correctly.
This commit is contained in:
Baptiste Roux
2025-03-12 12:48:14 +01:00
parent f44a5f1baf
commit a7b3032a1c
4 changed files with 288 additions and 46 deletions

View File

@@ -51,18 +51,23 @@ impl<const D: usize, const S: usize> From<ConstIOpProto<D, S>> for IOpProto {
}
// Define some common iop format
pub const IOP_CT_CT: ConstIOpProto<1, 2> = ConstIOpProto {
pub const IOP_CT_F_2CT: ConstIOpProto<1, 2> = ConstIOpProto {
dst: [VarMode::Native; 1],
src: [VarMode::Native; 2],
imm: 0,
};
pub const IOP_CT_CT_BOOL: ConstIOpProto<1, 3> = ConstIOpProto {
pub const IOP_CT_F_2CT_BOOL: ConstIOpProto<1, 3> = ConstIOpProto {
dst: [VarMode::Native; 1],
src: [VarMode::Native, VarMode::Native, VarMode::Bool],
imm: 0,
};
pub const IOP_CT_F_CT_BOOL: ConstIOpProto<1, 2> = ConstIOpProto {
dst: [VarMode::Native; 1],
src: [VarMode::Native, VarMode::Bool],
imm: 0,
};
pub const IOP_CT_SCALAR: ConstIOpProto<1, 1> = ConstIOpProto {
pub const IOP_CT_F_CT_SCALAR: ConstIOpProto<1, 1> = ConstIOpProto {
dst: [VarMode::Native; 1],
src: [VarMode::Native; 1],
imm: 1,
@@ -74,30 +79,38 @@ pub const IOP_CMP: ConstIOpProto<1, 2> = ConstIOpProto {
imm: 0,
};
pub const IOP_2CT_F_3CT: ConstIOpProto<2, 3> = ConstIOpProto {
dst: [VarMode::Native; 2],
src: [VarMode::Native; 3],
imm: 0,
};
use crate::iop;
use arg::IOpFormat;
use lazy_static::lazy_static;
use std::collections::HashMap;
iop!(
[IOP_CT_SCALAR -> "ADDS", opcode::ADDS],
[IOP_CT_SCALAR -> "SUBS", opcode::SUBS],
[IOP_CT_SCALAR -> "SSUB", opcode::SSUB],
[IOP_CT_SCALAR -> "MULS", opcode::MULS],
[IOP_CT_SCALAR -> "MULSL", opcode::MULSL],
[IOP_CT_CT -> "ADD", opcode::ADD],
[IOP_CT_CT -> "ADDK", opcode::ADDK],
[IOP_CT_CT -> "SUB", opcode::SUB],
[IOP_CT_CT -> "SUBK", opcode::SUBK],
[IOP_CT_CT -> "MUL", opcode::MUL],
[IOP_CT_CT -> "MULL", opcode::MULL],
[IOP_CT_CT -> "BW_AND", opcode::BW_AND],
[IOP_CT_CT -> "BW_OR", opcode::BW_OR],
[IOP_CT_CT -> "BW_XOR", opcode::BW_XOR],
[IOP_CT_F_CT_SCALAR -> "ADDS", opcode::ADDS],
[IOP_CT_F_CT_SCALAR -> "SUBS", opcode::SUBS],
[IOP_CT_F_CT_SCALAR -> "SSUB", opcode::SSUB],
[IOP_CT_F_CT_SCALAR -> "MULS", opcode::MULS],
[IOP_CT_F_CT_SCALAR -> "MULSF", opcode::MULSF],
[IOP_CT_F_2CT -> "ADD", opcode::ADD],
[IOP_CT_F_2CT -> "ADDK", opcode::ADDK],
[IOP_CT_F_2CT -> "SUB", opcode::SUB],
[IOP_CT_F_2CT -> "SUBK", opcode::SUBK],
[IOP_CT_F_2CT -> "MUL", opcode::MUL],
[IOP_CT_F_2CT -> "MULF", opcode::MULF],
[IOP_CT_F_2CT -> "BW_AND", opcode::BW_AND],
[IOP_CT_F_2CT -> "BW_OR", opcode::BW_OR],
[IOP_CT_F_2CT -> "BW_XOR", opcode::BW_XOR],
[IOP_CMP -> "CMP_GT", opcode::CMP_GT],
[IOP_CMP -> "CMP_GTE", opcode::CMP_GTE],
[IOP_CMP -> "CMP_LT", opcode::CMP_LT],
[IOP_CMP -> "CMP_LTE", opcode::CMP_LTE],
[IOP_CMP -> "CMP_EQ", opcode::CMP_EQ],
[IOP_CMP -> "CMP_NEQ", opcode::CMP_NEQ],
[IOP_CT_CT_BOOL -> "IF_THEN_ELSE", opcode::IF_THEN_ELSE],
[IOP_CT_F_CT_BOOL -> "IF_THEN_ZERO", opcode::IF_THEN_ZERO],
[IOP_CT_F_2CT_BOOL -> "IF_THEN_ELSE", opcode::IF_THEN_ELSE],
[IOP_2CT_F_3CT -> "ERC_20", opcode::ERC_20],
);

View File

@@ -47,6 +47,15 @@ pub const CMP_LTE: u8 = 0xC3;
pub const CMP_EQ: u8 = 0xC4;
pub const CMP_NEQ: u8 = 0xC5;
// Ternary operations
// IfThenZero -> Select or force to 0
// Take 1Ct and a Boolean Ct as input
pub const IF_THEN_ZERO: u8 = 0xCA;
// IfThenElse -> Select operation
// Take 2Ct and a Boolean Ct as input
pub const IF_THEN_ELSE: u8 = 0xCA;
pub const IF_THEN_ELSE: u8 = 0xCB;
// Custom algorithm
// ERC20 -> Found xfer algorithm
// 2Ct <- func(3Ct)
pub const ERC_20: u8 = 0x80;

View File

@@ -42,8 +42,11 @@ crate::impl_fw!("Ilp" [
CMP_EQ => (|prog| {fw_impl::ilp::iop_cmp(prog, asm::dop::PbsCmpEq::default().into())});
CMP_NEQ => (|prog| {fw_impl::ilp::iop_cmp(prog, asm::dop::PbsCmpNeq::default().into())});
IF_THEN_ZERO => fw_impl::ilp::iop_if_then_zero;
IF_THEN_ELSE => fw_impl::ilp::iop_if_then_else;
ERC_20 => fw_impl::ilp::iop_erc_20;
]);
#[instrument(level = "trace", skip(prog))]
@@ -744,6 +747,21 @@ pub fn iop_cmp(prog: &mut Program, cmp_op: Pbs) {
"CMP_{cmp_op} Operand::Dst Operand::Src Operand::Src"
));
// Deferred implementation to generic cmpx function
iop_cmpx(prog, &mut dst[0], &src_a, &src_b, cmp_op);
}
/// Generic Cmp operation
/// One destination block and two sources operands
/// Source could be Operand or Immediat
#[instrument(level = "trace", skip(prog))]
pub fn iop_cmpx(
prog: &mut Program,
dst: &mut metavar::MetaVarCell,
src_a: &[metavar::MetaVarCell],
src_b: &[metavar::MetaVarCell],
cmp_op: Pbs,
) {
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
@@ -754,7 +772,7 @@ pub fn iop_cmp(prog: &mut Program, cmp_op: Pbs) {
let cmp_reduce = new_pbs!(prog, "CmpReduce");
// Pack A and B elements by pairs
let packed = std::iter::zip(src_a.as_slice().chunks(2), src_b.as_slice().chunks(2))
let packed = std::iter::zip(src_a.chunks(2), src_b.chunks(2))
.map(|(a, b)| {
let pack_a = if a.len() > 1 {
// Reset noise for future block merge through sub
@@ -807,7 +825,7 @@ pub fn iop_cmp(prog: &mut Program, cmp_op: Pbs) {
// interprete reduce with expected cmp
let cmp = reduce.unwrap().pbs(&cmp_op, false);
dst[0] <<= cmp;
dst.mv_assign(&cmp);
}
// For the kogge stone add/sub
@@ -1335,6 +1353,47 @@ impl VecVarCellDeg {
}
#[instrument(level = "trace", skip(prog))]
pub fn iop_if_then_zero(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let dst = prog.iop_template_var(OperandKind::Dst, 0);
// SrcA -> Operand
let src = prog.iop_template_var(OperandKind::Src, 0);
// Cond -> Operand
// second operand must be a FheBool and have only one blk
let cond = {
let mut cond_blk = prog.iop_template_var(OperandKind::Src, 1);
cond_blk.truncate(1);
cond_blk.pop().unwrap()
};
// Add Comment header
prog.push_comment("IF_THEN_ZERO Operand::Dst Operand::Src Operand::Src[Condition]".to_string());
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
// Wrapped required lookup table in MetaVar
let pbs_if_false_zeroed = new_pbs!(prog, "IfFalseZeroed");
itertools::izip!(dst, src)
.chunks(props.pbs_batch_w)
.into_iter()
.for_each(|chunk| {
// Pack (cond, src)
let chunk_pack = chunk
.into_iter()
.map(|(d, src)| (d, cond.mac(tfhe_params.msg_range() as u8, &src)))
.collect::<Vec<_>>();
chunk_pack.into_iter().for_each(|(mut d, mut cond_src)| {
cond_src.pbs_assign(&pbs_if_false_zeroed, false);
d <<= cond_src;
});
});
}
#[instrument(level = "info", skip(prog))]
pub fn iop_if_then_else(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
@@ -1387,3 +1446,120 @@ pub fn iop_if_then_else(prog: &mut Program) {
});
});
}
/// Implement erc_20 fund xfer
/// Targeted algorithm is as follow:
/// 1. Check that from has enough funds
/// 2. Compute real_amount to xfer (i.e. amount or 0)
/// 3. Compute new amount (from - new_amount, to + new_amount)
#[instrument(level = "info", skip(prog))]
pub fn iop_erc_20(prog: &mut Program) {
// Allocate metavariables:
// Dest -> Operand
let mut dst_from = prog.iop_template_var(OperandKind::Dst, 0);
let mut dst_to = prog.iop_template_var(OperandKind::Dst, 1);
// Src -> Operand
let src_from = prog.iop_template_var(OperandKind::Src, 0);
let src_to = prog.iop_template_var(OperandKind::Src, 1);
// Src Amount -> Operand
let src_amount = prog.iop_template_var(OperandKind::Src, 2);
// Add Comment header
prog.push_comment("ERC_20 (new_from, new_to) <- (from, to, amount)".to_string());
let props = prog.params();
let tfhe_params: asm::DigitParameters = props.clone().into();
// Wrapped required lookup table in MetaVar
let pbs_msg = new_pbs!(prog, "MsgOnly");
let pbs_carry = new_pbs!(prog, "CarryInMsg");
let pbs_if_false_zeroed = new_pbs!(prog, "IfFalseZeroed");
// Check if from has enough funds
let enough_fund = {
let mut dst = prog.new_var();
iop_cmpx(
prog,
&mut dst,
&src_from,
&src_amount,
asm::dop::PbsCmpGte::default().into(),
);
dst
};
// Fuse real_amount computation and new_from, new_to
// First compute a batch of real_amount in advance
let mut real_amount_work = (0..props.blk_w()).map(|x| x);
prog.push_comment(format!(" ==> Compute some real_amount in advance"));
let mut real_amount = real_amount_work
.by_ref()
.take(props.pbs_batch_w)
.map(|blk| {
let mut val_cond = enough_fund.mac(tfhe_params.msg_range() as u8, &src_amount[blk]);
val_cond.pbs_assign(&pbs_if_false_zeroed, false);
val_cond
})
.collect::<VecDeque<_>>();
let mut add_carry: Option<metavar::MetaVarCell> = None;
let mut sub_z_cor: Option<usize> = None;
let mut sub_carry: Option<metavar::MetaVarCell> = None;
(0..prog.params().blk_w()).for_each(|blk| {
prog.push_comment(format!(" ==> Work on output block {blk}"));
// Compte next real_amount if any
if let Some(work) = real_amount_work.next() {
let mut val_cond = enough_fund.mac(tfhe_params.msg_range() as u8, &src_amount[work]);
val_cond.pbs_assign(&pbs_if_false_zeroed, false);
real_amount.push_back(val_cond);
}
let amount_blk = real_amount.pop_front().unwrap();
// Add
let mut add_msg = &src_to[blk] + &amount_blk;
if let Some(cin) = &add_carry {
add_msg += cin.clone();
}
if blk < (props.blk_w() - 1) {
add_carry = Some(add_msg.pbs(&pbs_carry, false));
}
// Force allocation of new reg to allow carry/msg pbs to run in //
let add_msg = add_msg.pbs(&pbs_msg, false);
// Sub
// Compute -b
let neg_from = if let Some(z) = &sub_z_cor {
prog.new_imm(tfhe_params.msg_range() - *z)
} else {
prog.new_imm(tfhe_params.msg_range())
};
let amount_neg = &neg_from - &amount_blk;
sub_z_cor = Some(
amount_blk
.get_degree()
.div_ceil(tfhe_params.msg_range())
.max(1),
);
// Compute a + (-b)
let mut sub_msg = &src_from[blk] + &amount_neg;
// Handle input/output carry and extract msg
if let Some(cin) = &sub_carry {
sub_msg += cin.clone();
}
if blk < (props.blk_w() - 1) {
sub_carry = Some(sub_msg.pbs(&pbs_carry, false));
}
// Force allocation of new reg to allow carry/msg pbs to run in //
let sub_msg = sub_msg.pbs(&pbs_msg, false);
// Store result
dst_to[blk] <<= add_msg;
dst_from[blk] <<= sub_msg;
});
}

View File

@@ -132,20 +132,18 @@ macro_rules! hpu_testcase {
let res_fhe = res_hpu
.iter()
.map(|x| x.to_radix_ciphertext()).collect::<Vec<_>>();
let res_vec = res_fhe
let res = res_fhe
.iter()
.map(|x| cks.decrypt_radix(x))
.collect::<Vec<$user_type>>();
let res = res_vec[0];
let exp_res = {
let $ct = &srcs_clear;
let $imm = imms.iter().map(|x| *x as $user_type).collect::<Vec<_>>();
($behav as $user_type)
($behav.iter().map(|x| *x as $user_type).collect::<Vec<_>>())
};
println!("{:>8} <{:>8x?}> <{:>8x?}> => {:<8x} [exp {:<8x}] {{Delta: 0b {:b} }}", iop, srcs_clear, imms, res, exp_res, res ^ exp_res);
res == exp_res
println!("{:>8} <{:>8x?}> <{:>8x?}> => {:<8x?} [exp {:<8x?}] {{Delta: {:x?} }}", iop, srcs_clear, imms, res, exp_res, std::iter::zip(res.iter(), exp_res.iter()).map(|(x,y)| x ^y).collect::<Vec<_>>());
std::iter::zip(res.iter(), exp_res.iter()).map(|(x,y)| x== y).fold(true, |acc, val| acc & val)
}).fold(true, |acc, val| acc & val)
}
)*
@@ -156,51 +154,67 @@ macro_rules! hpu_testcase {
// Define testcase implementation for all supported IOp
// Alu IOp with Ct x Imm
hpu_testcase!("ADDS" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_add(imm[0])));
|ct, imm| vec![ct[0].wrapping_add(imm[0])]);
hpu_testcase!("SUBS" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_sub(imm[0])));
|ct, imm| vec![ct[0].wrapping_sub(imm[0])]);
hpu_testcase!("SSUB" => [u8, u16, u32, u64, u128]
|ct, imm| (imm[0].wrapping_sub(ct[0])));
|ct, imm| vec![imm[0].wrapping_sub(ct[0])]);
hpu_testcase!("MULS" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_mul(imm[0])));
|ct, imm| vec![ct[0].wrapping_mul(imm[0])]);
// Alu IOp with Ct x Ct
hpu_testcase!("ADD" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_add(ct[1])));
|ct, imm| vec![ct[0].wrapping_add(ct[1])]);
hpu_testcase!("ADDK" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_add(ct[1])));
|ct, imm| vec![ct[0].wrapping_add(ct[1])]);
hpu_testcase!("SUB" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_sub(ct[1])));
|ct, imm| vec![ct[0].wrapping_sub(ct[1])]);
hpu_testcase!("SUBK" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_sub(ct[1])));
|ct, imm| vec![ct[0].wrapping_sub(ct[1])]);
hpu_testcase!("MUL" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0].wrapping_mul(ct[1])));
|ct, imm| vec![ct[0].wrapping_mul(ct[1])]);
// Bitwise IOp
hpu_testcase!("BW_AND" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] & ct[1]));
|ct, imm| vec![ct[0] & ct[1]]);
hpu_testcase!("BW_OR" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] | ct[1]));
|ct, imm| vec![ct[0] | ct[1]]);
hpu_testcase!("BW_XOR" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] ^ ct[1]));
|ct, imm| vec![ct[0] ^ ct[1]]);
// Comparaison IOp
hpu_testcase!("CMP_GT" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] > ct[1]));
|ct, imm| vec![ct[0] > ct[1]]);
hpu_testcase!("CMP_GTE" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] >= ct[1]));
|ct, imm| vec![ct[0] >= ct[1]]);
hpu_testcase!("CMP_LT" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] < ct[1]));
|ct, imm| vec![ct[0] < ct[1]]);
hpu_testcase!("CMP_LTE" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] <= ct[1]));
|ct, imm| vec![ct[0] <= ct[1]]);
hpu_testcase!("CMP_EQ" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] == ct[1]));
|ct, imm| vec![ct[0] == ct[1]]);
hpu_testcase!("CMP_NEQ" => [u8, u16, u32, u64, u128]
|ct, imm| (ct[0] != ct[1]));
|ct, imm| vec![ct[0] != ct[1]]);
// Ternary IOp
hpu_testcase!("IF_THEN_ZERO" => [u8, u16, u32, u64, u128]
|ct, imm| vec![if ct[1] != 0 {ct[0]} else { 0}]);
hpu_testcase!("IF_THEN_ELSE" => [u8, u16, u32, u64, u128]
|ct, imm| if ct[2] != 0 {ct[0]} else { ct[1]});
|ct, imm| vec![if ct[2] != 0 {ct[0]} else { ct[1]}]);
// ERC 20 found xfer
hpu_testcase!("ERC_20" => [u8, u16, u32, u64, u128]
|ct, imm| {
let from = ct[0];
let to = ct[1];
let amount = ct[2];
// TODO enhance this to prevent overflow
if from >= amount {
vec![from - amount, to.wrapping_add(amount)]
} else {
vec![from, to]
}
});
// Define a set of test bundle for various size
// 8bit ciphertext -----------------------------------------
@@ -240,9 +254,15 @@ crate::hpu_testbundle!("cmp"::8 => [
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::8 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::8 => [
"erc_20"
]);
// 16bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::16 => [
@@ -280,9 +300,15 @@ crate::hpu_testbundle!("cmp"::16 => [
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::16 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::16 => [
"erc_20"
]);
// 32bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::32 => [
@@ -320,9 +346,15 @@ crate::hpu_testbundle!("cmp"::32 => [
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::32 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::32 => [
"erc_20"
]);
// 64bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::64 => [
@@ -360,9 +392,15 @@ crate::hpu_testbundle!("cmp"::64 => [
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::64 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::64 => [
"erc_20"
]);
// 128bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::128 => [
@@ -400,5 +438,11 @@ crate::hpu_testbundle!("cmp"::128 => [
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::128 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::128 => [
"erc_20"
]);