mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
288 lines
8.0 KiB
Rust
288 lines
8.0 KiB
Rust
use crate::parser::u8_to_char;
|
|
use log::{debug, trace};
|
|
use std::collections::HashMap;
|
|
use std::rc::Rc;
|
|
use tfhe::integer::{IntegerCiphertext, RadixCiphertext, ServerKey};
|
|
|
|
#[derive(Clone, PartialEq, Eq, Hash)]
|
|
pub(crate) enum Executed {
|
|
Constant { c: u8 },
|
|
CtPos { at: usize },
|
|
And { a: Box<Executed>, b: Box<Executed> },
|
|
Or { a: Box<Executed>, b: Box<Executed> },
|
|
Equal { a: Box<Executed>, b: Box<Executed> },
|
|
GreaterOrEqual { a: Box<Executed>, b: Box<Executed> },
|
|
LessOrEqual { a: Box<Executed>, b: Box<Executed> },
|
|
Not { a: Box<Executed> },
|
|
}
|
|
type ExecutedResult = (RadixCiphertext, Executed);
|
|
|
|
impl Executed {
|
|
pub(crate) fn ct_pos(at: usize) -> Self {
|
|
Executed::CtPos { at }
|
|
}
|
|
|
|
fn get_trivial_constant(&self) -> Option<u8> {
|
|
match self {
|
|
Self::Constant { c } => Some(*c),
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
const CT_FALSE: u8 = 0;
|
|
const CT_TRUE: u8 = 1;
|
|
|
|
pub(crate) struct Execution {
|
|
sk: ServerKey,
|
|
cache: HashMap<Executed, RadixCiphertext>,
|
|
|
|
ct_ops: usize,
|
|
cache_hits: usize,
|
|
}
|
|
pub(crate) type LazyExecution = Rc<dyn Fn(&mut Execution) -> ExecutedResult>;
|
|
|
|
impl Execution {
|
|
pub(crate) fn new(sk: ServerKey) -> Self {
|
|
Self {
|
|
sk,
|
|
cache: HashMap::new(),
|
|
ct_ops: 0,
|
|
cache_hits: 0,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn ct_operations_count(&self) -> usize {
|
|
self.ct_ops
|
|
}
|
|
|
|
pub(crate) fn cache_hits(&self) -> usize {
|
|
self.cache_hits
|
|
}
|
|
|
|
pub(crate) fn ct_eq(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
|
let ctx = Executed::Equal {
|
|
a: Box::new(a.1.clone()),
|
|
b: Box::new(b.1.clone()),
|
|
};
|
|
self.with_cache(
|
|
ctx.clone(),
|
|
Rc::new(move |exec: &mut Execution| {
|
|
exec.ct_ops += 1;
|
|
|
|
let mut ct_a = a.0.clone();
|
|
let mut ct_b = b.0.clone();
|
|
(
|
|
exec.sk
|
|
.smart_eq(&mut ct_a, &mut ct_b)
|
|
.into_radix(ct_a.blocks().len(), &exec.sk),
|
|
ctx.clone(),
|
|
)
|
|
}),
|
|
)
|
|
}
|
|
|
|
pub(crate) fn ct_ge(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
|
let ctx = Executed::GreaterOrEqual {
|
|
a: Box::new(a.1.clone()),
|
|
b: Box::new(b.1.clone()),
|
|
};
|
|
self.with_cache(
|
|
ctx.clone(),
|
|
Rc::new(move |exec| {
|
|
exec.ct_ops += 1;
|
|
|
|
let mut ct_a = a.0.clone();
|
|
let mut ct_b = b.0.clone();
|
|
(
|
|
exec.sk
|
|
.smart_gt(&mut ct_a, &mut ct_b)
|
|
.into_radix(ct_a.blocks().len(), &exec.sk),
|
|
ctx.clone(),
|
|
)
|
|
}),
|
|
)
|
|
}
|
|
|
|
pub(crate) fn ct_le(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
|
let ctx = Executed::LessOrEqual {
|
|
a: Box::new(a.1.clone()),
|
|
b: Box::new(b.1.clone()),
|
|
};
|
|
self.with_cache(
|
|
ctx.clone(),
|
|
Rc::new(move |exec| {
|
|
exec.ct_ops += 1;
|
|
|
|
let mut ct_a = a.0.clone();
|
|
let mut ct_b = b.0.clone();
|
|
(
|
|
exec.sk
|
|
.smart_le(&mut ct_a, &mut ct_b)
|
|
.into_radix(ct_a.blocks().len(), &exec.sk),
|
|
ctx.clone(),
|
|
)
|
|
}),
|
|
)
|
|
}
|
|
|
|
pub(crate) fn ct_and(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
|
let ctx = Executed::And {
|
|
a: Box::new(a.1.clone()),
|
|
b: Box::new(b.1.clone()),
|
|
};
|
|
|
|
let c_a = a.1.get_trivial_constant();
|
|
let c_b = b.1.get_trivial_constant();
|
|
if c_a == Some(CT_TRUE) {
|
|
return (b.0, ctx);
|
|
}
|
|
if c_a == Some(CT_FALSE) {
|
|
return (a.0, ctx);
|
|
}
|
|
if c_b == Some(CT_TRUE) {
|
|
return (a.0, ctx);
|
|
}
|
|
if c_b == Some(CT_FALSE) {
|
|
return (b.0, ctx);
|
|
}
|
|
|
|
self.with_cache(
|
|
ctx.clone(),
|
|
Rc::new(move |exec| {
|
|
exec.ct_ops += 1;
|
|
|
|
let mut ct_a = a.0.clone();
|
|
let mut ct_b = b.0.clone();
|
|
(exec.sk.smart_bitand(&mut ct_a, &mut ct_b), ctx.clone())
|
|
}),
|
|
)
|
|
}
|
|
|
|
pub(crate) fn ct_or(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
|
let ctx = Executed::Or {
|
|
a: Box::new(a.1.clone()),
|
|
b: Box::new(b.1.clone()),
|
|
};
|
|
|
|
let c_a = a.1.get_trivial_constant();
|
|
let c_b = b.1.get_trivial_constant();
|
|
if c_a == Some(CT_TRUE) {
|
|
return (a.0, ctx);
|
|
}
|
|
if c_b == Some(CT_TRUE) {
|
|
return (b.0, ctx);
|
|
}
|
|
if c_a == Some(CT_FALSE) && c_b == Some(CT_FALSE) {
|
|
return (a.0, ctx);
|
|
}
|
|
|
|
self.with_cache(
|
|
ctx.clone(),
|
|
Rc::new(move |exec| {
|
|
exec.ct_ops += 1;
|
|
|
|
let mut ct_a = a.0.clone();
|
|
let mut ct_b = b.0.clone();
|
|
(exec.sk.smart_bitor(&mut ct_a, &mut ct_b), ctx.clone())
|
|
}),
|
|
)
|
|
}
|
|
|
|
pub(crate) fn ct_not(&mut self, a: ExecutedResult) -> ExecutedResult {
|
|
let ctx = Executed::Not {
|
|
a: Box::new(a.1.clone()),
|
|
};
|
|
self.with_cache(
|
|
ctx.clone(),
|
|
Rc::new(move |exec| {
|
|
exec.ct_ops += 1;
|
|
|
|
let mut ct_a = a.0.clone();
|
|
let mut ct_b = exec.ct_constant(1).0;
|
|
(exec.sk.smart_bitxor(&mut ct_a, &mut ct_b), ctx.clone())
|
|
}),
|
|
)
|
|
}
|
|
|
|
pub(crate) fn ct_false(&self) -> ExecutedResult {
|
|
self.ct_constant(CT_FALSE)
|
|
}
|
|
|
|
pub(crate) fn ct_true(&self) -> ExecutedResult {
|
|
self.ct_constant(CT_TRUE)
|
|
}
|
|
|
|
pub(crate) fn ct_constant(&self, c: u8) -> ExecutedResult {
|
|
(
|
|
self.sk.create_trivial_radix(c as u64, 4),
|
|
Executed::Constant { c },
|
|
)
|
|
}
|
|
|
|
fn with_cache(&mut self, ctx: Executed, f: LazyExecution) -> ExecutedResult {
|
|
if let Some(res) = self.cache.get(&ctx) {
|
|
trace!("cache hit: {:?}", &ctx);
|
|
self.cache_hits += 1;
|
|
return (res.clone(), ctx);
|
|
}
|
|
debug!("evaluation for: {:?}", &ctx);
|
|
let res = f(self);
|
|
self.cache.insert(ctx, res.0.clone());
|
|
res
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for Executed {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
match self {
|
|
Self::Constant { c } => match c {
|
|
0 => write!(f, "f"),
|
|
1 => write!(f, "t"),
|
|
_ => write!(f, "{}", u8_to_char(*c)),
|
|
},
|
|
Self::CtPos { at } => write!(f, "ct_{at}"),
|
|
Self::And { a, b } => {
|
|
write!(f, "(")?;
|
|
a.fmt(f)?;
|
|
write!(f, "/\\")?;
|
|
b.fmt(f)?;
|
|
write!(f, ")")
|
|
}
|
|
Self::Or { a, b } => {
|
|
write!(f, "(")?;
|
|
a.fmt(f)?;
|
|
write!(f, "\\/")?;
|
|
b.fmt(f)?;
|
|
write!(f, ")")
|
|
}
|
|
Self::Equal { a, b } => {
|
|
write!(f, "(")?;
|
|
a.fmt(f)?;
|
|
write!(f, "==")?;
|
|
b.fmt(f)?;
|
|
write!(f, ")")
|
|
}
|
|
Self::GreaterOrEqual { a, b } => {
|
|
write!(f, "(")?;
|
|
a.fmt(f)?;
|
|
write!(f, ">=")?;
|
|
b.fmt(f)?;
|
|
write!(f, ")")
|
|
}
|
|
Self::LessOrEqual { a, b } => {
|
|
write!(f, "(")?;
|
|
a.fmt(f)?;
|
|
write!(f, "<=")?;
|
|
b.fmt(f)?;
|
|
write!(f, ")")
|
|
}
|
|
Self::Not { a } => {
|
|
write!(f, "(!")?;
|
|
a.fmt(f)?;
|
|
write!(f, ")")
|
|
}
|
|
}
|
|
}
|
|
}
|