diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..721cb1b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +settings.toml +audio.mp3 +uploads/* +models/* +venv \ No newline at end of file diff --git a/README.md b/README.md index 90c4c36..5809929 100644 --- a/README.md +++ b/README.md @@ -1 +1,22 @@ -# local_transcription \ No newline at end of file +# local_transcription + +## Settings + +Make sure to set the IP, port, and HuggingFace API key in the `settings.toml` file. The huggingface API key is used to fetch the models, everything else is ran locally. + +```toml +[host] +ip="192.168.0.99" +port=5063 + +[settings] +upload_folder="uploads" +model_folder="models" +hf_token="HUGGINGFACE API KEY" +``` + +## Usage + +`python host.py` on the server + +`python client_upload.py audio.mp3` on the client \ No newline at end of file diff --git a/client_upload.py b/client_upload.py new file mode 100644 index 0000000..5dfbf92 --- /dev/null +++ b/client_upload.py @@ -0,0 +1,29 @@ +import requests +import sys +import toml + +settings = toml.load('settings.toml') + +def upload_audio_for_transcription(api_url, file_path): + """ + Uploads an audio file to the transcription API and returns the response. + + Parameters: + - api_url: The URL of the transcription API endpoint. + - file_path: The path to the audio file to be uploaded. + + Returns: + - A dictionary with the transcription result or error message. + """ + files = {'file': open(file_path, 'rb')} + response = requests.post(api_url, files=files) + if response.status_code == 200: + return response.json() + else: + return {"error": f"Failed to transcribe audio. Status code: {response.status_code}, Message: {response.text}"} + +# Example usage +api_url = f"http://{settings['host']['ip']}:{settings['host']['port']}/transcribe" +file_path = sys.argv[1] +result = upload_audio_for_transcription(api_url, file_path) +print(result) diff --git a/host.py b/host.py new file mode 100644 index 0000000..9b611f2 --- /dev/null +++ b/host.py @@ -0,0 +1,66 @@ +from flask import Flask, request, jsonify +import whisperx +import gc +import torch +from werkzeug.utils import secure_filename +import os +import toml + +settings = toml.load('settings.toml') + +app = Flask(__name__) +ALLOWED_EXTENSIONS = {'mp3', 'wav', 'm4a'} +app.config['UPLOAD_FOLDER'] = settings['settings']['upload_folder'] + +def allowed_file(filename): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +device = "cuda" +batch_size = 16 +language = "en" +compute_type = "float16" +model_dir = settings['settings']['model_folder'] +HF_TOKEN = settings['settings']['hf_token'] + +@app.route('/transcribe', methods=['POST']) +def transcribe_audio(): + if 'file' not in request.files: + return jsonify(error="No file part"), 400 + file = request.files['file'] + if file.filename == '': + return jsonify(error="No selected file"), 400 + if file and allowed_file(file.filename): + if file.filename is not None: + filename = secure_filename(file.filename) + file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + file.save(file_path) + + # Transcribe with Whisper + model = whisperx.load_model("large-v2", device, language=language, compute_type=compute_type, download_root=model_dir) + audio = whisperx.load_audio(file_path) + result = model.transcribe(audio, batch_size=batch_size) + print(result["segments"]) # before alignment + + # Align Whisper output + model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) + result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) + print(result["segments"]) # after alignment + + # Assign speaker labels + diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device) + diarize_segments = diarize_model(audio) + result = whisperx.assign_word_speakers(diarize_segments, result) + + # Clean-up + gc.collect() + torch.cuda.empty_cache() + del model, model_a, diarize_model + + os.remove(file_path) # Remove the uploaded file after processing + + return jsonify(result["segments"]) + + return jsonify(error="Invalid file type"), 400 + +if __name__ == '__main__': + app.run(debug=True, host=settings['host']['ip'], port=settings['host']['port']) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a82019d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +pytorch +torchvision +torchaudio +git+https://github.com/m-bain/whisperx.git +flask \ No newline at end of file diff --git a/settings.toml.example b/settings.toml.example new file mode 100644 index 0000000..4bf8dc8 --- /dev/null +++ b/settings.toml.example @@ -0,0 +1,8 @@ +[host] +ip="192.168.0.99" +port=5063 + +[settings] +upload_folder="uploads" +model_folder="models" +hf_token="HUGGINGFACE API KEY" \ No newline at end of file diff --git a/whisperx_test.py b/whisperx_test.py new file mode 100644 index 0000000..831488b --- /dev/null +++ b/whisperx_test.py @@ -0,0 +1,43 @@ +import whisperx +import gc + +device = "cuda" +audio_file = "audio.mp3" +batch_size = 16 +language = "en" +compute_type = "float16" + +with open("hf_token.txt", "r") as f: + HF_TOKEN = f.read() + +# 1. Transcribe with original whisper (batched) +# save model to local path (optional) +model_dir = "./models/" +model = whisperx.load_model("large-v2", device, language="en", compute_type=compute_type, download_root=model_dir) + +audio = whisperx.load_audio(audio_file) +result = model.transcribe(audio, batch_size=batch_size) +print(result["segments"]) # before alignment + +# delete model if low on GPU resources +# import gc; gc.collect(); torch.cuda.empty_cache(); del model + +# 2. Align whisper output +model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) +result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) + +print(result["segments"]) # after alignment + +# delete model if low on GPU resources +# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a + +# 3. Assign speaker labels +diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device) + +# add min/max number of speakers if known +diarize_segments = diarize_model(audio) +# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) + +result = whisperx.assign_word_speakers(diarize_segments, result) +print(diarize_segments) +print(result["segments"]) # segments are now assigned speaker IDs \ No newline at end of file