diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index b077d8b..421cefc 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -767,15 +767,14 @@ def restore_speech_timestamps( for segment in segments: if segment.words: - words = [] - for word in segment.words: - # Ensure the word start and end times are resolved to the same chunk. - chunk_index = ts_map.get_chunk_index(word.start) - word = word._replace( - start=ts_map.get_original_time(word.start, chunk_index), - end=ts_map.get_original_time(word.end, chunk_index), - ) - words.append(word) + timestamps = ts_map.fit_words_timestamps( + [{"start": word.start, "end": word.end} for word in segment.words] + ) + + words = [ + word._replace(start=timestamp["start"], end=timestamp["end"]) + for word, timestamp in zip(segment.words, timestamps) + ] segment = segment._replace( start=words[0].start, diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 080795d..d62591a 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -1,5 +1,7 @@ import bisect +import collections import functools +import itertools import os import warnings @@ -144,6 +146,9 @@ def get_speech_timestamps( speeches.append(current_speech) for i, speech in enumerate(speeches): + speech["start_no_pad"] = speech["start"] + speech["end_no_pad"] = speech["end"] + if i == 0: speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) if i != len(speeches) - 1: @@ -184,6 +189,7 @@ class SpeechTimestampsMap: self.time_precision = time_precision self.chunk_end_sample = [] self.total_silence_before = [] + self.speech_chunks = [] previous_end = 0 silent_samples = 0 @@ -194,6 +200,12 @@ class SpeechTimestampsMap: self.chunk_end_sample.append(chunk["end"] - silent_samples) self.total_silence_before.append(silent_samples / sampling_rate) + self.speech_chunks.append( + { + "start": chunk["start_no_pad"] / sampling_rate, + "end": chunk["end_no_pad"] / sampling_rate, + } + ) def get_original_time( self, @@ -213,6 +225,43 @@ class SpeechTimestampsMap: len(self.chunk_end_sample) - 1, ) + def fit_words_timestamps(self, timestamps: List[dict]) -> List[dict]: + chunk_timestamps = collections.OrderedDict() + + for timestamp in timestamps: + # Ensure the word start and end times are resolved to the same chunk. + chunk_index = self.get_chunk_index(timestamp["start"]) + timestamp["start"] = self.get_original_time(timestamp["start"], chunk_index) + timestamp["end"] = self.get_original_time(timestamp["end"], chunk_index) + + if chunk_index not in chunk_timestamps: + chunk_timestamps[chunk_index] = [timestamp] + else: + chunk_timestamps[chunk_index].append(timestamp) + + for chunk_index, timestamps in chunk_timestamps.items(): + speech_chunk = self.speech_chunks[chunk_index] + + speech_offset = speech_chunk["start"] + speech_duration = speech_chunk["end"] - speech_offset + + current_offset = timestamps[0]["start"] + current_duration = timestamps[-1]["end"] - current_offset + + scale = speech_duration / current_duration + + for timestamp in timestamps: + timestamp["start"] = round( + (timestamp["start"] - current_offset) * scale + speech_offset, + self.time_precision, + ) + timestamp["end"] = round( + (timestamp["end"] - current_offset) * scale + speech_offset, + self.time_precision, + ) + + return list(itertools.chain.from_iterable(chunk_timestamps.values())) + @functools.lru_cache def get_vad_model():