replace NamedTuple with dataclass (#1105)

* replace `NamedTuple` with `dataclass`

* add deprecation warnings
This commit is contained in:
Mahmoud Ashraf
2024-11-05 11:32:20 +02:00
committed by GitHub
parent 814472fdbf
commit 203dddb047
2 changed files with 49 additions and 32 deletions

View File

@@ -6,9 +6,11 @@ import random
import zlib
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from warnings import warn
import ctranslate2
import numpy as np
@@ -30,14 +32,24 @@ from faster_whisper.vad import (
)
class Word(NamedTuple):
@dataclass
class Word:
start: float
end: float
word: str
probability: float
def _asdict(self):
warn(
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
DeprecationWarning,
2,
)
return asdict(self)
class Segment(NamedTuple):
@dataclass
class Segment:
id: int
seek: int
start: float
@@ -50,9 +62,18 @@ class Segment(NamedTuple):
words: Optional[List[Word]]
temperature: Optional[float] = 1.0
def _asdict(self):
warn(
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
DeprecationWarning,
2,
)
return asdict(self)
# Added additional parameters for multilingual videos and fixes below
class TranscriptionOptions(NamedTuple):
@dataclass
class TranscriptionOptions:
beam_size: int
best_of: int
patience: float
@@ -83,7 +104,8 @@ class TranscriptionOptions(NamedTuple):
hotwords: Optional[str]
class TranscriptionInfo(NamedTuple):
@dataclass
class TranscriptionInfo:
language: str
language_probability: float
duration: float
@@ -108,7 +130,7 @@ class BatchedInferencePipeline:
def __init__(
self,
model,
options: Optional[NamedTuple] = None,
options: Optional[TranscriptionOptions] = None,
tokenizer=None,
language: Optional[str] = None,
):
@@ -473,7 +495,7 @@ class BatchedInferencePipeline:
results = self.forward(
features[i : i + batch_size],
chunks_metadata[i : i + batch_size],
**options._asdict(),
**asdict(options),
)
for result in results:
@@ -1043,16 +1065,15 @@ class WhisperModel:
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
if isinstance(options.clip_timestamps, str):
options = options._replace(
clip_timestamps=[
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
)
options.clip_timestamps = [
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
]
@@ -1999,23 +2020,17 @@ def restore_speech_timestamps(
# Ensure the word start and end times are resolved to the same chunk.
middle = (word.start + word.end) / 2
chunk_index = ts_map.get_chunk_index(middle)
word = word._replace(
start=ts_map.get_original_time(word.start, chunk_index),
end=ts_map.get_original_time(word.end, chunk_index),
)
word.start = ts_map.get_original_time(word.start, chunk_index)
word.end = ts_map.get_original_time(word.end, chunk_index)
words.append(word)
segment = segment._replace(
start=words[0].start,
end=words[-1].end,
words=words,
)
segment.start = words[0].start
segment.end = words[-1].end
segment.words = words
else:
segment = segment._replace(
start=ts_map.get_original_time(segment.start),
end=ts_map.get_original_time(segment.end),
)
segment.start = ts_map.get_original_time(segment.start)
segment.end = ts_map.get_original_time(segment.end)
yield segment

View File

@@ -2,7 +2,8 @@ import bisect
import functools
import os
from typing import Dict, List, NamedTuple, Optional, Tuple
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
@@ -11,7 +12,8 @@ from faster_whisper.utils import get_assets_path
# The code below is adapted from https://github.com/snakers4/silero-vad.
class VadOptions(NamedTuple):
@dataclass
class VadOptions:
"""VAD options.
Attributes: