mirror of
https://github.com/AtHeartEngineering/local_transcription.git
synced 2026-01-08 20:07:59 -05:00
init
This commit is contained in:
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
settings.toml
|
||||
audio.mp3
|
||||
uploads/*
|
||||
models/*
|
||||
venv
|
||||
23
README.md
23
README.md
@@ -1 +1,22 @@
|
||||
# local_transcription
|
||||
# local_transcription
|
||||
|
||||
## Settings
|
||||
|
||||
Make sure to set the IP, port, and HuggingFace API key in the `settings.toml` file. The huggingface API key is used to fetch the models, everything else is ran locally.
|
||||
|
||||
```toml
|
||||
[host]
|
||||
ip="192.168.0.99"
|
||||
port=5063
|
||||
|
||||
[settings]
|
||||
upload_folder="uploads"
|
||||
model_folder="models"
|
||||
hf_token="HUGGINGFACE API KEY"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
`python host.py` on the server
|
||||
|
||||
`python client_upload.py audio.mp3` on the client
|
||||
29
client_upload.py
Normal file
29
client_upload.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import requests
|
||||
import sys
|
||||
import toml
|
||||
|
||||
settings = toml.load('settings.toml')
|
||||
|
||||
def upload_audio_for_transcription(api_url, file_path):
|
||||
"""
|
||||
Uploads an audio file to the transcription API and returns the response.
|
||||
|
||||
Parameters:
|
||||
- api_url: The URL of the transcription API endpoint.
|
||||
- file_path: The path to the audio file to be uploaded.
|
||||
|
||||
Returns:
|
||||
- A dictionary with the transcription result or error message.
|
||||
"""
|
||||
files = {'file': open(file_path, 'rb')}
|
||||
response = requests.post(api_url, files=files)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Failed to transcribe audio. Status code: {response.status_code}, Message: {response.text}"}
|
||||
|
||||
# Example usage
|
||||
api_url = f"http://{settings['host']['ip']}:{settings['host']['port']}/transcribe"
|
||||
file_path = sys.argv[1]
|
||||
result = upload_audio_for_transcription(api_url, file_path)
|
||||
print(result)
|
||||
66
host.py
Normal file
66
host.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from flask import Flask, request, jsonify
|
||||
import whisperx
|
||||
import gc
|
||||
import torch
|
||||
from werkzeug.utils import secure_filename
|
||||
import os
|
||||
import toml
|
||||
|
||||
settings = toml.load('settings.toml')
|
||||
|
||||
app = Flask(__name__)
|
||||
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'm4a'}
|
||||
app.config['UPLOAD_FOLDER'] = settings['settings']['upload_folder']
|
||||
|
||||
def allowed_file(filename):
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
device = "cuda"
|
||||
batch_size = 16
|
||||
language = "en"
|
||||
compute_type = "float16"
|
||||
model_dir = settings['settings']['model_folder']
|
||||
HF_TOKEN = settings['settings']['hf_token']
|
||||
|
||||
@app.route('/transcribe', methods=['POST'])
|
||||
def transcribe_audio():
|
||||
if 'file' not in request.files:
|
||||
return jsonify(error="No file part"), 400
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return jsonify(error="No selected file"), 400
|
||||
if file and allowed_file(file.filename):
|
||||
if file.filename is not None:
|
||||
filename = secure_filename(file.filename)
|
||||
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||||
file.save(file_path)
|
||||
|
||||
# Transcribe with Whisper
|
||||
model = whisperx.load_model("large-v2", device, language=language, compute_type=compute_type, download_root=model_dir)
|
||||
audio = whisperx.load_audio(file_path)
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
# Align Whisper output
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
|
||||
print(result["segments"]) # after alignment
|
||||
|
||||
# Assign speaker labels
|
||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
|
||||
diarize_segments = diarize_model(audio)
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
|
||||
# Clean-up
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
del model, model_a, diarize_model
|
||||
|
||||
os.remove(file_path) # Remove the uploaded file after processing
|
||||
|
||||
return jsonify(result["segments"])
|
||||
|
||||
return jsonify(error="Invalid file type"), 400
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, host=settings['host']['ip'], port=settings['host']['port'])
|
||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pytorch
|
||||
torchvision
|
||||
torchaudio
|
||||
git+https://github.com/m-bain/whisperx.git
|
||||
flask
|
||||
8
settings.toml.example
Normal file
8
settings.toml.example
Normal file
@@ -0,0 +1,8 @@
|
||||
[host]
|
||||
ip="192.168.0.99"
|
||||
port=5063
|
||||
|
||||
[settings]
|
||||
upload_folder="uploads"
|
||||
model_folder="models"
|
||||
hf_token="HUGGINGFACE API KEY"
|
||||
43
whisperx_test.py
Normal file
43
whisperx_test.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import whisperx
|
||||
import gc
|
||||
|
||||
device = "cuda"
|
||||
audio_file = "audio.mp3"
|
||||
batch_size = 16
|
||||
language = "en"
|
||||
compute_type = "float16"
|
||||
|
||||
with open("hf_token.txt", "r") as f:
|
||||
HF_TOKEN = f.read()
|
||||
|
||||
# 1. Transcribe with original whisper (batched)
|
||||
# save model to local path (optional)
|
||||
model_dir = "./models/"
|
||||
model = whisperx.load_model("large-v2", device, language="en", compute_type=compute_type, download_root=model_dir)
|
||||
|
||||
audio = whisperx.load_audio(audio_file)
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
# delete model if low on GPU resources
|
||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
|
||||
|
||||
# 2. Align whisper output
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
|
||||
|
||||
print(result["segments"]) # after alignment
|
||||
|
||||
# delete model if low on GPU resources
|
||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||
|
||||
# 3. Assign speaker labels
|
||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
|
||||
|
||||
# add min/max number of speakers if known
|
||||
diarize_segments = diarize_model(audio)
|
||||
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
print(diarize_segments)
|
||||
print(result["segments"]) # segments are now assigned speaker IDs
|
||||
Reference in New Issue
Block a user