This commit is contained in:
AtHeartEngineer
2024-03-11 07:53:21 -04:00
parent df3bd00a93
commit 7395eb1dc4
7 changed files with 178 additions and 1 deletions

5
.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
settings.toml
audio.mp3
uploads/*
models/*
venv

View File

@@ -1 +1,22 @@
# local_transcription
# local_transcription
## Settings
Make sure to set the IP, port, and HuggingFace API key in the `settings.toml` file. The huggingface API key is used to fetch the models, everything else is ran locally.
```toml
[host]
ip="192.168.0.99"
port=5063
[settings]
upload_folder="uploads"
model_folder="models"
hf_token="HUGGINGFACE API KEY"
```
## Usage
`python host.py` on the server
`python client_upload.py audio.mp3` on the client

29
client_upload.py Normal file
View File

@@ -0,0 +1,29 @@
import requests
import sys
import toml
settings = toml.load('settings.toml')
def upload_audio_for_transcription(api_url, file_path):
"""
Uploads an audio file to the transcription API and returns the response.
Parameters:
- api_url: The URL of the transcription API endpoint.
- file_path: The path to the audio file to be uploaded.
Returns:
- A dictionary with the transcription result or error message.
"""
files = {'file': open(file_path, 'rb')}
response = requests.post(api_url, files=files)
if response.status_code == 200:
return response.json()
else:
return {"error": f"Failed to transcribe audio. Status code: {response.status_code}, Message: {response.text}"}
# Example usage
api_url = f"http://{settings['host']['ip']}:{settings['host']['port']}/transcribe"
file_path = sys.argv[1]
result = upload_audio_for_transcription(api_url, file_path)
print(result)

66
host.py Normal file
View File

@@ -0,0 +1,66 @@
from flask import Flask, request, jsonify
import whisperx
import gc
import torch
from werkzeug.utils import secure_filename
import os
import toml
settings = toml.load('settings.toml')
app = Flask(__name__)
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'm4a'}
app.config['UPLOAD_FOLDER'] = settings['settings']['upload_folder']
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
device = "cuda"
batch_size = 16
language = "en"
compute_type = "float16"
model_dir = settings['settings']['model_folder']
HF_TOKEN = settings['settings']['hf_token']
@app.route('/transcribe', methods=['POST'])
def transcribe_audio():
if 'file' not in request.files:
return jsonify(error="No file part"), 400
file = request.files['file']
if file.filename == '':
return jsonify(error="No selected file"), 400
if file and allowed_file(file.filename):
if file.filename is not None:
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# Transcribe with Whisper
model = whisperx.load_model("large-v2", device, language=language, compute_type=compute_type, download_root=model_dir)
audio = whisperx.load_audio(file_path)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
# Align Whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
print(result["segments"]) # after alignment
# Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
# Clean-up
gc.collect()
torch.cuda.empty_cache()
del model, model_a, diarize_model
os.remove(file_path) # Remove the uploaded file after processing
return jsonify(result["segments"])
return jsonify(error="Invalid file type"), 400
if __name__ == '__main__':
app.run(debug=True, host=settings['host']['ip'], port=settings['host']['port'])

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
pytorch
torchvision
torchaudio
git+https://github.com/m-bain/whisperx.git
flask

8
settings.toml.example Normal file
View File

@@ -0,0 +1,8 @@
[host]
ip="192.168.0.99"
port=5063
[settings]
upload_folder="uploads"
model_folder="models"
hf_token="HUGGINGFACE API KEY"

43
whisperx_test.py Normal file
View File

@@ -0,0 +1,43 @@
import whisperx
import gc
device = "cuda"
audio_file = "audio.mp3"
batch_size = 16
language = "en"
compute_type = "float16"
with open("hf_token.txt", "r") as f:
HF_TOKEN = f.read()
# 1. Transcribe with original whisper (batched)
# save model to local path (optional)
model_dir = "./models/"
model = whisperx.load_model("large-v2", device, language="en", compute_type=compute_type, download_root=model_dir)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
print(result["segments"]) # after alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs