From 203dddb047fd2c3ed2a520fe1416467a527e0f37 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Tue, 5 Nov 2024 11:32:20 +0200 Subject: [PATCH] replace `NamedTuple` with `dataclass` (#1105) * replace `NamedTuple` with `dataclass` * add deprecation warnings --- faster_whisper/transcribe.py | 75 +++++++++++++++++++++--------------- faster_whisper/vad.py | 6 ++- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a8db571..199bb09 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -6,9 +6,11 @@ import random import zlib from collections import Counter, defaultdict +from dataclasses import asdict, dataclass from inspect import signature from math import ceil -from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import BinaryIO, Iterable, List, Optional, Tuple, Union +from warnings import warn import ctranslate2 import numpy as np @@ -30,14 +32,24 @@ from faster_whisper.vad import ( ) -class Word(NamedTuple): +@dataclass +class Word: start: float end: float word: str probability: float + def _asdict(self): + warn( + "Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead", + DeprecationWarning, + 2, + ) + return asdict(self) -class Segment(NamedTuple): + +@dataclass +class Segment: id: int seek: int start: float @@ -50,9 +62,18 @@ class Segment(NamedTuple): words: Optional[List[Word]] temperature: Optional[float] = 1.0 + def _asdict(self): + warn( + "Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead", + DeprecationWarning, + 2, + ) + return asdict(self) + # Added additional parameters for multilingual videos and fixes below -class TranscriptionOptions(NamedTuple): +@dataclass +class TranscriptionOptions: beam_size: int best_of: int patience: float @@ -83,7 +104,8 @@ class TranscriptionOptions(NamedTuple): hotwords: Optional[str] -class TranscriptionInfo(NamedTuple): +@dataclass +class TranscriptionInfo: language: str language_probability: float duration: float @@ -108,7 +130,7 @@ class BatchedInferencePipeline: def __init__( self, model, - options: Optional[NamedTuple] = None, + options: Optional[TranscriptionOptions] = None, tokenizer=None, language: Optional[str] = None, ): @@ -473,7 +495,7 @@ class BatchedInferencePipeline: results = self.forward( features[i : i + batch_size], chunks_metadata[i : i + batch_size], - **options._asdict(), + **asdict(options), ) for result in results: @@ -1043,16 +1065,15 @@ class WhisperModel: content_duration = float(content_frames * self.feature_extractor.time_per_frame) if isinstance(options.clip_timestamps, str): - options = options._replace( - clip_timestamps=[ - float(ts) - for ts in ( - options.clip_timestamps.split(",") - if options.clip_timestamps - else [] - ) - ] - ) + options.clip_timestamps = [ + float(ts) + for ts in ( + options.clip_timestamps.split(",") + if options.clip_timestamps + else [] + ) + ] + seek_points: List[int] = [ round(ts * self.frames_per_second) for ts in options.clip_timestamps ] @@ -1999,23 +2020,17 @@ def restore_speech_timestamps( # Ensure the word start and end times are resolved to the same chunk. middle = (word.start + word.end) / 2 chunk_index = ts_map.get_chunk_index(middle) - word = word._replace( - start=ts_map.get_original_time(word.start, chunk_index), - end=ts_map.get_original_time(word.end, chunk_index), - ) + word.start = ts_map.get_original_time(word.start, chunk_index) + word.end = ts_map.get_original_time(word.end, chunk_index) words.append(word) - segment = segment._replace( - start=words[0].start, - end=words[-1].end, - words=words, - ) + segment.start = words[0].start + segment.end = words[-1].end + segment.words = words else: - segment = segment._replace( - start=ts_map.get_original_time(segment.start), - end=ts_map.get_original_time(segment.end), - ) + segment.start = ts_map.get_original_time(segment.start) + segment.end = ts_map.get_original_time(segment.end) yield segment diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index d448f5b..c94d3d5 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -2,7 +2,8 @@ import bisect import functools import os -from typing import Dict, List, NamedTuple, Optional, Tuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -11,7 +12,8 @@ from faster_whisper.utils import get_assets_path # The code below is adapted from https://github.com/snakers4/silero-vad. -class VadOptions(NamedTuple): +@dataclass +class VadOptions: """VAD options. Attributes: