mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-09 21:48:08 -05:00
replace NamedTuple with dataclass (#1105)
* replace `NamedTuple` with `dataclass` * add deprecation warnings
This commit is contained in:
@@ -6,9 +6,11 @@ import random
|
|||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -30,14 +32,24 @@ from faster_whisper.vad import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Word(NamedTuple):
|
@dataclass
|
||||||
|
class Word:
|
||||||
start: float
|
start: float
|
||||||
end: float
|
end: float
|
||||||
word: str
|
word: str
|
||||||
probability: float
|
probability: float
|
||||||
|
|
||||||
|
def _asdict(self):
|
||||||
|
warn(
|
||||||
|
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
|
||||||
|
DeprecationWarning,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
class Segment(NamedTuple):
|
|
||||||
|
@dataclass
|
||||||
|
class Segment:
|
||||||
id: int
|
id: int
|
||||||
seek: int
|
seek: int
|
||||||
start: float
|
start: float
|
||||||
@@ -50,9 +62,18 @@ class Segment(NamedTuple):
|
|||||||
words: Optional[List[Word]]
|
words: Optional[List[Word]]
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = 1.0
|
||||||
|
|
||||||
|
def _asdict(self):
|
||||||
|
warn(
|
||||||
|
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
|
||||||
|
DeprecationWarning,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
# Added additional parameters for multilingual videos and fixes below
|
# Added additional parameters for multilingual videos and fixes below
|
||||||
class TranscriptionOptions(NamedTuple):
|
@dataclass
|
||||||
|
class TranscriptionOptions:
|
||||||
beam_size: int
|
beam_size: int
|
||||||
best_of: int
|
best_of: int
|
||||||
patience: float
|
patience: float
|
||||||
@@ -83,7 +104,8 @@ class TranscriptionOptions(NamedTuple):
|
|||||||
hotwords: Optional[str]
|
hotwords: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionInfo(NamedTuple):
|
@dataclass
|
||||||
|
class TranscriptionInfo:
|
||||||
language: str
|
language: str
|
||||||
language_probability: float
|
language_probability: float
|
||||||
duration: float
|
duration: float
|
||||||
@@ -108,7 +130,7 @@ class BatchedInferencePipeline:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
options: Optional[NamedTuple] = None,
|
options: Optional[TranscriptionOptions] = None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@@ -473,7 +495,7 @@ class BatchedInferencePipeline:
|
|||||||
results = self.forward(
|
results = self.forward(
|
||||||
features[i : i + batch_size],
|
features[i : i + batch_size],
|
||||||
chunks_metadata[i : i + batch_size],
|
chunks_metadata[i : i + batch_size],
|
||||||
**options._asdict(),
|
**asdict(options),
|
||||||
)
|
)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -1043,16 +1065,15 @@ class WhisperModel:
|
|||||||
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
||||||
|
|
||||||
if isinstance(options.clip_timestamps, str):
|
if isinstance(options.clip_timestamps, str):
|
||||||
options = options._replace(
|
options.clip_timestamps = [
|
||||||
clip_timestamps=[
|
float(ts)
|
||||||
float(ts)
|
for ts in (
|
||||||
for ts in (
|
options.clip_timestamps.split(",")
|
||||||
options.clip_timestamps.split(",")
|
if options.clip_timestamps
|
||||||
if options.clip_timestamps
|
else []
|
||||||
else []
|
)
|
||||||
)
|
]
|
||||||
]
|
|
||||||
)
|
|
||||||
seek_points: List[int] = [
|
seek_points: List[int] = [
|
||||||
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
||||||
]
|
]
|
||||||
@@ -1999,23 +2020,17 @@ def restore_speech_timestamps(
|
|||||||
# Ensure the word start and end times are resolved to the same chunk.
|
# Ensure the word start and end times are resolved to the same chunk.
|
||||||
middle = (word.start + word.end) / 2
|
middle = (word.start + word.end) / 2
|
||||||
chunk_index = ts_map.get_chunk_index(middle)
|
chunk_index = ts_map.get_chunk_index(middle)
|
||||||
word = word._replace(
|
word.start = ts_map.get_original_time(word.start, chunk_index)
|
||||||
start=ts_map.get_original_time(word.start, chunk_index),
|
word.end = ts_map.get_original_time(word.end, chunk_index)
|
||||||
end=ts_map.get_original_time(word.end, chunk_index),
|
|
||||||
)
|
|
||||||
words.append(word)
|
words.append(word)
|
||||||
|
|
||||||
segment = segment._replace(
|
segment.start = words[0].start
|
||||||
start=words[0].start,
|
segment.end = words[-1].end
|
||||||
end=words[-1].end,
|
segment.words = words
|
||||||
words=words,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
segment = segment._replace(
|
segment.start = ts_map.get_original_time(segment.start)
|
||||||
start=ts_map.get_original_time(segment.start),
|
segment.end = ts_map.get_original_time(segment.end)
|
||||||
end=ts_map.get_original_time(segment.end),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield segment
|
yield segment
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import bisect
|
|||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Dict, List, NamedTuple, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -11,7 +12,8 @@ from faster_whisper.utils import get_assets_path
|
|||||||
|
|
||||||
|
|
||||||
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
||||||
class VadOptions(NamedTuple):
|
@dataclass
|
||||||
|
class VadOptions:
|
||||||
"""VAD options.
|
"""VAD options.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
|||||||
Reference in New Issue
Block a user