mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 13:38:01 -05:00
Remove torch dependency, Faster numpy Feature extraction (#1106)
This commit is contained in:
@@ -14,7 +14,6 @@ from typing import BinaryIO, Union
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def decode_audio(
|
||||
@@ -72,9 +71,9 @@ def decode_audio(
|
||||
if split_stereo:
|
||||
left_channel = audio[0::2]
|
||||
right_channel = audio[1::2]
|
||||
return torch.from_numpy(left_channel), torch.from_numpy(right_channel)
|
||||
return left_channel, right_channel
|
||||
|
||||
return torch.from_numpy(audio)
|
||||
return audio
|
||||
|
||||
|
||||
def _ignore_invalid_frames(frames):
|
||||
@@ -113,20 +112,12 @@ def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the Mel features array to 3000, as expected by the encoder.
|
||||
"""
|
||||
axis = axis % array.ndim
|
||||
if array.shape[axis] > length:
|
||||
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
|
||||
return array[idx]
|
||||
array = array.take(indices=range(length), axis=axis)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = (
|
||||
[
|
||||
0,
|
||||
]
|
||||
* array.ndim
|
||||
* 2
|
||||
)
|
||||
pad_widths[2 * axis] = length - array.shape[axis]
|
||||
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = np.pad(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
|
||||
class FeatureExtractor:
|
||||
def __init__(
|
||||
self,
|
||||
device: str = "auto",
|
||||
feature_size=80,
|
||||
sampling_rate=16000,
|
||||
hop_length=160,
|
||||
chunk_length=30,
|
||||
n_fft=400,
|
||||
):
|
||||
if device == "auto":
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
self.device = device
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.chunk_length = chunk_length
|
||||
@@ -25,24 +19,21 @@ class FeatureExtractor:
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_filters = self.get_mel_filters(
|
||||
sampling_rate, n_fft, n_mels=feature_size
|
||||
)
|
||||
).astype("float32")
|
||||
|
||||
@staticmethod
|
||||
def get_mel_filters(sr, n_fft, n_mels=128):
|
||||
"""
|
||||
Implementation of librosa.filters.mel in Pytorch
|
||||
"""
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = 0.0
|
||||
max_mel = 45.245640471924965
|
||||
|
||||
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
|
||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
@@ -52,30 +43,159 @@ class FeatureExtractor:
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
# If we have vector data, vectorize
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
|
||||
mel_f = freqs
|
||||
fdiff = np.diff(freqs)
|
||||
ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1)
|
||||
|
||||
fdiff = torch.diff(mel_f)
|
||||
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
|
||||
|
||||
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
|
||||
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
|
||||
lower = -ramps[:-2] / np.expand_dims(fdiff[:-1], axis=1)
|
||||
upper = ramps[2:] / np.expand_dims(fdiff[1:], axis=1)
|
||||
|
||||
# Intersect them with each other and zero, vectorized across all i
|
||||
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
|
||||
weights = np.maximum(np.zeros_like(lower), np.minimum(lower, upper))
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm.unsqueeze(1)
|
||||
enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels])
|
||||
weights *= np.expand_dims(enorm, axis=1)
|
||||
|
||||
return weights
|
||||
|
||||
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
|
||||
@staticmethod
|
||||
def stft(
|
||||
input_array: np.ndarray,
|
||||
n_fft: int,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: np.ndarray = None,
|
||||
center: bool = True,
|
||||
mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = None,
|
||||
return_complex: bool = None,
|
||||
):
|
||||
# Default initialization for hop_length and win_length
|
||||
hop_length = hop_length if hop_length is not None else n_fft // 4
|
||||
win_length = win_length if win_length is not None else n_fft
|
||||
input_is_complex = np.iscomplexobj(input_array)
|
||||
|
||||
# Determine if the output should be complex
|
||||
return_complex = (
|
||||
return_complex
|
||||
if return_complex is not None
|
||||
else (input_is_complex or (window is not None and np.iscomplexobj(window)))
|
||||
)
|
||||
|
||||
if not return_complex and return_complex is None:
|
||||
raise ValueError(
|
||||
"stft requires the return_complex parameter for real inputs."
|
||||
)
|
||||
|
||||
# Input checks
|
||||
if not np.issubdtype(input_array.dtype, np.floating) and not input_is_complex:
|
||||
raise ValueError(
|
||||
"stft: expected an array of floating point or complex values,"
|
||||
f" got {input_array.dtype}"
|
||||
)
|
||||
|
||||
if input_array.ndim > 2 or input_array.ndim < 1:
|
||||
raise ValueError(
|
||||
f"stft: expected a 1D or 2D array, but got {input_array.ndim}D array"
|
||||
)
|
||||
|
||||
# Handle 1D input
|
||||
if input_array.ndim == 1:
|
||||
input_array = np.expand_dims(input_array, axis=0)
|
||||
input_array_1d = True
|
||||
else:
|
||||
input_array_1d = False
|
||||
|
||||
# Center padding if required
|
||||
if center:
|
||||
pad_amount = n_fft // 2
|
||||
input_array = np.pad(
|
||||
input_array, ((0, 0), (pad_amount, pad_amount)), mode=mode
|
||||
)
|
||||
|
||||
batch, length = input_array.shape
|
||||
|
||||
# Additional input checks
|
||||
if n_fft <= 0 or n_fft > length:
|
||||
raise ValueError(
|
||||
f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}"
|
||||
)
|
||||
|
||||
if hop_length <= 0:
|
||||
raise ValueError(
|
||||
f"stft: expected hop_length > 0, but got hop_length={hop_length}"
|
||||
)
|
||||
|
||||
if win_length <= 0 or win_length > n_fft:
|
||||
raise ValueError(
|
||||
f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}"
|
||||
)
|
||||
|
||||
if window is not None:
|
||||
if window.ndim != 1 or window.shape[0] != win_length:
|
||||
raise ValueError(
|
||||
f"stft: expected a 1D window array of size equal to win_length={win_length}, "
|
||||
f"but got window with size {window.shape}"
|
||||
)
|
||||
|
||||
# Handle padding of the window if necessary
|
||||
if win_length < n_fft:
|
||||
left = (n_fft - win_length) // 2
|
||||
window_ = np.zeros(n_fft, dtype=window.dtype)
|
||||
window_[left : left + win_length] = window
|
||||
else:
|
||||
window_ = window
|
||||
|
||||
# Calculate the number of frames
|
||||
n_frames = 1 + (length - n_fft) // hop_length
|
||||
|
||||
# Time to columns
|
||||
input_array = np.lib.stride_tricks.as_strided(
|
||||
input_array,
|
||||
(batch, n_frames, n_fft),
|
||||
(
|
||||
input_array.strides[0],
|
||||
hop_length * input_array.strides[1],
|
||||
input_array.strides[1],
|
||||
),
|
||||
)
|
||||
|
||||
if window_ is not None:
|
||||
input_array = input_array * window_
|
||||
|
||||
# FFT and transpose
|
||||
complex_fft = input_is_complex
|
||||
onesided = onesided if onesided is not None else not complex_fft
|
||||
|
||||
if normalized:
|
||||
norm = "ortho"
|
||||
else:
|
||||
norm = None
|
||||
|
||||
if complex_fft:
|
||||
if onesided:
|
||||
raise ValueError(
|
||||
"Cannot have onesided output if window or input is complex"
|
||||
)
|
||||
output = np.fft.fft(input_array, n=n_fft, axis=-1, norm=norm)
|
||||
else:
|
||||
output = np.fft.rfft(input_array, n=n_fft, axis=-1, norm=norm)
|
||||
|
||||
output = output.transpose((0, 2, 1))
|
||||
|
||||
if input_array_1d:
|
||||
output = output.squeeze(0)
|
||||
|
||||
return output if return_complex else np.real(output)
|
||||
|
||||
def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided audio.
|
||||
"""
|
||||
@@ -84,31 +204,27 @@ class FeatureExtractor:
|
||||
self.n_samples = chunk_length * self.sampling_rate
|
||||
self.nb_max_frames = self.n_samples // self.hop_length
|
||||
|
||||
if waveform.dtype is not torch.float32:
|
||||
waveform = waveform.to(torch.float32)
|
||||
|
||||
waveform = (
|
||||
waveform.to(self.device)
|
||||
if self.device == "cuda" and not waveform.is_cuda
|
||||
else waveform
|
||||
)
|
||||
if waveform.dtype is not np.float32:
|
||||
waveform = waveform.astype(np.float32)
|
||||
|
||||
if padding:
|
||||
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
|
||||
waveform = np.pad(waveform, (0, padding))
|
||||
|
||||
window = torch.hann_window(self.n_fft).to(waveform.device)
|
||||
window = np.hanning(self.n_fft + 1)[:-1].astype("float32")
|
||||
|
||||
stft = torch.stft(
|
||||
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
|
||||
)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
stft = self.stft(
|
||||
waveform,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
).astype("complex64")
|
||||
magnitudes = np.abs(stft[..., :-1]) ** 2
|
||||
|
||||
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
|
||||
mel_spec = self.mel_filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
# When the model is running on multiple GPUs, the output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
return log_spec.cpu() if to_cpu else log_spec
|
||||
return log_spec
|
||||
|
||||
@@ -15,7 +15,6 @@ from warnings import warn
|
||||
import ctranslate2
|
||||
import numpy as np
|
||||
import tokenizers
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -228,7 +227,7 @@ class BatchedInferencePipeline:
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, BinaryIO, torch.Tensor, np.ndarray],
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
language: Optional[str] = None,
|
||||
task: str = None,
|
||||
log_progress: bool = False,
|
||||
@@ -357,9 +356,7 @@ class BatchedInferencePipeline:
|
||||
|
||||
sampling_rate = self.model.feature_extractor.sampling_rate
|
||||
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
elif not isinstance(audio, torch.Tensor):
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
|
||||
@@ -457,14 +454,11 @@ class BatchedInferencePipeline:
|
||||
)
|
||||
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
|
||||
to_cpu = (
|
||||
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
|
||||
)
|
||||
features = (
|
||||
torch.stack(
|
||||
np.stack(
|
||||
[
|
||||
pad_or_trim(
|
||||
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
|
||||
self.model.feature_extractor(chunk)[
|
||||
...,
|
||||
: chunk.shape[0] // self.model.feature_extractor.hop_length,
|
||||
]
|
||||
@@ -610,9 +604,7 @@ class WhisperModel:
|
||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||
)
|
||||
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
|
||||
self.feature_extractor = FeatureExtractor(
|
||||
**self.feat_kwargs, device=self.device
|
||||
)
|
||||
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
||||
self.input_stride = 2
|
||||
self.num_samples_per_token = (
|
||||
self.feature_extractor.hop_length * self.input_stride
|
||||
@@ -651,7 +643,7 @@ class WhisperModel:
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, BinaryIO, torch.Tensor, np.ndarray],
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
beam_size: int = 5,
|
||||
@@ -779,9 +771,7 @@ class WhisperModel:
|
||||
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
elif not isinstance(audio, torch.Tensor):
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
@@ -798,7 +788,7 @@ class WhisperModel:
|
||||
vad_parameters = VadOptions(**vad_parameters)
|
||||
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
|
||||
audio = torch.cat(audio_chunks, dim=0)
|
||||
audio = np.concatenate(audio_chunks, axis=0)
|
||||
duration_after_vad = audio.shape[0] / sampling_rate
|
||||
|
||||
self.logger.info(
|
||||
@@ -822,10 +812,7 @@ class WhisperModel:
|
||||
else:
|
||||
speech_chunks = None
|
||||
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
features = self.feature_extractor(
|
||||
audio, chunk_length=chunk_length, to_cpu=to_cpu
|
||||
)
|
||||
features = self.feature_extractor(audio, chunk_length=chunk_length)
|
||||
|
||||
encoder_output = None
|
||||
all_language_probs = None
|
||||
@@ -853,9 +840,7 @@ class WhisperModel:
|
||||
if isinstance(clip_timestamps, str)
|
||||
else clip_timestamps[0]
|
||||
)
|
||||
content_frames = (
|
||||
features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
)
|
||||
content_frames = features.shape[-1] - 1
|
||||
seek = (
|
||||
int(start_timestamp * self.frames_per_second)
|
||||
if start_timestamp * self.frames_per_second < content_frames
|
||||
@@ -1053,12 +1038,12 @@ class WhisperModel:
|
||||
|
||||
def generate_segments(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
content_frames = features.shape[-1] - 1
|
||||
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
||||
|
||||
if isinstance(options.clip_timestamps, str):
|
||||
@@ -1356,13 +1341,13 @@ class WhisperModel:
|
||||
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
def encode(self, features: torch.Tensor) -> ctranslate2.StorageView:
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
|
||||
if features.ndim == 2:
|
||||
features = features.unsqueeze(0)
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
@@ -1733,7 +1718,7 @@ class WhisperModel:
|
||||
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: dict,
|
||||
):
|
||||
@@ -1782,9 +1767,8 @@ class WhisperModel:
|
||||
|
||||
return encoder_output, output
|
||||
|
||||
def detect_language(self, audio: torch.Tensor):
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
|
||||
def detect_language(self, audio: np.ndarray):
|
||||
segment = self.feature_extractor(audio)[
|
||||
:, : self.feature_extractor.nb_max_frames
|
||||
]
|
||||
encoder_output = self.encode(pad_or_trim(segment))
|
||||
@@ -1798,7 +1782,7 @@ class WhisperModel:
|
||||
return language, language_probability, all_language_probs
|
||||
|
||||
def detect_language_multi_segment(
|
||||
self, audio: Union[str, BinaryIO, torch.Tensor], params: Optional[dict] = None
|
||||
self, audio: Union[str, BinaryIO, np.ndarray], params: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Detect language based on N highly-confident segments of a language.
|
||||
@@ -1834,8 +1818,8 @@ class WhisperModel:
|
||||
|
||||
# decode audio if it is not decoded already
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
if not isinstance(audio, torch.Tensor):
|
||||
audio: torch.Tensor = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio: np.ndarray = decode_audio(audio, sampling_rate=sampling_rate)
|
||||
|
||||
# calculate duration of audio as number of seconds
|
||||
# audio.shape[0] is the number of samples in the audio
|
||||
@@ -1850,7 +1834,7 @@ class WhisperModel:
|
||||
speech_chunks = get_speech_timestamps(audio, vad_params)
|
||||
# merge chunks of audio that contain speech into a single array
|
||||
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
|
||||
audio = torch.cat(audio_chunks, dim=0)
|
||||
audio = np.concatenate(audio_chunks, axis=0)
|
||||
|
||||
# calculate new duration of audio without silence
|
||||
duration_vad = audio.shape[0] / sampling_rate
|
||||
@@ -1874,8 +1858,7 @@ class WhisperModel:
|
||||
nb_max_frames = self.feature_extractor.nb_max_frames
|
||||
|
||||
# extract features from audio with padding (default)
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
features = self.feature_extractor(audio, to_cpu=to_cpu)
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
# number of segments in the audio
|
||||
num_segments = features.shape[-1] // nb_max_frames
|
||||
@@ -1987,8 +1970,8 @@ class WhisperModel:
|
||||
dc_offset = audio.mean()
|
||||
audio_minus_dc_offset = audio - dc_offset
|
||||
is_silent = (
|
||||
torch.all(audio.abs() < 0.01)
|
||||
or torch.sqrt(torch.mean(audio_minus_dc_offset**2)) < 0.01
|
||||
all(np.abs(audio) < 0.1)
|
||||
or np.sqrt(np.mean(audio_minus_dc_offset**2)) < 0.01
|
||||
)
|
||||
|
||||
if is_silent:
|
||||
@@ -2032,12 +2015,9 @@ def restore_speech_timestamps(
|
||||
yield segment
|
||||
|
||||
|
||||
def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView:
|
||||
segment = segment.contiguous()
|
||||
segment = ctranslate2.StorageView.from_array(
|
||||
segment if segment.is_cuda else segment.numpy()
|
||||
) # torch cpu tensors don't implement __array_interface__
|
||||
# https://github.com/pytorch/pytorch/issues/51156
|
||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from faster_whisper.utils import get_assets_path
|
||||
|
||||
@@ -44,7 +43,7 @@ class VadOptions:
|
||||
|
||||
|
||||
def get_speech_timestamps(
|
||||
audio: torch.Tensor,
|
||||
audio: np.ndarray,
|
||||
vad_options: Optional[VadOptions] = None,
|
||||
sampling_rate: int = 16000,
|
||||
**kwargs,
|
||||
@@ -84,7 +83,7 @@ def get_speech_timestamps(
|
||||
model = get_vad_model()
|
||||
|
||||
padded_audio = np.pad(
|
||||
audio.numpy(), (0, window_size_samples - audio.shape[0] % window_size_samples)
|
||||
audio, (0, window_size_samples - audio.shape[0] % window_size_samples)
|
||||
)
|
||||
speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0)
|
||||
|
||||
@@ -183,15 +182,15 @@ def get_speech_timestamps(
|
||||
|
||||
|
||||
def collect_chunks(
|
||||
audio: torch.Tensor, chunks: List[dict], sampling_rate: int = 16000
|
||||
) -> Tuple[List[torch.Tensor], List[Dict[str, int]]]:
|
||||
audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000
|
||||
) -> Tuple[List[np.ndarray], List[Dict[str, int]]]:
|
||||
"""Collects audio chunks."""
|
||||
if not chunks:
|
||||
chunk_metadata = {
|
||||
"start_time": 0,
|
||||
"end_time": 0,
|
||||
}
|
||||
return [torch.tensor([], dtype=torch.float32)], [chunk_metadata]
|
||||
return [np.array([], dtype=np.float32)], [chunk_metadata]
|
||||
|
||||
audio_chunks = []
|
||||
chunks_metadata = []
|
||||
@@ -281,7 +280,7 @@ class SileroVADModel:
|
||||
):
|
||||
assert (
|
||||
audio.ndim == 2
|
||||
), "Input should be a 2D tensor with size (batch_size, num_samples)"
|
||||
), "Input should be a 2D array with size (batch_size, num_samples)"
|
||||
assert (
|
||||
audio.shape[1] % num_samples == 0
|
||||
), "Input size should be a multiple of num_samples"
|
||||
|
||||
Reference in New Issue
Block a user