diff --git a/lib/classes/tts_engines/bark.py b/lib/classes/tts_engines/bark.py index b8f153d1..c8147fab 100644 --- a/lib/classes/tts_engines/bark.py +++ b/lib/classes/tts_engines/bark.py @@ -9,7 +9,6 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): self.cache_dir = tts_dir self.speakers_path = None self.tts_key = self.session['model_cache'] - self.tts_zs_key = default_vc_model.rsplit('/',1)[-1] self.pth_voice_file = None self.sentences_total_time = 0.0 self.sentence_idx = 1 @@ -30,7 +29,6 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): self._apply_cuda_policy(using_gpu=using_gpu, enough_vram=enough_vram, seed=seed) self.xtts_speakers = self._load_xtts_builtin_list() self.engine = self._load_engine() - self.engine_zs = self._load_engine_zs() except Exception as e: error = f'__init__() error: {e}' raise ValueError(error) @@ -123,7 +121,6 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): print(msg) return False if self.engine: - device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_index}.{default_audio_proc_format}') if sentence == TTS_SML['break']: silence_time = int(np.random.uniform(0.3, 0.6) * 100) / 100 @@ -186,6 +183,7 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): **fine_tuned_params ) """ + device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] self.engine.to(device) audio_sentence = self.engine.tts( text=sentence, @@ -194,6 +192,7 @@ class Bark(TTSUtils, TTSRegistry, name='bark'): **tts_dyn_params, **fine_tuned_params ) + self.engine.to('cpu') if is_audio_data_valid(audio_sentence): src_tensor = self._tensor_type(audio_sentence) audio_tensor = src_tensor.clone().detach().unsqueeze(0).cpu() diff --git a/lib/classes/tts_engines/common/utils.py b/lib/classes/tts_engines/common/utils.py index 50a9c1b4..99ae8e6a 100644 --- a/lib/classes/tts_engines/common/utils.py +++ b/lib/classes/tts_engines/common/utils.py @@ -203,6 +203,8 @@ class TTSUtils: if self.session.get(key) is not None } with torch.no_grad(): + device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] + engine.to(device) result = engine.inference( text=default_text.strip(), language=self.session['language_iso1'], @@ -210,6 +212,7 @@ class TTSUtils: speaker_embedding=speaker_embedding, **fine_tuned_params, ) + engine.to('cpu') audio_sentence = result.get('wav') if is_audio_data_valid(audio_sentence): sourceTensor = self._tensor_type(audio_sentence) diff --git a/lib/classes/tts_engines/fairseq.py b/lib/classes/tts_engines/fairseq.py index 2c5f5d97..279d3b38 100644 --- a/lib/classes/tts_engines/fairseq.py +++ b/lib/classes/tts_engines/fairseq.py @@ -107,6 +107,7 @@ class Fairseq(TTSUtils, TTSRegistry, name='fairseq'): file_path=tmp_in_wav, **speaker_argument ) + self.engine.to('cpu') if self.params['voice_path'] in self.params['semitones'].keys(): semitones = self.params['semitones'][self.params['voice_path']] else: @@ -150,6 +151,7 @@ class Fairseq(TTSUtils, TTSRegistry, name='fairseq'): source_wav=source_wav, target_wav=target_wav ) + self.engine_zs.to('cpu') else: error = f'Engine {self.tts_zs_key} is None' print(error) @@ -167,6 +169,7 @@ class Fairseq(TTSUtils, TTSRegistry, name='fairseq'): text=re.sub(not_supported_punc_pattern, ' ', sentence), **speaker_argument ) + self.engine.to('cpu') if is_audio_data_valid(audio_sentence): src_tensor = self._tensor_type(audio_sentence) audio_tensor = src_tensor.clone().detach().unsqueeze(0).cpu() diff --git a/lib/classes/tts_engines/tacotron.py b/lib/classes/tts_engines/tacotron.py index 0347c8e6..6095a59c 100644 --- a/lib/classes/tts_engines/tacotron.py +++ b/lib/classes/tts_engines/tacotron.py @@ -135,6 +135,7 @@ class Tacotron2(TTSUtils, TTSRegistry, name='tacotron'): file_path=tmp_in_wav, **speaker_argument ) + self.engine.to('cpu') if self.params['voice_path'] in self.params['semitones'].keys(): semitones = self.params['semitones'][self.params['voice_path']] else: @@ -178,6 +179,7 @@ class Tacotron2(TTSUtils, TTSRegistry, name='tacotron'): source_wav=source_wav, target_wav=target_wav ) + self.engine_zs.to('cpu') else: error = f'Engine {self.tts_zs_key} is None' print(error) @@ -195,6 +197,7 @@ class Tacotron2(TTSUtils, TTSRegistry, name='tacotron'): text=re.sub(not_supported_punc_pattern, ' ', sentence), **speaker_argument ) + self.engine.to('cpu') if is_audio_data_valid(audio_sentence): src_tensor = self._tensor_type(audio_sentence) audio_tensor = src_tensor.clone().detach().unsqueeze(0).cpu() diff --git a/lib/classes/tts_engines/vits.py b/lib/classes/tts_engines/vits.py index e04be172..40c3c534 100644 --- a/lib/classes/tts_engines/vits.py +++ b/lib/classes/tts_engines/vits.py @@ -120,6 +120,7 @@ class Vits(TTSUtils, TTSRegistry, name='vits'): file_path=tmp_in_wav, **speaker_argument ) + self.engine.to('cpu') if self.params['voice_path'] in self.params['semitones'].keys(): semitones = self.params['semitones'][self.params['voice_path']] else: @@ -163,6 +164,7 @@ class Vits(TTSUtils, TTSRegistry, name='vits'): source_wav=source_wav, target_wav=target_wav ) + self.engine_zs.to('cpu') else: error = f'Engine {self.tts_zs_key} is None' print(error) @@ -180,6 +182,7 @@ class Vits(TTSUtils, TTSRegistry, name='vits'): text=sentence, **speaker_argument ) + self.engine.to('cpu') if is_audio_data_valid(audio_sentence): src_tensor = self._tensor_type(audio_sentence) audio_tensor = src_tensor.clone().detach().unsqueeze(0).cpu() diff --git a/lib/classes/tts_engines/xtts.py b/lib/classes/tts_engines/xtts.py index 42d8738e..57b50ce3 100644 --- a/lib/classes/tts_engines/xtts.py +++ b/lib/classes/tts_engines/xtts.py @@ -84,7 +84,6 @@ class XTTSv2(TTSUtils, TTSRegistry, name='xtts'): print(msg) return False if self.engine: - device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_index}.{default_audio_proc_format}') if sentence == TTS_SML['break']: silence_time = int(np.random.uniform(0.3, 0.6) * 100) / 100 @@ -131,6 +130,7 @@ class XTTSv2(TTSUtils, TTSRegistry, name='xtts'): if self.session.get(key) is not None } with torch.no_grad(): + device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] self.engine.to(device) result = self.engine.inference( text=sentence, @@ -139,6 +139,7 @@ class XTTSv2(TTSUtils, TTSRegistry, name='xtts'): speaker_embedding=self.params['speaker_embedding'], **fine_tuned_params ) + self.engine.to('cpu') audio_sentence = result.get('wav') if is_audio_data_valid(audio_sentence): src_tensor = self._tensor_type(audio_sentence) diff --git a/lib/classes/tts_engines/yourtts.py b/lib/classes/tts_engines/yourtts.py index 9f03793d..18db70e3 100644 --- a/lib/classes/tts_engines/yourtts.py +++ b/lib/classes/tts_engines/yourtts.py @@ -9,7 +9,6 @@ class YourTTS(TTSUtils, TTSRegistry, name='yourtts'): self.cache_dir = tts_dir self.speakers_path = None self.tts_key = self.session['model_cache'] - self.tts_zs_key = default_vc_model.rsplit('/',1)[-1] self.pth_voice_file = None self.sentences_total_time = 0.0 self.sentence_idx = 1 @@ -30,7 +29,6 @@ class YourTTS(TTSUtils, TTSRegistry, name='yourtts'): self._apply_cuda_policy(using_gpu=using_gpu, enough_vram=enough_vram, seed=seed) self.xtts_speakers = self._load_xtts_builtin_list() self.engine = self._load_engine() - self.engine_zs = self._load_engine_zs() except Exception as e: error = f'__init__() error: {e}' raise ValueError(error) @@ -75,7 +73,6 @@ class YourTTS(TTSUtils, TTSRegistry, name='yourtts'): print(msg) return False if self.engine: - device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_index}.{default_audio_proc_format}') if sentence == TTS_SML['break']: silence_time = int(np.random.uniform(0.3, 0.6) * 100) / 100 @@ -102,12 +99,14 @@ class YourTTS(TTSUtils, TTSRegistry, name='yourtts'): voice_key = default_engine_settings[self.session['tts_engine']]['voices']['ElectroMale-2'] speaker_argument = {"speaker": voice_key} with torch.no_grad(): + device = devices['CUDA']['proc'] if self.session['device'] in ['cuda', 'jetson'] else self.session['device'] self.engine.to(device) audio_sentence = self.engine.tts( text=re.sub(not_supported_punc_pattern, ' ', sentence), language=language, **speaker_argument ) + self.engine.to('cpu') if is_audio_data_valid(audio_sentence): src_tensor = self._tensor_type(audio_sentence) audio_tensor = src_tensor.clone().detach().unsqueeze(0).cpu()