replace NamedTuple with dataclass (#1105)

* replace `NamedTuple` with `dataclass`

* add deprecation warnings
This commit is contained in:
Mahmoud Ashraf
2024-11-05 11:32:20 +02:00
committed by GitHub
parent 814472fdbf
commit 203dddb047
2 changed files with 49 additions and 32 deletions

View File

@@ -6,9 +6,11 @@ import random
import zlib import zlib
from collections import Counter, defaultdict from collections import Counter, defaultdict
from dataclasses import asdict, dataclass
from inspect import signature from inspect import signature
from math import ceil from math import ceil
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from warnings import warn
import ctranslate2 import ctranslate2
import numpy as np import numpy as np
@@ -30,14 +32,24 @@ from faster_whisper.vad import (
) )
class Word(NamedTuple): @dataclass
class Word:
start: float start: float
end: float end: float
word: str word: str
probability: float probability: float
def _asdict(self):
warn(
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
DeprecationWarning,
2,
)
return asdict(self)
class Segment(NamedTuple):
@dataclass
class Segment:
id: int id: int
seek: int seek: int
start: float start: float
@@ -50,9 +62,18 @@ class Segment(NamedTuple):
words: Optional[List[Word]] words: Optional[List[Word]]
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
def _asdict(self):
warn(
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
DeprecationWarning,
2,
)
return asdict(self)
# Added additional parameters for multilingual videos and fixes below # Added additional parameters for multilingual videos and fixes below
class TranscriptionOptions(NamedTuple): @dataclass
class TranscriptionOptions:
beam_size: int beam_size: int
best_of: int best_of: int
patience: float patience: float
@@ -83,7 +104,8 @@ class TranscriptionOptions(NamedTuple):
hotwords: Optional[str] hotwords: Optional[str]
class TranscriptionInfo(NamedTuple): @dataclass
class TranscriptionInfo:
language: str language: str
language_probability: float language_probability: float
duration: float duration: float
@@ -108,7 +130,7 @@ class BatchedInferencePipeline:
def __init__( def __init__(
self, self,
model, model,
options: Optional[NamedTuple] = None, options: Optional[TranscriptionOptions] = None,
tokenizer=None, tokenizer=None,
language: Optional[str] = None, language: Optional[str] = None,
): ):
@@ -473,7 +495,7 @@ class BatchedInferencePipeline:
results = self.forward( results = self.forward(
features[i : i + batch_size], features[i : i + batch_size],
chunks_metadata[i : i + batch_size], chunks_metadata[i : i + batch_size],
**options._asdict(), **asdict(options),
) )
for result in results: for result in results:
@@ -1043,16 +1065,15 @@ class WhisperModel:
content_duration = float(content_frames * self.feature_extractor.time_per_frame) content_duration = float(content_frames * self.feature_extractor.time_per_frame)
if isinstance(options.clip_timestamps, str): if isinstance(options.clip_timestamps, str):
options = options._replace( options.clip_timestamps = [
clip_timestamps=[ float(ts)
float(ts) for ts in (
for ts in ( options.clip_timestamps.split(",")
options.clip_timestamps.split(",") if options.clip_timestamps
if options.clip_timestamps else []
else [] )
) ]
]
)
seek_points: List[int] = [ seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps round(ts * self.frames_per_second) for ts in options.clip_timestamps
] ]
@@ -1999,23 +2020,17 @@ def restore_speech_timestamps(
# Ensure the word start and end times are resolved to the same chunk. # Ensure the word start and end times are resolved to the same chunk.
middle = (word.start + word.end) / 2 middle = (word.start + word.end) / 2
chunk_index = ts_map.get_chunk_index(middle) chunk_index = ts_map.get_chunk_index(middle)
word = word._replace( word.start = ts_map.get_original_time(word.start, chunk_index)
start=ts_map.get_original_time(word.start, chunk_index), word.end = ts_map.get_original_time(word.end, chunk_index)
end=ts_map.get_original_time(word.end, chunk_index),
)
words.append(word) words.append(word)
segment = segment._replace( segment.start = words[0].start
start=words[0].start, segment.end = words[-1].end
end=words[-1].end, segment.words = words
words=words,
)
else: else:
segment = segment._replace( segment.start = ts_map.get_original_time(segment.start)
start=ts_map.get_original_time(segment.start), segment.end = ts_map.get_original_time(segment.end)
end=ts_map.get_original_time(segment.end),
)
yield segment yield segment

View File

@@ -2,7 +2,8 @@ import bisect
import functools import functools
import os import os
from typing import Dict, List, NamedTuple, Optional, Tuple from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@@ -11,7 +12,8 @@ from faster_whisper.utils import get_assets_path
# The code below is adapted from https://github.com/snakers4/silero-vad. # The code below is adapted from https://github.com/snakers4/silero-vad.
class VadOptions(NamedTuple): @dataclass
class VadOptions:
"""VAD options. """VAD options.
Attributes: Attributes: