refactor(fhe_strings): add len function

This commit is contained in:
Mayeul@Zama
2024-10-17 13:34:25 +02:00
committed by mayeul-zama
parent aebc2619b2
commit 27e34a835c
10 changed files with 36 additions and 32 deletions

View File

@@ -170,6 +170,10 @@ impl FheString {
self.padded = true;
}
pub fn len(&self) -> usize {
self.chars().len()
}
pub fn empty() -> FheString {
FheString {
enc_string: vec![],

View File

@@ -4,8 +4,8 @@ use tfhe::integer::BooleanBlock;
impl ServerKey {
fn eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option<BooleanBlock> {
let lhs_len = lhs.chars().len();
let rhs_len = rhs.chars().len();
let lhs_len = lhs.len();
let rhs_len = rhs.len();
// If lhs is empty, rhs must also be empty in order to be equal (the case where lhs is
// empty with > 1 padding zeros is handled next)
@@ -28,14 +28,14 @@ impl ServerKey {
}
// Two strings without padding that have different lengths cannot be equal
if (!lhs.is_padded() && !rhs.is_padded()) && (lhs.chars().len() != rhs.chars().len()) {
if (!lhs.is_padded() && !rhs.is_padded()) && (lhs.len() != rhs.len()) {
return Some(self.key.create_trivial_boolean_block(false));
}
// A string without padding cannot be equal to a string with padding that has the same or
// lower length
if (!lhs.is_padded() && rhs.is_padded()) && (rhs.chars().len() <= lhs.chars().len())
|| (!rhs.is_padded() && lhs.is_padded()) && (lhs.chars().len() <= rhs.chars().len())
if (!lhs.is_padded() && rhs.is_padded()) && (rhs.len() <= lhs.len())
|| (!rhs.is_padded() && lhs.is_padded()) && (lhs.len() <= rhs.len())
{
return Some(self.key.create_trivial_boolean_block(false));
}

View File

@@ -260,7 +260,7 @@ impl ServerKey {
// If the shifting amount is >= than the str length we get zero i.e. all chars are out of
// range (instead of wrapping, which is the behavior of Rust and tfhe-rs)
let bit_len = (str.chars().len() * 8) as u32;
let bit_len = (str.len() * 8) as u32;
let shift_ge_than_str = self.key.scalar_ge_parallelized(&shift_bits, bit_len);
let result = self.key.if_then_else_parallelized(
@@ -283,7 +283,7 @@ impl ServerKey {
// If the shifting amount is >= than the str length we get zero i.e. all chars are out of
// range (instead of wrapping, which is the behavior of Rust and tfhe-rs)
let bit_len = (str.chars().len() * 8) as u32;
let bit_len = (str.len() * 8) as u32;
let shift_ge_than_str = self.key.scalar_ge_parallelized(&shift_bits, bit_len);
let result = self.key.if_then_else_parallelized(

View File

@@ -59,7 +59,7 @@ impl ServerKey {
FheStringLen::Padding(len)
} else {
FheStringLen::NoPadding(str.chars().len())
FheStringLen::NoPadding(str.len())
}
}
@@ -101,7 +101,7 @@ impl ServerKey {
/// ```
pub fn is_empty(&self, str: &FheString) -> FheStringIsEmpty {
if str.is_padded() {
if str.chars().len() == 1 {
if str.len() == 1 {
return FheStringIsEmpty::Padding(self.key.create_trivial_boolean_block(true));
}
@@ -110,7 +110,7 @@ impl ServerKey {
FheStringIsEmpty::Padding(result)
} else {
FheStringIsEmpty::NoPadding(str.chars().is_empty())
FheStringIsEmpty::NoPadding(str.len() == 0)
}
}
@@ -295,7 +295,7 @@ impl ServerKey {
// If lhs is padded we can shift it right such that all nulls move to the start, then
// we append the rhs and shift it left again to move the nulls to the new end
FheStringLen::Padding(len) => {
let padded_len = self.key.create_trivial_radix(lhs.chars().len() as u32, 16);
let padded_len = self.key.create_trivial_radix(lhs.len() as u32, 16);
let number_of_nulls = self.key.sub_parallelized(&padded_len, &len);
result = self.right_shift_chars(&result, &number_of_nulls);
@@ -348,7 +348,7 @@ impl ServerKey {
return FheString::empty();
}
let str_len = str.chars().len();
let str_len = str.len();
if str_len == 0 || (str.is_padded() && str_len == 1) {
return FheString::empty();
}

View File

@@ -186,8 +186,8 @@ impl ServerKey {
};
}
let str_len = str.chars().len();
let pat_len = trivial_or_enc_pat.chars().len();
let str_len = str.len();
let pat_len = trivial_or_enc_pat.len();
// In the padded pattern case we can remove the last char (as it's always null)
let pat_chars = &trivial_or_enc_pat.chars()[..pat_len - 1];

View File

@@ -224,7 +224,7 @@ impl ServerKey {
let ignore_pat_pad = trivial_or_enc_pat.is_padded();
let str_len = str.chars().len();
let str_len = str.len();
let (null, ext_iter) = if !str.is_padded() && trivial_or_enc_pat.is_padded() {
(Some(FheAsciiChar::null(self)), Some(0..str_len + 1))
} else {

View File

@@ -21,8 +21,8 @@ enum IsMatch {
// methods below contain logic for the different cases
impl ServerKey {
fn length_checks(&self, str: &FheString, pat: &FheString) -> IsMatch {
let pat_len = pat.chars().len();
let str_len = str.chars().len();
let pat_len = pat.len();
let str_len = str.len();
// If the pattern is empty it will match any string, this is the behavior of core::str
// Note that this doesn't handle the case where pattern is empty and has > 1 padding zeros
@@ -62,8 +62,8 @@ impl ServerKey {
pat: &'a FheString,
null: Option<&'a FheAsciiChar>,
) -> (CharIter<'a>, CharIter<'a>, Range<usize>) {
let pat_len = pat.chars().len();
let str_len = str.chars().len();
let pat_len = pat.len();
let str_len = str.len();
match (str.is_padded(), pat.is_padded()) {
// If neither has padding we just check if pat matches the `pat_len` last chars or str
@@ -140,7 +140,7 @@ impl ServerKey {
pat: &str,
) -> (CharIter<'a>, String, Range<usize>) {
let pat_len = pat.len();
let str_len = str.chars().len();
let str_len = str.len();
if str.is_padded() {
let str_chars = str.chars()[..str_len - 1].iter();
@@ -166,8 +166,8 @@ impl ServerKey {
pat: &'a FheString,
null: Option<&'a FheAsciiChar>,
) -> (CharIter<'a>, CharIter<'a>, Range<usize>) {
let pat_len = pat.chars().len();
let str_len = str.chars().len();
let pat_len = pat.len();
let str_len = str.len();
match (str.is_padded(), pat.is_padded()) {
(_, false) => {

View File

@@ -30,7 +30,7 @@ impl ServerKey {
let (mut replaced, rhs) = rayon::join(
|| {
let str_len = self.key.create_trivial_radix(str.chars().len() as u32, 16);
let str_len = self.key.create_trivial_radix(str.len() as u32, 16);
// Get the [lhs] shifting right by [from, rhs].len()
let shift_right = self.key.sub_parallelized(&str_len, find_index);
@@ -200,7 +200,7 @@ impl ServerKey {
}
fn max_matches(&self, str: &FheString, pat: &FheString) -> u16 {
let str_len = str.chars().len() - if str.is_padded() { 1 } else { 0 };
let str_len = str.len() - if str.is_padded() { 1 } else { 0 };
// Max number of matches is str_len + 1 when pattern is empty
let mut max: u16 = (str_len + 1).try_into().expect("str should be shorter");
@@ -209,7 +209,7 @@ impl ServerKey {
// str_len - pat_len + 1. For instance "xx" matches "xxxx" at most 4 - 2 + 1 = 3 times.
// This works as long as str_len >= pat_len (guaranteed due to the outer length checks)
if !pat.is_padded() {
let pat_len = pat.chars().len() as u16;
let pat_len = pat.len() as u16;
max = str_len as u16 - pat_len + 1;
}
@@ -282,7 +282,7 @@ impl ServerKey {
IsMatch::Clear(true) => {
// If `from` is empty and str too, there's only one match and one replacement
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
if let UIntArg::Clear(_) = count {
return to.clone();
}
@@ -390,7 +390,7 @@ impl ServerKey {
IsMatch::Clear(false) => return result,
IsMatch::Clear(true) => {
// If `from` is empty and str too, there's only one match and one replacement
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
return to.clone();
}
}

View File

@@ -13,7 +13,7 @@ impl ServerKey {
index: &RadixCiphertext,
inclusive: bool,
) -> (FheString, FheString) {
let str_len = self.key.create_trivial_radix(str.chars().len() as u32, 16);
let str_len = self.key.create_trivial_radix(str.len() as u32, 16);
let trivial_or_enc_pat = match pat {
GenericPattern::Clear(pat) => FheString::trivial(self, pat.str()),
GenericPattern::Enc(pat) => pat.clone(),

View File

@@ -11,7 +11,7 @@ pub struct SplitAsciiWhitespace {
impl FheStringIterator for SplitAsciiWhitespace {
fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) {
let str_len = self.state.chars().len();
let str_len = self.state.len();
if str_len == 0 || (self.state.is_padded() && str_len == 1) {
return (
@@ -206,7 +206,7 @@ impl ServerKey {
pub fn trim_start(&self, str: &FheString) -> FheString {
let mut result = str.clone();
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
return result;
}
@@ -262,7 +262,7 @@ impl ServerKey {
pub fn trim_end(&self, str: &FheString) -> FheString {
let mut result = str.clone();
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
return result;
}
@@ -300,7 +300,7 @@ impl ServerKey {
/// assert_eq!(trimmed, "hello world"); // Whitespace at both ends is removed
/// ```
pub fn trim(&self, str: &FheString) -> FheString {
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
return str.clone();
}