mirror of
https://github.com/AtHeartEngineering/local_transcription.git
synced 2026-01-09 15:37:59 -05:00
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import whisperx
|
|
import gc
|
|
|
|
device = "cuda"
|
|
audio_file = "audio.mp3"
|
|
batch_size = 16
|
|
language = "en"
|
|
compute_type = "float16"
|
|
|
|
with open("hf_token.txt", "r") as f:
|
|
HF_TOKEN = f.read()
|
|
|
|
# 1. Transcribe with original whisper (batched)
|
|
# save model to local path (optional)
|
|
model_dir = "./models/"
|
|
model = whisperx.load_model("large-v2", device, language="en", compute_type=compute_type, download_root=model_dir)
|
|
|
|
audio = whisperx.load_audio(audio_file)
|
|
result = model.transcribe(audio, batch_size=batch_size)
|
|
print(result["segments"]) # before alignment
|
|
|
|
# delete model if low on GPU resources
|
|
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
|
|
|
|
# 2. Align whisper output
|
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
|
|
|
|
print(result["segments"]) # after alignment
|
|
|
|
# delete model if low on GPU resources
|
|
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
|
|
|
|
# 3. Assign speaker labels
|
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
|
|
|
|
# add min/max number of speakers if known
|
|
diarize_segments = diarize_model(audio)
|
|
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
|
|
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
|
print(diarize_segments)
|
|
print(result["segments"]) # segments are now assigned speaker IDs |