mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Update VITS to use fetch helper (#2422)
* use fetch helper on vits * remove duplicate weight loading
This commit is contained in:
@@ -2,9 +2,8 @@ import json, logging, math, re, sys, time, wave, argparse, numpy as np
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from extra.utils import download_file
|
||||
from tinygrad import nn
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.helpers import dtypes, fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.tensor import Tensor
|
||||
from unidecode import unidecode
|
||||
@@ -516,10 +515,8 @@ class HParams:
|
||||
|
||||
# MODEL LOADING
|
||||
def load_model(symbols, hps, model) -> Synthesizer:
|
||||
weights_path = model[1]
|
||||
download_if_not_present(weights_path, model[3])
|
||||
net_g = Synthesizer(len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers = hps.data.n_speakers, **hps.model)
|
||||
_ = load_checkpoint(weights_path, net_g, None)
|
||||
_ = load_checkpoint(fetch(model[1]), net_g, None)
|
||||
return net_g
|
||||
def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_list=[]):
|
||||
assert Path(checkpoint_path).is_file()
|
||||
@@ -555,12 +552,6 @@ def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_li
|
||||
logging.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration}) in {time.time() - start_time:.4f}s")
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
def download_if_not_present(file_path: Path, url: str):
|
||||
if not file_path.is_file():
|
||||
logging.info(f"Did not find {file_path.as_posix()}, downloading...")
|
||||
download_file(url, file_path)
|
||||
return file_path
|
||||
|
||||
# Used for cleaning input text and mapping to symbols
|
||||
class TextMapper: # Based on https://github.com/keithito/tacotron
|
||||
def __init__(self, symbols, apply_cleaners=True):
|
||||
@@ -650,13 +641,13 @@ class TextMapper: # Based on https://github.com/keithito/tacotron
|
||||
# anime lady 2 | --model_to_use uma_trilingual --speaker_id 121
|
||||
#########################################################################################
|
||||
VITS_PATH = Path(__file__).parents[1] / "weights/VITS/"
|
||||
MODELS = { # config_path, weights_path, config_url, weights_url
|
||||
"ljs": (VITS_PATH / "config_ljs.json", VITS_PATH / "pretrained_ljs.pth", "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"),
|
||||
"vctk": (VITS_PATH / "config_vctk.json", VITS_PATH / "pretrained_vctk.pth", "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t"),
|
||||
"mmts-tts": (VITS_PATH / "config_mmts-tts.json", VITS_PATH / "pretrained_mmts-tts.pth", "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth"),
|
||||
"uma_trilingual": (VITS_PATH / "config_uma_trilingual.json", VITS_PATH / "pretrained_uma_trilingual.pth", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth"),
|
||||
"cjks": (VITS_PATH / "config_cjks.json", VITS_PATH / "pretrained_cjks.pth", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth"),
|
||||
"voistock": (VITS_PATH / "config_voistock.json", VITS_PATH / "pretrained_voistock.pth", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth"),
|
||||
MODELS = { # config_url, weights_url
|
||||
"ljs": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"),
|
||||
"vctk": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t"),
|
||||
"mmts-tts": ("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth"),
|
||||
"uma_trilingual": ("https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth"),
|
||||
"cjks": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth"),
|
||||
"voistock": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth"),
|
||||
}
|
||||
Y_LENGTH_ESTIMATE_SCALARS = {"ljs": 2.8, "vctk": 1.74, "mmts-tts": 1.9, "uma_trilingual": 2.3, "cjks": 3.3, "voistock": 3.1}
|
||||
if __name__ == '__main__':
|
||||
@@ -681,9 +672,7 @@ if __name__ == '__main__':
|
||||
model_config = MODELS[args.model_to_use]
|
||||
|
||||
# Load the hyperparameters from the config file.
|
||||
config_path = model_config[0]
|
||||
download_if_not_present(config_path, model_config[2])
|
||||
hps = get_hparams_from_file(config_path)
|
||||
hps = get_hparams_from_file(fetch(model_config[0]))
|
||||
|
||||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||
@@ -705,7 +694,7 @@ if __name__ == '__main__':
|
||||
|
||||
# Load symbols, instantiate TextMapper and clean the text.
|
||||
if hps.__contains__("symbols"): symbols = hps.symbols
|
||||
elif args.model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in open(download_if_not_present(VITS_PATH / "vocab_mmts-tts.txt", "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"), encoding="utf-8").readlines()]
|
||||
elif args.model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
|
||||
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user