remove log_prob_low_threshold (#1160)

This commit is contained in:
Mahmoud Ashraf
2024-11-20 23:03:21 +02:00
committed by GitHub
parent 9c8ef76c98
commit 08f6900217

View File

@@ -77,7 +77,6 @@ class TranscriptionOptions:
repetition_penalty: float
no_repeat_ngram_size: int
log_prob_threshold: Optional[float]
log_prob_low_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
@@ -275,7 +274,6 @@ class BatchedInferencePipeline:
],
compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: Optional[float] = -1.0,
log_prob_low_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
@@ -356,9 +354,6 @@ class BatchedInferencePipeline:
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
@@ -490,7 +485,6 @@ class BatchedInferencePipeline:
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
log_prob_low_threshold=log_prob_low_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
temperatures=(
@@ -636,12 +630,10 @@ class WhisperModel:
local_files_only=local_files_only,
cache_dir=download_root,
)
self.device = device
# set the random seed to make sure consistency across runs
ctranslate2.set_random_seed(42)
self.model = ctranslate2.models.Whisper(
model_path,
device=self.device,
device=device,
device_index=device_index,
compute_type=compute_type,
intra_threads=cpu_threads,
@@ -719,7 +711,6 @@ class WhisperModel:
],
compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: Optional[float] = -1.0,
log_prob_low_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
@@ -766,9 +757,6 @@ class WhisperModel:
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
wheras log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
@@ -820,7 +808,6 @@ class WhisperModel:
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""
sampling_rate = self.feature_extractor.sampling_rate
if multilingual and not self.model.is_multilingual:
@@ -933,7 +920,6 @@ class WhisperModel:
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
log_prob_low_threshold=log_prob_low_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text,
@@ -977,6 +963,7 @@ class WhisperModel:
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
return segments, info
def _split_segments_by_timestamps(
@@ -1188,18 +1175,6 @@ class WhisperModel:
options.no_speech_threshold,
)
# Skip if the logprob is very low (below the threshold value),
# despite no_speech_prob being low (ex: Too ambiguous outputs)
if options.log_prob_low_threshold:
if avg_logprob < options.log_prob_low_threshold:
should_skip = True
self.logger.debug(
"log prob low threshold is met (%f > %f)",
avg_logprob,
options.log_prob_low_threshold,
)
if should_skip:
# fast-forward to the next segment boundary
seek += segment_size
continue