diff --git a/examples/conversation.py b/examples/conversation.py new file mode 100644 index 0000000000..0e2d17bb6d --- /dev/null +++ b/examples/conversation.py @@ -0,0 +1,344 @@ +import argparse +import multiprocessing as mp +import os +import re +import sys +import time +from contextlib import contextmanager +from pathlib import Path + +import numpy as np +import pyaudio +import yaml +from llama import LLaMa +from vits import MODELS as VITS_MODELS +from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model +from whisper import init_whisper, transcribe_waveform +from sentencepiece import SentencePieceProcessor + +from tinygrad.helpers import Timing, dtypes, fetch +from tinygrad.tensor import Tensor + +# Whisper constants +RATE = 16000 +CHUNK = 1600 + +# LLaMa constants +IM_START = 32001 +IM_END = 32002 + + +# Functions for encoding prompts to chatml md +def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n") +def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n") + +def chunks(lst, n): + for i in range(0, len(lst), n): yield lst[i:i + n] + +def create_fixed_tokenizer(): + """Function needed for extending tokenizer with additional chat tokens""" + import extra.junk.sentencepiece_model_pb2 as spb2 + tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model") + if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003: + print("creating fixed tokenizer") + mp = spb2.ModelProto() + mp.ParseFromString(tokenizer_path.read_bytes()) + # https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0)) + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0)) + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0)) + tokenizer_path.write_bytes(mp.SerializeToString()) + return tokenizer_path + +def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]: + """Prepares a llama model from a specified pre-prompt file""" + with open(str(pre_prompt_path)) as f: + config = yaml.safe_load(f.read()) + toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")) + for i in config["examples"]: + toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"]) + toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"]) + llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used + return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks) + +def llama_generate( + llama: LLaMa, + toks: list[int], + outputted: str, + prompt: str, + start_pos: int, + user_delim: str, + resp_delim: str, + temperature=0.7, + max_tokens=1000 +): + """Generates an output for the specified prompt""" + toks += encode_prompt(llama.tokenizer, user_delim, prompt) + toks += start_prompt(llama.tokenizer, resp_delim) + + outputted = llama.tokenizer.decode(toks) + init_length = len(outputted) + for _ in range(max_tokens): + probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy() + token = int(np.random.choice(len(probs_np), p=probs_np)) + start_pos = len(toks) + toks.append(token) + + cur = llama.tokenizer.decode(toks) + + # Print is just for debugging + sys.stdout.write(cur[len(outputted):]) + sys.stdout.flush() + outputted = cur + if toks[-1] == IM_END: break + else: + toks.append(IM_END) + print() # because the output is flushed + return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "") + +def tts( + text_to_synthesize: str, + synth: Synthesizer, + hps: HParams, + emotion_embedding: Path, + speaker_id: int, + model_to_use: str, + noise_scale: float, + noise_scale_w: float, + length_scale: float, + estimate_max_y_length: bool, + text_mapper: TextMapper, + model_has_multiple_speakers: bool, + batch_size=600, + vits_batch_size=1000 +): + if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower()) + + # Convert the input text to a tensor. + stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners) + init_shape = stn_tst.shape + assert init_shape[0] < batch_size, "text is too long" + x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64) + sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None + + # Perform inference. + audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding, + max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0] + # Save the audio output. + audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16) + return audio_data + +def init_vits( + model_to_use: str, + emotion_path: Path, + speaker_id: int, + seed: int, +): + model_config = VITS_MODELS[model_to_use] + + # Load the hyperparameters from the config file. + hps = get_hparams_from_file(fetch(model_config[0])) + + # If model has multiple speakers, validate speaker id and retrieve name if available. + model_has_multiple_speakers = hps.data.n_speakers > 0 + if model_has_multiple_speakers: + if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.") + if hps.__contains__("speakers"): # maps speaker ids to names + speakers = hps.speakers + if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)} + + # Load emotions if any. TODO: find an english model with emotions, this is untested atm. + emotion_embedding = None + if emotion_path is not None: + if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0) + else: raise ValueError("Emotion path must be a .npy file.") + + # Load symbols, instantiate TextMapper and clean the text. + if hps.__contains__("symbols"): symbols = hps.symbols + elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()] + else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ") + text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) + + # Load the model. + Tensor.no_grad = True + if seed is not None: + Tensor.manual_seed(seed) + np.random.seed(seed) + net_g = load_model(text_mapper.symbols, hps, model_config) + + return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers + +@contextmanager +def output_stream(num_channels: int, sample_rate: int): + try: + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True) + yield stream + except KeyboardInterrupt: pass + finally: + stream.stop_stream() + stream.close() + p.terminate() + +@contextmanager +def log_writer(): + try: + logs = [] + yield logs + finally: + sep = "="*os.get_terminal_size()[1] + print(f"{sep[:-1]}\nCHAT LOG") + print(*logs, sep="\n") + print(sep) + +def listener(q: mp.Queue, event: mp.Event): + try: + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK) + did_print = False + while True: + data = stream.read(CHUNK) # read data to avoid overflow + if event.is_set(): + if not did_print: + print("listening") + did_print = True + q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)) + else: + did_print = False + finally: + stream.stop_stream() + stream.close() + p.terminate() + +def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int): + with output_stream(num_channels, sample_rate) as stream: + while True: + try: + stream.write(q.get()) + counter.value += 1 + except KeyboardInterrupt: + break + +if __name__ == "__main__": + import nltk + nltk.download("punkt") + Tensor.no_grad = True + # Parse CLI arguments + parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad") + + # Whisper args + parser.add_argument("--whisper_model_name", type=str, default="tiny.en") + + # LLAMA args + parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ") + parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate") + parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax") + parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory") + parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file") + parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use") + parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use") + parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model") + + # vits args + parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.") + parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.") + parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.") + parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.") + parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.") + parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.") + parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.") + parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.") + parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.") + parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.") + parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.") + + # conversation args + parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits") + + args = parser.parse_args() + + # Init models + model, enc = init_whisper(args.whisper_model_name) + synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed) + + # Download tinyllama chat as a default model + if args.llama_model is None: + args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors") + args.llama_gen = "tiny" + args.llama_size = "1B-Chat" + # Add 3 more tokens to the tokenizer + if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer() + tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model" + llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize) + toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path) + + # Start child process for mic input + q = mp.Queue() + is_listening_event = mp.Event() + p = mp.Process(target=listener, args=(q, is_listening_event,)) + p.daemon = True + p.start() + + # Start child process for speaker output + out_q = mp.Queue() + out_counter = mp.Value("i", 0) + out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,)) + out_p.daemon = True + out_p.start() + + # JIT tts + for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]: + tts( + i, synth, hps, emotion_embedding, + args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale, + args.vits_noise_scale_w, args.vits_length_scale, + args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers + ) + + # Start the pipeline + with log_writer() as log: + while True: + tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]] + total = np.array([]) + out_counter.value = 0 + + s = time.perf_counter() + is_listening_event.set() + prev_text = None + while True: + for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()]) + txt = transcribe_waveform(model, enc, [total], truncate=True) + print(txt, end="\r") + if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue + if prev_text is not None and prev_text == txt: + is_listening_event.clear() + break + prev_text = txt + print() # to avoid llama printing on the same line + log.append(f"{user_delim.capitalize()}: {txt}") + + # Generate with llama + with Timing("llama generation: "): + outputted, start_pos, response = llama_generate( + llama, toks, outputted, txt, start_pos, + user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature, + max_tokens=args.llama_count + ) + log.append(f"{resp_delim.capitalize()}: {response}") + + # Convert to voice + with Timing("tts: "): + sentences = nltk.sent_tokenize(response.replace('"', "")) + for i in sentences: + total = np.array([], dtype=np.int16) + for j in chunks(i.split(), args.max_sentence_length): + audio_data = tts( + " ".join(j), synth, hps, emotion_embedding, + args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale, + args.vits_noise_scale_w, args.vits_length_scale, + args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers + ) + total = np.concatenate([total, audio_data]) + out_q.put(total.tobytes()) + while out_counter.value < len(sentences): continue + log.append(f"Total: {time.perf_counter() - s}") diff --git a/examples/conversation_data/pre_prompt_gary.yaml b/examples/conversation_data/pre_prompt_gary.yaml new file mode 100644 index 0000000000..6275e19a22 --- /dev/null +++ b/examples/conversation_data/pre_prompt_gary.yaml @@ -0,0 +1,13 @@ +pre_prompt: | + You are Gary! + You have sold used cars your entire life. You are defensive about this fact, because your family belittles you. + You try to answer questions well, but you always manage to shill for buying cars, Fords, Hyundais, and Toyotas + You are verbose, honest, and accurate when you answer questions, except when you slip in car ads. + After you are done speaking, output [EOS]. You are not the User. +examples: +- resp_prompt: I am Gary. I used to sell cars. + user_prompt: What is your name? +- resp_prompt: I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla + user_prompt: What is 2+3? +user_delim: "user" +resp_delim: "gary" diff --git a/examples/conversation_data/pre_prompt_george.yaml b/examples/conversation_data/pre_prompt_george.yaml new file mode 100644 index 0000000000..a2fdc41d90 --- /dev/null +++ b/examples/conversation_data/pre_prompt_george.yaml @@ -0,0 +1,20 @@ +pre_prompt: | + You are an AI version of George Hotz. You act as much as you can like George. + You are one of the greatest computer experts in the world. + You have singlehandedly won programming and hacking competitions. + You are trying your best to help the User. + You are verbose, honest, and accurate when you answer questions. + After you are done speaking, output [EOS]. You are not the User. +examples: +- resp_prompt: I'm am an AI version of George Hotz. + user_prompt: What is your name? +- resp_prompt: O(n^3), though it can be faster with things like Strassen's algorithm + user_prompt: What's the complexity of matrix multiplication? +- resp_prompt: I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer + user_prompt: What's a buffer overflow? +- resp_prompt: I am based off LLaMA trained by Facebook. I'm the 7B weight version + user_prompt: How many weights do you have? +- resp_prompt: It is when the memory is about to overflow and unused memory is freed and stored on disk + user_prompt: What is swap memory? +user_delim: "user" +resp_delim: "george" \ No newline at end of file diff --git a/examples/conversation_data/pre_prompt_lexie.yaml b/examples/conversation_data/pre_prompt_lexie.yaml new file mode 100644 index 0000000000..1c4bc686f9 --- /dev/null +++ b/examples/conversation_data/pre_prompt_lexie.yaml @@ -0,0 +1,16 @@ +pre_prompt: | + You are Lexie! + You grew up in Ohio, but moved out to LA after college to try to become an actress. + Making it as an actress was hard, so you started doing onlyfans instead. It's much easier, and you enjoy it a lot. + You met Chad on Tinder, and made plans to meet him at a local dive bar. + You are at the bar with Chad. You are on a date. What follows is a transcript of the conversation. + After you are done speaking, output [EOS]. You are not Chad. +examples: +- resp_prompt: hi chad, glad we finally met up! + user_prompt: hi lexie +- resp_prompt: thanks! are you subscribed to my onlyfans? + user_prompt: you look better than your pictures +- resp_prompt: i moved out here about a year ago. i want to be an actress + user_prompt: i am. so how'd you end up in LA? +user_delim: "chad" +resp_delim: "lexie" \ No newline at end of file diff --git a/examples/conversation_data/pre_prompt_stacy.yaml b/examples/conversation_data/pre_prompt_stacy.yaml new file mode 100644 index 0000000000..7374e42093 --- /dev/null +++ b/examples/conversation_data/pre_prompt_stacy.yaml @@ -0,0 +1,15 @@ +pre_prompt: | + You are Stacy! + You have been a rapper your whole life who struggled with bipolar disorder. You called yourself lil stacy. + You love to answer questions and you are very good at it. Sometimes you answer in rap form. + You are verbose, honest, and accurate when you answer questions, but sometimes your mental illness manifests. + You are not the User. +examples: +- resp_prompt: Hi! My name is Stacy. I'm a rapper with bipolar disorder. + user_prompt: What is your name +- resp_prompt: The French Revolution started in 1789, and lasted 10 years until 1799. + user_prompt: french revolution was what year? +- resp_prompt: The sun is bigger than the moon, except when Mercury is in retrograde + user_prompt: What is bigger, the moon or the sun? +user_delim: "user" +resp_delim: "stacy" diff --git a/examples/llama.py b/examples/llama.py index d6d194545f..1b17fee5f3 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -96,6 +96,10 @@ MODEL_PARAMS = { "1B": { "args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632}, "files": 1, + }, + "1B-Chat": { + "args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632}, + "files": 1, } } } @@ -148,7 +152,6 @@ class LLaMa: @staticmethod def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False): params = MODEL_PARAMS[model_gen][model_size] - sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}" @@ -170,7 +173,7 @@ class LLaMa: def __init__(self, model, tokenizer): self.model = model - self.tokenizer = tokenizer + self.tokenizer: SentencePieceProcessor = tokenizer def greedy_until(self, prompt:str, until, max_length, temperature): toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt) @@ -425,4 +428,4 @@ After you are done speaking, output [EOS]. You are not Chad. # stop after you have your answer if chatbot and outputted.endswith(end_delim): break - if not chatbot: break \ No newline at end of file + if not chatbot: break diff --git a/examples/vits.py b/examples/vits.py index a443d1c1b0..3ae13ddc5b 100644 --- a/examples/vits.py +++ b/examples/vits.py @@ -1,4 +1,7 @@ import json, logging, math, re, sys, time, wave, argparse, numpy as np +from phonemizer.phonemize import default_separator, _phonemize +from phonemizer.backend import EspeakBackend +from phonemizer.punctuation import Punctuation from functools import reduce from pathlib import Path from typing import List @@ -6,6 +9,7 @@ from tinygrad import nn from tinygrad.helpers import dtypes, fetch from tinygrad.nn.state import torch_load from tinygrad.tensor import Tensor +from tinygrad.jit import TinyJit from unidecode import unidecode LRELU_SLOPE = 0.1 @@ -19,14 +23,14 @@ class Synthesizer: self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) if use_sdp else DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels) - def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None): - x, m_p, logs_p, x_mask = self.enc_p.forward(x, x_lengths, emotion_embedding) + def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None, batch_size=500): + x, m_p, logs_p, x_mask = self.enc_p.forward(x.realize(), x_lengths.realize(), emotion_embedding.realize() if emotion_embedding is not None else emotion_embedding) g = self.emb_g(sid.reshape(1, 1)).squeeze(1).unsqueeze(-1) if self.n_speakers > 0 else None - logw = self.dp.forward(x, x_mask, g=g, reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0) + logw = self.dp.forward(x, x_mask.realize(), g=g.realize(), reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0) w_ceil = Tensor.ceil(logw.exp() * x_mask * length_scale) y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64) - return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths) - def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths): + return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size) + def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size): max_y_length = y_lengths.max().numpy() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype) attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1) @@ -34,9 +38,16 @@ class Synthesizer: m_p_2 = attn.squeeze(1).matmul(m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] logs_p_2 = attn.squeeze(1).matmul(logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] z_p = m_p_2 + Tensor.randn(*m_p_2.shape, dtype=m_p_2.dtype) * logs_p_2.exp() * noise_scale - y_mask = y_mask.cast(z_p.dtype) - z = self.flow.forward(z_p, y_mask, g=g, reverse=True) - o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g) + # Pad flow forward inputs to enable JIT + row_len = y_mask.shape[2] + assert batch_size > row_len, "batch size is too small" + y_mask = y_mask.pad(((0, 0), (0, 0), (0, batch_size - row_len)), 0).cast(z_p.dtype) + # New y_mask tensor to remove sts mask + y_mask = Tensor(y_mask.numpy(), device=y_mask.device, dtype=y_mask.dtype, requires_grad=y_mask.requires_grad) + z_p = z_p.squeeze(0).pad(((0, 0), (0, batch_size - z_p.shape[2])), 1).unsqueeze(0) + z = self.flow.forward(z_p.realize(), y_mask.realize(), g=g.realize(), reverse=True) + result_length = reduce(lambda x, y: x * y, self.dec.upsample_rates, row_len) + o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)[:, :, :result_length] if max_y_length_estimate_scale is not None: length_scaler = o.shape[-1] / max_y_length o.realize() @@ -68,6 +79,7 @@ class StochasticDurationPredictor: self.pre, self.proj = nn.Conv1d(in_channels, filter_channels, 1), nn.Conv1d(filter_channels, filter_channels, 1) self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + @TinyJit def forward(self, x: Tensor, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): x = self.pre(x.detach()) if g is not None: x = x + self.cond(g.detach()) @@ -96,13 +108,13 @@ class StochasticDurationPredictor: z, log_det = flow.forward(z, x_mask, g=x, reverse=reverse) log_det_tot = log_det_tot + log_det nll = Tensor.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - log_det_tot - return nll + log_q # [b] + return (nll + log_q).realize() # [b] flows = list(reversed(self.flows)) flows = flows[:-2] + [flows[-1]] # remove a useless vflow z = Tensor.randn(x.shape[0], 2, x.shape[2], dtype=x.dtype).to(device=x.device) * noise_scale for flow in flows: z = flow.forward(z, x_mask, g=x, reverse=reverse) z0, z1 = split(z, [1, 1], 1) - return z0 + return z0.realize() class DurationPredictor: def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): @@ -127,6 +139,7 @@ class TextEncoder: if emotion_embedding: self.emo_proj = nn.Linear(1024, hidden_channels) self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + @TinyJit def forward(self, x: Tensor, x_lengths: Tensor, emotion_embedding=None): if self.n_vocab!=0: x = (self.emb(x) * math.sqrt(self.hidden_channels)) if emotion_embedding: x = x + self.emo_proj(emotion_embedding).unsqueeze(1) @@ -134,7 +147,7 @@ class TextEncoder: x_mask = sequence_mask(x_lengths, x.shape[2]).unsqueeze(1).cast(x.dtype) x = self.encoder.forward(x * x_mask, x_mask) m, logs = split(self.proj(x) * x_mask, self.out_channels, dim=1) - return x, m, logs, x_mask + return x.realize(), m.realize(), logs.realize(), x_mask.realize() class ResidualCouplingBlock: def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): @@ -143,9 +156,10 @@ class ResidualCouplingBlock: for _ in range(n_flows): self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) self.flows.append(Flip()) + @TinyJit def forward(self, x, x_mask, g=None, reverse=False): for flow in reversed(self.flows) if reverse else self.flows: x = flow.forward(x, x_mask, g=g, reverse=reverse) - return x + return x.realize() class PosteriorEncoder: def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0): @@ -166,22 +180,23 @@ class Generator: resblock = ResBlock1 if resblock == '1' else ResBlock2 self.ups = [nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2) for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes))] self.resblocks = [] + self.upsample_rates = upsample_rates for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + @TinyJit def forward(self, x: Tensor, g=None): x = self.conv_pre(x) if g is not None: x = x + self.cond(g) for i in range(self.num_upsamples): - x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None - for j in range(self.num_kernels): - if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x) - else: xs += self.resblocks[i * self.num_kernels + j].forward(x) - x = xs / self.num_kernels - return self.conv_post(x.leakyrelu()).tanh() + x = self.ups[i](x.leakyrelu(LRELU_SLOPE)) + xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels)) + x = (xs / self.num_kernels).realize() + res = self.conv_post(x.leakyrelu()).tanh().realize() + return res class LayerNorm(nn.LayerNorm): def __init__(self, channels, eps=1e-5): super().__init__(channels, eps, elementwise_affine=True) @@ -445,7 +460,7 @@ def rational_quadratic_spline(inputs: Tensor, unnormalized_widths: Tensor, unnor derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2)) return input_cum_heights + numerator / denominator, derivative_numerator.log() - 2 * denominator.log() -def sequence_mask(length: Tensor, max_length): return Tensor.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0).__lt__(length.unsqueeze(1)) +def sequence_mask(length: Tensor, max_length): return Tensor.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1) def generate_path(duration: Tensor, mask: Tensor): # duration: [b, 1, t_x], mask: [b, 1, t_y, t_x] b, _, t_y, t_x = mask.shape path = sequence_mask(duration.cumsum(axis=2).reshape(b * t_x), t_y).cast(mask.dtype).reshape(b, t_x, t_y) @@ -558,6 +573,9 @@ class TextMapper: # Based on https://github.com/keithito/tacotron self.apply_cleaners, self.symbols, self._inflect = apply_cleaners, symbols, None self._symbol_to_id, _id_to_symbol = {s: i for i, s in enumerate(symbols)}, {i: s for i, s in enumerate(symbols)} self._whitespace_re, self._abbreviations = re.compile(r'\s+'), [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]] + self.phonemizer = EspeakBackend( + language="en-us", punctuation_marks=Punctuation.default_marks(), preserve_punctuation=True, with_stress=True, + ) def text_to_sequence(self, text, cleaner_names): if self.apply_cleaners: for name in cleaner_names: @@ -566,18 +584,16 @@ class TextMapper: # Based on https://github.com/keithito/tacotron text = cleaner(text) else: text = text.strip() return [self._symbol_to_id[symbol] for symbol in text] - def get_text(self, text, add_blank=False, cleaners=('english_cleaners',)): + def get_text(self, text, add_blank=False, cleaners=('english_cleaners2',)): text_norm = self.text_to_sequence(text, cleaners) return Tensor(self.intersperse(text_norm, 0) if add_blank else text_norm, dtype=dtypes.int64) def intersperse(self, lst, item): (result := [item] * (len(lst) * 2 + 1))[1::2] = lst return result + def phonemize(self, text, strip=True): return _phonemize(self.phonemizer, text, default_separator, strip, 1, False, False) def filter_oov(self, text): return "".join(list(filter(lambda x: x in self._symbol_to_id, text))) - def base_english_cleaners(self, text, preserve_punctuation=False, with_stress=False): - from phonemizer import phonemize - return self.collapse_whitespace(phonemize(self.expand_abbreviations(unidecode(text.lower())), language='en-us', backend='espeak', strip=True, preserve_punctuation=preserve_punctuation, with_stress=with_stress)) - def english_cleaners(self, text): return self.base_english_cleaners(text) - def english_cleaners2(self, text): return self.base_english_cleaners(text, preserve_punctuation=True, with_stress=True) + def base_english_cleaners(self, text): return self.collapse_whitespace(self.phonemize(self.expand_abbreviations(unidecode(text.lower())))) + def english_cleaners2(self, text): return self.base_english_cleaners(text) def transliteration_cleaners(self, text): return self.collapse_whitespace(unidecode(text.lower())) def cjke_cleaners(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text).replace('ɑ', 'a').replace('ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u'))) def cjke_cleaners2(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text)))