diff --git a/lib/classes/tts_engines/.template.py b/lib/classes/tts_engines/.template.py new file mode 100644 index 00000000..49fe4fc0 --- /dev/null +++ b/lib/classes/tts_engines/.template.py @@ -0,0 +1,232 @@ +import hashlib +import math +import os +import shutil +import subprocess +import tempfile +import threading +import uuid + +import numpy as np +import regex as re +import soundfile as sf +import torch +import torchaudio + +from huggingface_hub import hf_hub_download +from pathlib import Path +from pprint import pprint + +from lib import * +from lib.classes.tts_engines.common.utils import unload_tts, append_sentence2vtt +from lib.classes.tts_engines.common.audio_filters import detect_gender, trim_audio, normalize_audio, is_audio_data_valid + +#import logging +#logging.basicConfig(level=logging.DEBUG) + +lock = threading.Lock() + +class Coqui: + + def __init__(self, session): + try: + self.session = session + self.cache_dir = tts_dir + self.speakers_path = None + self.tts_key = f"{self.session['tts_engine']}-{self.session['fine_tuned']}" + self.tts_vc_key = default_vc_model.rsplit('/', 1)[-1] + self.is_bf16 = True if self.session['device'] == 'cuda' and torch.cuda.is_bf16_supported() == True else False + self.npz_path = None + self.npz_data = None + self.sentences_total_time = 0.0 + self.sentence_idx = 1 + self.params = {TTS_ENGINES['NEW_TTS']: {}} + self.params[self.session['tts_engine']]['samplerate'] = models[self.session['tts_engine']][self.session['fine_tuned']]['samplerate'] + self.vtt_path = os.path.join(self.session['process_dir'], os.path.splitext(self.session['final_name'])[0] + '.vtt') + self.resampler_cache = {} + self.audio_segments = [] + self._build() + except Exception as e: + error = f'__init__() error: {e}' + print(error) + return None + + def _build(self): + try: + tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False) + if not tts: + if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']: + if self.session['custom_model'] is not None: + msg = f"{self.session['tts_engine']} custom model not implemented yet!" + print(msg) + return False + else: + model_path = models[self.session['tts_engine']][self.session['fine_tuned']]['repo'] + tts = self._load_api(self.tts_key, model_path, self.session['device']) + return (loaded_tts.get(self.tts_key) or {}).get('engine', False) + except Exception as e: + error = f'build() error: {e}' + print(error) + return False + + def _load_api(self, key, model_path, device): + global lock + try: + if key in loaded_tts.keys(): + return loaded_tts[key]['engine'] + unload_tts(device, [self.tts_key, self.tts_vc_key]) + with lock: + tts = NEW_TTS(model_path) + if tts + if device == 'cuda': + NEW_TTS.WITH_CUDA + else: + NEW_TTS.WITHOUT_CUDA + loaded_tts[key] = {"engine": tts, "config": None} + msg = f'{model_path} Loaded!' + print(msg) + return tts + else: + error = 'TTS engine could not be created!' + print(error) + except Exception as e: + error = f'_load_api() error: {e}' + print(error) + return False + + def _load_checkpoint(self, **kwargs): + global lock + try: + key = kwargs.get('key') + if key in loaded_tts.keys(): + return loaded_tts[key]['engine'] + tts_engine = kwargs.get('tts_engine') + device = kwargs.get('device') + unload_tts(device, [self.tts_key]) + with lock: + checkpoint_dir = kwargs.get('checkpoint_dir') + NEW_TTS.LOAD_CHECKPOINT( + config, + checkpoint_dir=checkpoint_dir, + eval=True + ) + if tts: + if device == 'cuda': + NEW_TTS.WITH_CUDA + else: + NEW_TTS.WITHOUT_CUDA + loaded_tts[key] = {"engine": tts, "config": config} + msg = f'{tts_engine} Loaded!' + print(msg) + return tts + else: + error = 'TTS engine could not be created!' + print(error) + except Exception as e: + error = f'_load_checkpoint() error: {e}' + return False + + def _tensor_type(self, audio_data): + if isinstance(audio_data, torch.Tensor): + return audio_data + elif isinstance(audio_data, np.ndarray): + return torch.from_numpy(audio_data).float() + elif isinstance(audio_data, list): + return torch.tensor(audio_data, dtype=torch.float32) + else: + raise TypeError(f"Unsupported type for audio_data: {type(audio_data)}") + + def _get_resampler(self, orig_sr, target_sr): + key = (orig_sr, target_sr) + if key not in self.resampler_cache: + self.resampler_cache[key] = torchaudio.transforms.Resample( + orig_freq=orig_sr, new_freq=target_sr + ) + return self.resampler_cache[key] + + def _resample_wav(self, wav_path, expected_sr): + waveform, orig_sr = torchaudio.load(wav_path) + if orig_sr == expected_sr and waveform.size(0) == 1: + return wav_path + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if orig_sr != expected_sr: + resampler = self._get_resampler(orig_sr, expected_sr) + waveform = resampler(waveform) + wav_tensor = waveform.squeeze(0) + wav_numpy = wav_tensor.cpu().numpy() + tmp_fh = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp_fh.name + tmp_fh.close() + sf.write(tmp_path, wav_numpy, expected_sr, subtype="PCM_16") + return tmp_path + + def convert(self, sentence_number, sentence): + global xtts_builtin_speakers_list + try: + speaker = None + audio_data = False + trim_audio_buffer = 0.004 + settings = self.params[self.session['tts_engine']] + final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_number}.{default_audio_proc_format}') + sentence = sentence.strip() + settings['voice_path'] = ( + self.session['voice'] if self.session['voice'] is not None + else os.path.join(self.session['custom_model_dir'], self.session['tts_engine'], self.session['custom_model'], 'ref.wav') if self.session['custom_model'] is not None + else models[self.session['tts_engine']][self.session['fine_tuned']]['voice'] + ) + if settings['voice_path'] is not None: + speaker = re.sub(r'\.wav$', '', os.path.basename(settings['voice_path'])) + tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False) + if tts: + if sentence[-1].isalnum(): + sentence = f'{sentence} —' + if sentence == TTS_SML['break']: + break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100))) # 0.4 to 0.7 seconds + self.audio_segments.append(break_tensor.clone()) + return True + elif sentence == TTS_SML['pause']: + pause_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(1.0, 1.8) * 100) / 100))) # 1.0 to 1.8 seconds + self.audio_segments.append(pause_tensor.clone()) + return True + else: + if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']: + audio_sentence = NEW_TTS.CONVERT() # audio_sentence must be torch.Tensor or (list, tuple) or np.ndarray + if is_audio_data_valid(audio_sentence): + sourceTensor = self._tensor_type(audio_sentence) + audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if sentence[-1].isalnum() or sentence[-1] == '—': + audio_tensor = trim_audio(audio_tensor.squeeze(), settings['samplerate'], 0.003, trim_audio_buffer).unsqueeze(0) + self.audio_segments.append(audio_tensor) + if not re.search(r'\w$', sentence, flags=re.UNICODE): + break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100))) + self.audio_segments.append(break_tensor.clone()) + if self.audio_segments: + audio_tensor = torch.cat(self.audio_segments, dim=-1) + start_time = self.sentences_total_time + duration = audio_tensor.shape[-1] / settings['samplerate'] + end_time = start_time + duration + self.sentences_total_time = end_time + sentence_obj = { + "start": start_time, + "end": end_time, + "text": sentence, + "resume_check": self.sentence_idx + } + self.sentence_idx = append_sentence2vtt(sentence_obj, self.vtt_path) + if self.sentence_idx: + torchaudio.save(final_sentence_file, audio_tensor, settings['samplerate'], format=default_audio_proc_format) + del audio_tensor + self.audio_segments = [] + if os.path.exists(final_sentence_file): + return True + else: + error = f"Cannot create {final_sentence_file}" + print(error) + else: + error = f"convert() error: {self.session['tts_engine']} is None" + print(error) + except Exception as e: + error = f'Coquit.convert(): {e}' + raise ValueError(e) + return False \ No newline at end of file