mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 21:48:08 -05:00
Use correct features padding for encoder input (#1101)
* pad to 3000 instead of `feature_extractor.nb_max_frames` * correct trimming for batched features
This commit is contained in:
@@ -109,9 +109,9 @@ def _resample_frames(frames, resampler):
|
|||||||
yield from resampler.resample(frame)
|
yield from resampler.resample(frame)
|
||||||
|
|
||||||
|
|
||||||
def pad_or_trim(array, length: int, *, axis: int = -1):
|
def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
|
||||||
"""
|
"""
|
||||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
Pad or trim the Mel features array to 3000, as expected by the encoder.
|
||||||
"""
|
"""
|
||||||
axis = axis % array.ndim
|
axis = axis % array.ndim
|
||||||
if array.shape[axis] > length:
|
if array.shape[axis] > length:
|
||||||
|
|||||||
@@ -441,9 +441,12 @@ class BatchedInferencePipeline:
|
|||||||
features = (
|
features = (
|
||||||
torch.stack(
|
torch.stack(
|
||||||
[
|
[
|
||||||
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
|
pad_or_trim(
|
||||||
..., : self.model.feature_extractor.nb_max_frames
|
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
|
||||||
]
|
...,
|
||||||
|
: chunk.shape[0] // self.model.feature_extractor.hop_length,
|
||||||
|
]
|
||||||
|
)
|
||||||
for chunk in audio_chunks
|
for chunk in audio_chunks
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -847,7 +850,7 @@ class WhisperModel:
|
|||||||
segment = features[
|
segment = features[
|
||||||
:, seek : seek + self.feature_extractor.nb_max_frames
|
:, seek : seek + self.feature_extractor.nb_max_frames
|
||||||
]
|
]
|
||||||
encoder_output = self.encode(segment)
|
encoder_output = self.encode(pad_or_trim(segment))
|
||||||
# results is a list of tuple[str, float] with language names and
|
# results is a list of tuple[str, float] with language names and
|
||||||
# probabilities.
|
# probabilities.
|
||||||
results = self.model.detect_language(encoder_output)[0]
|
results = self.model.detect_language(encoder_output)[0]
|
||||||
@@ -1105,7 +1108,7 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
segment = features[:, seek : seek + segment_size]
|
segment = features[:, seek : seek + segment_size]
|
||||||
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
||||||
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
|
segment = pad_or_trim(segment)
|
||||||
|
|
||||||
if self.logger.isEnabledFor(logging.DEBUG):
|
if self.logger.isEnabledFor(logging.DEBUG):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@@ -1766,7 +1769,7 @@ class WhisperModel:
|
|||||||
segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
|
segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
|
||||||
:, : self.feature_extractor.nb_max_frames
|
:, : self.feature_extractor.nb_max_frames
|
||||||
]
|
]
|
||||||
encoder_output = self.encode(segment)
|
encoder_output = self.encode(pad_or_trim(segment))
|
||||||
results = self.model.detect_language(encoder_output)
|
results = self.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
@@ -1895,7 +1898,7 @@ class WhisperModel:
|
|||||||
for i in indices:
|
for i in indices:
|
||||||
segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames]
|
segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames]
|
||||||
try:
|
try:
|
||||||
encoder_output = self.encode(segment_features)
|
encoder_output = self.encode(pad_or_trim(segment_features))
|
||||||
results = self.model.detect_language(encoder_output)[0]
|
results = self.model.detect_language(encoder_output)[0]
|
||||||
|
|
||||||
except ValueError as e: # or RuntimeError
|
except ValueError as e: # or RuntimeError
|
||||||
|
|||||||
Reference in New Issue
Block a user