diff --git a/lib/classes/tts_engines/bark.py b/lib/classes/tts_engines/bark.py index 92c5f294..b8c67b89 100644 --- a/lib/classes/tts_engines/bark.py +++ b/lib/classes/tts_engines/bark.py @@ -165,6 +165,7 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): """ pth_voice_dir = os.path.join(bark_dir, speaker) pth_voice_file = os.path.join(bark_dir, speaker, f'{speaker}.pth') + self.engine.synthesizer.voice_dir = pth_voice_dir tts_dyn_params = {} if not os.path.exists(pth_voice_file) or speaker not in self.engine.speakers: tts_dyn_params['speaker_wav'] = self.params['voice_path'] @@ -197,8 +198,16 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): #if is_audio_data_valid(audio_sentence): # audio_sentence = audio_sentence.tolist() if is_audio_data_valid(audio_sentence): - sourceTensor = self._tensor_type(audio_sentence) - audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported Bark wav type: {type(audio_sentence)}" + print(error) + return False if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), self.params['samplerate'], 0.001, trim_audio_buffer).unsqueeze(0) if audio_tensor is not None and audio_tensor.numel() > 0: @@ -231,6 +240,10 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): error = f"Cannot create {final_sentence_file}" print(error) return False + else: + error = f"audio_tensor not valid" + print(error) + return False else: error = f"audio_sentence not valid" print(error) diff --git a/lib/classes/tts_engines/common/utils.py b/lib/classes/tts_engines/common/utils.py index a297ddb6..4fae9aaf 100644 --- a/lib/classes/tts_engines/common/utils.py +++ b/lib/classes/tts_engines/common/utils.py @@ -207,7 +207,16 @@ class TTSUtils: speaker_embedding=speaker_embedding, **fine_tuned_params, ) - audio_sentence = result.get('wav') if isinstance(result, dict) else None + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported XTTSv2 wav type: {type(audio_sentence)}" + print(error) + return False if audio_sentence is not None: audio_sentence = audio_sentence.tolist() sourceTensor = self._tensor_type(audio_sentence) diff --git a/lib/classes/tts_engines/fairseq.py b/lib/classes/tts_engines/fairseq.py index 91637f42..9ae232d9 100644 --- a/lib/classes/tts_engines/fairseq.py +++ b/lib/classes/tts_engines/fairseq.py @@ -166,8 +166,16 @@ class Fairseq(TTSUtils, TTSRegistry, name='fairseq'): **speaker_argument ) if is_audio_data_valid(audio_sentence): - sourceTensor = self._tensor_type(audio_sentence) - audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported Fairseq wav type: {type(audio_sentence)}" + print(error) + return False if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), self.params['samplerate'], 0.001, trim_audio_buffer).unsqueeze(0) if audio_tensor is not None and audio_tensor.numel() > 0: @@ -200,6 +208,10 @@ class Fairseq(TTSUtils, TTSRegistry, name='fairseq'): error = f"Cannot create {final_sentence_file}" print(error) return False + else: + error = f"audio_tensor not valid" + print(error) + return False else: error = f"audio_sentence not valid" print(error) diff --git a/lib/classes/tts_engines/tacotron.py b/lib/classes/tts_engines/tacotron.py index 3238de37..0b9a15ca 100644 --- a/lib/classes/tts_engines/tacotron.py +++ b/lib/classes/tts_engines/tacotron.py @@ -194,8 +194,16 @@ class Tacotron2(TTSUtils, TTSRegistry, name='tacotron'): **speaker_argument ) if is_audio_data_valid(audio_sentence): - sourceTensor = self._tensor_type(audio_sentence) - audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported Tacotron2 wav type: {type(audio_sentence)}" + print(error) + return False if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), self.params['samplerate'], 0.001, trim_audio_buffer).unsqueeze(0) if audio_tensor is not None and audio_tensor.numel() > 0: @@ -228,6 +236,10 @@ class Tacotron2(TTSUtils, TTSRegistry, name='tacotron'): error = f"Cannot create {final_sentence_file}" print(error) return False + else: + error = f"audio_tensor not valid" + print(error) + return False else: error = f"audio_sentence not valid" print(error) diff --git a/lib/classes/tts_engines/vits.py b/lib/classes/tts_engines/vits.py index 7d6654b7..e902c525 100644 --- a/lib/classes/tts_engines/vits.py +++ b/lib/classes/tts_engines/vits.py @@ -179,8 +179,16 @@ class Vits(TTSUtils, TTSRegistry, name='vits'): **speaker_argument ) if is_audio_data_valid(audio_sentence): - sourceTensor = self._tensor_type(audio_sentence) - audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported Vits wav type: {type(audio_sentence)}" + print(error) + return False if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), self.params['samplerate'], 0.001, trim_audio_buffer).unsqueeze(0) if audio_tensor is not None and audio_tensor.numel() > 0: @@ -213,6 +221,10 @@ class Vits(TTSUtils, TTSRegistry, name='vits'): error = f"Cannot create {final_sentence_file}" print(error) return False + else: + error = f"audio_tensor not valid" + print(error) + return False else: error = f"audio_sentence not valid" print(error) diff --git a/lib/classes/tts_engines/xtts.py b/lib/classes/tts_engines/xtts.py index 0c9beaa7..fc8a4b21 100644 --- a/lib/classes/tts_engines/xtts.py +++ b/lib/classes/tts_engines/xtts.py @@ -141,10 +141,16 @@ class XTTSv2(TTSUtils, TTSRegistry, name='xtts'): ) audio_sentence = result.get('wav') if is_audio_data_valid(audio_sentence): - audio_sentence = audio_sentence.tolist() - if is_audio_data_valid(audio_sentence): - sourceTensor = self._tensor_type(audio_sentence) - audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported XTTSv2 wav type: {type(audio_sentence)}" + print(error) + return False if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), self.params['samplerate'], 0.001, trim_audio_buffer).unsqueeze(0) if audio_tensor is not None and audio_tensor.numel() > 0: @@ -177,6 +183,10 @@ class XTTSv2(TTSUtils, TTSRegistry, name='xtts'): error = f"Cannot create {final_sentence_file}" print(error) return False + else: + error = f"audio_tensor not valid" + print(error) + return False else: error = f"audio_sentence not valid" print(error) diff --git a/lib/classes/tts_engines/yourtts.py b/lib/classes/tts_engines/yourtts.py index e185e5d4..6caaff59 100644 --- a/lib/classes/tts_engines/yourtts.py +++ b/lib/classes/tts_engines/yourtts.py @@ -109,8 +109,16 @@ class YourTTS(TTSUtils, TTSRegistry, name='yourtts'): **speaker_argument ) if is_audio_data_valid(audio_sentence): - sourceTensor = self._tensor_type(audio_sentence) - audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() + if isinstance(audio_sentence, torch.Tensor): + audio_tensor = audio_sentence.detach().cpu().unsqueeze(0) + elif isinstance(audio_sentence, np.ndarray): + audio_tensor = torch.from_numpy(audio_sentence).unsqueeze(0) + elif isinstance(audio_sentence, (list, tuple)): + audio_tensor = torch.tensor(audio_sentence, dtype=torch.float32).unsqueeze(0) + else: + error = f"Unsupported YourTTS wav type: {type(audio_sentence)}" + print(error) + return False if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), self.params['samplerate'], 0.001, trim_audio_buffer).unsqueeze(0) if audio_tensor is not None and audio_tensor.numel() > 0: @@ -143,6 +151,10 @@ class YourTTS(TTSUtils, TTSRegistry, name='yourtts'): error = f"Cannot create {final_sentence_file}" print(error) return False + else: + error = f"audio_tensor not valid" + print(error) + return False else: error = f"audio_sentence not valid" print(error) diff --git a/lib/conf.py b/lib/conf.py index 3d8cf4d7..e61c5a7f 100644 --- a/lib/conf.py +++ b/lib/conf.py @@ -45,14 +45,16 @@ os.environ['MPLCONFIGDIR'] = f'{models_dir}/matplotlib' os.environ['TESSDATA_PREFIX'] = f'{models_dir}/tessdata' os.environ['STANZA_RESOURCES_DIR'] = os.path.join(models_dir, 'stanza') os.environ['ARGOS_TRANSLATE_PACKAGE_PATH'] = os.path.join(models_dir, 'argostranslate') +os.environ['MallocStackLogging'] = '0' +os.environ['MallocStackLoggingNoCompact'] = '0' os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1' os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' os.environ['TORCH_CUDA_ENABLE_CUDA_GRAPH'] = '0' os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.6,expandable_segments:True' -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" -os.environ["CUDA_CACHE_MAXSIZE"] = "2147483648" +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +os.environ['CUDA_CACHE_MAXSIZE'] = '2147483648' os.environ['SUNO_OFFLOAD_CPU'] = 'False' os.environ['SUNO_USE_SMALL_MODELS'] = 'False' if platform.system() == 'Windows':