mirror of
https://github.com/m-bain/whisperX.git
synced 2026-01-09 12:38:08 -05:00
feat: add centralized logging to replace ad-hoc print statements (#1254)
* feat: add logging utility functions * feat: add logging setup and log level argument to CLI * feat: integrate logging across modules
This commit is contained in:
@@ -29,3 +29,29 @@ def load_audio(*args, **kwargs):
|
||||
def assign_word_speakers(*args, **kwargs):
|
||||
diarize = _lazy_import("diarize")
|
||||
return diarize.assign_word_speakers(*args, **kwargs)
|
||||
|
||||
|
||||
def setup_logging(*args, **kwargs):
|
||||
"""
|
||||
Configure logging for WhisperX.
|
||||
|
||||
Args:
|
||||
level: Logging level (debug, info, warning, error, critical). Default: warning
|
||||
log_file: Optional path to log file. If None, logs only to console.
|
||||
"""
|
||||
logging_module = _lazy_import("log_utils")
|
||||
return logging_module.setup_logging(*args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(*args, **kwargs):
|
||||
"""
|
||||
Get a logger instance for the given module.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically __name__ from calling module)
|
||||
|
||||
Returns:
|
||||
Logger instance configured with WhisperX settings
|
||||
"""
|
||||
logging_module = _lazy_import("log_utils")
|
||||
return logging_module.get_logger(*args, **kwargs)
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
|
||||
optional_int, str2bool)
|
||||
from whisperx.log_utils import setup_logging
|
||||
|
||||
|
||||
def cli():
|
||||
@@ -23,6 +24,7 @@ def cli():
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
@@ -80,6 +82,16 @@ def cli():
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
log_level = args.get("log_level")
|
||||
verbose = args.get("verbose")
|
||||
|
||||
if log_level is not None:
|
||||
setup_logging(level=log_level)
|
||||
elif verbose:
|
||||
setup_logging(level="info")
|
||||
else:
|
||||
setup_logging(level="warning")
|
||||
|
||||
from whisperx.transcribe import transcribe_task
|
||||
|
||||
transcribe_task(args, parser)
|
||||
|
||||
@@ -24,6 +24,9 @@ from whisperx.schema import (
|
||||
)
|
||||
import nltk
|
||||
from nltk.data import load as nltk_load
|
||||
from whisperx.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
@@ -81,8 +84,9 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]
|
||||
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
||||
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
||||
else:
|
||||
print(f"There is no default alignment model set for this language ({language_code}).\
|
||||
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
||||
logger.error(f"No default alignment model for language: {language_code}. "
|
||||
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
|
||||
f"then pass the model name via --align_model [MODEL_NAME]")
|
||||
raise ValueError(f"No default align-model for language: {language_code}")
|
||||
|
||||
if model_name in torchaudio.pipelines.__all__:
|
||||
@@ -223,12 +227,12 @@ def align(
|
||||
|
||||
# check we can align
|
||||
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||
logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
if t1 >= MAX_DURATION:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
||||
logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
@@ -270,7 +274,7 @@ def align(
|
||||
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||
|
||||
if path is None:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
||||
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from whisperx.schema import SingleSegment, TranscriptionResult
|
||||
from whisperx.vads import Vad, Silero, Pyannote
|
||||
from whisperx.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
@@ -247,7 +250,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
if self.suppress_numerals:
|
||||
previous_suppress_tokens = self.options.suppress_tokens
|
||||
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||
print(f"Suppressing numeral and symbol tokens")
|
||||
logger.info("Suppressing numeral and symbol tokens")
|
||||
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
||||
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
||||
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
||||
@@ -285,7 +288,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
def detect_language(self, audio: np.ndarray) -> str:
|
||||
if audio.shape[0] < N_SAMPLES:
|
||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
@@ -294,7 +297,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
results = self.model.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio")
|
||||
return language
|
||||
|
||||
|
||||
@@ -344,7 +347,7 @@ def load_model(
|
||||
if language is not None:
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
|
||||
@@ -6,6 +6,9 @@ import torch
|
||||
|
||||
from whisperx.audio import load_audio, SAMPLE_RATE
|
||||
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult
|
||||
from whisperx.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DiarizationPipeline:
|
||||
@@ -18,6 +21,7 @@ class DiarizationPipeline:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
||||
logger.info(f"Loading diarization model: {model_config}")
|
||||
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(
|
||||
|
||||
67
whisperx/log_utils.py
Normal file
67
whisperx/log_utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: str = "info",
|
||||
log_file: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Configure logging for WhisperX.
|
||||
|
||||
Args:
|
||||
level: Logging level (debug, info, warning, error, critical). Default: info
|
||||
log_file: Optional path to log file. If None, logs only to console.
|
||||
"""
|
||||
logger = logging.getLogger("whisperx")
|
||||
|
||||
logger.handlers.clear()
|
||||
|
||||
try:
|
||||
log_level = getattr(logging, level.upper())
|
||||
except AttributeError:
|
||||
log_level = logging.WARNING
|
||||
logger.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)
|
||||
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(log_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
if log_file:
|
||||
try:
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setLevel(log_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
except (OSError) as e:
|
||||
logger.warning(f"Failed to create log file '{log_file}': {e}")
|
||||
logger.warning("Continuing with console logging only")
|
||||
|
||||
# Don't propagate to root logger to avoid duplicate messages
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance for the given module.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically __name__ from calling module)
|
||||
|
||||
Returns:
|
||||
Logger instance configured with WhisperX settings
|
||||
"""
|
||||
whisperx_logger = logging.getLogger("whisperx")
|
||||
if not whisperx_logger.handlers:
|
||||
setup_logging()
|
||||
|
||||
logger_name = "whisperx" if name == "__main__" else name
|
||||
return logging.getLogger(logger_name)
|
||||
@@ -12,6 +12,9 @@ from whisperx.audio import load_audio
|
||||
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
|
||||
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
||||
from whisperx.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
@@ -142,7 +145,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
for audio_path in args.pop("audio"):
|
||||
audio = load_audio(audio_path)
|
||||
# >> VAD & ASR
|
||||
print(">>Performing transcription...")
|
||||
logger.info("Performing transcription...")
|
||||
result: TranscriptionResult = model.transcribe(
|
||||
audio,
|
||||
batch_size=batch_size,
|
||||
@@ -175,13 +178,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(
|
||||
logger.info(
|
||||
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
||||
)
|
||||
align_model, align_metadata = load_align_model(
|
||||
result["language"], device
|
||||
)
|
||||
print(">>Performing alignment...")
|
||||
logger.info("Performing alignment...")
|
||||
result: AlignedTranscriptionResult = align(
|
||||
result["segments"],
|
||||
align_model,
|
||||
@@ -203,12 +206,12 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
# >> Diarize
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print(
|
||||
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
||||
logger.warning(
|
||||
"No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model"
|
||||
)
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
print(">>Using model:", diarize_model_name)
|
||||
logger.info("Performing diarization...")
|
||||
logger.info(f"Using model: {diarize_model_name}")
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
|
||||
@@ -13,6 +13,9 @@ from pyannote.core import Segment
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
from whisperx.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
@@ -232,7 +235,7 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
class Pyannote(Vad):
|
||||
|
||||
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||
print(">>Performing voice activity detection using Pyannote...")
|
||||
logger.info("Performing voice activity detection using Pyannote...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
||||
|
||||
@@ -257,7 +260,7 @@ class Pyannote(Vad):
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
logger.warning("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
||||
|
||||
@@ -8,6 +8,9 @@ import torch
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
from whisperx.log_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||
|
||||
@@ -15,7 +18,7 @@ AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||
class Silero(Vad):
|
||||
# check again default values
|
||||
def __init__(self, **kwargs):
|
||||
print(">>Performing voice activity detection using Silero...")
|
||||
logger.info("Performing voice activity detection using Silero...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
|
||||
self.vad_onset = kwargs['vad_onset']
|
||||
@@ -60,7 +63,7 @@ class Silero(Vad):
|
||||
):
|
||||
assert chunk_size > 0
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
logger.warning("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
||||
|
||||
Reference in New Issue
Block a user