mirror of
https://github.com/DrewThomasson/ebook2audiobook.git
synced 2026-01-10 06:18:02 -05:00
...
This commit is contained in:
232
lib/classes/tts_engines/.template.py
Normal file
232
lib/classes/tts_engines/.template.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from lib import *
|
||||
from lib.classes.tts_engines.common.utils import unload_tts, append_sentence2vtt
|
||||
from lib.classes.tts_engines.common.audio_filters import detect_gender, trim_audio, normalize_audio, is_audio_data_valid
|
||||
|
||||
#import logging
|
||||
#logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
class Coqui:
|
||||
|
||||
def __init__(self, session):
|
||||
try:
|
||||
self.session = session
|
||||
self.cache_dir = tts_dir
|
||||
self.speakers_path = None
|
||||
self.tts_key = f"{self.session['tts_engine']}-{self.session['fine_tuned']}"
|
||||
self.tts_vc_key = default_vc_model.rsplit('/', 1)[-1]
|
||||
self.is_bf16 = True if self.session['device'] == 'cuda' and torch.cuda.is_bf16_supported() == True else False
|
||||
self.npz_path = None
|
||||
self.npz_data = None
|
||||
self.sentences_total_time = 0.0
|
||||
self.sentence_idx = 1
|
||||
self.params = {TTS_ENGINES['NEW_TTS']: {}}
|
||||
self.params[self.session['tts_engine']]['samplerate'] = models[self.session['tts_engine']][self.session['fine_tuned']]['samplerate']
|
||||
self.vtt_path = os.path.join(self.session['process_dir'], os.path.splitext(self.session['final_name'])[0] + '.vtt')
|
||||
self.resampler_cache = {}
|
||||
self.audio_segments = []
|
||||
self._build()
|
||||
except Exception as e:
|
||||
error = f'__init__() error: {e}'
|
||||
print(error)
|
||||
return None
|
||||
|
||||
def _build(self):
|
||||
try:
|
||||
tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False)
|
||||
if not tts:
|
||||
if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']:
|
||||
if self.session['custom_model'] is not None:
|
||||
msg = f"{self.session['tts_engine']} custom model not implemented yet!"
|
||||
print(msg)
|
||||
return False
|
||||
else:
|
||||
model_path = models[self.session['tts_engine']][self.session['fine_tuned']]['repo']
|
||||
tts = self._load_api(self.tts_key, model_path, self.session['device'])
|
||||
return (loaded_tts.get(self.tts_key) or {}).get('engine', False)
|
||||
except Exception as e:
|
||||
error = f'build() error: {e}'
|
||||
print(error)
|
||||
return False
|
||||
|
||||
def _load_api(self, key, model_path, device):
|
||||
global lock
|
||||
try:
|
||||
if key in loaded_tts.keys():
|
||||
return loaded_tts[key]['engine']
|
||||
unload_tts(device, [self.tts_key, self.tts_vc_key])
|
||||
with lock:
|
||||
tts = NEW_TTS(model_path)
|
||||
if tts
|
||||
if device == 'cuda':
|
||||
NEW_TTS.WITH_CUDA
|
||||
else:
|
||||
NEW_TTS.WITHOUT_CUDA
|
||||
loaded_tts[key] = {"engine": tts, "config": None}
|
||||
msg = f'{model_path} Loaded!'
|
||||
print(msg)
|
||||
return tts
|
||||
else:
|
||||
error = 'TTS engine could not be created!'
|
||||
print(error)
|
||||
except Exception as e:
|
||||
error = f'_load_api() error: {e}'
|
||||
print(error)
|
||||
return False
|
||||
|
||||
def _load_checkpoint(self, **kwargs):
|
||||
global lock
|
||||
try:
|
||||
key = kwargs.get('key')
|
||||
if key in loaded_tts.keys():
|
||||
return loaded_tts[key]['engine']
|
||||
tts_engine = kwargs.get('tts_engine')
|
||||
device = kwargs.get('device')
|
||||
unload_tts(device, [self.tts_key])
|
||||
with lock:
|
||||
checkpoint_dir = kwargs.get('checkpoint_dir')
|
||||
NEW_TTS.LOAD_CHECKPOINT(
|
||||
config,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
eval=True
|
||||
)
|
||||
if tts:
|
||||
if device == 'cuda':
|
||||
NEW_TTS.WITH_CUDA
|
||||
else:
|
||||
NEW_TTS.WITHOUT_CUDA
|
||||
loaded_tts[key] = {"engine": tts, "config": config}
|
||||
msg = f'{tts_engine} Loaded!'
|
||||
print(msg)
|
||||
return tts
|
||||
else:
|
||||
error = 'TTS engine could not be created!'
|
||||
print(error)
|
||||
except Exception as e:
|
||||
error = f'_load_checkpoint() error: {e}'
|
||||
return False
|
||||
|
||||
def _tensor_type(self, audio_data):
|
||||
if isinstance(audio_data, torch.Tensor):
|
||||
return audio_data
|
||||
elif isinstance(audio_data, np.ndarray):
|
||||
return torch.from_numpy(audio_data).float()
|
||||
elif isinstance(audio_data, list):
|
||||
return torch.tensor(audio_data, dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for audio_data: {type(audio_data)}")
|
||||
|
||||
def _get_resampler(self, orig_sr, target_sr):
|
||||
key = (orig_sr, target_sr)
|
||||
if key not in self.resampler_cache:
|
||||
self.resampler_cache[key] = torchaudio.transforms.Resample(
|
||||
orig_freq=orig_sr, new_freq=target_sr
|
||||
)
|
||||
return self.resampler_cache[key]
|
||||
|
||||
def _resample_wav(self, wav_path, expected_sr):
|
||||
waveform, orig_sr = torchaudio.load(wav_path)
|
||||
if orig_sr == expected_sr and waveform.size(0) == 1:
|
||||
return wav_path
|
||||
if waveform.size(0) > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=True)
|
||||
if orig_sr != expected_sr:
|
||||
resampler = self._get_resampler(orig_sr, expected_sr)
|
||||
waveform = resampler(waveform)
|
||||
wav_tensor = waveform.squeeze(0)
|
||||
wav_numpy = wav_tensor.cpu().numpy()
|
||||
tmp_fh = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp_path = tmp_fh.name
|
||||
tmp_fh.close()
|
||||
sf.write(tmp_path, wav_numpy, expected_sr, subtype="PCM_16")
|
||||
return tmp_path
|
||||
|
||||
def convert(self, sentence_number, sentence):
|
||||
global xtts_builtin_speakers_list
|
||||
try:
|
||||
speaker = None
|
||||
audio_data = False
|
||||
trim_audio_buffer = 0.004
|
||||
settings = self.params[self.session['tts_engine']]
|
||||
final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_number}.{default_audio_proc_format}')
|
||||
sentence = sentence.strip()
|
||||
settings['voice_path'] = (
|
||||
self.session['voice'] if self.session['voice'] is not None
|
||||
else os.path.join(self.session['custom_model_dir'], self.session['tts_engine'], self.session['custom_model'], 'ref.wav') if self.session['custom_model'] is not None
|
||||
else models[self.session['tts_engine']][self.session['fine_tuned']]['voice']
|
||||
)
|
||||
if settings['voice_path'] is not None:
|
||||
speaker = re.sub(r'\.wav$', '', os.path.basename(settings['voice_path']))
|
||||
tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False)
|
||||
if tts:
|
||||
if sentence[-1].isalnum():
|
||||
sentence = f'{sentence} —'
|
||||
if sentence == TTS_SML['break']:
|
||||
break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100))) # 0.4 to 0.7 seconds
|
||||
self.audio_segments.append(break_tensor.clone())
|
||||
return True
|
||||
elif sentence == TTS_SML['pause']:
|
||||
pause_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(1.0, 1.8) * 100) / 100))) # 1.0 to 1.8 seconds
|
||||
self.audio_segments.append(pause_tensor.clone())
|
||||
return True
|
||||
else:
|
||||
if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']:
|
||||
audio_sentence = NEW_TTS.CONVERT() # audio_sentence must be torch.Tensor or (list, tuple) or np.ndarray
|
||||
if is_audio_data_valid(audio_sentence):
|
||||
sourceTensor = self._tensor_type(audio_sentence)
|
||||
audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu()
|
||||
if sentence[-1].isalnum() or sentence[-1] == '—':
|
||||
audio_tensor = trim_audio(audio_tensor.squeeze(), settings['samplerate'], 0.003, trim_audio_buffer).unsqueeze(0)
|
||||
self.audio_segments.append(audio_tensor)
|
||||
if not re.search(r'\w$', sentence, flags=re.UNICODE):
|
||||
break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100)))
|
||||
self.audio_segments.append(break_tensor.clone())
|
||||
if self.audio_segments:
|
||||
audio_tensor = torch.cat(self.audio_segments, dim=-1)
|
||||
start_time = self.sentences_total_time
|
||||
duration = audio_tensor.shape[-1] / settings['samplerate']
|
||||
end_time = start_time + duration
|
||||
self.sentences_total_time = end_time
|
||||
sentence_obj = {
|
||||
"start": start_time,
|
||||
"end": end_time,
|
||||
"text": sentence,
|
||||
"resume_check": self.sentence_idx
|
||||
}
|
||||
self.sentence_idx = append_sentence2vtt(sentence_obj, self.vtt_path)
|
||||
if self.sentence_idx:
|
||||
torchaudio.save(final_sentence_file, audio_tensor, settings['samplerate'], format=default_audio_proc_format)
|
||||
del audio_tensor
|
||||
self.audio_segments = []
|
||||
if os.path.exists(final_sentence_file):
|
||||
return True
|
||||
else:
|
||||
error = f"Cannot create {final_sentence_file}"
|
||||
print(error)
|
||||
else:
|
||||
error = f"convert() error: {self.session['tts_engine']} is None"
|
||||
print(error)
|
||||
except Exception as e:
|
||||
error = f'Coquit.convert(): {e}'
|
||||
raise ValueError(e)
|
||||
return False
|
||||
Reference in New Issue
Block a user