mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 21:48:08 -05:00
Remove Silence in Batched transcription (#1297)
This commit is contained in:
@@ -25,7 +25,6 @@ from faster_whisper.vad import (
|
|||||||
VadOptions,
|
VadOptions,
|
||||||
collect_chunks,
|
collect_chunks,
|
||||||
get_speech_timestamps,
|
get_speech_timestamps,
|
||||||
merge_segments,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -125,7 +124,7 @@ class BatchedInferencePipeline:
|
|||||||
segmented_outputs = []
|
segmented_outputs = []
|
||||||
segment_sizes = []
|
segment_sizes = []
|
||||||
for chunk_metadata, output in zip(chunks_metadata, outputs):
|
for chunk_metadata, output in zip(chunks_metadata, outputs):
|
||||||
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
|
duration = chunk_metadata["duration"]
|
||||||
segment_size = int(ceil(duration) * self.model.frames_per_second)
|
segment_size = int(ceil(duration) * self.model.frames_per_second)
|
||||||
segment_sizes.append(segment_size)
|
segment_sizes.append(segment_size)
|
||||||
(
|
(
|
||||||
@@ -135,7 +134,7 @@ class BatchedInferencePipeline:
|
|||||||
) = self.model._split_segments_by_timestamps(
|
) = self.model._split_segments_by_timestamps(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokens=output["tokens"],
|
tokens=output["tokens"],
|
||||||
time_offset=chunk_metadata["start_time"],
|
time_offset=chunk_metadata["offset"],
|
||||||
segment_size=segment_size,
|
segment_size=segment_size,
|
||||||
segment_duration=duration,
|
segment_duration=duration,
|
||||||
seek=0,
|
seek=0,
|
||||||
@@ -153,7 +152,7 @@ class BatchedInferencePipeline:
|
|||||||
tokenizer.decode(subsegment["tokens"])
|
tokenizer.decode(subsegment["tokens"])
|
||||||
),
|
),
|
||||||
seek=int(
|
seek=int(
|
||||||
chunk_metadata["start_time"] * self.model.frames_per_second
|
chunk_metadata["offset"] * self.model.frames_per_second
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for subsegment in subsegments
|
for subsegment in subsegments
|
||||||
@@ -409,8 +408,7 @@ class BatchedInferencePipeline:
|
|||||||
**vad_parameters, max_speech_duration_s=chunk_length
|
**vad_parameters, max_speech_duration_s=chunk_length
|
||||||
)
|
)
|
||||||
|
|
||||||
active_segments = get_speech_timestamps(audio, vad_parameters)
|
clip_timestamps = get_speech_timestamps(audio, vad_parameters)
|
||||||
clip_timestamps = merge_segments(active_segments, vad_parameters)
|
|
||||||
# run the audio if it is less than 30 sec even without clip_timestamps
|
# run the audio if it is less than 30 sec even without clip_timestamps
|
||||||
elif duration < chunk_length:
|
elif duration < chunk_length:
|
||||||
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
|
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
|
||||||
@@ -419,6 +417,15 @@ class BatchedInferencePipeline:
|
|||||||
"No clip timestamps found. "
|
"No clip timestamps found. "
|
||||||
"Set 'vad_filter' to True or provide 'clip_timestamps'."
|
"Set 'vad_filter' to True or provide 'clip_timestamps'."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
clip_timestamps = [
|
||||||
|
{k: int(v * sampling_rate) for k, v in segment.items()}
|
||||||
|
for segment in clip_timestamps
|
||||||
|
]
|
||||||
|
|
||||||
|
audio_chunks, chunks_metadata = collect_chunks(
|
||||||
|
audio, clip_timestamps, max_duration=chunk_length
|
||||||
|
)
|
||||||
|
|
||||||
duration_after_vad = (
|
duration_after_vad = (
|
||||||
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
|
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
|
||||||
@@ -430,7 +437,6 @@ class BatchedInferencePipeline:
|
|||||||
format_timestamp(duration - duration_after_vad),
|
format_timestamp(duration - duration_after_vad),
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
|
|
||||||
features = (
|
features = (
|
||||||
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
|
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
|
||||||
if duration_after_vad
|
if duration_after_vad
|
||||||
@@ -541,6 +547,7 @@ class BatchedInferencePipeline:
|
|||||||
options,
|
options,
|
||||||
log_progress,
|
log_progress,
|
||||||
)
|
)
|
||||||
|
segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate)
|
||||||
|
|
||||||
return segments, info
|
return segments, info
|
||||||
|
|
||||||
|
|||||||
@@ -184,25 +184,62 @@ def get_speech_timestamps(
|
|||||||
|
|
||||||
|
|
||||||
def collect_chunks(
|
def collect_chunks(
|
||||||
audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000
|
audio: np.ndarray,
|
||||||
) -> Tuple[List[np.ndarray], List[Dict[str, int]]]:
|
chunks: List[dict],
|
||||||
"""Collects audio chunks."""
|
sampling_rate: int = 16000,
|
||||||
|
max_duration: float = float("inf"),
|
||||||
|
) -> Tuple[List[np.ndarray], List[Dict[str, float]]]:
|
||||||
|
"""This function merges the chunks of audio into chunks of max_duration (s) length."""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
chunk_metadata = {
|
chunk_metadata = {
|
||||||
"start_time": 0,
|
"offset": 0,
|
||||||
"end_time": 0,
|
"duration": 0,
|
||||||
|
"segments": [],
|
||||||
}
|
}
|
||||||
return [np.array([], dtype=np.float32)], [chunk_metadata]
|
return [np.array([], dtype=np.float32)], [chunk_metadata]
|
||||||
|
|
||||||
audio_chunks = []
|
audio_chunks = []
|
||||||
chunks_metadata = []
|
chunks_metadata = []
|
||||||
|
|
||||||
|
current_segments = []
|
||||||
|
current_duration = 0
|
||||||
|
total_duration = 0
|
||||||
|
current_audio = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_metadata = {
|
if (
|
||||||
"start_time": chunk["start"] / sampling_rate,
|
current_duration + chunk["end"] - chunk["start"]
|
||||||
"end_time": chunk["end"] / sampling_rate,
|
> max_duration * sampling_rate
|
||||||
}
|
):
|
||||||
audio_chunks.append(audio[chunk["start"] : chunk["end"]])
|
audio_chunks.append(current_audio)
|
||||||
chunks_metadata.append(chunk_metadata)
|
chunk_metadata = {
|
||||||
|
"offset": total_duration / sampling_rate,
|
||||||
|
"duration": current_duration / sampling_rate,
|
||||||
|
"segments": current_segments,
|
||||||
|
}
|
||||||
|
total_duration += current_duration
|
||||||
|
chunks_metadata.append(chunk_metadata)
|
||||||
|
|
||||||
|
current_segments = []
|
||||||
|
|
||||||
|
current_audio = audio[chunk["start"] : chunk["end"]]
|
||||||
|
current_duration = chunk["end"] - chunk["start"]
|
||||||
|
else:
|
||||||
|
current_segments.append(chunk)
|
||||||
|
current_audio = np.concatenate(
|
||||||
|
(current_audio, audio[chunk["start"] : chunk["end"]])
|
||||||
|
)
|
||||||
|
|
||||||
|
current_duration += chunk["end"] - chunk["start"]
|
||||||
|
|
||||||
|
audio_chunks.append(current_audio)
|
||||||
|
|
||||||
|
chunk_metadata = {
|
||||||
|
"offset": total_duration / sampling_rate,
|
||||||
|
"duration": current_duration / sampling_rate,
|
||||||
|
"segments": current_segments,
|
||||||
|
}
|
||||||
|
chunks_metadata.append(chunk_metadata)
|
||||||
return audio_chunks, chunks_metadata
|
return audio_chunks, chunks_metadata
|
||||||
|
|
||||||
|
|
||||||
@@ -329,48 +366,3 @@ class SileroVADModel:
|
|||||||
|
|
||||||
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
|
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
|
|
||||||
if not segments_list:
|
|
||||||
return []
|
|
||||||
|
|
||||||
curr_end = 0
|
|
||||||
seg_idxs = []
|
|
||||||
merged_segments = []
|
|
||||||
edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
|
|
||||||
chunk_length = vad_options.max_speech_duration_s * sampling_rate
|
|
||||||
|
|
||||||
curr_start = segments_list[0]["start"]
|
|
||||||
|
|
||||||
for idx, seg in enumerate(segments_list):
|
|
||||||
# if any segment start timing is less than previous segment end timing,
|
|
||||||
# reset the edge padding. Similarly for end timing.
|
|
||||||
if idx > 0:
|
|
||||||
if seg["start"] < segments_list[idx - 1]["end"]:
|
|
||||||
seg["start"] += edge_padding
|
|
||||||
if idx < len(segments_list) - 1:
|
|
||||||
if seg["end"] > segments_list[idx + 1]["start"]:
|
|
||||||
seg["end"] -= edge_padding
|
|
||||||
|
|
||||||
if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
|
|
||||||
merged_segments.append(
|
|
||||||
{
|
|
||||||
"start": curr_start,
|
|
||||||
"end": curr_end,
|
|
||||||
"segments": seg_idxs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
curr_start = seg["start"]
|
|
||||||
seg_idxs = []
|
|
||||||
curr_end = seg["end"]
|
|
||||||
seg_idxs.append((seg["start"], seg["end"]))
|
|
||||||
# add final
|
|
||||||
merged_segments.append(
|
|
||||||
{
|
|
||||||
"start": curr_start,
|
|
||||||
"end": curr_end,
|
|
||||||
"segments": seg_idxs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return merged_segments
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def test_batched_transcribe(physcisworks_path):
|
|||||||
{"start": segment.start, "end": segment.end, "text": segment.text}
|
{"start": segment.start, "end": segment.end, "text": segment.text}
|
||||||
)
|
)
|
||||||
# number of near 30 sec segments
|
# number of near 30 sec segments
|
||||||
assert len(segments) == 7
|
assert len(segments) == 6
|
||||||
|
|
||||||
result, info = batched_model.transcribe(
|
result, info = batched_model.transcribe(
|
||||||
physcisworks_path,
|
physcisworks_path,
|
||||||
|
|||||||
Reference in New Issue
Block a user