feat: add centralized logging to replace ad-hoc print statements (#1254)

* feat: add logging utility functions

* feat: add logging setup and log level argument to CLI

* feat: integrate logging across modules
This commit is contained in:
Barabazs
2025-10-10 08:41:06 +02:00
committed by GitHub
parent 3b1b9a8c4d
commit a51ae7a81a
9 changed files with 145 additions and 20 deletions

View File

@@ -29,3 +29,29 @@ def load_audio(*args, **kwargs):
def assign_word_speakers(*args, **kwargs):
diarize = _lazy_import("diarize")
return diarize.assign_word_speakers(*args, **kwargs)
def setup_logging(*args, **kwargs):
"""
Configure logging for WhisperX.
Args:
level: Logging level (debug, info, warning, error, critical). Default: warning
log_file: Optional path to log file. If None, logs only to console.
"""
logging_module = _lazy_import("log_utils")
return logging_module.setup_logging(*args, **kwargs)
def get_logger(*args, **kwargs):
"""
Get a logger instance for the given module.
Args:
name: Logger name (typically __name__ from calling module)
Returns:
Logger instance configured with WhisperX settings
"""
logging_module = _lazy_import("log_utils")
return logging_module.get_logger(*args, **kwargs)

View File

@@ -6,6 +6,7 @@ import torch
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
optional_int, str2bool)
from whisperx.log_utils import setup_logging
def cli():
@@ -23,6 +24,7 @@ def cli():
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
@@ -80,6 +82,16 @@ def cli():
args = parser.parse_args().__dict__
log_level = args.get("log_level")
verbose = args.get("verbose")
if log_level is not None:
setup_logging(level=log_level)
elif verbose:
setup_logging(level="info")
else:
setup_logging(level="warning")
from whisperx.transcribe import transcribe_task
transcribe_task(args, parser)

View File

@@ -24,6 +24,9 @@ from whisperx.schema import (
)
import nltk
from nltk.data import load as nltk_load
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -81,8 +84,9 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]
elif language_code in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
print(f"There is no default alignment model set for this language ({language_code}).\
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
logger.error(f"No default alignment model for language: {language_code}. "
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
f"then pass the model name via --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")
if model_name in torchaudio.pipelines.__all__:
@@ -223,12 +227,12 @@ def align(
# check we can align
if len(segment_data[sdx]["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original')
aligned_segments.append(aligned_seg)
continue
if t1 >= MAX_DURATION:
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping')
aligned_segments.append(aligned_seg)
continue
@@ -270,7 +274,7 @@ def align(
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
aligned_segments.append(aligned_seg)
continue

View File

@@ -14,6 +14,9 @@ from transformers.pipelines.pt_utils import PipelineIterator
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.schema import SingleSegment, TranscriptionResult
from whisperx.vads import Vad, Silero, Pyannote
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
def find_numeral_symbol_tokens(tokenizer):
@@ -247,7 +250,7 @@ class FasterWhisperPipeline(Pipeline):
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
print(f"Suppressing numeral and symbol tokens")
logger.info("Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
@@ -285,7 +288,7 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
model_n_mels = self.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
@@ -294,7 +297,7 @@ class FasterWhisperPipeline(Pipeline):
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio")
return language
@@ -344,7 +347,7 @@ def load_model(
if language is not None:
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
tokenizer = None
default_asr_options = {

View File

@@ -6,6 +6,9 @@ import torch
from whisperx.audio import load_audio, SAMPLE_RATE
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
class DiarizationPipeline:
@@ -18,6 +21,7 @@ class DiarizationPipeline:
if isinstance(device, str):
device = torch.device(device)
model_config = model_name or "pyannote/speaker-diarization-3.1"
logger.info(f"Loading diarization model: {model_config}")
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
def __call__(

67
whisperx/log_utils.py Normal file
View File

@@ -0,0 +1,67 @@
import logging
import sys
from typing import Optional
_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
def setup_logging(
level: str = "info",
log_file: Optional[str] = None,
) -> None:
"""
Configure logging for WhisperX.
Args:
level: Logging level (debug, info, warning, error, critical). Default: info
log_file: Optional path to log file. If None, logs only to console.
"""
logger = logging.getLogger("whisperx")
logger.handlers.clear()
try:
log_level = getattr(logging, level.upper())
except AttributeError:
log_level = logging.WARNING
logger.setLevel(log_level)
formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if log_file:
try:
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
except (OSError) as e:
logger.warning(f"Failed to create log file '{log_file}': {e}")
logger.warning("Continuing with console logging only")
# Don't propagate to root logger to avoid duplicate messages
logger.propagate = False
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for the given module.
Args:
name: Logger name (typically __name__ from calling module)
Returns:
Logger instance configured with WhisperX settings
"""
whisperx_logger = logging.getLogger("whisperx")
if not whisperx_logger.handlers:
setup_logging()
logger_name = "whisperx" if name == "__main__" else name
return logging.getLogger(logger_name)

View File

@@ -12,6 +12,9 @@ from whisperx.audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
@@ -142,7 +145,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
logger.info("Performing transcription...")
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
@@ -175,13 +178,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(
logger.info(
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...")
logger.info("Performing alignment...")
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
@@ -203,12 +206,12 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
# >> Diarize
if diarize:
if hf_token is None:
print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
logger.warning(
"No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model"
)
tmp_results = results
print(">>Performing diarization...")
print(">>Using model:", diarize_model_name)
logger.info("Performing diarization...")
logger.info(f"Using model: {diarize_model_name}")
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:

View File

@@ -13,6 +13,9 @@ from pyannote.core import Segment
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
@@ -232,7 +235,7 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
class Pyannote(Vad):
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
logger.info("Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
@@ -257,7 +260,7 @@ class Pyannote(Vad):
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
print("No active speech found in audio")
logger.warning("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

View File

@@ -8,6 +8,9 @@ import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
from whisperx.log_utils import get_logger
logger = get_logger(__name__)
AudioFile = Union[Text, Path, IOBase, Mapping]
@@ -15,7 +18,7 @@ AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
logger.info("Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
@@ -60,7 +63,7 @@ class Silero(Vad):
):
assert chunk_size > 0
if len(segments_list) == 0:
print("No active speech found in audio")
logger.warning("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)