Update VITS to use fetch helper (#2422)

* use fetch helper on vits

* remove duplicate weight loading
This commit is contained in:
Francis Lata
2023-11-24 11:50:03 -05:00
committed by GitHub
parent 857d440ea7
commit 7169de57e2

View File

@@ -2,9 +2,8 @@ import json, logging, math, re, sys, time, wave, argparse, numpy as np
from functools import reduce
from pathlib import Path
from typing import List
from extra.utils import download_file
from tinygrad import nn
from tinygrad.helpers import dtypes
from tinygrad.helpers import dtypes, fetch
from tinygrad.nn.state import torch_load
from tinygrad.tensor import Tensor
from unidecode import unidecode
@@ -516,10 +515,8 @@ class HParams:
# MODEL LOADING
def load_model(symbols, hps, model) -> Synthesizer:
weights_path = model[1]
download_if_not_present(weights_path, model[3])
net_g = Synthesizer(len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers = hps.data.n_speakers, **hps.model)
_ = load_checkpoint(weights_path, net_g, None)
_ = load_checkpoint(fetch(model[1]), net_g, None)
return net_g
def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_list=[]):
assert Path(checkpoint_path).is_file()
@@ -555,12 +552,6 @@ def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_li
logging.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration}) in {time.time() - start_time:.4f}s")
return model, optimizer, learning_rate, iteration
def download_if_not_present(file_path: Path, url: str):
if not file_path.is_file():
logging.info(f"Did not find {file_path.as_posix()}, downloading...")
download_file(url, file_path)
return file_path
# Used for cleaning input text and mapping to symbols
class TextMapper: # Based on https://github.com/keithito/tacotron
def __init__(self, symbols, apply_cleaners=True):
@@ -650,13 +641,13 @@ class TextMapper: # Based on https://github.com/keithito/tacotron
# anime lady 2 | --model_to_use uma_trilingual --speaker_id 121
#########################################################################################
VITS_PATH = Path(__file__).parents[1] / "weights/VITS/"
MODELS = { # config_path, weights_path, config_url, weights_url
"ljs": (VITS_PATH / "config_ljs.json", VITS_PATH / "pretrained_ljs.pth", "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"),
"vctk": (VITS_PATH / "config_vctk.json", VITS_PATH / "pretrained_vctk.pth", "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t"),
"mmts-tts": (VITS_PATH / "config_mmts-tts.json", VITS_PATH / "pretrained_mmts-tts.pth", "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth"),
"uma_trilingual": (VITS_PATH / "config_uma_trilingual.json", VITS_PATH / "pretrained_uma_trilingual.pth", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth"),
"cjks": (VITS_PATH / "config_cjks.json", VITS_PATH / "pretrained_cjks.pth", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth"),
"voistock": (VITS_PATH / "config_voistock.json", VITS_PATH / "pretrained_voistock.pth", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth"),
MODELS = { # config_url, weights_url
"ljs": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"),
"vctk": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t"),
"mmts-tts": ("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth"),
"uma_trilingual": ("https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth"),
"cjks": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth"),
"voistock": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth"),
}
Y_LENGTH_ESTIMATE_SCALARS = {"ljs": 2.8, "vctk": 1.74, "mmts-tts": 1.9, "uma_trilingual": 2.3, "cjks": 3.3, "voistock": 3.1}
if __name__ == '__main__':
@@ -681,9 +672,7 @@ if __name__ == '__main__':
model_config = MODELS[args.model_to_use]
# Load the hyperparameters from the config file.
config_path = model_config[0]
download_if_not_present(config_path, model_config[2])
hps = get_hparams_from_file(config_path)
hps = get_hparams_from_file(fetch(model_config[0]))
# If model has multiple speakers, validate speaker id and retrieve name if available.
model_has_multiple_speakers = hps.data.n_speakers > 0
@@ -705,7 +694,7 @@ if __name__ == '__main__':
# Load symbols, instantiate TextMapper and clean the text.
if hps.__contains__("symbols"): symbols = hps.symbols
elif args.model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in open(download_if_not_present(VITS_PATH / "vocab_mmts-tts.txt", "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"), encoding="utf-8").readlines()]
elif args.model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'")
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)