mirror of
https://github.com/SYSTRAN/faster-whisper.git
synced 2026-01-08 13:14:00 -05:00
replace NamedTuple with dataclass (#1105)
* replace `NamedTuple` with `dataclass` * add deprecation warnings
This commit is contained in:
@@ -6,9 +6,11 @@ import random
|
||||
import zlib
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from inspect import signature
|
||||
from math import ceil
|
||||
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
|
||||
from warnings import warn
|
||||
|
||||
import ctranslate2
|
||||
import numpy as np
|
||||
@@ -30,14 +32,24 @@ from faster_whisper.vad import (
|
||||
)
|
||||
|
||||
|
||||
class Word(NamedTuple):
|
||||
@dataclass
|
||||
class Word:
|
||||
start: float
|
||||
end: float
|
||||
word: str
|
||||
probability: float
|
||||
|
||||
def _asdict(self):
|
||||
warn(
|
||||
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
|
||||
DeprecationWarning,
|
||||
2,
|
||||
)
|
||||
return asdict(self)
|
||||
|
||||
class Segment(NamedTuple):
|
||||
|
||||
@dataclass
|
||||
class Segment:
|
||||
id: int
|
||||
seek: int
|
||||
start: float
|
||||
@@ -50,9 +62,18 @@ class Segment(NamedTuple):
|
||||
words: Optional[List[Word]]
|
||||
temperature: Optional[float] = 1.0
|
||||
|
||||
def _asdict(self):
|
||||
warn(
|
||||
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
|
||||
DeprecationWarning,
|
||||
2,
|
||||
)
|
||||
return asdict(self)
|
||||
|
||||
|
||||
# Added additional parameters for multilingual videos and fixes below
|
||||
class TranscriptionOptions(NamedTuple):
|
||||
@dataclass
|
||||
class TranscriptionOptions:
|
||||
beam_size: int
|
||||
best_of: int
|
||||
patience: float
|
||||
@@ -83,7 +104,8 @@ class TranscriptionOptions(NamedTuple):
|
||||
hotwords: Optional[str]
|
||||
|
||||
|
||||
class TranscriptionInfo(NamedTuple):
|
||||
@dataclass
|
||||
class TranscriptionInfo:
|
||||
language: str
|
||||
language_probability: float
|
||||
duration: float
|
||||
@@ -108,7 +130,7 @@ class BatchedInferencePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
options: Optional[NamedTuple] = None,
|
||||
options: Optional[TranscriptionOptions] = None,
|
||||
tokenizer=None,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
@@ -473,7 +495,7 @@ class BatchedInferencePipeline:
|
||||
results = self.forward(
|
||||
features[i : i + batch_size],
|
||||
chunks_metadata[i : i + batch_size],
|
||||
**options._asdict(),
|
||||
**asdict(options),
|
||||
)
|
||||
|
||||
for result in results:
|
||||
@@ -1043,16 +1065,15 @@ class WhisperModel:
|
||||
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
||||
|
||||
if isinstance(options.clip_timestamps, str):
|
||||
options = options._replace(
|
||||
clip_timestamps=[
|
||||
float(ts)
|
||||
for ts in (
|
||||
options.clip_timestamps.split(",")
|
||||
if options.clip_timestamps
|
||||
else []
|
||||
)
|
||||
]
|
||||
)
|
||||
options.clip_timestamps = [
|
||||
float(ts)
|
||||
for ts in (
|
||||
options.clip_timestamps.split(",")
|
||||
if options.clip_timestamps
|
||||
else []
|
||||
)
|
||||
]
|
||||
|
||||
seek_points: List[int] = [
|
||||
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
||||
]
|
||||
@@ -1999,23 +2020,17 @@ def restore_speech_timestamps(
|
||||
# Ensure the word start and end times are resolved to the same chunk.
|
||||
middle = (word.start + word.end) / 2
|
||||
chunk_index = ts_map.get_chunk_index(middle)
|
||||
word = word._replace(
|
||||
start=ts_map.get_original_time(word.start, chunk_index),
|
||||
end=ts_map.get_original_time(word.end, chunk_index),
|
||||
)
|
||||
word.start = ts_map.get_original_time(word.start, chunk_index)
|
||||
word.end = ts_map.get_original_time(word.end, chunk_index)
|
||||
words.append(word)
|
||||
|
||||
segment = segment._replace(
|
||||
start=words[0].start,
|
||||
end=words[-1].end,
|
||||
words=words,
|
||||
)
|
||||
segment.start = words[0].start
|
||||
segment.end = words[-1].end
|
||||
segment.words = words
|
||||
|
||||
else:
|
||||
segment = segment._replace(
|
||||
start=ts_map.get_original_time(segment.start),
|
||||
end=ts_map.get_original_time(segment.end),
|
||||
)
|
||||
segment.start = ts_map.get_original_time(segment.start)
|
||||
segment.end = ts_map.get_original_time(segment.end)
|
||||
|
||||
yield segment
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ import bisect
|
||||
import functools
|
||||
import os
|
||||
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -11,7 +12,8 @@ from faster_whisper.utils import get_assets_path
|
||||
|
||||
|
||||
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
||||
class VadOptions(NamedTuple):
|
||||
@dataclass
|
||||
class VadOptions:
|
||||
"""VAD options.
|
||||
|
||||
Attributes:
|
||||
|
||||
Reference in New Issue
Block a user