Remove Silence in Batched transcription (#1297)

This commit is contained in:
Mahmoud Ashraf
2025-08-06 03:30:59 +03:00
committed by GitHub
parent fbeb1ba731
commit a0c3cb9802
3 changed files with 63 additions and 64 deletions

View File

@@ -25,7 +25,6 @@ from faster_whisper.vad import (
VadOptions, VadOptions,
collect_chunks, collect_chunks,
get_speech_timestamps, get_speech_timestamps,
merge_segments,
) )
@@ -125,7 +124,7 @@ class BatchedInferencePipeline:
segmented_outputs = [] segmented_outputs = []
segment_sizes = [] segment_sizes = []
for chunk_metadata, output in zip(chunks_metadata, outputs): for chunk_metadata, output in zip(chunks_metadata, outputs):
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"] duration = chunk_metadata["duration"]
segment_size = int(ceil(duration) * self.model.frames_per_second) segment_size = int(ceil(duration) * self.model.frames_per_second)
segment_sizes.append(segment_size) segment_sizes.append(segment_size)
( (
@@ -135,7 +134,7 @@ class BatchedInferencePipeline:
) = self.model._split_segments_by_timestamps( ) = self.model._split_segments_by_timestamps(
tokenizer=tokenizer, tokenizer=tokenizer,
tokens=output["tokens"], tokens=output["tokens"],
time_offset=chunk_metadata["start_time"], time_offset=chunk_metadata["offset"],
segment_size=segment_size, segment_size=segment_size,
segment_duration=duration, segment_duration=duration,
seek=0, seek=0,
@@ -153,7 +152,7 @@ class BatchedInferencePipeline:
tokenizer.decode(subsegment["tokens"]) tokenizer.decode(subsegment["tokens"])
), ),
seek=int( seek=int(
chunk_metadata["start_time"] * self.model.frames_per_second chunk_metadata["offset"] * self.model.frames_per_second
), ),
) )
for subsegment in subsegments for subsegment in subsegments
@@ -409,8 +408,7 @@ class BatchedInferencePipeline:
**vad_parameters, max_speech_duration_s=chunk_length **vad_parameters, max_speech_duration_s=chunk_length
) )
active_segments = get_speech_timestamps(audio, vad_parameters) clip_timestamps = get_speech_timestamps(audio, vad_parameters)
clip_timestamps = merge_segments(active_segments, vad_parameters)
# run the audio if it is less than 30 sec even without clip_timestamps # run the audio if it is less than 30 sec even without clip_timestamps
elif duration < chunk_length: elif duration < chunk_length:
clip_timestamps = [{"start": 0, "end": audio.shape[0]}] clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
@@ -419,6 +417,15 @@ class BatchedInferencePipeline:
"No clip timestamps found. " "No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'." "Set 'vad_filter' to True or provide 'clip_timestamps'."
) )
else:
clip_timestamps = [
{k: int(v * sampling_rate) for k, v in segment.items()}
for segment in clip_timestamps
]
audio_chunks, chunks_metadata = collect_chunks(
audio, clip_timestamps, max_duration=chunk_length
)
duration_after_vad = ( duration_after_vad = (
sum((segment["end"] - segment["start"]) for segment in clip_timestamps) sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
@@ -430,7 +437,6 @@ class BatchedInferencePipeline:
format_timestamp(duration - duration_after_vad), format_timestamp(duration - duration_after_vad),
) )
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = ( features = (
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks] [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
if duration_after_vad if duration_after_vad
@@ -541,6 +547,7 @@ class BatchedInferencePipeline:
options, options,
log_progress, log_progress,
) )
segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate)
return segments, info return segments, info

View File

@@ -184,25 +184,62 @@ def get_speech_timestamps(
def collect_chunks( def collect_chunks(
audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000 audio: np.ndarray,
) -> Tuple[List[np.ndarray], List[Dict[str, int]]]: chunks: List[dict],
"""Collects audio chunks.""" sampling_rate: int = 16000,
max_duration: float = float("inf"),
) -> Tuple[List[np.ndarray], List[Dict[str, float]]]:
"""This function merges the chunks of audio into chunks of max_duration (s) length."""
if not chunks: if not chunks:
chunk_metadata = { chunk_metadata = {
"start_time": 0, "offset": 0,
"end_time": 0, "duration": 0,
"segments": [],
} }
return [np.array([], dtype=np.float32)], [chunk_metadata] return [np.array([], dtype=np.float32)], [chunk_metadata]
audio_chunks = [] audio_chunks = []
chunks_metadata = [] chunks_metadata = []
current_segments = []
current_duration = 0
total_duration = 0
current_audio = np.array([], dtype=np.float32)
for chunk in chunks: for chunk in chunks:
chunk_metadata = { if (
"start_time": chunk["start"] / sampling_rate, current_duration + chunk["end"] - chunk["start"]
"end_time": chunk["end"] / sampling_rate, > max_duration * sampling_rate
} ):
audio_chunks.append(audio[chunk["start"] : chunk["end"]]) audio_chunks.append(current_audio)
chunks_metadata.append(chunk_metadata) chunk_metadata = {
"offset": total_duration / sampling_rate,
"duration": current_duration / sampling_rate,
"segments": current_segments,
}
total_duration += current_duration
chunks_metadata.append(chunk_metadata)
current_segments = []
current_audio = audio[chunk["start"] : chunk["end"]]
current_duration = chunk["end"] - chunk["start"]
else:
current_segments.append(chunk)
current_audio = np.concatenate(
(current_audio, audio[chunk["start"] : chunk["end"]])
)
current_duration += chunk["end"] - chunk["start"]
audio_chunks.append(current_audio)
chunk_metadata = {
"offset": total_duration / sampling_rate,
"duration": current_duration / sampling_rate,
"segments": current_segments,
}
chunks_metadata.append(chunk_metadata)
return audio_chunks, chunks_metadata return audio_chunks, chunks_metadata
@@ -329,48 +366,3 @@ class SileroVADModel:
out = np.stack(decoder_outputs, axis=1).squeeze(-1) out = np.stack(decoder_outputs, axis=1).squeeze(-1)
return out return out
def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
if not segments_list:
return []
curr_end = 0
seg_idxs = []
merged_segments = []
edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
chunk_length = vad_options.max_speech_duration_s * sampling_rate
curr_start = segments_list[0]["start"]
for idx, seg in enumerate(segments_list):
# if any segment start timing is less than previous segment end timing,
# reset the edge padding. Similarly for end timing.
if idx > 0:
if seg["start"] < segments_list[idx - 1]["end"]:
seg["start"] += edge_padding
if idx < len(segments_list) - 1:
if seg["end"] > segments_list[idx + 1]["start"]:
seg["end"] -= edge_padding
if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
merged_segments.append(
{
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
}
)
curr_start = seg["start"]
seg_idxs = []
curr_end = seg["end"]
seg_idxs.append((seg["start"], seg["end"]))
# add final
merged_segments.append(
{
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
}
)
return merged_segments

View File

@@ -71,7 +71,7 @@ def test_batched_transcribe(physcisworks_path):
{"start": segment.start, "end": segment.end, "text": segment.text} {"start": segment.start, "end": segment.end, "text": segment.text}
) )
# number of near 30 sec segments # number of near 30 sec segments
assert len(segments) == 7 assert len(segments) == 6
result, info = batched_model.transcribe( result, info = batched_model.transcribe(
physcisworks_path, physcisworks_path,