mirror of
https://github.com/AtHeartEngineering/local_transcription.git
synced 2026-01-07 22:54:01 -05:00
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
from flask import Flask, request, jsonify
|
|
import whisperx
|
|
import gc
|
|
import torch
|
|
from werkzeug.utils import secure_filename
|
|
import os
|
|
import toml
|
|
from format import format_srt
|
|
|
|
settings = toml.load('settings.toml')
|
|
|
|
app = Flask(__name__)
|
|
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'm4a'}
|
|
app.config['UPLOAD_FOLDER'] = settings['settings']['upload_folder']
|
|
|
|
def allowed_file(filename):
|
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
|
|
|
device = "cuda"
|
|
batch_size = 16
|
|
language = "en"
|
|
compute_type = "float16"
|
|
model_dir = settings['settings']['model_folder']
|
|
HF_TOKEN = settings['settings']['hf_token']
|
|
|
|
@app.route('/transcribe', methods=['POST'])
|
|
def transcribe_audio():
|
|
if 'file' not in request.files:
|
|
return jsonify(error="No file part"), 400
|
|
file = request.files['file']
|
|
if file.filename == '':
|
|
return jsonify(error="No selected file"), 400
|
|
if file and allowed_file(file.filename):
|
|
if file.filename is not None:
|
|
filename = secure_filename(file.filename)
|
|
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
file.save(file_path)
|
|
|
|
# Transcribe with Whisper
|
|
model = whisperx.load_model("large-v2", device, language=language, compute_type=compute_type, download_root=model_dir)
|
|
audio = whisperx.load_audio(file_path)
|
|
result = model.transcribe(audio, batch_size=batch_size)
|
|
|
|
# Align Whisper output
|
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
|
|
|
|
# Assign speaker labels
|
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
|
|
diarize_segments = diarize_model(audio)
|
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
|
|
|
# Clean-up
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
del model, model_a, diarize_model
|
|
|
|
os.remove(file_path) # Remove the uploaded file after processing
|
|
srt = format_srt(result['segments'])
|
|
return jsonify(srt)
|
|
|
|
return jsonify(error="Invalid file type"), 400
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True, host=settings['host']['ip'], port=settings['host']['port'])
|