mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
refactor(fhe_strings): add len function
This commit is contained in:
@@ -170,6 +170,10 @@ impl FheString {
|
||||
self.padded = true;
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.chars().len()
|
||||
}
|
||||
|
||||
pub fn empty() -> FheString {
|
||||
FheString {
|
||||
enc_string: vec![],
|
||||
|
||||
@@ -4,8 +4,8 @@ use tfhe::integer::BooleanBlock;
|
||||
|
||||
impl ServerKey {
|
||||
fn eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option<BooleanBlock> {
|
||||
let lhs_len = lhs.chars().len();
|
||||
let rhs_len = rhs.chars().len();
|
||||
let lhs_len = lhs.len();
|
||||
let rhs_len = rhs.len();
|
||||
|
||||
// If lhs is empty, rhs must also be empty in order to be equal (the case where lhs is
|
||||
// empty with > 1 padding zeros is handled next)
|
||||
@@ -28,14 +28,14 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
// Two strings without padding that have different lengths cannot be equal
|
||||
if (!lhs.is_padded() && !rhs.is_padded()) && (lhs.chars().len() != rhs.chars().len()) {
|
||||
if (!lhs.is_padded() && !rhs.is_padded()) && (lhs.len() != rhs.len()) {
|
||||
return Some(self.key.create_trivial_boolean_block(false));
|
||||
}
|
||||
|
||||
// A string without padding cannot be equal to a string with padding that has the same or
|
||||
// lower length
|
||||
if (!lhs.is_padded() && rhs.is_padded()) && (rhs.chars().len() <= lhs.chars().len())
|
||||
|| (!rhs.is_padded() && lhs.is_padded()) && (lhs.chars().len() <= rhs.chars().len())
|
||||
if (!lhs.is_padded() && rhs.is_padded()) && (rhs.len() <= lhs.len())
|
||||
|| (!rhs.is_padded() && lhs.is_padded()) && (lhs.len() <= rhs.len())
|
||||
{
|
||||
return Some(self.key.create_trivial_boolean_block(false));
|
||||
}
|
||||
|
||||
@@ -260,7 +260,7 @@ impl ServerKey {
|
||||
|
||||
// If the shifting amount is >= than the str length we get zero i.e. all chars are out of
|
||||
// range (instead of wrapping, which is the behavior of Rust and tfhe-rs)
|
||||
let bit_len = (str.chars().len() * 8) as u32;
|
||||
let bit_len = (str.len() * 8) as u32;
|
||||
let shift_ge_than_str = self.key.scalar_ge_parallelized(&shift_bits, bit_len);
|
||||
|
||||
let result = self.key.if_then_else_parallelized(
|
||||
@@ -283,7 +283,7 @@ impl ServerKey {
|
||||
|
||||
// If the shifting amount is >= than the str length we get zero i.e. all chars are out of
|
||||
// range (instead of wrapping, which is the behavior of Rust and tfhe-rs)
|
||||
let bit_len = (str.chars().len() * 8) as u32;
|
||||
let bit_len = (str.len() * 8) as u32;
|
||||
let shift_ge_than_str = self.key.scalar_ge_parallelized(&shift_bits, bit_len);
|
||||
|
||||
let result = self.key.if_then_else_parallelized(
|
||||
|
||||
@@ -59,7 +59,7 @@ impl ServerKey {
|
||||
|
||||
FheStringLen::Padding(len)
|
||||
} else {
|
||||
FheStringLen::NoPadding(str.chars().len())
|
||||
FheStringLen::NoPadding(str.len())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ impl ServerKey {
|
||||
/// ```
|
||||
pub fn is_empty(&self, str: &FheString) -> FheStringIsEmpty {
|
||||
if str.is_padded() {
|
||||
if str.chars().len() == 1 {
|
||||
if str.len() == 1 {
|
||||
return FheStringIsEmpty::Padding(self.key.create_trivial_boolean_block(true));
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ impl ServerKey {
|
||||
|
||||
FheStringIsEmpty::Padding(result)
|
||||
} else {
|
||||
FheStringIsEmpty::NoPadding(str.chars().is_empty())
|
||||
FheStringIsEmpty::NoPadding(str.len() == 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,7 +295,7 @@ impl ServerKey {
|
||||
// If lhs is padded we can shift it right such that all nulls move to the start, then
|
||||
// we append the rhs and shift it left again to move the nulls to the new end
|
||||
FheStringLen::Padding(len) => {
|
||||
let padded_len = self.key.create_trivial_radix(lhs.chars().len() as u32, 16);
|
||||
let padded_len = self.key.create_trivial_radix(lhs.len() as u32, 16);
|
||||
let number_of_nulls = self.key.sub_parallelized(&padded_len, &len);
|
||||
|
||||
result = self.right_shift_chars(&result, &number_of_nulls);
|
||||
@@ -348,7 +348,7 @@ impl ServerKey {
|
||||
return FheString::empty();
|
||||
}
|
||||
|
||||
let str_len = str.chars().len();
|
||||
let str_len = str.len();
|
||||
if str_len == 0 || (str.is_padded() && str_len == 1) {
|
||||
return FheString::empty();
|
||||
}
|
||||
|
||||
@@ -186,8 +186,8 @@ impl ServerKey {
|
||||
};
|
||||
}
|
||||
|
||||
let str_len = str.chars().len();
|
||||
let pat_len = trivial_or_enc_pat.chars().len();
|
||||
let str_len = str.len();
|
||||
let pat_len = trivial_or_enc_pat.len();
|
||||
|
||||
// In the padded pattern case we can remove the last char (as it's always null)
|
||||
let pat_chars = &trivial_or_enc_pat.chars()[..pat_len - 1];
|
||||
|
||||
@@ -224,7 +224,7 @@ impl ServerKey {
|
||||
|
||||
let ignore_pat_pad = trivial_or_enc_pat.is_padded();
|
||||
|
||||
let str_len = str.chars().len();
|
||||
let str_len = str.len();
|
||||
let (null, ext_iter) = if !str.is_padded() && trivial_or_enc_pat.is_padded() {
|
||||
(Some(FheAsciiChar::null(self)), Some(0..str_len + 1))
|
||||
} else {
|
||||
|
||||
@@ -21,8 +21,8 @@ enum IsMatch {
|
||||
// methods below contain logic for the different cases
|
||||
impl ServerKey {
|
||||
fn length_checks(&self, str: &FheString, pat: &FheString) -> IsMatch {
|
||||
let pat_len = pat.chars().len();
|
||||
let str_len = str.chars().len();
|
||||
let pat_len = pat.len();
|
||||
let str_len = str.len();
|
||||
|
||||
// If the pattern is empty it will match any string, this is the behavior of core::str
|
||||
// Note that this doesn't handle the case where pattern is empty and has > 1 padding zeros
|
||||
@@ -62,8 +62,8 @@ impl ServerKey {
|
||||
pat: &'a FheString,
|
||||
null: Option<&'a FheAsciiChar>,
|
||||
) -> (CharIter<'a>, CharIter<'a>, Range<usize>) {
|
||||
let pat_len = pat.chars().len();
|
||||
let str_len = str.chars().len();
|
||||
let pat_len = pat.len();
|
||||
let str_len = str.len();
|
||||
|
||||
match (str.is_padded(), pat.is_padded()) {
|
||||
// If neither has padding we just check if pat matches the `pat_len` last chars or str
|
||||
@@ -140,7 +140,7 @@ impl ServerKey {
|
||||
pat: &str,
|
||||
) -> (CharIter<'a>, String, Range<usize>) {
|
||||
let pat_len = pat.len();
|
||||
let str_len = str.chars().len();
|
||||
let str_len = str.len();
|
||||
|
||||
if str.is_padded() {
|
||||
let str_chars = str.chars()[..str_len - 1].iter();
|
||||
@@ -166,8 +166,8 @@ impl ServerKey {
|
||||
pat: &'a FheString,
|
||||
null: Option<&'a FheAsciiChar>,
|
||||
) -> (CharIter<'a>, CharIter<'a>, Range<usize>) {
|
||||
let pat_len = pat.chars().len();
|
||||
let str_len = str.chars().len();
|
||||
let pat_len = pat.len();
|
||||
let str_len = str.len();
|
||||
|
||||
match (str.is_padded(), pat.is_padded()) {
|
||||
(_, false) => {
|
||||
|
||||
@@ -30,7 +30,7 @@ impl ServerKey {
|
||||
|
||||
let (mut replaced, rhs) = rayon::join(
|
||||
|| {
|
||||
let str_len = self.key.create_trivial_radix(str.chars().len() as u32, 16);
|
||||
let str_len = self.key.create_trivial_radix(str.len() as u32, 16);
|
||||
|
||||
// Get the [lhs] shifting right by [from, rhs].len()
|
||||
let shift_right = self.key.sub_parallelized(&str_len, find_index);
|
||||
@@ -200,7 +200,7 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
fn max_matches(&self, str: &FheString, pat: &FheString) -> u16 {
|
||||
let str_len = str.chars().len() - if str.is_padded() { 1 } else { 0 };
|
||||
let str_len = str.len() - if str.is_padded() { 1 } else { 0 };
|
||||
|
||||
// Max number of matches is str_len + 1 when pattern is empty
|
||||
let mut max: u16 = (str_len + 1).try_into().expect("str should be shorter");
|
||||
@@ -209,7 +209,7 @@ impl ServerKey {
|
||||
// str_len - pat_len + 1. For instance "xx" matches "xxxx" at most 4 - 2 + 1 = 3 times.
|
||||
// This works as long as str_len >= pat_len (guaranteed due to the outer length checks)
|
||||
if !pat.is_padded() {
|
||||
let pat_len = pat.chars().len() as u16;
|
||||
let pat_len = pat.len() as u16;
|
||||
max = str_len as u16 - pat_len + 1;
|
||||
}
|
||||
|
||||
@@ -282,7 +282,7 @@ impl ServerKey {
|
||||
|
||||
IsMatch::Clear(true) => {
|
||||
// If `from` is empty and str too, there's only one match and one replacement
|
||||
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
|
||||
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
|
||||
if let UIntArg::Clear(_) = count {
|
||||
return to.clone();
|
||||
}
|
||||
@@ -390,7 +390,7 @@ impl ServerKey {
|
||||
IsMatch::Clear(false) => return result,
|
||||
IsMatch::Clear(true) => {
|
||||
// If `from` is empty and str too, there's only one match and one replacement
|
||||
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
|
||||
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
|
||||
return to.clone();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ impl ServerKey {
|
||||
index: &RadixCiphertext,
|
||||
inclusive: bool,
|
||||
) -> (FheString, FheString) {
|
||||
let str_len = self.key.create_trivial_radix(str.chars().len() as u32, 16);
|
||||
let str_len = self.key.create_trivial_radix(str.len() as u32, 16);
|
||||
let trivial_or_enc_pat = match pat {
|
||||
GenericPattern::Clear(pat) => FheString::trivial(self, pat.str()),
|
||||
GenericPattern::Enc(pat) => pat.clone(),
|
||||
|
||||
@@ -11,7 +11,7 @@ pub struct SplitAsciiWhitespace {
|
||||
|
||||
impl FheStringIterator for SplitAsciiWhitespace {
|
||||
fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock) {
|
||||
let str_len = self.state.chars().len();
|
||||
let str_len = self.state.len();
|
||||
|
||||
if str_len == 0 || (self.state.is_padded() && str_len == 1) {
|
||||
return (
|
||||
@@ -206,7 +206,7 @@ impl ServerKey {
|
||||
pub fn trim_start(&self, str: &FheString) -> FheString {
|
||||
let mut result = str.clone();
|
||||
|
||||
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
|
||||
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -262,7 +262,7 @@ impl ServerKey {
|
||||
pub fn trim_end(&self, str: &FheString) -> FheString {
|
||||
let mut result = str.clone();
|
||||
|
||||
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
|
||||
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ impl ServerKey {
|
||||
/// assert_eq!(trimmed, "hello world"); // Whitespace at both ends is removed
|
||||
/// ```
|
||||
pub fn trim(&self, str: &FheString) -> FheString {
|
||||
if str.chars().is_empty() || (str.is_padded() && str.chars().len() == 1) {
|
||||
if str.len() == 0 || (str.is_padded() && str.len() == 1) {
|
||||
return str.clone();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user