diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index 1f1970a..e7e225a 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -109,9 +109,9 @@ def _resample_frames(frames, resampler): yield from resampler.resample(frame) -def pad_or_trim(array, length: int, *, axis: int = -1): +def pad_or_trim(array, length: int = 3000, *, axis: int = -1): """ - Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + Pad or trim the Mel features array to 3000, as expected by the encoder. """ axis = axis % array.ndim if array.shape[axis] > length: diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 7611237..da23d50 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -441,9 +441,12 @@ class BatchedInferencePipeline: features = ( torch.stack( [ - self.model.feature_extractor(chunk, to_cpu=to_cpu)[ - ..., : self.model.feature_extractor.nb_max_frames - ] + pad_or_trim( + self.model.feature_extractor(chunk, to_cpu=to_cpu)[ + ..., + : chunk.shape[0] // self.model.feature_extractor.hop_length, + ] + ) for chunk in audio_chunks ] ) @@ -847,7 +850,7 @@ class WhisperModel: segment = features[ :, seek : seek + self.feature_extractor.nb_max_frames ] - encoder_output = self.encode(segment) + encoder_output = self.encode(pad_or_trim(segment)) # results is a list of tuple[str, float] with language names and # probabilities. results = self.model.detect_language(encoder_output)[0] @@ -1105,7 +1108,7 @@ class WhisperModel: ) segment = features[:, seek : seek + segment_size] segment_duration = segment_size * self.feature_extractor.time_per_frame - segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames) + segment = pad_or_trim(segment) if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug( @@ -1766,7 +1769,7 @@ class WhisperModel: segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ :, : self.feature_extractor.nb_max_frames ] - encoder_output = self.encode(segment) + encoder_output = self.encode(pad_or_trim(segment)) results = self.model.detect_language(encoder_output) language_token, language_probability = results[0][0] language = language_token[2:-2] @@ -1895,7 +1898,7 @@ class WhisperModel: for i in indices: segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] try: - encoder_output = self.encode(segment_features) + encoder_output = self.encode(pad_or_trim(segment_features)) results = self.model.detect_language(encoder_output)[0] except ValueError as e: # or RuntimeError