Fit words timestamps in VAD speech chunks

This commit is contained in:
Guillaume Klein
2023-04-07 10:51:53 +02:00
parent e9a082dcf2
commit 2f6790a6f5
2 changed files with 57 additions and 9 deletions

View File

@@ -767,15 +767,14 @@ def restore_speech_timestamps(
for segment in segments: for segment in segments:
if segment.words: if segment.words:
words = [] timestamps = ts_map.fit_words_timestamps(
for word in segment.words: [{"start": word.start, "end": word.end} for word in segment.words]
# Ensure the word start and end times are resolved to the same chunk. )
chunk_index = ts_map.get_chunk_index(word.start)
word = word._replace( words = [
start=ts_map.get_original_time(word.start, chunk_index), word._replace(start=timestamp["start"], end=timestamp["end"])
end=ts_map.get_original_time(word.end, chunk_index), for word, timestamp in zip(segment.words, timestamps)
) ]
words.append(word)
segment = segment._replace( segment = segment._replace(
start=words[0].start, start=words[0].start,

View File

@@ -1,5 +1,7 @@
import bisect import bisect
import collections
import functools import functools
import itertools
import os import os
import warnings import warnings
@@ -144,6 +146,9 @@ def get_speech_timestamps(
speeches.append(current_speech) speeches.append(current_speech)
for i, speech in enumerate(speeches): for i, speech in enumerate(speeches):
speech["start_no_pad"] = speech["start"]
speech["end_no_pad"] = speech["end"]
if i == 0: if i == 0:
speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
if i != len(speeches) - 1: if i != len(speeches) - 1:
@@ -184,6 +189,7 @@ class SpeechTimestampsMap:
self.time_precision = time_precision self.time_precision = time_precision
self.chunk_end_sample = [] self.chunk_end_sample = []
self.total_silence_before = [] self.total_silence_before = []
self.speech_chunks = []
previous_end = 0 previous_end = 0
silent_samples = 0 silent_samples = 0
@@ -194,6 +200,12 @@ class SpeechTimestampsMap:
self.chunk_end_sample.append(chunk["end"] - silent_samples) self.chunk_end_sample.append(chunk["end"] - silent_samples)
self.total_silence_before.append(silent_samples / sampling_rate) self.total_silence_before.append(silent_samples / sampling_rate)
self.speech_chunks.append(
{
"start": chunk["start_no_pad"] / sampling_rate,
"end": chunk["end_no_pad"] / sampling_rate,
}
)
def get_original_time( def get_original_time(
self, self,
@@ -213,6 +225,43 @@ class SpeechTimestampsMap:
len(self.chunk_end_sample) - 1, len(self.chunk_end_sample) - 1,
) )
def fit_words_timestamps(self, timestamps: List[dict]) -> List[dict]:
chunk_timestamps = collections.OrderedDict()
for timestamp in timestamps:
# Ensure the word start and end times are resolved to the same chunk.
chunk_index = self.get_chunk_index(timestamp["start"])
timestamp["start"] = self.get_original_time(timestamp["start"], chunk_index)
timestamp["end"] = self.get_original_time(timestamp["end"], chunk_index)
if chunk_index not in chunk_timestamps:
chunk_timestamps[chunk_index] = [timestamp]
else:
chunk_timestamps[chunk_index].append(timestamp)
for chunk_index, timestamps in chunk_timestamps.items():
speech_chunk = self.speech_chunks[chunk_index]
speech_offset = speech_chunk["start"]
speech_duration = speech_chunk["end"] - speech_offset
current_offset = timestamps[0]["start"]
current_duration = timestamps[-1]["end"] - current_offset
scale = speech_duration / current_duration
for timestamp in timestamps:
timestamp["start"] = round(
(timestamp["start"] - current_offset) * scale + speech_offset,
self.time_precision,
)
timestamp["end"] = round(
(timestamp["end"] - current_offset) * scale + speech_offset,
self.time_precision,
)
return list(itertools.chain.from_iterable(chunk_timestamps.values()))
@functools.lru_cache @functools.lru_cache
def get_vad_model(): def get_vad_model():