mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
remove more stale stuff (#13765)
* remove more stale stuff * remove disassemblers/adreno * stale
This commit is contained in:
@@ -1,341 +0,0 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import yaml
|
||||
from llama import LLaMa
|
||||
from vits import MODELS as VITS_MODELS
|
||||
from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
|
||||
from whisper import init_whisper, transcribe_waveform
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
from tinygrad.helpers import Timing, fetch
|
||||
from tinygrad import Tensor, dtypes
|
||||
|
||||
# Whisper constants
|
||||
RATE = 16000
|
||||
CHUNK = 1600
|
||||
|
||||
# LLaMa constants
|
||||
IM_START = 32001
|
||||
IM_END = 32002
|
||||
|
||||
|
||||
# Functions for encoding prompts to chatml md
|
||||
def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
|
||||
def chunks(lst, n):
|
||||
for i in range(0, len(lst), n): yield lst[i:i + n]
|
||||
|
||||
def create_fixed_tokenizer():
|
||||
"""Function needed for extending tokenizer with additional chat tokens"""
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
|
||||
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
|
||||
print("creating fixed tokenizer")
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(tokenizer_path.read_bytes())
|
||||
# https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
tokenizer_path.write_bytes(mp.SerializeToString())
|
||||
return tokenizer_path
|
||||
|
||||
def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
|
||||
"""Prepares a llama model from a specified pre-prompt file"""
|
||||
with open(str(pre_prompt_path)) as f:
|
||||
config = yaml.safe_load(f.read())
|
||||
toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
|
||||
for i in config["examples"]:
|
||||
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
|
||||
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
|
||||
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
|
||||
return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
|
||||
|
||||
def llama_generate(
|
||||
llama: LLaMa,
|
||||
toks: list[int],
|
||||
outputted: str,
|
||||
prompt: str,
|
||||
start_pos: int,
|
||||
user_delim: str,
|
||||
resp_delim: str,
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
):
|
||||
"""Generates an output for the specified prompt"""
|
||||
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
|
||||
toks += start_prompt(llama.tokenizer, resp_delim)
|
||||
|
||||
outputted = llama.tokenizer.decode(toks)
|
||||
init_length = len(outputted)
|
||||
for _ in range(max_tokens):
|
||||
token = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
|
||||
start_pos = len(toks)
|
||||
toks.append(token)
|
||||
|
||||
cur = llama.tokenizer.decode(toks)
|
||||
|
||||
# Print is just for debugging
|
||||
sys.stdout.write(cur[len(outputted):])
|
||||
sys.stdout.flush()
|
||||
outputted = cur
|
||||
if toks[-1] == IM_END: break
|
||||
else:
|
||||
toks.append(IM_END)
|
||||
print() # because the output is flushed
|
||||
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
|
||||
|
||||
def tts(
|
||||
text_to_synthesize: str,
|
||||
synth: Synthesizer,
|
||||
hps: HParams,
|
||||
emotion_embedding: Path,
|
||||
speaker_id: int,
|
||||
model_to_use: str,
|
||||
noise_scale: float,
|
||||
noise_scale_w: float,
|
||||
length_scale: float,
|
||||
estimate_max_y_length: bool,
|
||||
text_mapper: TextMapper,
|
||||
model_has_multiple_speakers: bool,
|
||||
pad_length=600,
|
||||
vits_pad_length=1000
|
||||
):
|
||||
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
|
||||
# Convert the input text to a tensor.
|
||||
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
|
||||
init_shape = stn_tst.shape
|
||||
assert init_shape[0] < pad_length, "text is too long"
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, pad_length - init_shape[0]),), value=1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
|
||||
|
||||
# Perform inference.
|
||||
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, pad_length=vits_pad_length)[0, 0]
|
||||
# Save the audio output.
|
||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||
return audio_data
|
||||
|
||||
def init_vits(
|
||||
model_to_use: str,
|
||||
emotion_path: Path,
|
||||
speaker_id: int,
|
||||
seed: int,
|
||||
):
|
||||
model_config = VITS_MODELS[model_to_use]
|
||||
|
||||
# Load the hyperparameters from the config file.
|
||||
hps = get_hparams_from_file(fetch(model_config[0]))
|
||||
|
||||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||
if model_has_multiple_speakers:
|
||||
if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
||||
if hps.__contains__("speakers"): # maps speaker ids to names
|
||||
speakers = hps.speakers
|
||||
if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||
|
||||
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
||||
emotion_embedding = None
|
||||
if emotion_path is not None:
|
||||
if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
|
||||
else: raise ValueError("Emotion path must be a .npy file.")
|
||||
|
||||
# Load symbols, instantiate TextMapper and clean the text.
|
||||
if hps.__contains__("symbols"): symbols = hps.symbols
|
||||
elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
|
||||
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
# Load the model.
|
||||
if seed is not None:
|
||||
Tensor.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
net_g = load_model(text_mapper.symbols, hps, model_config)
|
||||
|
||||
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
|
||||
|
||||
@contextmanager
|
||||
def output_stream(num_channels: int, sample_rate: int):
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
|
||||
yield stream
|
||||
except KeyboardInterrupt: pass
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
@contextmanager
|
||||
def log_writer():
|
||||
try:
|
||||
logs = []
|
||||
yield logs
|
||||
finally:
|
||||
sep = "="*os.get_terminal_size()[1]
|
||||
print(f"{sep[:-1]}\nCHAT LOG")
|
||||
print(*logs, sep="\n")
|
||||
print(sep)
|
||||
|
||||
def listener(q: mp.Queue, event: mp.Event):
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
did_print = False
|
||||
while True:
|
||||
data = stream.read(CHUNK) # read data to avoid overflow
|
||||
if event.is_set():
|
||||
if not did_print:
|
||||
print("listening")
|
||||
did_print = True
|
||||
q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3))
|
||||
else:
|
||||
did_print = False
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
|
||||
with output_stream(num_channels, sample_rate) as stream:
|
||||
while True:
|
||||
try:
|
||||
stream.write(q.get())
|
||||
counter.value += 1
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
import nltk
|
||||
nltk.download("punkt")
|
||||
# Parse CLI arguments
|
||||
parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
|
||||
|
||||
# Whisper args
|
||||
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
|
||||
|
||||
# LLAMA args
|
||||
parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
|
||||
parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument("--llama_quantize", type=str, default=None, help="Quantize the weights to int8 or nf4 in memory")
|
||||
parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
|
||||
parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
|
||||
parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
|
||||
|
||||
# vits args
|
||||
parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
|
||||
parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
|
||||
parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
|
||||
parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
|
||||
parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
|
||||
parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
|
||||
parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
|
||||
parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
|
||||
parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
|
||||
parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
|
||||
parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
|
||||
|
||||
# conversation args
|
||||
parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init models
|
||||
model, enc = init_whisper(args.whisper_model_name)
|
||||
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
|
||||
|
||||
# Download tinyllama chat as a default model
|
||||
if args.llama_model is None:
|
||||
args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
|
||||
args.llama_gen = "tiny"
|
||||
args.llama_size = "1B-Chat"
|
||||
# Add 3 more tokens to the tokenizer
|
||||
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
|
||||
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
|
||||
llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
|
||||
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
|
||||
|
||||
# Start child process for mic input
|
||||
q = mp.Queue()
|
||||
is_listening_event = mp.Event()
|
||||
p = mp.Process(target=listener, args=(q, is_listening_event,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
# Start child process for speaker output
|
||||
out_q = mp.Queue()
|
||||
out_counter = mp.Value("i", 0)
|
||||
out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
|
||||
out_p.daemon = True
|
||||
out_p.start()
|
||||
|
||||
# JIT tts
|
||||
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
|
||||
tts(
|
||||
i, synth, hps, emotion_embedding,
|
||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
||||
args.vits_noise_scale_w, args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
||||
)
|
||||
|
||||
# Start the pipeline
|
||||
with log_writer() as log:
|
||||
while True:
|
||||
tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
total = np.array([])
|
||||
out_counter.value = 0
|
||||
|
||||
s = time.perf_counter()
|
||||
is_listening_event.set()
|
||||
prev_text = None
|
||||
while True:
|
||||
for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
|
||||
txt = transcribe_waveform(model, enc, [total], truncate=True)
|
||||
print(txt, end="\r")
|
||||
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
|
||||
if prev_text is not None and prev_text == txt:
|
||||
is_listening_event.clear()
|
||||
break
|
||||
prev_text = txt
|
||||
print() # to avoid llama printing on the same line
|
||||
log.append(f"{user_delim.capitalize()}: {txt}")
|
||||
|
||||
# Generate with llama
|
||||
with Timing("llama generation: "):
|
||||
outputted, start_pos, response = llama_generate(
|
||||
llama, toks, outputted, txt, start_pos,
|
||||
user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
|
||||
max_tokens=args.llama_count
|
||||
)
|
||||
log.append(f"{resp_delim.capitalize()}: {response}")
|
||||
|
||||
# Convert to voice
|
||||
with Timing("tts: "):
|
||||
sentences = nltk.sent_tokenize(response.replace('"', ""))
|
||||
for i in sentences:
|
||||
total = np.array([], dtype=np.int16)
|
||||
for j in chunks(i.split(), args.max_sentence_length):
|
||||
audio_data = tts(
|
||||
" ".join(j), synth, hps, emotion_embedding,
|
||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
||||
args.vits_noise_scale_w, args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
||||
)
|
||||
total = np.concatenate([total, audio_data])
|
||||
out_q.put(total.tobytes())
|
||||
while out_counter.value < len(sentences): continue
|
||||
log.append(f"Total: {time.perf_counter() - s}")
|
||||
@@ -204,43 +204,6 @@ def eval_bert():
|
||||
|
||||
st = time.perf_counter()
|
||||
|
||||
def eval_mrcnn():
|
||||
from tqdm import tqdm
|
||||
from extra.models.mask_rcnn import MaskRCNN
|
||||
from extra.models.resnet import ResNet
|
||||
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
bbox_output = '/tmp/results_bbox.json'
|
||||
mask_output = '/tmp/results_mask.json'
|
||||
|
||||
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
||||
accumulate_predictions_for_coco([], mask_output, rm=True)
|
||||
|
||||
#TODO: bs > 1 not as accurate
|
||||
bs = 1
|
||||
|
||||
for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs):
|
||||
batch_imgs = []
|
||||
for image_row in batch:
|
||||
image_name = image_row['file_name']
|
||||
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
|
||||
batch_imgs.append(img)
|
||||
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
||||
for image_row, result in zip(batch, batch_result):
|
||||
image_name = image_row['file_name']
|
||||
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
||||
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
||||
accumulate_predictions_for_coco(box_pred, bbox_output)
|
||||
accumulate_predictions_for_coco(mask_pred, mask_output)
|
||||
del batch_imgs
|
||||
del batch_result
|
||||
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
||||
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
||||
|
||||
def eval_llama3():
|
||||
from extra.models.llama import Transformer
|
||||
from examples.llama3 import MODEL_PARAMS, load, convert_from_huggingface
|
||||
@@ -541,7 +504,7 @@ if __name__ == "__main__":
|
||||
# inference only
|
||||
Tensor.training = False
|
||||
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(",")
|
||||
for m in models:
|
||||
nm = f"eval_{m}"
|
||||
if nm in globals():
|
||||
|
||||
740
examples/vits.py
740
examples/vits.py
@@ -1,740 +0,0 @@
|
||||
import json, logging, math, re, sys, time, wave, argparse, numpy as np
|
||||
from phonemizer.phonemize import default_separator, _phonemize
|
||||
from phonemizer.backend import EspeakBackend
|
||||
from phonemizer.punctuation import Punctuation
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from unidecode import unidecode
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
class Synthesizer:
|
||||
def __init__(self, n_vocab, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, n_speakers=0, gin_channels=0, use_sdp=True, emotion_embedding=False, **kwargs):
|
||||
self.n_vocab, self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.use_sdp = n_vocab, spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, use_sdp
|
||||
self.enc_p = TextEncoder(n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, emotion_embedding)
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) if use_sdp else DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
||||
if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None, pad_length=-1):
|
||||
x, m_p, logs_p, x_mask = self.enc_p.forward(x.realize(), x_lengths.realize(), emotion_embedding.realize() if emotion_embedding is not None else emotion_embedding)
|
||||
g = self.emb_g(sid.reshape(1, 1)).squeeze(1).unsqueeze(-1) if self.n_speakers > 0 else None
|
||||
logw = self.dp.forward(x, x_mask.realize(), g=g.realize(), reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0)
|
||||
w_ceil = Tensor.ceil(logw.exp() * x_mask * length_scale)
|
||||
y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64)
|
||||
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length)
|
||||
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length):
|
||||
max_y_length = y_lengths.max().item() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
|
||||
y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1)
|
||||
attn = generate_path(w_ceil, attn_mask)
|
||||
m_p_2 = attn.squeeze(1).matmul(m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
logs_p_2 = attn.squeeze(1).matmul(logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
z_p = m_p_2 + Tensor.randn(*m_p_2.shape, dtype=m_p_2.dtype) * logs_p_2.exp() * noise_scale
|
||||
row_len = y_mask.shape[2]
|
||||
if pad_length > -1:
|
||||
# Pad flow forward inputs to enable JIT
|
||||
assert pad_length > row_len, "pad length is too small"
|
||||
y_mask = y_mask.pad(((0, 0), (0, 0), (0, pad_length - row_len))).cast(z_p.dtype)
|
||||
# New y_mask tensor to remove sts mask
|
||||
y_mask = Tensor(y_mask.numpy(), device=y_mask.device, dtype=y_mask.dtype, requires_grad=y_mask.requires_grad)
|
||||
z_p = z_p.squeeze(0).pad(((0, 0), (0, pad_length - z_p.shape[2])), value=1).unsqueeze(0)
|
||||
z = self.flow.forward(z_p.realize(), y_mask.realize(), g=g.realize(), reverse=True)
|
||||
result_length = reduce(lambda x, y: x * y, self.dec.upsample_rates, row_len)
|
||||
o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)[:, :, :result_length]
|
||||
if max_y_length_estimate_scale is not None:
|
||||
length_scaler = o.shape[-1] / max_y_length
|
||||
o.realize()
|
||||
real_max_y_length = y_lengths.max().numpy()
|
||||
if real_max_y_length > max_y_length:
|
||||
logging.warning(f"Underestimated max length by {(((real_max_y_length / max_y_length) * 100) - 100):.2f}%, recomputing inference without estimate...")
|
||||
return self.generate(g, logs_p, m_p, max_len, None, noise_scale, w_ceil, x, x_mask, y_lengths)
|
||||
if real_max_y_length < max_y_length:
|
||||
overestimation = ((max_y_length / real_max_y_length) * 100) - 100
|
||||
logging.info(f"Overestimated max length by {overestimation:.2f}%")
|
||||
if overestimation > 10: logging.warning("Warning: max length overestimated by more than 10%")
|
||||
o = o[:, :, :(real_max_y_length * length_scaler).astype(np.int32)]
|
||||
return o
|
||||
|
||||
class StochasticDurationPredictor:
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
||||
filter_channels = in_channels # it needs to be removed from future version.
|
||||
self.in_channels, self.filter_channels, self.kernel_size, self.p_dropout, self.n_flows, self.gin_channels = in_channels, filter_channels, kernel_size, p_dropout, n_flows, gin_channels
|
||||
self.log_flow, self.flows = Log(), [ElementwiseAffine(2)]
|
||||
for _ in range(n_flows):
|
||||
self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(Flip())
|
||||
self.post_pre, self.post_proj = nn.Conv1d(1, filter_channels, 1), nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = [ElementwiseAffine(2)]
|
||||
for _ in range(4):
|
||||
self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(Flip())
|
||||
self.pre, self.proj = nn.Conv1d(in_channels, filter_channels, 1), nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
@TinyJit
|
||||
def forward(self, x: Tensor, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
||||
x = self.pre(x.detach())
|
||||
if g is not None: x = x + self.cond(g.detach())
|
||||
x = self.convs.forward(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
assert w is not None
|
||||
log_det_tot_q = 0
|
||||
h_w = self.post_proj(self.post_convs.forward(self.post_pre(w), x_mask)) * x_mask
|
||||
e_q = Tensor.randn(w.size(0), 2, w.size(2), dtype=x.dtype).to(device=x.device) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, log_det_q = flow.forward(z_q, x_mask, g=(x + h_w))
|
||||
log_det_tot_q += log_det_q
|
||||
z_u, z1 = z_q.split([1, 1], 1)
|
||||
u = z_u.sigmoid() * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
log_det_tot_q += Tensor.sum((z_u.logsigmoid() + (-z_u).logsigmoid()) * x_mask, [1,2])
|
||||
log_q = Tensor.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - log_det_tot_q
|
||||
log_det_tot = 0
|
||||
z0, log_det = self.log_flow.forward(z0, x_mask)
|
||||
log_det_tot += log_det
|
||||
z = z0.cat(z1, 1)
|
||||
for flow in flows:
|
||||
z, log_det = flow.forward(z, x_mask, g=x, reverse=reverse)
|
||||
log_det_tot = log_det_tot + log_det
|
||||
nll = Tensor.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - log_det_tot
|
||||
return (nll + log_q).realize() # [b]
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = Tensor.randn(x.shape[0], 2, x.shape[2], dtype=x.dtype).to(device=x.device) * noise_scale
|
||||
for flow in flows: z = flow.forward(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = z.split([1, 1], 1)
|
||||
return z0.realize()
|
||||
|
||||
class DurationPredictor:
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
self.in_channels, self.filter_channels, self.kernel_size, self.p_dropout, self.gin_channels = in_channels, filter_channels, kernel_size, p_dropout, gin_channels
|
||||
self.conv_1, self.norm_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2), LayerNorm(filter_channels)
|
||||
self.conv_2, self.norm_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2), LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
||||
def forward(self, x: Tensor, x_mask, g=None):
|
||||
x = x.detach()
|
||||
if g is not None: x = x + self.cond(g.detach())
|
||||
x = self.conv_1(x * x_mask).relu()
|
||||
x = self.norm_1(x).dropout(self.p_dropout)
|
||||
x = self.conv_2(x * x_mask).relu(x)
|
||||
x = self.norm_2(x).dropout(self.p_dropout)
|
||||
return self.proj(x * x_mask) * x_mask
|
||||
|
||||
class TextEncoder:
|
||||
def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, emotion_embedding):
|
||||
self.n_vocab, self.out_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout = n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
if n_vocab!=0:self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
if emotion_embedding: self.emo_proj = nn.Linear(1024, hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
@TinyJit
|
||||
def forward(self, x: Tensor, x_lengths: Tensor, emotion_embedding=None):
|
||||
if self.n_vocab!=0: x = (self.emb(x) * math.sqrt(self.hidden_channels))
|
||||
if emotion_embedding: x = x + self.emo_proj(emotion_embedding).unsqueeze(1)
|
||||
x = x.transpose(1, -1) # [b, t, h] -transpose-> [b, h, t]
|
||||
x_mask = sequence_mask(x_lengths, x.shape[2]).unsqueeze(1).cast(x.dtype)
|
||||
x = self.encoder.forward(x * x_mask, x_mask)
|
||||
m, logs = (self.proj(x) * x_mask).split(self.out_channels, dim=1)
|
||||
return x.realize(), m.realize(), logs.realize(), x_mask.realize()
|
||||
|
||||
class ResidualCouplingBlock:
|
||||
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
||||
self.channels, self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.n_flows, self.gin_channels = channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows, gin_channels
|
||||
self.flows = []
|
||||
for _ in range(n_flows):
|
||||
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
||||
self.flows.append(Flip())
|
||||
@TinyJit
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
for flow in reversed(self.flows) if reverse else self.flows: x = flow.forward(x, x_mask, g=g, reverse=reverse)
|
||||
return x.realize()
|
||||
|
||||
class PosteriorEncoder:
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
|
||||
self.in_channels, self.out_channels, self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.gin_channels = in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels
|
||||
self.pre, self.proj = nn.Conv1d(in_channels, hidden_channels, 1), nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).cast(x.dtype)
|
||||
stats = self.proj(self.enc.forward(self.pre(x) * x_mask, x_mask, g=g)) * x_mask
|
||||
m, logs = stats.split(self.out_channels, dim=1)
|
||||
z = (m + Tensor.randn(m.shape, m.dtype) * logs.exp()) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
class Generator:
|
||||
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
||||
self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates)
|
||||
self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = ResBlock1 if resblock == '1' else ResBlock2
|
||||
self.ups = [nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2) for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes))]
|
||||
self.resblocks = []
|
||||
self.upsample_rates = upsample_rates
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
@TinyJit
|
||||
def forward(self, x: Tensor, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None: x = x + self.cond(g)
|
||||
for i in range(self.num_upsamples):
|
||||
x = self.ups[i](x.leaky_relu(LRELU_SLOPE))
|
||||
xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels))
|
||||
x = (xs / self.num_kernels).realize()
|
||||
res = self.conv_post(x.leaky_relu()).tanh().realize()
|
||||
return res
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, channels, eps=1e-5): super().__init__(channels, eps, elementwise_affine=True)
|
||||
def forward(self, x: Tensor): return self.__call__(x.transpose(1, -1)).transpose(1, -1)
|
||||
|
||||
class WN:
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
assert (kernel_size % 2 == 1)
|
||||
self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.gin_channels, self.p_dropout = hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout
|
||||
self.in_layers, self.res_skip_layers = [], []
|
||||
if gin_channels != 0: self.cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate ** i
|
||||
self.in_layers.append(nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=int((kernel_size * dilation - dilation) / 2)))
|
||||
self.res_skip_layers.append(nn.Conv1d(hidden_channels, 2 * hidden_channels if i < n_layers - 1 else hidden_channels, 1))
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = Tensor.zeros_like(x)
|
||||
if g is not None: g = self.cond_layer(g)
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = Tensor.zeros_like(x_in)
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels)
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
class ResBlock1:
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
self.convs1 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[i], padding=get_padding(kernel_size, dilation[i])) for i in range(3)]
|
||||
self.convs2 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) for _ in range(3)]
|
||||
def forward(self, x: Tensor, x_mask=None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = x.leaky_relu(LRELU_SLOPE)
|
||||
xt = c1(xt if x_mask is None else xt * x_mask).leaky_relu(LRELU_SLOPE)
|
||||
x = c2(xt if x_mask is None else xt * x_mask) + x
|
||||
return x if x_mask is None else x * x_mask
|
||||
|
||||
class ResBlock2:
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
self.convs = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[i], padding=get_padding(kernel_size, dilation[i])) for i in range(2)]
|
||||
def forward(self, x, x_mask=None):
|
||||
for c in self.convs:
|
||||
xt = x.leaky_relu(LRELU_SLOPE)
|
||||
xt = c(xt if x_mask is None else xt * x_mask)
|
||||
x = xt + x
|
||||
return x if x_mask is None else x * x_mask
|
||||
|
||||
class DDSConv: # Dilated and Depth-Separable Convolution
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
||||
self.channels, self.kernel_size, self.n_layers, self.p_dropout = channels, kernel_size, n_layers, p_dropout
|
||||
self.convs_sep, self.convs_1x1, self.norms_1, self.norms_2 = [], [], [], []
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size ** i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding))
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None: x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i].forward(y).gelu()
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i].forward(y).gelu()
|
||||
x = x + y.dropout(self.p_dropout)
|
||||
return x * x_mask
|
||||
|
||||
class ConvFlow:
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
|
||||
self.in_channels, self.filter_channels, self.kernel_size, self.n_layers, self.num_bins, self.tail_bound = in_channels, filter_channels, kernel_size, n_layers, num_bins, tail_bound
|
||||
self.half_channels = in_channels // 2
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = x.split([self.half_channels] * 2, 1)
|
||||
h = self.proj(self.convs.forward(self.pre(x0), x_mask, g=g)) * x_mask
|
||||
b, c, t = x0.shape
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
un_normalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
|
||||
un_normalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
|
||||
un_normalized_derivatives = h[..., 2 * self.num_bins:]
|
||||
x1, log_abs_det = piecewise_rational_quadratic_transform(x1, un_normalized_widths, un_normalized_heights, un_normalized_derivatives, inverse=reverse, tails='linear', tail_bound=self.tail_bound)
|
||||
x = x0.cat(x1, dim=1) * x_mask
|
||||
return x if reverse else (x, Tensor.sum(log_abs_det * x_mask, [1,2]))
|
||||
|
||||
class ResidualCouplingLayer:
|
||||
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
self.channels, self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.mean_only = channels, hidden_channels, kernel_size, dilation_rate, n_layers, mean_only
|
||||
self.half_channels = channels // 2
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = x.split([self.half_channels] * 2, 1)
|
||||
stats = self.post(self.enc.forward(self.pre(x0) * x_mask, x_mask, g=g)) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = stats.split([self.half_channels] * 2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = Tensor.zeros_like(m)
|
||||
if not reverse: return x0.cat((m + x1 * logs.exp() * x_mask), dim=1)
|
||||
return x0.cat(((x1 - m) * (-logs).exp() * x_mask), dim=1)
|
||||
|
||||
class Log:
|
||||
def forward(self, x : Tensor, x_mask, reverse=False):
|
||||
if not reverse:
|
||||
y = x.maximum(1e-5).log() * x_mask
|
||||
return y, (-y).sum([1, 2])
|
||||
return x.exp() * x_mask
|
||||
|
||||
class Flip:
|
||||
def forward(self, x: Tensor, *args, reverse=False, **kwargs):
|
||||
return x.flip([1]) if reverse else (x.flip([1]), Tensor.zeros(x.shape[0], dtype=x.dtype).to(device=x.device))
|
||||
|
||||
class ElementwiseAffine:
|
||||
def __init__(self, channels): self.m, self.logs = Tensor.zeros(channels, 1), Tensor.zeros(channels, 1)
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs): # x if reverse else y, logdet
|
||||
return (x - self.m) * Tensor.exp(-self.logs) * x_mask if reverse \
|
||||
else ((self.m + Tensor.exp(self.logs) * x) * x_mask, Tensor.sum(self.logs * x_mask, [1, 2]))
|
||||
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
||||
assert channels % n_heads == 0
|
||||
self.channels, self.out_channels, self.n_heads, self.p_dropout, self.window_size, self.heads_share, self.block_length, self.proximal_bias, self.proximal_init = channels, out_channels, n_heads, p_dropout, window_size, heads_share, block_length, proximal_bias, proximal_init
|
||||
self.attn, self.k_channels = None, channels // n_heads
|
||||
self.conv_q, self.conv_k, self.conv_v = [nn.Conv1d(channels, channels, 1) for _ in range(3)]
|
||||
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||
if window_size is not None: self.emb_rel_k, self.emb_rel_v = [Tensor.randn(1 if heads_share else n_heads, window_size * 2 + 1, self.k_channels) * (self.k_channels ** -0.5) for _ in range(2)]
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
return self.conv_o(x)
|
||||
def attention(self, query: Tensor, key: Tensor, value: Tensor, mask=None):# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = key.shape[0], key.shape[1], key.shape[2], query.shape[2]
|
||||
query = query.reshape(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.reshape(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.reshape(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
scores = (query / math.sqrt(self.k_channels)) @ key.transpose(-2, -1)
|
||||
if self.window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
||||
scores = scores + self._relative_position_to_absolute_position(rel_logits)
|
||||
if mask is not None:
|
||||
scores = Tensor.where(mask, scores, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||
scores = Tensor.where(Tensor.ones_like(scores).triu(-self.block_length).tril(self.block_length), scores, -1e4)
|
||||
p_attn = scores.softmax(axis=-1) # [b, n_h, t_t, t_s]
|
||||
output = p_attn.matmul(value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().reshape(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
def _matmul_with_relative_values(self, x, y): return x.matmul(y.unsqueeze(0)) # x: [b, h, l, m], y: [h or 1, m, d], ret: [b, h, l, d]
|
||||
def _matmul_with_relative_keys(self, x, y): return x.matmul(y.unsqueeze(0).transpose(-2, -1)) # x: [b, h, l, d], y: [h or 1, m, d], re, : [b, h, l, m]
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
pad_length, slice_start_position = max(length - (self.window_size + 1), 0), max((self.window_size + 1) - length, 0)
|
||||
padded_relative_embeddings = relative_embeddings if pad_length <= 0\
|
||||
else relative_embeddings.pad(convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
||||
return padded_relative_embeddings[:, slice_start_position:(slice_start_position + 2 * length - 1)] #used_relative_embeddings
|
||||
def _relative_position_to_absolute_position(self, x: Tensor): # x: [b, h, l, 2*l-1] -> [b, h, l, l]
|
||||
batch, heads, length, _ = x.shape
|
||||
x = x.pad(convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
||||
x_flat = x.reshape([batch, heads, length * 2 * length]).pad(convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
||||
return x_flat.reshape([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
||||
def _absolute_position_to_relative_position(self, x: Tensor): # x: [b, h, l, l] -> [b, h, l, 2*l-1]
|
||||
batch, heads, length, _ = x.shape
|
||||
x = x.pad(convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
||||
x_flat = x.reshape([batch, heads, length**2 + length*(length -1)]).pad(convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
return x_flat.reshape([batch, heads, length, 2*length])[:,:,:,1:]
|
||||
|
||||
class FFN:
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
|
||||
self.in_channels, self.out_channels, self.filter_channels, self.kernel_size, self.p_dropout, self.activation, self.causal = in_channels, out_channels, filter_channels, kernel_size, p_dropout, activation, causal
|
||||
self.padding = self._causal_padding if causal else self._same_padding
|
||||
self.conv_1, self.conv_2 = nn.Conv1d(in_channels, filter_channels, kernel_size), nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
x = x * (1.702 * x).sigmoid() if self.activation == "gelu" else x.relu()
|
||||
return self.conv_2(self.padding(x.dropout(self.p_dropout) * x_mask)) * x_mask
|
||||
def _causal_padding(self, x):return x if self.kernel_size == 1 else x.pad(convert_pad_shape([[0, 0], [0, 0], [self.kernel_size - 1, 0]]))
|
||||
def _same_padding(self, x): return x if self.kernel_size == 1 else x.pad(convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1) // 2, self.kernel_size // 2]]))
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
|
||||
self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.window_size = hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, window_size
|
||||
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 = [], [], [], []
|
||||
for _ in range(n_layers):
|
||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask, x = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1), x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.attn_layers[i].forward(x, x, attn_mask).dropout(self.p_dropout)
|
||||
x = self.norm_layers_1[i].forward(x + y)
|
||||
y = self.ffn_layers[i].forward(x, x_mask).dropout(self.p_dropout)
|
||||
x = self.norm_layers_2[i].forward(x + y)
|
||||
return x * x_mask
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH, DEFAULT_MIN_BIN_HEIGHT, DEFAULT_MIN_DERIVATIVE = 1e-3, 1e-3, 1e-3
|
||||
def piecewise_rational_quadratic_transform(inputs, un_normalized_widths, un_normalized_heights, un_normalized_derivatives, inverse=False, tails=None, tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
if tails is None: spline_fn, spline_kwargs = rational_quadratic_spline, {}
|
||||
else: spline_fn, spline_kwargs = unconstrained_rational_quadratic_spline, {'tails': tails, 'tail_bound': tail_bound}
|
||||
return spline_fn(inputs=inputs, un_normalized_widths=un_normalized_widths, un_normalized_heights=un_normalized_heights, un_normalized_derivatives=un_normalized_derivatives, inverse=inverse, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, **spline_kwargs)
|
||||
def unconstrained_rational_quadratic_spline(inputs, un_normalized_widths, un_normalized_heights, un_normalized_derivatives, inverse=False, tails='linear', tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
if not tails == 'linear': raise RuntimeError('{} tails are not implemented.'.format(tails))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1).item()
|
||||
un_normalized_derivatives = cat_lr(un_normalized_derivatives, constant, constant)
|
||||
output, log_abs_det = rational_quadratic_spline(inputs=inputs.squeeze(dim=0).squeeze(dim=0), unnormalized_widths=un_normalized_widths.squeeze(dim=0).squeeze(dim=0), unnormalized_heights=un_normalized_heights.squeeze(dim=0).squeeze(dim=0), unnormalized_derivatives=un_normalized_derivatives.squeeze(dim=0).squeeze(dim=0), inverse=inverse, left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative)
|
||||
return output.unsqueeze(dim=0).unsqueeze(dim=0), log_abs_det.unsqueeze(dim=0).unsqueeze(dim=0)
|
||||
def rational_quadratic_spline(inputs: Tensor, unnormalized_widths: Tensor, unnormalized_heights: Tensor, unnormalized_derivatives: Tensor, inverse=False, left=0., right=1., bottom=0., top=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
if min_bin_width * num_bins > 1.0: raise ValueError('Minimal bin width too large for the number of bins')
|
||||
if min_bin_height * num_bins > 1.0: raise ValueError('Minimal bin height too large for the number of bins')
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * unnormalized_widths.softmax(axis=-1)
|
||||
cum_widths = cat_lr(((right - left) * widths[..., :-1].cumsum(axis=1) + left), left, right + 1e-6 if not inverse else right)
|
||||
widths = cum_widths[..., 1:] - cum_widths[..., :-1]
|
||||
derivatives = min_derivative + (unnormalized_derivatives.exp()+1).log()
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * unnormalized_heights.softmax(axis=-1)
|
||||
cum_heights = cat_lr(((top - bottom) * heights[..., :-1].cumsum(axis=1) + bottom), bottom, top + 1e-6 if inverse else top)
|
||||
heights = cum_heights[..., 1:] - cum_heights[..., :-1]
|
||||
bin_idx = ((inputs[..., None] >= (cum_heights if inverse else cum_widths)).sum(axis=-1) - 1)[..., None]
|
||||
input_cum_widths = gather(cum_widths, bin_idx, axis=-1)[..., 0]
|
||||
input_bin_widths = gather(widths, bin_idx, axis=-1)[..., 0]
|
||||
input_cum_heights = gather(cum_heights, bin_idx, axis=-1)[..., 0]
|
||||
input_delta = gather(heights / widths, bin_idx, axis=-1)[..., 0]
|
||||
input_derivatives = gather(derivatives, bin_idx, axis=-1)[..., 0]
|
||||
input_derivatives_plus_one = gather(derivatives[..., 1:], bin_idx, axis=-1)[..., 0]
|
||||
input_heights = gather(heights, bin_idx, axis=-1)[..., 0]
|
||||
if inverse:
|
||||
a = ((inputs - input_cum_heights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + input_heights * (input_delta - input_derivatives))
|
||||
b = (input_heights * input_derivatives - (inputs - input_cum_heights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta))
|
||||
c = - input_delta * (inputs - input_cum_heights)
|
||||
discriminant = b.square() - 4 * a * c
|
||||
# assert (discriminant.numpy() >= 0).all()
|
||||
root = (2 * c) / (-b - discriminant.sqrt())
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.square() * (input_derivatives_plus_one * root.square() + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).square())
|
||||
return root * input_bin_widths + input_cum_widths, -(derivative_numerator.log() - 2 * denominator.log())
|
||||
theta = (inputs - input_cum_widths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
|
||||
return input_cum_heights + numerator / denominator, derivative_numerator.log() - 2 * denominator.log()
|
||||
|
||||
def sequence_mask(length: Tensor, max_length): return Tensor.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
|
||||
def generate_path(duration: Tensor, mask: Tensor): # duration: [b, 1, t_x], mask: [b, 1, t_y, t_x]
|
||||
b, _, t_y, t_x = mask.shape
|
||||
path = sequence_mask(duration.cumsum(axis=2).reshape(b * t_x), t_y).cast(mask.dtype).reshape(b, t_x, t_y)
|
||||
path = path - path.pad(convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
return path.unsqueeze(1).transpose(2, 3) * mask
|
||||
def fused_add_tanh_sigmoid_multiply(input_a: Tensor, input_b: Tensor, n_channels: int):
|
||||
n_channels_int, in_act = n_channels, input_a + input_b
|
||||
t_act, s_act = in_act[:, :n_channels_int, :].tanh(), in_act[:, n_channels_int:, :].sigmoid()
|
||||
return t_act * s_act
|
||||
|
||||
def cat_lr(t, left, right): return Tensor.full(get_shape(t), left).cat(t, dim=-1).cat(Tensor.full(get_shape(t), right), dim=-1)
|
||||
def get_shape(tensor):
|
||||
(shape := list(tensor.shape))[-1] = 1
|
||||
return tuple(shape)
|
||||
def convert_pad_shape(pad_shape): return tuple(tuple(x) for x in pad_shape)
|
||||
def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
def gather(x, indices, axis):
|
||||
indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(0, axis)
|
||||
permute_args = list(range(x.ndim))
|
||||
permute_args[0], permute_args[axis] = permute_args[axis], permute_args[0]
|
||||
permute_args.append(permute_args.pop(0))
|
||||
x = x.permute(*permute_args)
|
||||
reshape_arg = [1] * x.ndim + [x.shape[-1]]
|
||||
return ((indices.unsqueeze(indices.ndim).expand(*indices.shape, x.shape[-1]) ==
|
||||
Tensor.arange(x.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, x.shape[-1])) * x).sum(indices.ndim).transpose(0, axis)
|
||||
|
||||
def norm_except_dim(v, dim):
|
||||
if dim == -1: return np.linalg.norm(v)
|
||||
if dim == 0:
|
||||
(output_shape := [1] * v.ndim)[0] = v.shape[0]
|
||||
return np.linalg.norm(v.reshape(v.shape[0], -1), axis=1).reshape(output_shape)
|
||||
if dim == v.ndim - 1:
|
||||
(output_shape := [1] * v.ndim)[-1] = v.shape[-1]
|
||||
return np.linalg.norm(v.reshape(-1, v.shape[-1]), axis=0).reshape(output_shape)
|
||||
transposed_v = np.transpose(v, (dim,) + tuple(i for i in range(v.ndim) if i != dim))
|
||||
return np.transpose(norm_except_dim(transposed_v, 0), (dim,) + tuple(i for i in range(v.ndim) if i != dim))
|
||||
def weight_norm(v: Tensor, g: Tensor, dim):
|
||||
v, g = v.numpy(), g.numpy()
|
||||
return Tensor(v * (g / norm_except_dim(v, dim)))
|
||||
|
||||
# HPARAMS LOADING
|
||||
def get_hparams_from_file(path):
|
||||
with open(path, "r") as f:
|
||||
data = f.read()
|
||||
return HParams(**json.loads(data))
|
||||
class HParams:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items(): self[k] = v if type(v) != dict else HParams(**v)
|
||||
def keys(self): return self.__dict__.keys()
|
||||
def items(self): return self.__dict__.items()
|
||||
def values(self): return self.__dict__.values()
|
||||
def __len__(self): return len(self.__dict__)
|
||||
def __getitem__(self, key): return getattr(self, key)
|
||||
def __setitem__(self, key, value): return setattr(self, key, value)
|
||||
def __contains__(self, key): return key in self.__dict__
|
||||
def __repr__(self): return self.__dict__.__repr__()
|
||||
|
||||
# MODEL LOADING
|
||||
def load_model(symbols, hps, model) -> Synthesizer:
|
||||
net_g = Synthesizer(len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers = hps.data.n_speakers, **hps.model)
|
||||
_ = load_checkpoint(fetch(model[1]), net_g, None)
|
||||
return net_g
|
||||
def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_list=[]):
|
||||
assert Path(checkpoint_path).is_file()
|
||||
start_time = time.time()
|
||||
checkpoint_dict = torch_load(checkpoint_path)
|
||||
iteration, learning_rate = checkpoint_dict['iteration'], checkpoint_dict['learning_rate']
|
||||
if optimizer: optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
saved_state_dict = checkpoint_dict['model']
|
||||
weight_g, weight_v, parent = None, None, None
|
||||
for key, v in saved_state_dict.items():
|
||||
if any(layer in key for layer in skip_list): continue
|
||||
try:
|
||||
obj, skip = model, False
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric(): obj = obj[int(k)]
|
||||
elif isinstance(obj, dict): obj = obj[k]
|
||||
else:
|
||||
if isinstance(obj, (LayerNorm, nn.LayerNorm)) and k in ["gamma", "beta"]:
|
||||
k = "weight" if k == "gamma" else "bias"
|
||||
elif k in ["weight_g", "weight_v"]:
|
||||
parent, skip = obj, True
|
||||
if k == "weight_g": weight_g = v
|
||||
else: weight_v = v
|
||||
if not skip: obj = getattr(obj, k)
|
||||
if weight_g is not None and weight_v is not None:
|
||||
setattr(obj, "weight_g", weight_g.numpy())
|
||||
setattr(obj, "weight_v", weight_v.numpy())
|
||||
obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
|
||||
weight_g, weight_v, parent, skip = None, None, None, False
|
||||
if not skip and obj.shape == v.shape: obj.assign(v.to(obj.device))
|
||||
elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}")
|
||||
except Exception as e: raise e
|
||||
logging.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration}) in {time.time() - start_time:.4f}s")
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
# Used for cleaning input text and mapping to symbols
|
||||
class TextMapper: # Based on https://github.com/keithito/tacotron
|
||||
def __init__(self, symbols, apply_cleaners=True):
|
||||
self.apply_cleaners, self.symbols, self._inflect = apply_cleaners, symbols, None
|
||||
self._symbol_to_id, _id_to_symbol = {s: i for i, s in enumerate(symbols)}, {i: s for i, s in enumerate(symbols)}
|
||||
self._whitespace_re, self._abbreviations = re.compile(r'\s+'), [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]]
|
||||
self.phonemizer = EspeakBackend(
|
||||
language="en-us", punctuation_marks=Punctuation.default_marks(), preserve_punctuation=True, with_stress=True,
|
||||
)
|
||||
def text_to_sequence(self, text, cleaner_names):
|
||||
if self.apply_cleaners:
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(self, name)
|
||||
if not cleaner: raise ModuleNotFoundError('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
else: text = text.strip()
|
||||
return [self._symbol_to_id[symbol] for symbol in text]
|
||||
def get_text(self, text, add_blank=False, cleaners=('english_cleaners2',)):
|
||||
text_norm = self.text_to_sequence(text, cleaners)
|
||||
return Tensor(self.intersperse(text_norm, 0) if add_blank else text_norm, dtype=dtypes.int64)
|
||||
def intersperse(self, lst, item):
|
||||
(result := [item] * (len(lst) * 2 + 1))[1::2] = lst
|
||||
return result
|
||||
def phonemize(self, text, strip=True): return _phonemize(self.phonemizer, text, default_separator, strip, 1, False, False)
|
||||
def filter_oov(self, text): return "".join(list(filter(lambda x: x in self._symbol_to_id, text)))
|
||||
def base_english_cleaners(self, text): return self.collapse_whitespace(self.phonemize(self.expand_abbreviations(unidecode(text.lower()))))
|
||||
def english_cleaners2(self, text): return self.base_english_cleaners(text)
|
||||
def transliteration_cleaners(self, text): return self.collapse_whitespace(unidecode(text.lower()))
|
||||
def cjke_cleaners(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text).replace('ɑ', 'a').replace('ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')))
|
||||
def cjke_cleaners2(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text)))
|
||||
def cjks_cleaners(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_lazy_ipa(text)))
|
||||
def english_to_ipa2(self, text):
|
||||
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ ('r', 'ɹ'), ('ʤ', 'dʒ'), ('ʧ', 'tʃ')]]
|
||||
return reduce(lambda t, rx: re.sub(rx[0], rx[1], t), _ipa_to_ipa2, self.mark_dark_l(self.english_to_ipa(text))).replace('...', '…')
|
||||
def mark_dark_l(self, text): return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ' + x.group(1), text)
|
||||
def english_to_ipa(self, text):
|
||||
import eng_to_ipa as ipa
|
||||
return self.collapse_whitespace(ipa.convert(self.normalize_numbers(self.expand_abbreviations(unidecode(text).lower()))))
|
||||
def english_to_lazy_ipa(self, text):
|
||||
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [('r', 'ɹ'), ('æ', 'e'), ('ɑ', 'a'), ('ɔ', 'o'), ('ð', 'z'), ('θ', 's'), ('ɛ', 'e'), ('ɪ', 'i'), ('ʊ', 'u'), ('ʒ', 'ʥ'), ('ʤ', 'ʥ'), ('ˈ', '↓')]]
|
||||
return reduce(lambda t, rx: re.sub(rx[0], rx[1], t), _lazy_ipa, self.english_to_ipa(text))
|
||||
def expand_abbreviations(self, text): return reduce(lambda t, abbr: re.sub(abbr[0], abbr[1], t), self._abbreviations, text)
|
||||
def collapse_whitespace(self, text): return re.sub(self._whitespace_re, ' ', text)
|
||||
def normalize_numbers(self, text):
|
||||
import inflect
|
||||
self._inflect = inflect.engine()
|
||||
text = re.sub(re.compile(r'([0-9][0-9\,]+[0-9])'), self._remove_commas, text)
|
||||
text = re.sub(re.compile(r'£([0-9\,]*[0-9]+)'), r'\1 pounds', text)
|
||||
text = re.sub(re.compile(r'\$([0-9\.\,]*[0-9]+)'), self._expand_dollars, text)
|
||||
text = re.sub(re.compile(r'([0-9]+\.[0-9]+)'), self._expand_decimal_point, text)
|
||||
text = re.sub(re.compile(r'[0-9]+(st|nd|rd|th)'), self._expand_ordinal, text)
|
||||
text = re.sub(re.compile(r'[0-9]+'), self._expand_number, text)
|
||||
return text
|
||||
def _remove_commas(self, m): return m.group(1).replace(',', '') # george won't like this
|
||||
def _expand_dollars(self, m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2: return match + ' dollars' # Unexpected format
|
||||
dollars, cents = int(parts[0]) if parts[0] else 0, int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents: return '%s %s, %s %s' % (dollars, 'dollar' if dollars == 1 else 'dollars', cents, 'cent' if cents == 1 else 'cents')
|
||||
if dollars: return '%s %s' % (dollars, 'dollar' if dollars == 1 else 'dollars')
|
||||
if cents: return '%s %s' % (cents, 'cent' if cents == 1 else 'cents')
|
||||
return 'zero dollars'
|
||||
def _expand_decimal_point(self, m): return m.group(1).replace('.', ' point ')
|
||||
def _expand_ordinal(self, m): return self._inflect.number_to_words(m.group(0))
|
||||
def _expand_number(self, _inflect, m):
|
||||
num = int(m.group(0))
|
||||
if 1000 < num < 3000:
|
||||
if num == 2000: return 'two thousand'
|
||||
if 2000 < num < 2010: return 'two thousand ' + self._inflect.number_to_words(num % 100)
|
||||
if num % 100 == 0: return self._inflect.number_to_words(num // 100) + ' hundred'
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
return self._inflect.number_to_words(num, andword='')
|
||||
|
||||
#########################################################################################
|
||||
# PAPER: https://arxiv.org/abs/2106.06103
|
||||
# CODE: https://github.com/jaywalnut310/vits/tree/main
|
||||
#########################################################################################
|
||||
# INSTALLATION: this is based on default config, dependencies are for preprocessing.
|
||||
# vctk, ljs | pip3 install unidecode phonemizer | phonemizer requires [eSpeak](https://espeak.sourceforge.net) backend to be installed on your system
|
||||
# mmts-tts | pip3 install unidecode |
|
||||
# uma_trilingual, cjks, voistock | pip3 install unidecode inflect eng_to_ipa |
|
||||
#########################################################################################
|
||||
# Some good speakers to try out, there may be much better ones, I only tried out a few:
|
||||
# male vctk 1 | --model_to_use vctk --speaker_id 2
|
||||
# male vctk 2 | --model_to_use vctk --speaker_id 6
|
||||
# anime lady 1 | --model_to_use uma_trilingual --speaker_id 36
|
||||
# anime lady 2 | --model_to_use uma_trilingual --speaker_id 121
|
||||
#########################################################################################
|
||||
VITS_PATH = Path(__file__).parents[1] / "weights/VITS/"
|
||||
MODELS = { # config_url, weights_url
|
||||
"ljs": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"),
|
||||
"vctk": ("https://huggingface.co/csukuangfj/vits-vctk/resolve/main/vctk_base.json", "https://huggingface.co/csukuangfj/vits-vctk/resolve/main/pretrained_vctk.pth"),
|
||||
"mmts-tts": ("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth"),
|
||||
"uma_trilingual": ("https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth"),
|
||||
"cjks": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth"),
|
||||
"voistock": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth"),
|
||||
}
|
||||
Y_LENGTH_ESTIMATE_SCALARS = {"ljs": 2.8, "vctk": 1.74, "mmts-tts": 1.9, "uma_trilingual": 2.3, "cjks": 3.3, "voistock": 3.1}
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
|
||||
parser.add_argument("--speaker_id", type=int, default=6, help="Specify the speaker ID. Default is 6.")
|
||||
parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
|
||||
parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
|
||||
parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
|
||||
parser.add_argument("--text_to_synthesize", default="""Hello person. If the code you are contributing isn't some of the highest quality code you've written in your life, either put in the effort to make it great, or don't bother.""", help="Specify the text to synthesize. Default is a greeting message.")
|
||||
parser.add_argument("--noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
|
||||
parser.add_argument("--noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
|
||||
parser.add_argument("--length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
|
||||
parser.add_argument("--seed", type=int, default=1337, help="Specify the seed (set to None if no seed). Default is 1337.")
|
||||
parser.add_argument("--num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
|
||||
parser.add_argument("--sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
|
||||
parser.add_argument("--emotion_path", type=str, default=None, help="Specify the path to emotion reference.")
|
||||
parser.add_argument("--estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_config = MODELS[args.model_to_use]
|
||||
|
||||
# Load the hyperparameters from the config file.
|
||||
hps = get_hparams_from_file(fetch(model_config[0]))
|
||||
|
||||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||
if model_has_multiple_speakers:
|
||||
logging.info(f"Model has {hps.data.n_speakers} speakers")
|
||||
if args.speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {args.speaker_id} is invalid for this model.")
|
||||
speaker_name = "?"
|
||||
if hps.__contains__("speakers"): # maps speaker ids to names
|
||||
speakers = hps.speakers
|
||||
if isinstance(speakers, List): speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||
speaker_name = next((key for key, value in speakers.items() if value == args.speaker_id), None)
|
||||
logging.info(f"You selected speaker {args.speaker_id} (name: {speaker_name})")
|
||||
|
||||
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
||||
emotion_embedding = None
|
||||
if args.emotion_path is not None:
|
||||
if args.emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(args.emotion_path), dtype=dtypes.int64).unsqueeze(0)
|
||||
else: raise ValueError("Emotion path must be a .npy file.")
|
||||
|
||||
# Load symbols, instantiate TextMapper and clean the text.
|
||||
if hps.__contains__("symbols"): symbols = hps.symbols
|
||||
elif args.model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
|
||||
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
# Load the model.
|
||||
if args.seed is not None:
|
||||
Tensor.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
net_g = load_model(text_mapper.symbols, hps, model_config)
|
||||
logging.debug(f"Loaded model with hps: {hps}")
|
||||
|
||||
# Convert the input text to a tensor.
|
||||
text_to_synthesize = args.text_to_synthesize
|
||||
if args.model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
|
||||
logging.debug(f"Converted input text to tensor \"{text_to_synthesize}\" -> Tensor({stn_tst.shape}): {stn_tst.numpy()}")
|
||||
x_tst, x_tst_lengths = stn_tst.unsqueeze(0), Tensor([stn_tst.shape[0]], dtype=dtypes.int64)
|
||||
sid = Tensor([args.speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
|
||||
|
||||
# Perform inference.
|
||||
start_time = time.time()
|
||||
audio_tensor = net_g.infer(x_tst, x_tst_lengths, sid, args.noise_scale, args.length_scale, args.noise_scale_w, emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[args.model_to_use] if args.estimate_max_y_length else None)[0, 0].realize()
|
||||
logging.info(f"Inference took {(time.time() - start_time):.2f}s")
|
||||
|
||||
# Save the audio output.
|
||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||
out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model_to_use}{f'_sid_{args.speaker_id}' if model_has_multiple_speakers else ''}_{args.base_name}.wav")
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with wave.open(str(out_path), 'wb') as wav_file:
|
||||
wav_file.setnchannels(args.num_channels)
|
||||
wav_file.setsampwidth(args.sample_width)
|
||||
wav_file.setframerate(hps.data.sampling_rate)
|
||||
wav_file.setnframes(len(audio_data))
|
||||
wav_file.writeframes(audio_data.tobytes())
|
||||
logging.info(f"Saved audio output to {out_path}")
|
||||
@@ -1,199 +0,0 @@
|
||||
import json
|
||||
import pathlib
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from tinygrad.helpers import fetch
|
||||
import pycocotools._mask as _mask
|
||||
from examples.mask_rcnn import Masker
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
iou = _mask.iou
|
||||
merge = _mask.merge
|
||||
frPyObjects = _mask.frPyObjects
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent / "COCO"
|
||||
BASEDIR.mkdir(exist_ok=True)
|
||||
|
||||
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'val2017').is_dir():
|
||||
fn = fetch('http://images.cocodataset.org/zips/val2017.zip')
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'annotations').is_dir():
|
||||
fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
|
||||
annotations_raw = json.loads(f.read())
|
||||
images = annotations_raw['images']
|
||||
categories = annotations_raw['categories']
|
||||
annotations = annotations_raw['annotations']
|
||||
file_name_to_id = create_dict('file_name', 'id', images)
|
||||
id_to_width = create_dict('id', 'width', images)
|
||||
id_to_height = create_dict('id', 'height', images)
|
||||
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
|
||||
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
|
||||
|
||||
|
||||
def encode(bimask):
|
||||
if len(bimask.shape) == 3:
|
||||
return _mask.encode(bimask)
|
||||
elif len(bimask.shape) == 2:
|
||||
h, w = bimask.shape
|
||||
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
|
||||
|
||||
def decode(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.decode(rleObjs)
|
||||
else:
|
||||
return _mask.decode([rleObjs])[:,:,0]
|
||||
|
||||
def area(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.area(rleObjs)
|
||||
else:
|
||||
return _mask.area([rleObjs])[0]
|
||||
|
||||
def toBbox(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.toBbox(rleObjs)
|
||||
else:
|
||||
return _mask.toBbox([rleObjs])[0]
|
||||
|
||||
|
||||
def convert_prediction_to_coco_bbox(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
prediction = prediction.convert("xywh")
|
||||
|
||||
boxes = prediction.bbox.numpy().tolist()
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"bbox": box,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, box in enumerate(boxes)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
def convert_prediction_to_coco_mask(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
masks = prediction.get_field("mask")
|
||||
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
|
||||
masks = masker([masks], [prediction])[0].numpy()
|
||||
|
||||
rles = [
|
||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
|
||||
for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"segmentation": rle,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, rle in enumerate(rles)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
|
||||
|
||||
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
|
||||
path = pathlib.Path(json_result_file)
|
||||
if rm and path.exists(): path.unlink()
|
||||
with open(path, "a") as f:
|
||||
for s in coco_results:
|
||||
f.write(json.dumps(s))
|
||||
f.write('\n')
|
||||
|
||||
def remove_dup(l):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in l if not (x in seen or seen_add(x))]
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super(NpEncoder, self).default(obj)
|
||||
|
||||
|
||||
def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
|
||||
coco_results = []
|
||||
with open(json_result_file, "r") as f:
|
||||
for line in f:
|
||||
coco_results.append(json.loads(line))
|
||||
|
||||
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
|
||||
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
||||
unique_list = [json.loads(s) for s in set_of_json]
|
||||
|
||||
with open(f'{json_result_file}.flattend', "w") as f:
|
||||
json.dump(unique_list, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
return coco_eval
|
||||
|
||||
def iterate(files, bs=1):
|
||||
batch = []
|
||||
for file in files:
|
||||
batch.append(file)
|
||||
if len(batch) >= bs: yield batch; batch = []
|
||||
if len(batch) > 0: yield batch; batch = []
|
||||
1
extra/disassemblers/adreno/.gitignore
vendored
1
extra/disassemblers/adreno/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
disasm.so
|
||||
@@ -1,5 +0,0 @@
|
||||
From the Freedreno project
|
||||
|
||||
https://gallium.readthedocs.io/en/latest/gallium/drivers/freedreno.html
|
||||
|
||||
In Mesa3D, so licensed MIT.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,906 +0,0 @@
|
||||
/*
|
||||
* Mesa 3-D graphics library
|
||||
*
|
||||
* Copyright (C) 1999-2008 Brian Paul All Rights Reserved.
|
||||
* Copyright (C) 2009 VMware, Inc. All Rights Reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included
|
||||
* in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef SHADER_ENUMS_H
|
||||
#define SHADER_ENUMS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Project-wide (GL and Vulkan) maximum. */
|
||||
#define MAX_DRAW_BUFFERS 8
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Shader stages.
|
||||
*
|
||||
* The order must match how shaders are ordered in the pipeline.
|
||||
* The GLSL linker assumes that if i<j, then the j-th shader is
|
||||
* executed later than the i-th shader.
|
||||
*/
|
||||
typedef enum
|
||||
{
|
||||
MESA_SHADER_NONE = -1,
|
||||
MESA_SHADER_VERTEX = 0,
|
||||
MESA_SHADER_TESS_CTRL = 1,
|
||||
MESA_SHADER_TESS_EVAL = 2,
|
||||
MESA_SHADER_GEOMETRY = 3,
|
||||
MESA_SHADER_FRAGMENT = 4,
|
||||
MESA_SHADER_COMPUTE = 5,
|
||||
/* must be last so it doesn't affect the GL pipeline */
|
||||
MESA_SHADER_KERNEL = 6,
|
||||
} gl_shader_stage;
|
||||
|
||||
static inline bool
|
||||
gl_shader_stage_is_compute(gl_shader_stage stage)
|
||||
{
|
||||
return stage == MESA_SHADER_COMPUTE || stage == MESA_SHADER_KERNEL;
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of STATE_* values we need to address any GL state.
|
||||
* Used to dimension arrays.
|
||||
*/
|
||||
#define STATE_LENGTH 5
|
||||
|
||||
typedef short gl_state_index16; /* see enum gl_state_index */
|
||||
|
||||
const char *gl_shader_stage_name(gl_shader_stage stage);
|
||||
|
||||
/**
|
||||
* Translate a gl_shader_stage to a short shader stage name for debug
|
||||
* printouts and error messages.
|
||||
*/
|
||||
const char *_mesa_shader_stage_to_string(unsigned stage);
|
||||
|
||||
/**
|
||||
* Translate a gl_shader_stage to a shader stage abbreviation (VS, GS, FS)
|
||||
* for debug printouts and error messages.
|
||||
*/
|
||||
const char *_mesa_shader_stage_to_abbrev(unsigned stage);
|
||||
|
||||
/**
|
||||
* GL related stages (not including CL)
|
||||
*/
|
||||
#define MESA_SHADER_STAGES (MESA_SHADER_COMPUTE + 1)
|
||||
|
||||
/**
|
||||
* All stages
|
||||
*/
|
||||
#define MESA_ALL_SHADER_STAGES (MESA_SHADER_KERNEL + 1)
|
||||
|
||||
|
||||
/**
|
||||
* Indexes for vertex program attributes.
|
||||
* GL_NV_vertex_program aliases generic attributes over the conventional
|
||||
* attributes. In GL_ARB_vertex_program shader the aliasing is optional.
|
||||
* In GL_ARB_vertex_shader / OpenGL 2.0 the aliasing is disallowed (the
|
||||
* generic attributes are distinct/separate).
|
||||
*/
|
||||
typedef enum
|
||||
{
|
||||
VERT_ATTRIB_POS,
|
||||
VERT_ATTRIB_NORMAL,
|
||||
VERT_ATTRIB_COLOR0,
|
||||
VERT_ATTRIB_COLOR1,
|
||||
VERT_ATTRIB_FOG,
|
||||
VERT_ATTRIB_COLOR_INDEX,
|
||||
VERT_ATTRIB_EDGEFLAG,
|
||||
VERT_ATTRIB_TEX0,
|
||||
VERT_ATTRIB_TEX1,
|
||||
VERT_ATTRIB_TEX2,
|
||||
VERT_ATTRIB_TEX3,
|
||||
VERT_ATTRIB_TEX4,
|
||||
VERT_ATTRIB_TEX5,
|
||||
VERT_ATTRIB_TEX6,
|
||||
VERT_ATTRIB_TEX7,
|
||||
VERT_ATTRIB_POINT_SIZE,
|
||||
VERT_ATTRIB_GENERIC0,
|
||||
VERT_ATTRIB_GENERIC1,
|
||||
VERT_ATTRIB_GENERIC2,
|
||||
VERT_ATTRIB_GENERIC3,
|
||||
VERT_ATTRIB_GENERIC4,
|
||||
VERT_ATTRIB_GENERIC5,
|
||||
VERT_ATTRIB_GENERIC6,
|
||||
VERT_ATTRIB_GENERIC7,
|
||||
VERT_ATTRIB_GENERIC8,
|
||||
VERT_ATTRIB_GENERIC9,
|
||||
VERT_ATTRIB_GENERIC10,
|
||||
VERT_ATTRIB_GENERIC11,
|
||||
VERT_ATTRIB_GENERIC12,
|
||||
VERT_ATTRIB_GENERIC13,
|
||||
VERT_ATTRIB_GENERIC14,
|
||||
VERT_ATTRIB_GENERIC15,
|
||||
VERT_ATTRIB_MAX
|
||||
} gl_vert_attrib;
|
||||
|
||||
const char *gl_vert_attrib_name(gl_vert_attrib attrib);
|
||||
|
||||
/**
|
||||
* Symbolic constats to help iterating over
|
||||
* specific blocks of vertex attributes.
|
||||
*
|
||||
* VERT_ATTRIB_FF
|
||||
* includes all fixed function attributes as well as
|
||||
* the aliased GL_NV_vertex_program shader attributes.
|
||||
* VERT_ATTRIB_TEX
|
||||
* include the classic texture coordinate attributes.
|
||||
* Is a subset of VERT_ATTRIB_FF.
|
||||
* VERT_ATTRIB_GENERIC
|
||||
* include the OpenGL 2.0+ GLSL generic shader attributes.
|
||||
* These alias the generic GL_ARB_vertex_shader attributes.
|
||||
* VERT_ATTRIB_MAT
|
||||
* include the generic shader attributes used to alias
|
||||
* varying material values for the TNL shader programs.
|
||||
* They are located at the end of the generic attribute
|
||||
* block not to overlap with the generic 0 attribute.
|
||||
*/
|
||||
#define VERT_ATTRIB_FF(i) (VERT_ATTRIB_POS + (i))
|
||||
#define VERT_ATTRIB_FF_MAX VERT_ATTRIB_GENERIC0
|
||||
|
||||
#define VERT_ATTRIB_TEX(i) (VERT_ATTRIB_TEX0 + (i))
|
||||
#define VERT_ATTRIB_TEX_MAX MAX_TEXTURE_COORD_UNITS
|
||||
|
||||
#define VERT_ATTRIB_GENERIC(i) (VERT_ATTRIB_GENERIC0 + (i))
|
||||
#define VERT_ATTRIB_GENERIC_MAX MAX_VERTEX_GENERIC_ATTRIBS
|
||||
|
||||
#define VERT_ATTRIB_MAT0 \
|
||||
(VERT_ATTRIB_GENERIC_MAX - VERT_ATTRIB_MAT_MAX)
|
||||
#define VERT_ATTRIB_MAT(i) \
|
||||
VERT_ATTRIB_GENERIC((i) + VERT_ATTRIB_MAT0)
|
||||
#define VERT_ATTRIB_MAT_MAX MAT_ATTRIB_MAX
|
||||
|
||||
/**
|
||||
* Bitflags for vertex attributes.
|
||||
* These are used in bitfields in many places.
|
||||
*/
|
||||
/*@{*/
|
||||
#define VERT_BIT_POS BITFIELD_BIT(VERT_ATTRIB_POS)
|
||||
#define VERT_BIT_NORMAL BITFIELD_BIT(VERT_ATTRIB_NORMAL)
|
||||
#define VERT_BIT_COLOR0 BITFIELD_BIT(VERT_ATTRIB_COLOR0)
|
||||
#define VERT_BIT_COLOR1 BITFIELD_BIT(VERT_ATTRIB_COLOR1)
|
||||
#define VERT_BIT_FOG BITFIELD_BIT(VERT_ATTRIB_FOG)
|
||||
#define VERT_BIT_COLOR_INDEX BITFIELD_BIT(VERT_ATTRIB_COLOR_INDEX)
|
||||
#define VERT_BIT_EDGEFLAG BITFIELD_BIT(VERT_ATTRIB_EDGEFLAG)
|
||||
#define VERT_BIT_TEX0 BITFIELD_BIT(VERT_ATTRIB_TEX0)
|
||||
#define VERT_BIT_TEX1 BITFIELD_BIT(VERT_ATTRIB_TEX1)
|
||||
#define VERT_BIT_TEX2 BITFIELD_BIT(VERT_ATTRIB_TEX2)
|
||||
#define VERT_BIT_TEX3 BITFIELD_BIT(VERT_ATTRIB_TEX3)
|
||||
#define VERT_BIT_TEX4 BITFIELD_BIT(VERT_ATTRIB_TEX4)
|
||||
#define VERT_BIT_TEX5 BITFIELD_BIT(VERT_ATTRIB_TEX5)
|
||||
#define VERT_BIT_TEX6 BITFIELD_BIT(VERT_ATTRIB_TEX6)
|
||||
#define VERT_BIT_TEX7 BITFIELD_BIT(VERT_ATTRIB_TEX7)
|
||||
#define VERT_BIT_POINT_SIZE BITFIELD_BIT(VERT_ATTRIB_POINT_SIZE)
|
||||
#define VERT_BIT_GENERIC0 BITFIELD_BIT(VERT_ATTRIB_GENERIC0)
|
||||
|
||||
#define VERT_BIT(i) BITFIELD_BIT(i)
|
||||
#define VERT_BIT_ALL BITFIELD_RANGE(0, VERT_ATTRIB_MAX)
|
||||
|
||||
#define VERT_BIT_FF(i) VERT_BIT(i)
|
||||
#define VERT_BIT_FF_ALL BITFIELD_RANGE(0, VERT_ATTRIB_FF_MAX)
|
||||
#define VERT_BIT_TEX(i) VERT_BIT(VERT_ATTRIB_TEX(i))
|
||||
#define VERT_BIT_TEX_ALL \
|
||||
BITFIELD_RANGE(VERT_ATTRIB_TEX(0), VERT_ATTRIB_TEX_MAX)
|
||||
|
||||
#define VERT_BIT_GENERIC(i) VERT_BIT(VERT_ATTRIB_GENERIC(i))
|
||||
#define VERT_BIT_GENERIC_ALL \
|
||||
BITFIELD_RANGE(VERT_ATTRIB_GENERIC(0), VERT_ATTRIB_GENERIC_MAX)
|
||||
|
||||
#define VERT_BIT_MAT(i) VERT_BIT(VERT_ATTRIB_MAT(i))
|
||||
#define VERT_BIT_MAT_ALL \
|
||||
BITFIELD_RANGE(VERT_ATTRIB_MAT(0), VERT_ATTRIB_MAT_MAX)
|
||||
/*@}*/
|
||||
|
||||
#define MAX_VARYING 32 /**< number of float[4] vectors */
|
||||
|
||||
/**
|
||||
* Indexes for vertex shader outputs, geometry shader inputs/outputs, and
|
||||
* fragment shader inputs.
|
||||
*
|
||||
* Note that some of these values are not available to all pipeline stages.
|
||||
*
|
||||
* When this enum is updated, the following code must be updated too:
|
||||
* - vertResults (in prog_print.c's arb_output_attrib_string())
|
||||
* - fragAttribs (in prog_print.c's arb_input_attrib_string())
|
||||
* - _mesa_varying_slot_in_fs()
|
||||
*/
|
||||
typedef enum
|
||||
{
|
||||
VARYING_SLOT_POS,
|
||||
VARYING_SLOT_COL0, /* COL0 and COL1 must be contiguous */
|
||||
VARYING_SLOT_COL1,
|
||||
VARYING_SLOT_FOGC,
|
||||
VARYING_SLOT_TEX0, /* TEX0-TEX7 must be contiguous */
|
||||
VARYING_SLOT_TEX1,
|
||||
VARYING_SLOT_TEX2,
|
||||
VARYING_SLOT_TEX3,
|
||||
VARYING_SLOT_TEX4,
|
||||
VARYING_SLOT_TEX5,
|
||||
VARYING_SLOT_TEX6,
|
||||
VARYING_SLOT_TEX7,
|
||||
VARYING_SLOT_PSIZ, /* Does not appear in FS */
|
||||
VARYING_SLOT_BFC0, /* Does not appear in FS */
|
||||
VARYING_SLOT_BFC1, /* Does not appear in FS */
|
||||
VARYING_SLOT_EDGE, /* Does not appear in FS */
|
||||
VARYING_SLOT_CLIP_VERTEX, /* Does not appear in FS */
|
||||
VARYING_SLOT_CLIP_DIST0,
|
||||
VARYING_SLOT_CLIP_DIST1,
|
||||
VARYING_SLOT_CULL_DIST0,
|
||||
VARYING_SLOT_CULL_DIST1,
|
||||
VARYING_SLOT_PRIMITIVE_ID, /* Does not appear in VS */
|
||||
VARYING_SLOT_LAYER, /* Appears as VS or GS output */
|
||||
VARYING_SLOT_VIEWPORT, /* Appears as VS or GS output */
|
||||
VARYING_SLOT_FACE, /* FS only */
|
||||
VARYING_SLOT_PNTC, /* FS only */
|
||||
VARYING_SLOT_TESS_LEVEL_OUTER, /* Only appears as TCS output. */
|
||||
VARYING_SLOT_TESS_LEVEL_INNER, /* Only appears as TCS output. */
|
||||
VARYING_SLOT_BOUNDING_BOX0, /* Only appears as TCS output. */
|
||||
VARYING_SLOT_BOUNDING_BOX1, /* Only appears as TCS output. */
|
||||
VARYING_SLOT_VIEW_INDEX,
|
||||
VARYING_SLOT_VIEWPORT_MASK, /* Does not appear in FS */
|
||||
VARYING_SLOT_VAR0, /* First generic varying slot */
|
||||
/* the remaining are simply for the benefit of gl_varying_slot_name()
|
||||
* and not to be construed as an upper bound:
|
||||
*/
|
||||
VARYING_SLOT_VAR1,
|
||||
VARYING_SLOT_VAR2,
|
||||
VARYING_SLOT_VAR3,
|
||||
VARYING_SLOT_VAR4,
|
||||
VARYING_SLOT_VAR5,
|
||||
VARYING_SLOT_VAR6,
|
||||
VARYING_SLOT_VAR7,
|
||||
VARYING_SLOT_VAR8,
|
||||
VARYING_SLOT_VAR9,
|
||||
VARYING_SLOT_VAR10,
|
||||
VARYING_SLOT_VAR11,
|
||||
VARYING_SLOT_VAR12,
|
||||
VARYING_SLOT_VAR13,
|
||||
VARYING_SLOT_VAR14,
|
||||
VARYING_SLOT_VAR15,
|
||||
VARYING_SLOT_VAR16,
|
||||
VARYING_SLOT_VAR17,
|
||||
VARYING_SLOT_VAR18,
|
||||
VARYING_SLOT_VAR19,
|
||||
VARYING_SLOT_VAR20,
|
||||
VARYING_SLOT_VAR21,
|
||||
VARYING_SLOT_VAR22,
|
||||
VARYING_SLOT_VAR23,
|
||||
VARYING_SLOT_VAR24,
|
||||
VARYING_SLOT_VAR25,
|
||||
VARYING_SLOT_VAR26,
|
||||
VARYING_SLOT_VAR27,
|
||||
VARYING_SLOT_VAR28,
|
||||
VARYING_SLOT_VAR29,
|
||||
VARYING_SLOT_VAR30,
|
||||
VARYING_SLOT_VAR31,
|
||||
} gl_varying_slot;
|
||||
|
||||
|
||||
#define VARYING_SLOT_MAX (VARYING_SLOT_VAR0 + MAX_VARYING)
|
||||
#define VARYING_SLOT_PATCH0 (VARYING_SLOT_MAX)
|
||||
#define VARYING_SLOT_TESS_MAX (VARYING_SLOT_PATCH0 + MAX_VARYING)
|
||||
#define MAX_VARYINGS_INCL_PATCH (VARYING_SLOT_TESS_MAX - VARYING_SLOT_VAR0)
|
||||
|
||||
const char *gl_varying_slot_name(gl_varying_slot slot);
|
||||
|
||||
/**
|
||||
* Bitflags for varying slots.
|
||||
*/
|
||||
/*@{*/
|
||||
#define VARYING_BIT_POS BITFIELD64_BIT(VARYING_SLOT_POS)
|
||||
#define VARYING_BIT_COL0 BITFIELD64_BIT(VARYING_SLOT_COL0)
|
||||
#define VARYING_BIT_COL1 BITFIELD64_BIT(VARYING_SLOT_COL1)
|
||||
#define VARYING_BIT_FOGC BITFIELD64_BIT(VARYING_SLOT_FOGC)
|
||||
#define VARYING_BIT_TEX0 BITFIELD64_BIT(VARYING_SLOT_TEX0)
|
||||
#define VARYING_BIT_TEX1 BITFIELD64_BIT(VARYING_SLOT_TEX1)
|
||||
#define VARYING_BIT_TEX2 BITFIELD64_BIT(VARYING_SLOT_TEX2)
|
||||
#define VARYING_BIT_TEX3 BITFIELD64_BIT(VARYING_SLOT_TEX3)
|
||||
#define VARYING_BIT_TEX4 BITFIELD64_BIT(VARYING_SLOT_TEX4)
|
||||
#define VARYING_BIT_TEX5 BITFIELD64_BIT(VARYING_SLOT_TEX5)
|
||||
#define VARYING_BIT_TEX6 BITFIELD64_BIT(VARYING_SLOT_TEX6)
|
||||
#define VARYING_BIT_TEX7 BITFIELD64_BIT(VARYING_SLOT_TEX7)
|
||||
#define VARYING_BIT_TEX(U) BITFIELD64_BIT(VARYING_SLOT_TEX0 + (U))
|
||||
#define VARYING_BITS_TEX_ANY BITFIELD64_RANGE(VARYING_SLOT_TEX0, \
|
||||
MAX_TEXTURE_COORD_UNITS)
|
||||
#define VARYING_BIT_PSIZ BITFIELD64_BIT(VARYING_SLOT_PSIZ)
|
||||
#define VARYING_BIT_BFC0 BITFIELD64_BIT(VARYING_SLOT_BFC0)
|
||||
#define VARYING_BIT_BFC1 BITFIELD64_BIT(VARYING_SLOT_BFC1)
|
||||
#define VARYING_BITS_COLOR (VARYING_BIT_COL0 | \
|
||||
VARYING_BIT_COL1 | \
|
||||
VARYING_BIT_BFC0 | \
|
||||
VARYING_BIT_BFC1)
|
||||
#define VARYING_BIT_EDGE BITFIELD64_BIT(VARYING_SLOT_EDGE)
|
||||
#define VARYING_BIT_CLIP_VERTEX BITFIELD64_BIT(VARYING_SLOT_CLIP_VERTEX)
|
||||
#define VARYING_BIT_CLIP_DIST0 BITFIELD64_BIT(VARYING_SLOT_CLIP_DIST0)
|
||||
#define VARYING_BIT_CLIP_DIST1 BITFIELD64_BIT(VARYING_SLOT_CLIP_DIST1)
|
||||
#define VARYING_BIT_CULL_DIST0 BITFIELD64_BIT(VARYING_SLOT_CULL_DIST0)
|
||||
#define VARYING_BIT_CULL_DIST1 BITFIELD64_BIT(VARYING_SLOT_CULL_DIST1)
|
||||
#define VARYING_BIT_PRIMITIVE_ID BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_ID)
|
||||
#define VARYING_BIT_LAYER BITFIELD64_BIT(VARYING_SLOT_LAYER)
|
||||
#define VARYING_BIT_VIEWPORT BITFIELD64_BIT(VARYING_SLOT_VIEWPORT)
|
||||
#define VARYING_BIT_FACE BITFIELD64_BIT(VARYING_SLOT_FACE)
|
||||
#define VARYING_BIT_PNTC BITFIELD64_BIT(VARYING_SLOT_PNTC)
|
||||
#define VARYING_BIT_TESS_LEVEL_OUTER BITFIELD64_BIT(VARYING_SLOT_TESS_LEVEL_OUTER)
|
||||
#define VARYING_BIT_TESS_LEVEL_INNER BITFIELD64_BIT(VARYING_SLOT_TESS_LEVEL_INNER)
|
||||
#define VARYING_BIT_BOUNDING_BOX0 BITFIELD64_BIT(VARYING_SLOT_BOUNDING_BOX0)
|
||||
#define VARYING_BIT_BOUNDING_BOX1 BITFIELD64_BIT(VARYING_SLOT_BOUNDING_BOX1)
|
||||
#define VARYING_BIT_VIEWPORT_MASK BITFIELD64_BIT(VARYING_SLOT_VIEWPORT_MASK)
|
||||
#define VARYING_BIT_VAR(V) BITFIELD64_BIT(VARYING_SLOT_VAR0 + (V))
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* Bitflags for system values.
|
||||
*/
|
||||
#define SYSTEM_BIT_SAMPLE_ID ((uint64_t)1 << SYSTEM_VALUE_SAMPLE_ID)
|
||||
#define SYSTEM_BIT_SAMPLE_POS ((uint64_t)1 << SYSTEM_VALUE_SAMPLE_POS)
|
||||
#define SYSTEM_BIT_SAMPLE_MASK_IN ((uint64_t)1 << SYSTEM_VALUE_SAMPLE_MASK_IN)
|
||||
#define SYSTEM_BIT_LOCAL_INVOCATION_ID ((uint64_t)1 << SYSTEM_VALUE_LOCAL_INVOCATION_ID)
|
||||
|
||||
/**
|
||||
* If the gl_register_file is PROGRAM_SYSTEM_VALUE, the register index will be
|
||||
* one of these values. If a NIR variable's mode is nir_var_system_value, it
|
||||
* will be one of these values.
|
||||
*/
|
||||
typedef enum
|
||||
{
|
||||
/**
|
||||
* \name System values applicable to all shaders
|
||||
*/
|
||||
/*@{*/
|
||||
|
||||
/**
|
||||
* Builtin variables added by GL_ARB_shader_ballot.
|
||||
*/
|
||||
/*@{*/
|
||||
|
||||
/**
|
||||
* From the GL_ARB_shader-ballot spec:
|
||||
*
|
||||
* "A sub-group is a collection of invocations which execute in lockstep.
|
||||
* The variable <gl_SubGroupSizeARB> is the maximum number of
|
||||
* invocations in a sub-group. The maximum <gl_SubGroupSizeARB>
|
||||
* supported in this extension is 64."
|
||||
*
|
||||
* The spec defines this as a uniform. However, it's highly unlikely that
|
||||
* implementations actually treat it as a uniform (which is loaded from a
|
||||
* constant buffer). Most likely, this is an implementation-wide constant,
|
||||
* or perhaps something that depends on the shader stage.
|
||||
*/
|
||||
SYSTEM_VALUE_SUBGROUP_SIZE,
|
||||
|
||||
/**
|
||||
* From the GL_ARB_shader_ballot spec:
|
||||
*
|
||||
* "The variable <gl_SubGroupInvocationARB> holds the index of the
|
||||
* invocation within sub-group. This variable is in the range 0 to
|
||||
* <gl_SubGroupSizeARB>-1, where <gl_SubGroupSizeARB> is the total
|
||||
* number of invocations in a sub-group."
|
||||
*/
|
||||
SYSTEM_VALUE_SUBGROUP_INVOCATION,
|
||||
|
||||
/**
|
||||
* From the GL_ARB_shader_ballot spec:
|
||||
*
|
||||
* "The <gl_SubGroup??MaskARB> variables provide a bitmask for all
|
||||
* invocations, with one bit per invocation starting with the least
|
||||
* significant bit, according to the following table,
|
||||
*
|
||||
* variable equation for bit values
|
||||
* -------------------- ------------------------------------
|
||||
* gl_SubGroupEqMaskARB bit index == gl_SubGroupInvocationARB
|
||||
* gl_SubGroupGeMaskARB bit index >= gl_SubGroupInvocationARB
|
||||
* gl_SubGroupGtMaskARB bit index > gl_SubGroupInvocationARB
|
||||
* gl_SubGroupLeMaskARB bit index <= gl_SubGroupInvocationARB
|
||||
* gl_SubGroupLtMaskARB bit index < gl_SubGroupInvocationARB
|
||||
*/
|
||||
SYSTEM_VALUE_SUBGROUP_EQ_MASK,
|
||||
SYSTEM_VALUE_SUBGROUP_GE_MASK,
|
||||
SYSTEM_VALUE_SUBGROUP_GT_MASK,
|
||||
SYSTEM_VALUE_SUBGROUP_LE_MASK,
|
||||
SYSTEM_VALUE_SUBGROUP_LT_MASK,
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* Builtin variables added by VK_KHR_subgroups
|
||||
*/
|
||||
/*@{*/
|
||||
SYSTEM_VALUE_NUM_SUBGROUPS,
|
||||
SYSTEM_VALUE_SUBGROUP_ID,
|
||||
/*@}*/
|
||||
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* \name Vertex shader system values
|
||||
*/
|
||||
/*@{*/
|
||||
/**
|
||||
* OpenGL-style vertex ID.
|
||||
*
|
||||
* Section 2.11.7 (Shader Execution), subsection Shader Inputs, of the
|
||||
* OpenGL 3.3 core profile spec says:
|
||||
*
|
||||
* "gl_VertexID holds the integer index i implicitly passed by
|
||||
* DrawArrays or one of the other drawing commands defined in section
|
||||
* 2.8.3."
|
||||
*
|
||||
* Section 2.8.3 (Drawing Commands) of the same spec says:
|
||||
*
|
||||
* "The commands....are equivalent to the commands with the same base
|
||||
* name (without the BaseVertex suffix), except that the ith element
|
||||
* transferred by the corresponding draw call will be taken from
|
||||
* element indices[i] + basevertex of each enabled array."
|
||||
*
|
||||
* Additionally, the overview in the GL_ARB_shader_draw_parameters spec
|
||||
* says:
|
||||
*
|
||||
* "In unextended GL, vertex shaders have inputs named gl_VertexID and
|
||||
* gl_InstanceID, which contain, respectively the index of the vertex
|
||||
* and instance. The value of gl_VertexID is the implicitly passed
|
||||
* index of the vertex being processed, which includes the value of
|
||||
* baseVertex, for those commands that accept it."
|
||||
*
|
||||
* gl_VertexID gets basevertex added in. This differs from DirectX where
|
||||
* SV_VertexID does \b not get basevertex added in.
|
||||
*
|
||||
* \note
|
||||
* If all system values are available, \c SYSTEM_VALUE_VERTEX_ID will be
|
||||
* equal to \c SYSTEM_VALUE_VERTEX_ID_ZERO_BASE plus
|
||||
* \c SYSTEM_VALUE_BASE_VERTEX.
|
||||
*
|
||||
* \sa SYSTEM_VALUE_VERTEX_ID_ZERO_BASE, SYSTEM_VALUE_BASE_VERTEX
|
||||
*/
|
||||
SYSTEM_VALUE_VERTEX_ID,
|
||||
|
||||
/**
|
||||
* Instanced ID as supplied to gl_InstanceID
|
||||
*
|
||||
* Values assigned to gl_InstanceID always begin with zero, regardless of
|
||||
* the value of baseinstance.
|
||||
*
|
||||
* Section 11.1.3.9 (Shader Inputs) of the OpenGL 4.4 core profile spec
|
||||
* says:
|
||||
*
|
||||
* "gl_InstanceID holds the integer instance number of the current
|
||||
* primitive in an instanced draw call (see section 10.5)."
|
||||
*
|
||||
* Through a big chain of pseudocode, section 10.5 describes that
|
||||
* baseinstance is not counted by gl_InstanceID. In that section, notice
|
||||
*
|
||||
* "If an enabled vertex attribute array is instanced (it has a
|
||||
* non-zero divisor as specified by VertexAttribDivisor), the element
|
||||
* index that is transferred to the GL, for all vertices, is given by
|
||||
*
|
||||
* floor(instance/divisor) + baseinstance
|
||||
*
|
||||
* If an array corresponding to an attribute required by a vertex
|
||||
* shader is not enabled, then the corresponding element is taken from
|
||||
* the current attribute state (see section 10.2)."
|
||||
*
|
||||
* Note that baseinstance is \b not included in the value of instance.
|
||||
*/
|
||||
SYSTEM_VALUE_INSTANCE_ID,
|
||||
|
||||
/**
|
||||
* Vulkan InstanceIndex.
|
||||
*
|
||||
* InstanceIndex = gl_InstanceID + gl_BaseInstance
|
||||
*/
|
||||
SYSTEM_VALUE_INSTANCE_INDEX,
|
||||
|
||||
/**
|
||||
* DirectX-style vertex ID.
|
||||
*
|
||||
* Unlike \c SYSTEM_VALUE_VERTEX_ID, this system value does \b not include
|
||||
* the value of basevertex.
|
||||
*
|
||||
* \sa SYSTEM_VALUE_VERTEX_ID, SYSTEM_VALUE_BASE_VERTEX
|
||||
*/
|
||||
SYSTEM_VALUE_VERTEX_ID_ZERO_BASE,
|
||||
|
||||
/**
|
||||
* Value of \c basevertex passed to \c glDrawElementsBaseVertex and similar
|
||||
* functions.
|
||||
*
|
||||
* \sa SYSTEM_VALUE_VERTEX_ID, SYSTEM_VALUE_VERTEX_ID_ZERO_BASE
|
||||
*/
|
||||
SYSTEM_VALUE_BASE_VERTEX,
|
||||
|
||||
/**
|
||||
* Depending on the type of the draw call (indexed or non-indexed),
|
||||
* is the value of \c basevertex passed to \c glDrawElementsBaseVertex and
|
||||
* similar, or is the value of \c first passed to \c glDrawArrays and
|
||||
* similar.
|
||||
*
|
||||
* \note
|
||||
* It can be used to calculate the \c SYSTEM_VALUE_VERTEX_ID as
|
||||
* \c SYSTEM_VALUE_VERTEX_ID_ZERO_BASE plus \c SYSTEM_VALUE_FIRST_VERTEX.
|
||||
*
|
||||
* \sa SYSTEM_VALUE_VERTEX_ID_ZERO_BASE, SYSTEM_VALUE_VERTEX_ID
|
||||
*/
|
||||
SYSTEM_VALUE_FIRST_VERTEX,
|
||||
|
||||
/**
|
||||
* If the Draw command used to start the rendering was an indexed draw
|
||||
* or not (~0/0). Useful to calculate \c SYSTEM_VALUE_BASE_VERTEX as
|
||||
* \c SYSTEM_VALUE_IS_INDEXED_DRAW & \c SYSTEM_VALUE_FIRST_VERTEX.
|
||||
*/
|
||||
SYSTEM_VALUE_IS_INDEXED_DRAW,
|
||||
|
||||
/**
|
||||
* Value of \c baseinstance passed to instanced draw entry points
|
||||
*
|
||||
* \sa SYSTEM_VALUE_INSTANCE_ID
|
||||
*/
|
||||
SYSTEM_VALUE_BASE_INSTANCE,
|
||||
|
||||
/**
|
||||
* From _ARB_shader_draw_parameters:
|
||||
*
|
||||
* "Additionally, this extension adds a further built-in variable,
|
||||
* gl_DrawID to the shading language. This variable contains the index
|
||||
* of the draw currently being processed by a Multi* variant of a
|
||||
* drawing command (such as MultiDrawElements or
|
||||
* MultiDrawArraysIndirect)."
|
||||
*
|
||||
* If GL_ARB_multi_draw_indirect is not supported, this is always 0.
|
||||
*/
|
||||
SYSTEM_VALUE_DRAW_ID,
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* \name Geometry shader system values
|
||||
*/
|
||||
/*@{*/
|
||||
SYSTEM_VALUE_INVOCATION_ID, /**< (Also in Tessellation Control shader) */
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* \name Fragment shader system values
|
||||
*/
|
||||
/*@{*/
|
||||
SYSTEM_VALUE_FRAG_COORD,
|
||||
SYSTEM_VALUE_POINT_COORD,
|
||||
SYSTEM_VALUE_FRONT_FACE,
|
||||
SYSTEM_VALUE_SAMPLE_ID,
|
||||
SYSTEM_VALUE_SAMPLE_POS,
|
||||
SYSTEM_VALUE_SAMPLE_MASK_IN,
|
||||
SYSTEM_VALUE_HELPER_INVOCATION,
|
||||
SYSTEM_VALUE_COLOR0,
|
||||
SYSTEM_VALUE_COLOR1,
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* \name Tessellation Evaluation shader system values
|
||||
*/
|
||||
/*@{*/
|
||||
SYSTEM_VALUE_TESS_COORD,
|
||||
SYSTEM_VALUE_VERTICES_IN, /**< Tessellation vertices in input patch */
|
||||
SYSTEM_VALUE_PRIMITIVE_ID,
|
||||
SYSTEM_VALUE_TESS_LEVEL_OUTER, /**< TES input */
|
||||
SYSTEM_VALUE_TESS_LEVEL_INNER, /**< TES input */
|
||||
SYSTEM_VALUE_TESS_LEVEL_OUTER_DEFAULT, /**< TCS input for passthru TCS */
|
||||
SYSTEM_VALUE_TESS_LEVEL_INNER_DEFAULT, /**< TCS input for passthru TCS */
|
||||
/*@}*/
|
||||
|
||||
/**
|
||||
* \name Compute shader system values
|
||||
*/
|
||||
/*@{*/
|
||||
SYSTEM_VALUE_LOCAL_INVOCATION_ID,
|
||||
SYSTEM_VALUE_LOCAL_INVOCATION_INDEX,
|
||||
SYSTEM_VALUE_GLOBAL_INVOCATION_ID,
|
||||
SYSTEM_VALUE_GLOBAL_INVOCATION_INDEX,
|
||||
SYSTEM_VALUE_WORK_GROUP_ID,
|
||||
SYSTEM_VALUE_NUM_WORK_GROUPS,
|
||||
SYSTEM_VALUE_LOCAL_GROUP_SIZE,
|
||||
SYSTEM_VALUE_GLOBAL_GROUP_SIZE,
|
||||
SYSTEM_VALUE_WORK_DIM,
|
||||
SYSTEM_VALUE_USER_DATA_AMD,
|
||||
/*@}*/
|
||||
|
||||
/** Required for VK_KHR_device_group */
|
||||
SYSTEM_VALUE_DEVICE_INDEX,
|
||||
|
||||
/** Required for VK_KHX_multiview */
|
||||
SYSTEM_VALUE_VIEW_INDEX,
|
||||
|
||||
/**
|
||||
* Driver internal vertex-count, used (for example) for drivers to
|
||||
* calculate stride for stream-out outputs. Not externally visible.
|
||||
*/
|
||||
SYSTEM_VALUE_VERTEX_CNT,
|
||||
|
||||
/**
|
||||
* Required for AMD_shader_explicit_vertex_parameter and also used for
|
||||
* varying-fetch instructions.
|
||||
*
|
||||
* The _SIZE value is "primitive size", used to scale i/j in primitive
|
||||
* space to pixel space.
|
||||
*/
|
||||
SYSTEM_VALUE_BARYCENTRIC_PERSP_PIXEL,
|
||||
SYSTEM_VALUE_BARYCENTRIC_PERSP_SAMPLE,
|
||||
SYSTEM_VALUE_BARYCENTRIC_PERSP_CENTROID,
|
||||
SYSTEM_VALUE_BARYCENTRIC_PERSP_SIZE,
|
||||
SYSTEM_VALUE_BARYCENTRIC_LINEAR_PIXEL,
|
||||
SYSTEM_VALUE_BARYCENTRIC_LINEAR_CENTROID,
|
||||
SYSTEM_VALUE_BARYCENTRIC_LINEAR_SAMPLE,
|
||||
SYSTEM_VALUE_BARYCENTRIC_PULL_MODEL,
|
||||
|
||||
/**
|
||||
* IR3 specific geometry shader and tesselation control shader system
|
||||
* values that packs invocation id, thread id and vertex id. Having this
|
||||
* as a nir level system value lets us do the unpacking in nir.
|
||||
*/
|
||||
SYSTEM_VALUE_GS_HEADER_IR3,
|
||||
SYSTEM_VALUE_TCS_HEADER_IR3,
|
||||
|
||||
SYSTEM_VALUE_MAX /**< Number of values */
|
||||
} gl_system_value;
|
||||
|
||||
const char *gl_system_value_name(gl_system_value sysval);
|
||||
|
||||
/**
|
||||
* The possible interpolation qualifiers that can be applied to a fragment
|
||||
* shader input in GLSL.
|
||||
*
|
||||
* Note: INTERP_MODE_NONE must be 0 so that memsetting the
|
||||
* ir_variable data structure to 0 causes the default behavior.
|
||||
*/
|
||||
enum glsl_interp_mode
|
||||
{
|
||||
INTERP_MODE_NONE = 0,
|
||||
INTERP_MODE_SMOOTH,
|
||||
INTERP_MODE_FLAT,
|
||||
INTERP_MODE_NOPERSPECTIVE,
|
||||
INTERP_MODE_EXPLICIT,
|
||||
INTERP_MODE_COUNT /**< Number of interpolation qualifiers */
|
||||
};
|
||||
|
||||
enum glsl_interface_packing {
|
||||
GLSL_INTERFACE_PACKING_STD140,
|
||||
GLSL_INTERFACE_PACKING_SHARED,
|
||||
GLSL_INTERFACE_PACKING_PACKED,
|
||||
GLSL_INTERFACE_PACKING_STD430
|
||||
};
|
||||
|
||||
const char *glsl_interp_mode_name(enum glsl_interp_mode qual);
|
||||
|
||||
/**
|
||||
* Fragment program results
|
||||
*/
|
||||
typedef enum
|
||||
{
|
||||
FRAG_RESULT_DEPTH = 0,
|
||||
FRAG_RESULT_STENCIL = 1,
|
||||
/* If a single color should be written to all render targets, this
|
||||
* register is written. No FRAG_RESULT_DATAn will be written.
|
||||
*/
|
||||
FRAG_RESULT_COLOR = 2,
|
||||
FRAG_RESULT_SAMPLE_MASK = 3,
|
||||
|
||||
/* FRAG_RESULT_DATAn are the per-render-target (GLSL gl_FragData[n]
|
||||
* or ARB_fragment_program fragment.color[n]) color results. If
|
||||
* any are written, FRAG_RESULT_COLOR will not be written.
|
||||
* FRAG_RESULT_DATA1 and up are simply for the benefit of
|
||||
* gl_frag_result_name() and not to be construed as an upper bound
|
||||
*/
|
||||
FRAG_RESULT_DATA0 = 4,
|
||||
FRAG_RESULT_DATA1,
|
||||
FRAG_RESULT_DATA2,
|
||||
FRAG_RESULT_DATA3,
|
||||
FRAG_RESULT_DATA4,
|
||||
FRAG_RESULT_DATA5,
|
||||
FRAG_RESULT_DATA6,
|
||||
FRAG_RESULT_DATA7,
|
||||
} gl_frag_result;
|
||||
|
||||
const char *gl_frag_result_name(gl_frag_result result);
|
||||
|
||||
#define FRAG_RESULT_MAX (FRAG_RESULT_DATA0 + MAX_DRAW_BUFFERS)
|
||||
|
||||
/**
|
||||
* \brief Layout qualifiers for gl_FragDepth.
|
||||
*
|
||||
* Extension AMD_conservative_depth allows gl_FragDepth to be redeclared with
|
||||
* a layout qualifier.
|
||||
*
|
||||
* \see enum ir_depth_layout
|
||||
*/
|
||||
enum gl_frag_depth_layout
|
||||
{
|
||||
FRAG_DEPTH_LAYOUT_NONE, /**< No layout is specified. */
|
||||
FRAG_DEPTH_LAYOUT_ANY,
|
||||
FRAG_DEPTH_LAYOUT_GREATER,
|
||||
FRAG_DEPTH_LAYOUT_LESS,
|
||||
FRAG_DEPTH_LAYOUT_UNCHANGED
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Buffer access qualifiers
|
||||
*/
|
||||
enum gl_access_qualifier
|
||||
{
|
||||
ACCESS_COHERENT = (1 << 0),
|
||||
ACCESS_RESTRICT = (1 << 1),
|
||||
ACCESS_VOLATILE = (1 << 2),
|
||||
ACCESS_NON_READABLE = (1 << 3),
|
||||
ACCESS_NON_WRITEABLE = (1 << 4),
|
||||
|
||||
/** The access may use a non-uniform buffer or image index */
|
||||
ACCESS_NON_UNIFORM = (1 << 5),
|
||||
|
||||
/* This has the same semantics as NIR_INTRINSIC_CAN_REORDER, only to be
|
||||
* used with loads. In other words, it means that the load can be
|
||||
* arbitrarily reordered, or combined with other loads to the same address.
|
||||
* It is implied by ACCESS_NON_WRITEABLE together with ACCESS_RESTRICT, and
|
||||
* a lack of ACCESS_COHERENT and ACCESS_VOLATILE.
|
||||
*/
|
||||
ACCESS_CAN_REORDER = (1 << 6),
|
||||
|
||||
/** Use as little cache space as possible. */
|
||||
ACCESS_STREAM_CACHE_POLICY = (1 << 7),
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Blend support qualifiers
|
||||
*/
|
||||
enum gl_advanced_blend_mode
|
||||
{
|
||||
BLEND_NONE = 0x0000,
|
||||
|
||||
BLEND_MULTIPLY = 0x0001,
|
||||
BLEND_SCREEN = 0x0002,
|
||||
BLEND_OVERLAY = 0x0004,
|
||||
BLEND_DARKEN = 0x0008,
|
||||
BLEND_LIGHTEN = 0x0010,
|
||||
BLEND_COLORDODGE = 0x0020,
|
||||
BLEND_COLORBURN = 0x0040,
|
||||
BLEND_HARDLIGHT = 0x0080,
|
||||
BLEND_SOFTLIGHT = 0x0100,
|
||||
BLEND_DIFFERENCE = 0x0200,
|
||||
BLEND_EXCLUSION = 0x0400,
|
||||
BLEND_HSL_HUE = 0x0800,
|
||||
BLEND_HSL_SATURATION = 0x1000,
|
||||
BLEND_HSL_COLOR = 0x2000,
|
||||
BLEND_HSL_LUMINOSITY = 0x4000,
|
||||
|
||||
BLEND_ALL = 0x7fff,
|
||||
};
|
||||
|
||||
enum blend_func
|
||||
{
|
||||
BLEND_FUNC_ADD,
|
||||
BLEND_FUNC_SUBTRACT,
|
||||
BLEND_FUNC_REVERSE_SUBTRACT,
|
||||
BLEND_FUNC_MIN,
|
||||
BLEND_FUNC_MAX,
|
||||
};
|
||||
|
||||
enum blend_factor
|
||||
{
|
||||
BLEND_FACTOR_ZERO,
|
||||
BLEND_FACTOR_SRC_COLOR,
|
||||
BLEND_FACTOR_DST_COLOR,
|
||||
BLEND_FACTOR_SRC_ALPHA,
|
||||
BLEND_FACTOR_DST_ALPHA,
|
||||
BLEND_FACTOR_CONSTANT_COLOR,
|
||||
BLEND_FACTOR_CONSTANT_ALPHA,
|
||||
BLEND_FACTOR_SRC_ALPHA_SATURATE,
|
||||
};
|
||||
|
||||
enum gl_tess_spacing
|
||||
{
|
||||
TESS_SPACING_UNSPECIFIED,
|
||||
TESS_SPACING_EQUAL,
|
||||
TESS_SPACING_FRACTIONAL_ODD,
|
||||
TESS_SPACING_FRACTIONAL_EVEN,
|
||||
};
|
||||
|
||||
/**
|
||||
* A compare function enum for use in compiler lowering passes. This is in
|
||||
* the same order as GL's compare functions (shifted down by GL_NEVER), and is
|
||||
* exactly the same as gallium's PIPE_FUNC_*.
|
||||
*/
|
||||
enum compare_func
|
||||
{
|
||||
COMPARE_FUNC_NEVER,
|
||||
COMPARE_FUNC_LESS,
|
||||
COMPARE_FUNC_EQUAL,
|
||||
COMPARE_FUNC_LEQUAL,
|
||||
COMPARE_FUNC_GREATER,
|
||||
COMPARE_FUNC_NOTEQUAL,
|
||||
COMPARE_FUNC_GEQUAL,
|
||||
COMPARE_FUNC_ALWAYS,
|
||||
};
|
||||
|
||||
/**
|
||||
* Arrangements for grouping invocations from NV_compute_shader_derivatives.
|
||||
*
|
||||
* The extension provides new layout qualifiers that support two different
|
||||
* arrangements of compute shader invocations for the purpose of derivative
|
||||
* computation. When specifying
|
||||
*
|
||||
* layout(derivative_group_quadsNV) in;
|
||||
*
|
||||
* compute shader invocations are grouped into 2x2x1 arrays whose four local
|
||||
* invocation ID values follow the pattern:
|
||||
*
|
||||
* +-----------------+------------------+
|
||||
* | (2x+0, 2y+0, z) | (2x+1, 2y+0, z) |
|
||||
* +-----------------+------------------+
|
||||
* | (2x+0, 2y+1, z) | (2x+1, 2y+1, z) |
|
||||
* +-----------------+------------------+
|
||||
*
|
||||
* where Y increases from bottom to top. When specifying
|
||||
*
|
||||
* layout(derivative_group_linearNV) in;
|
||||
*
|
||||
* compute shader invocations are grouped into 2x2x1 arrays whose four local
|
||||
* invocation index values follow the pattern:
|
||||
*
|
||||
* +------+------+
|
||||
* | 4n+0 | 4n+1 |
|
||||
* +------+------+
|
||||
* | 4n+2 | 4n+3 |
|
||||
* +------+------+
|
||||
*
|
||||
* If neither layout qualifier is specified, derivatives in compute shaders
|
||||
* return zero, which is consistent with the handling of built-in texture
|
||||
* functions like texture() in GLSL 4.50 compute shaders.
|
||||
*/
|
||||
enum gl_derivative_group {
|
||||
DERIVATIVE_GROUP_NONE = 0,
|
||||
DERIVATIVE_GROUP_QUADS,
|
||||
DERIVATIVE_GROUP_LINEAR,
|
||||
};
|
||||
|
||||
enum float_controls
|
||||
{
|
||||
FLOAT_CONTROLS_DEFAULT_FLOAT_CONTROL_MODE = 0x0000,
|
||||
FLOAT_CONTROLS_DENORM_PRESERVE_FP16 = 0x0001,
|
||||
FLOAT_CONTROLS_DENORM_PRESERVE_FP32 = 0x0002,
|
||||
FLOAT_CONTROLS_DENORM_PRESERVE_FP64 = 0x0004,
|
||||
FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP16 = 0x0008,
|
||||
FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32 = 0x0010,
|
||||
FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP64 = 0x0020,
|
||||
FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 = 0x0040,
|
||||
FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 = 0x0080,
|
||||
FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64 = 0x0100,
|
||||
FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16 = 0x0200,
|
||||
FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP32 = 0x0400,
|
||||
FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP64 = 0x0800,
|
||||
FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16 = 0x1000,
|
||||
FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP32 = 0x2000,
|
||||
FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP64 = 0x4000,
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif
|
||||
|
||||
#endif /* SHADER_ENUMS_H */
|
||||
@@ -1,326 +0,0 @@
|
||||
/**************************************************************************
|
||||
*
|
||||
* Copyright 2008 VMware, Inc.
|
||||
* All Rights Reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the
|
||||
* "Software"), to deal in the Software without restriction, including
|
||||
* without limitation the rights to use, copy, modify, merge, publish,
|
||||
* distribute, sub license, and/or sell copies of the Software, and to
|
||||
* permit persons to whom the Software is furnished to do so, subject to
|
||||
* the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice (including the
|
||||
* next paragraph) shall be included in all copies or substantial portions
|
||||
* of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
|
||||
* IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
|
||||
* ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
**************************************************************************/
|
||||
|
||||
|
||||
#ifndef BITSCAN_H
|
||||
#define BITSCAN_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#if defined(__POPCNT__)
|
||||
#include <popcntintrin.h>
|
||||
#endif
|
||||
|
||||
//#include "c99_compat.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* Find first bit set in word. Least significant bit is 1.
|
||||
* Return 0 if no bits set.
|
||||
*/
|
||||
#ifdef HAVE___BUILTIN_FFS
|
||||
#define ffs __builtin_ffs
|
||||
#elif defined(_MSC_VER) && (_M_IX86 || _M_ARM || _M_AMD64 || _M_IA64)
|
||||
static inline
|
||||
int ffs(int i)
|
||||
{
|
||||
unsigned long index;
|
||||
if (_BitScanForward(&index, i))
|
||||
return index + 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
extern
|
||||
int ffs(int i);
|
||||
#endif
|
||||
|
||||
#ifdef HAVE___BUILTIN_FFSLL
|
||||
#define ffsll __builtin_ffsll
|
||||
#elif defined(_MSC_VER) && (_M_AMD64 || _M_ARM64 || _M_IA64)
|
||||
static inline int
|
||||
ffsll(long long int i)
|
||||
{
|
||||
unsigned long index;
|
||||
if (_BitScanForward64(&index, i))
|
||||
return index + 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
extern int
|
||||
ffsll(long long int val);
|
||||
#endif
|
||||
|
||||
|
||||
/* Destructively loop over all of the bits in a mask as in:
|
||||
*
|
||||
* while (mymask) {
|
||||
* int i = u_bit_scan(&mymask);
|
||||
* ... process element i
|
||||
* }
|
||||
*
|
||||
*/
|
||||
static inline int
|
||||
u_bit_scan(unsigned *mask)
|
||||
{
|
||||
const int i = ffs(*mask) - 1;
|
||||
*mask ^= (1u << i);
|
||||
return i;
|
||||
}
|
||||
|
||||
static inline int
|
||||
u_bit_scan64(uint64_t *mask)
|
||||
{
|
||||
const int i = ffsll(*mask) - 1;
|
||||
*mask ^= (((uint64_t)1) << i);
|
||||
return i;
|
||||
}
|
||||
|
||||
/* Determine if an unsigned value is a power of two.
|
||||
*
|
||||
* \note
|
||||
* Zero is treated as a power of two.
|
||||
*/
|
||||
static inline bool
|
||||
util_is_power_of_two_or_zero(unsigned v)
|
||||
{
|
||||
return (v & (v - 1)) == 0;
|
||||
}
|
||||
|
||||
/* Determine if an uint64_t value is a power of two.
|
||||
*
|
||||
* \note
|
||||
* Zero is treated as a power of two.
|
||||
*/
|
||||
static inline bool
|
||||
util_is_power_of_two_or_zero64(uint64_t v)
|
||||
{
|
||||
return (v & (v - 1)) == 0;
|
||||
}
|
||||
|
||||
/* Determine if an unsigned value is a power of two.
|
||||
*
|
||||
* \note
|
||||
* Zero is \b not treated as a power of two.
|
||||
*/
|
||||
static inline bool
|
||||
util_is_power_of_two_nonzero(unsigned v)
|
||||
{
|
||||
/* __POPCNT__ is different from HAVE___BUILTIN_POPCOUNT. The latter
|
||||
* indicates the existence of the __builtin_popcount function. The former
|
||||
* indicates that _mm_popcnt_u32 exists and is a native instruction.
|
||||
*
|
||||
* The other alternative is to use SSE 4.2 compile-time flags. This has
|
||||
* two drawbacks. First, there is currently no build infrastructure for
|
||||
* SSE 4.2 (only 4.1), so that would have to be added. Second, some AMD
|
||||
* CPUs support POPCNT but not SSE 4.2 (e.g., Barcelona).
|
||||
*/
|
||||
#ifdef __POPCNT__
|
||||
return _mm_popcnt_u32(v) == 1;
|
||||
#else
|
||||
return v != 0 && (v & (v - 1)) == 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
/* For looping over a bitmask when you want to loop over consecutive bits
|
||||
* manually, for example:
|
||||
*
|
||||
* while (mask) {
|
||||
* int start, count, i;
|
||||
*
|
||||
* u_bit_scan_consecutive_range(&mask, &start, &count);
|
||||
*
|
||||
* for (i = 0; i < count; i++)
|
||||
* ... process element (start+i)
|
||||
* }
|
||||
*/
|
||||
static inline void
|
||||
u_bit_scan_consecutive_range(unsigned *mask, int *start, int *count)
|
||||
{
|
||||
if (*mask == 0xffffffff) {
|
||||
*start = 0;
|
||||
*count = 32;
|
||||
*mask = 0;
|
||||
return;
|
||||
}
|
||||
*start = ffs(*mask) - 1;
|
||||
*count = ffs(~(*mask >> *start)) - 1;
|
||||
*mask &= ~(((1u << *count) - 1) << *start);
|
||||
}
|
||||
|
||||
static inline void
|
||||
u_bit_scan_consecutive_range64(uint64_t *mask, int *start, int *count)
|
||||
{
|
||||
if (*mask == ~0ull) {
|
||||
*start = 0;
|
||||
*count = 64;
|
||||
*mask = 0;
|
||||
return;
|
||||
}
|
||||
*start = ffsll(*mask) - 1;
|
||||
*count = ffsll(~(*mask >> *start)) - 1;
|
||||
*mask &= ~(((((uint64_t)1) << *count) - 1) << *start);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Find last bit set in a word. The least significant bit is 1.
|
||||
* Return 0 if no bits are set.
|
||||
* Essentially ffs() in the reverse direction.
|
||||
*/
|
||||
static inline unsigned
|
||||
util_last_bit(unsigned u)
|
||||
{
|
||||
#if defined(HAVE___BUILTIN_CLZ)
|
||||
return u == 0 ? 0 : 32 - __builtin_clz(u);
|
||||
#elif defined(_MSC_VER) && (_M_IX86 || _M_ARM || _M_AMD64 || _M_IA64)
|
||||
unsigned long index;
|
||||
if (_BitScanReverse(&index, u))
|
||||
return index + 1;
|
||||
else
|
||||
return 0;
|
||||
#else
|
||||
unsigned r = 0;
|
||||
while (u) {
|
||||
r++;
|
||||
u >>= 1;
|
||||
}
|
||||
return r;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Find last bit set in a word. The least significant bit is 1.
|
||||
* Return 0 if no bits are set.
|
||||
* Essentially ffsll() in the reverse direction.
|
||||
*/
|
||||
static inline unsigned
|
||||
util_last_bit64(uint64_t u)
|
||||
{
|
||||
#if defined(HAVE___BUILTIN_CLZLL)
|
||||
return u == 0 ? 0 : 64 - __builtin_clzll(u);
|
||||
#elif defined(_MSC_VER) && (_M_AMD64 || _M_ARM64 || _M_IA64)
|
||||
unsigned long index;
|
||||
if (_BitScanReverse64(&index, u))
|
||||
return index + 1;
|
||||
else
|
||||
return 0;
|
||||
#else
|
||||
unsigned r = 0;
|
||||
while (u) {
|
||||
r++;
|
||||
u >>= 1;
|
||||
}
|
||||
return r;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Find last bit in a word that does not match the sign bit. The least
|
||||
* significant bit is 1.
|
||||
* Return 0 if no bits are set.
|
||||
*/
|
||||
static inline unsigned
|
||||
util_last_bit_signed(int i)
|
||||
{
|
||||
if (i >= 0)
|
||||
return util_last_bit(i);
|
||||
else
|
||||
return util_last_bit(~(unsigned)i);
|
||||
}
|
||||
|
||||
/* Returns a bitfield in which the first count bits starting at start are
|
||||
* set.
|
||||
*/
|
||||
static inline unsigned
|
||||
u_bit_consecutive(unsigned start, unsigned count)
|
||||
{
|
||||
assert(start + count <= 32);
|
||||
if (count == 32)
|
||||
return ~0;
|
||||
return ((1u << count) - 1) << start;
|
||||
}
|
||||
|
||||
static inline uint64_t
|
||||
u_bit_consecutive64(unsigned start, unsigned count)
|
||||
{
|
||||
assert(start + count <= 64);
|
||||
if (count == 64)
|
||||
return ~(uint64_t)0;
|
||||
return (((uint64_t)1 << count) - 1) << start;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return number of bits set in n.
|
||||
*/
|
||||
static inline unsigned
|
||||
util_bitcount(unsigned n)
|
||||
{
|
||||
#if defined(HAVE___BUILTIN_POPCOUNT)
|
||||
return __builtin_popcount(n);
|
||||
#else
|
||||
/* K&R classic bitcount.
|
||||
*
|
||||
* For each iteration, clear the LSB from the bitfield.
|
||||
* Requires only one iteration per set bit, instead of
|
||||
* one iteration per bit less than highest set bit.
|
||||
*/
|
||||
unsigned bits;
|
||||
for (bits = 0; n; bits++) {
|
||||
n &= n - 1;
|
||||
}
|
||||
return bits;
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline unsigned
|
||||
util_bitcount64(uint64_t n)
|
||||
{
|
||||
#ifdef HAVE___BUILTIN_POPCOUNTLL
|
||||
return __builtin_popcountll(n);
|
||||
#else
|
||||
return util_bitcount(n) + util_bitcount(n >> 32);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* BITSCAN_H */
|
||||
@@ -1,261 +0,0 @@
|
||||
/*
|
||||
* Mesa 3-D graphics library
|
||||
*
|
||||
* Copyright (C) 2006 Brian Paul All Rights Reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included
|
||||
* in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
/**
|
||||
* \file bitset.h
|
||||
* \brief Bitset of arbitrary size definitions.
|
||||
* \author Michal Krol
|
||||
*/
|
||||
|
||||
#ifndef BITSET_H
|
||||
#define BITSET_H
|
||||
|
||||
//#include "util/bitscan.h"
|
||||
//#include "util/macros.h"
|
||||
|
||||
/****************************************************************************
|
||||
* generic bitset implementation
|
||||
*/
|
||||
|
||||
#define BITSET_WORD unsigned int
|
||||
#define BITSET_WORDBITS (sizeof (BITSET_WORD) * 8)
|
||||
|
||||
/* bitset declarations
|
||||
*/
|
||||
#define BITSET_WORDS(bits) (((bits) + BITSET_WORDBITS - 1) / BITSET_WORDBITS)
|
||||
#define BITSET_DECLARE(name, bits) BITSET_WORD name[BITSET_WORDS(bits)]
|
||||
|
||||
/* bitset operations
|
||||
*/
|
||||
#define BITSET_COPY(x, y) memcpy( (x), (y), sizeof (x) )
|
||||
#define BITSET_EQUAL(x, y) (memcmp( (x), (y), sizeof (x) ) == 0)
|
||||
#define BITSET_ZERO(x) memset( (x), 0, sizeof (x) )
|
||||
#define BITSET_ONES(x) memset( (x), 0xff, sizeof (x) )
|
||||
|
||||
#define BITSET_BITWORD(b) ((b) / BITSET_WORDBITS)
|
||||
#define BITSET_BIT(b) (1u << ((b) % BITSET_WORDBITS))
|
||||
|
||||
/* single bit operations
|
||||
*/
|
||||
#define BITSET_TEST(x, b) (((x)[BITSET_BITWORD(b)] & BITSET_BIT(b)) != 0)
|
||||
#define BITSET_SET(x, b) ((x)[BITSET_BITWORD(b)] |= BITSET_BIT(b))
|
||||
#define BITSET_CLEAR(x, b) ((x)[BITSET_BITWORD(b)] &= ~BITSET_BIT(b))
|
||||
|
||||
#define BITSET_MASK(b) (((b) % BITSET_WORDBITS == 0) ? ~0 : BITSET_BIT(b) - 1)
|
||||
#define BITSET_RANGE(b, e) ((BITSET_MASK((e) + 1)) & ~(BITSET_BIT(b) - 1))
|
||||
|
||||
/* bit range operations
|
||||
*/
|
||||
#define BITSET_TEST_RANGE(x, b, e) \
|
||||
(BITSET_BITWORD(b) == BITSET_BITWORD(e) ? \
|
||||
(((x)[BITSET_BITWORD(b)] & BITSET_RANGE(b, e)) != 0) : \
|
||||
(assert (!"BITSET_TEST_RANGE: bit range crosses word boundary"), 0))
|
||||
#define BITSET_SET_RANGE(x, b, e) \
|
||||
(BITSET_BITWORD(b) == BITSET_BITWORD(e) ? \
|
||||
((x)[BITSET_BITWORD(b)] |= BITSET_RANGE(b, e)) : \
|
||||
(assert (!"BITSET_SET_RANGE: bit range crosses word boundary"), 0))
|
||||
#define BITSET_CLEAR_RANGE(x, b, e) \
|
||||
(BITSET_BITWORD(b) == BITSET_BITWORD(e) ? \
|
||||
((x)[BITSET_BITWORD(b)] &= ~BITSET_RANGE(b, e)) : \
|
||||
(assert (!"BITSET_CLEAR_RANGE: bit range crosses word boundary"), 0))
|
||||
|
||||
/* Get first bit set in a bitset.
|
||||
*/
|
||||
static inline int
|
||||
__bitset_ffs(const BITSET_WORD *x, int n)
|
||||
{
|
||||
int i;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
if (x[i])
|
||||
return ffs(x[i]) + BITSET_WORDBITS * i;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define BITSET_FFS(x) __bitset_ffs(x, ARRAY_SIZE(x))
|
||||
|
||||
static inline unsigned
|
||||
__bitset_next_set(unsigned i, BITSET_WORD *tmp,
|
||||
const BITSET_WORD *set, unsigned size)
|
||||
{
|
||||
unsigned bit, word;
|
||||
|
||||
/* NOTE: The initial conditions for this function are very specific. At
|
||||
* the start of the loop, the tmp variable must be set to *set and the
|
||||
* initial i value set to 0. This way, if there is a bit set in the first
|
||||
* word, we ignore the i-value and just grab that bit (so 0 is ok, even
|
||||
* though 0 may be returned). If the first word is 0, then the value of
|
||||
* `word` will be 0 and we will go on to look at the second word.
|
||||
*/
|
||||
word = BITSET_BITWORD(i);
|
||||
while (*tmp == 0) {
|
||||
word++;
|
||||
|
||||
if (word >= BITSET_WORDS(size))
|
||||
return size;
|
||||
|
||||
*tmp = set[word];
|
||||
}
|
||||
|
||||
/* Find the next set bit in the non-zero word */
|
||||
bit = ffs(*tmp) - 1;
|
||||
|
||||
/* Unset the bit */
|
||||
*tmp &= ~(1ull << bit);
|
||||
|
||||
return word * BITSET_WORDBITS + bit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterates over each set bit in a set
|
||||
*
|
||||
* @param __i iteration variable, bit number
|
||||
* @param __set the bitset to iterate (will not be modified)
|
||||
* @param __size number of bits in the set to consider
|
||||
*/
|
||||
#define BITSET_FOREACH_SET(__i, __set, __size) \
|
||||
for (BITSET_WORD __tmp = *(__set), *__foo = &__tmp; __foo != NULL; __foo = NULL) \
|
||||
for (__i = 0; \
|
||||
(__i = __bitset_next_set(__i, &__tmp, __set, __size)) < __size;)
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
/**
|
||||
* Simple C++ wrapper of a bitset type of static size, with value semantics
|
||||
* and basic bitwise arithmetic operators. The operators defined below are
|
||||
* expected to have the same semantics as the same operator applied to other
|
||||
* fundamental integer types. T is the name of the struct to instantiate
|
||||
* it as, and N is the number of bits in the bitset.
|
||||
*/
|
||||
#define DECLARE_BITSET_T(T, N) struct T { \
|
||||
EXPLICIT_CONVERSION \
|
||||
operator bool() const \
|
||||
{ \
|
||||
for (unsigned i = 0; i < BITSET_WORDS(N); i++) \
|
||||
if (words[i]) \
|
||||
return true; \
|
||||
return false; \
|
||||
} \
|
||||
\
|
||||
T & \
|
||||
operator=(int x) \
|
||||
{ \
|
||||
const T c = {{ (BITSET_WORD)x }}; \
|
||||
return *this = c; \
|
||||
} \
|
||||
\
|
||||
friend bool \
|
||||
operator==(const T &b, const T &c) \
|
||||
{ \
|
||||
return BITSET_EQUAL(b.words, c.words); \
|
||||
} \
|
||||
\
|
||||
friend bool \
|
||||
operator!=(const T &b, const T &c) \
|
||||
{ \
|
||||
return !(b == c); \
|
||||
} \
|
||||
\
|
||||
friend bool \
|
||||
operator==(const T &b, int x) \
|
||||
{ \
|
||||
const T c = {{ (BITSET_WORD)x }}; \
|
||||
return b == c; \
|
||||
} \
|
||||
\
|
||||
friend bool \
|
||||
operator!=(const T &b, int x) \
|
||||
{ \
|
||||
return !(b == x); \
|
||||
} \
|
||||
\
|
||||
friend T \
|
||||
operator~(const T &b) \
|
||||
{ \
|
||||
T c; \
|
||||
for (unsigned i = 0; i < BITSET_WORDS(N); i++) \
|
||||
c.words[i] = ~b.words[i]; \
|
||||
return c; \
|
||||
} \
|
||||
\
|
||||
T & \
|
||||
operator|=(const T &b) \
|
||||
{ \
|
||||
for (unsigned i = 0; i < BITSET_WORDS(N); i++) \
|
||||
words[i] |= b.words[i]; \
|
||||
return *this; \
|
||||
} \
|
||||
\
|
||||
friend T \
|
||||
operator|(const T &b, const T &c) \
|
||||
{ \
|
||||
T d = b; \
|
||||
d |= c; \
|
||||
return d; \
|
||||
} \
|
||||
\
|
||||
T & \
|
||||
operator&=(const T &b) \
|
||||
{ \
|
||||
for (unsigned i = 0; i < BITSET_WORDS(N); i++) \
|
||||
words[i] &= b.words[i]; \
|
||||
return *this; \
|
||||
} \
|
||||
\
|
||||
friend T \
|
||||
operator&(const T &b, const T &c) \
|
||||
{ \
|
||||
T d = b; \
|
||||
d &= c; \
|
||||
return d; \
|
||||
} \
|
||||
\
|
||||
bool \
|
||||
test(unsigned i) const \
|
||||
{ \
|
||||
return BITSET_TEST(words, i); \
|
||||
} \
|
||||
\
|
||||
T & \
|
||||
set(unsigned i) \
|
||||
{ \
|
||||
BITSET_SET(words, i); \
|
||||
return *this; \
|
||||
} \
|
||||
\
|
||||
T & \
|
||||
clear(unsigned i) \
|
||||
{ \
|
||||
BITSET_CLEAR(words, i); \
|
||||
return *this; \
|
||||
} \
|
||||
\
|
||||
BITSET_WORD words[BITSET_WORDS(N)]; \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,262 +0,0 @@
|
||||
/**************************************************************************
|
||||
*
|
||||
* Copyright 2006 VMware, Inc., Bismarck, ND. USA.
|
||||
* All Rights Reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the
|
||||
* "Software"), to deal in the Software without restriction, including
|
||||
* without limitation the rights to use, copy, modify, merge, publish,
|
||||
* distribute, sub license, and/or sell copies of the Software, and to
|
||||
* permit persons to whom the Software is furnished to do so, subject to
|
||||
* the following conditions:
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
|
||||
* THE COPYRIGHT HOLDERS, AUTHORS AND/OR ITS SUPPLIERS BE LIABLE FOR ANY CLAIM,
|
||||
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
||||
* USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
* The above copyright notice and this permission notice (including the
|
||||
* next paragraph) shall be included in all copies or substantial portions
|
||||
* of the Software.
|
||||
*
|
||||
**************************************************************************/
|
||||
|
||||
/**
|
||||
* \file
|
||||
* List macros heavily inspired by the Linux kernel
|
||||
* list handling. No list looping yet.
|
||||
*
|
||||
* Is not threadsafe, so common operations need to
|
||||
* be protected using an external mutex.
|
||||
*/
|
||||
|
||||
#ifndef _UTIL_LIST_H_
|
||||
#define _UTIL_LIST_H_
|
||||
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <assert.h>
|
||||
|
||||
#ifdef DEBUG
|
||||
# define list_assert(cond, msg) assert(cond && msg)
|
||||
#else
|
||||
# define list_assert(cond, msg) (void)(0 && (cond))
|
||||
#endif
|
||||
|
||||
struct list_head
|
||||
{
|
||||
struct list_head *prev;
|
||||
struct list_head *next;
|
||||
};
|
||||
|
||||
static inline void list_inithead(struct list_head *item)
|
||||
{
|
||||
item->prev = item;
|
||||
item->next = item;
|
||||
}
|
||||
|
||||
static inline void list_add(struct list_head *item, struct list_head *list)
|
||||
{
|
||||
item->prev = list;
|
||||
item->next = list->next;
|
||||
list->next->prev = item;
|
||||
list->next = item;
|
||||
}
|
||||
|
||||
static inline void list_addtail(struct list_head *item, struct list_head *list)
|
||||
{
|
||||
item->next = list;
|
||||
item->prev = list->prev;
|
||||
list->prev->next = item;
|
||||
list->prev = item;
|
||||
}
|
||||
|
||||
static inline bool list_is_empty(const struct list_head *list);
|
||||
|
||||
static inline void list_replace(struct list_head *from, struct list_head *to)
|
||||
{
|
||||
if (list_is_empty(from)) {
|
||||
list_inithead(to);
|
||||
} else {
|
||||
to->prev = from->prev;
|
||||
to->next = from->next;
|
||||
from->next->prev = to;
|
||||
from->prev->next = to;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void list_del(struct list_head *item)
|
||||
{
|
||||
item->prev->next = item->next;
|
||||
item->next->prev = item->prev;
|
||||
item->prev = item->next = NULL;
|
||||
}
|
||||
|
||||
static inline void list_delinit(struct list_head *item)
|
||||
{
|
||||
item->prev->next = item->next;
|
||||
item->next->prev = item->prev;
|
||||
item->next = item;
|
||||
item->prev = item;
|
||||
}
|
||||
|
||||
static inline bool list_is_empty(const struct list_head *list)
|
||||
{
|
||||
return list->next == list;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the list has exactly one element.
|
||||
*/
|
||||
static inline bool list_is_singular(const struct list_head *list)
|
||||
{
|
||||
return list->next != NULL && list->next != list && list->next->next == list;
|
||||
}
|
||||
|
||||
static inline unsigned list_length(const struct list_head *list)
|
||||
{
|
||||
struct list_head *node;
|
||||
unsigned length = 0;
|
||||
for (node = list->next; node != list; node = node->next)
|
||||
length++;
|
||||
return length;
|
||||
}
|
||||
|
||||
static inline void list_splice(struct list_head *src, struct list_head *dst)
|
||||
{
|
||||
if (list_is_empty(src))
|
||||
return;
|
||||
|
||||
src->next->prev = dst;
|
||||
src->prev->next = dst->next;
|
||||
dst->next->prev = src->prev;
|
||||
dst->next = src->next;
|
||||
}
|
||||
|
||||
static inline void list_splicetail(struct list_head *src, struct list_head *dst)
|
||||
{
|
||||
if (list_is_empty(src))
|
||||
return;
|
||||
|
||||
src->prev->next = dst;
|
||||
src->next->prev = dst->prev;
|
||||
dst->prev->next = src->next;
|
||||
dst->prev = src->prev;
|
||||
}
|
||||
|
||||
static inline void list_validate(const struct list_head *list)
|
||||
{
|
||||
struct list_head *node;
|
||||
assert(list->next->prev == list && list->prev->next == list);
|
||||
for (node = list->next; node != list; node = node->next)
|
||||
assert(node->next->prev == node && node->prev->next == node);
|
||||
}
|
||||
|
||||
#define LIST_ENTRY(__type, __item, __field) \
|
||||
((__type *)(((char *)(__item)) - offsetof(__type, __field)))
|
||||
|
||||
/**
|
||||
* Cast from a pointer to a member of a struct back to the containing struct.
|
||||
*
|
||||
* 'sample' MUST be initialized, or else the result is undefined!
|
||||
*/
|
||||
#ifndef container_of
|
||||
#define container_of(ptr, sample, member) \
|
||||
(void *)((char *)(ptr) \
|
||||
- ((char *)&(sample)->member - (char *)(sample)))
|
||||
#endif
|
||||
|
||||
#define list_first_entry(ptr, type, member) \
|
||||
LIST_ENTRY(type, (ptr)->next, member)
|
||||
|
||||
#define list_last_entry(ptr, type, member) \
|
||||
LIST_ENTRY(type, (ptr)->prev, member)
|
||||
|
||||
|
||||
#define LIST_FOR_EACH_ENTRY(pos, head, member) \
|
||||
for (pos = NULL, pos = container_of((head)->next, pos, member); \
|
||||
&pos->member != (head); \
|
||||
pos = container_of(pos->member.next, pos, member))
|
||||
|
||||
#define LIST_FOR_EACH_ENTRY_SAFE(pos, storage, head, member) \
|
||||
for (pos = NULL, pos = container_of((head)->next, pos, member), \
|
||||
storage = container_of(pos->member.next, pos, member); \
|
||||
&pos->member != (head); \
|
||||
pos = storage, storage = container_of(storage->member.next, storage, member))
|
||||
|
||||
#define LIST_FOR_EACH_ENTRY_SAFE_REV(pos, storage, head, member) \
|
||||
for (pos = NULL, pos = container_of((head)->prev, pos, member), \
|
||||
storage = container_of(pos->member.prev, pos, member); \
|
||||
&pos->member != (head); \
|
||||
pos = storage, storage = container_of(storage->member.prev, storage, member))
|
||||
|
||||
#define LIST_FOR_EACH_ENTRY_FROM(pos, start, head, member) \
|
||||
for (pos = NULL, pos = container_of((start), pos, member); \
|
||||
&pos->member != (head); \
|
||||
pos = container_of(pos->member.next, pos, member))
|
||||
|
||||
#define LIST_FOR_EACH_ENTRY_FROM_REV(pos, start, head, member) \
|
||||
for (pos = NULL, pos = container_of((start), pos, member); \
|
||||
&pos->member != (head); \
|
||||
pos = container_of(pos->member.prev, pos, member))
|
||||
|
||||
#define list_for_each_entry(type, pos, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (head)->next, member), \
|
||||
*__next = LIST_ENTRY(type, pos->member.next, member); \
|
||||
&pos->member != (head); \
|
||||
pos = LIST_ENTRY(type, pos->member.next, member), \
|
||||
list_assert(pos == __next, "use _safe iterator"), \
|
||||
__next = LIST_ENTRY(type, __next->member.next, member))
|
||||
|
||||
#define list_for_each_entry_safe(type, pos, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (head)->next, member), \
|
||||
*__next = LIST_ENTRY(type, pos->member.next, member); \
|
||||
&pos->member != (head); \
|
||||
pos = __next, \
|
||||
__next = LIST_ENTRY(type, __next->member.next, member))
|
||||
|
||||
#define list_for_each_entry_rev(type, pos, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (head)->prev, member), \
|
||||
*__prev = LIST_ENTRY(type, pos->member.prev, member); \
|
||||
&pos->member != (head); \
|
||||
pos = LIST_ENTRY(type, pos->member.prev, member), \
|
||||
list_assert(pos == __prev, "use _safe iterator"), \
|
||||
__prev = LIST_ENTRY(type, __prev->member.prev, member))
|
||||
|
||||
#define list_for_each_entry_safe_rev(type, pos, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (head)->prev, member), \
|
||||
*__prev = LIST_ENTRY(type, pos->member.prev, member); \
|
||||
&pos->member != (head); \
|
||||
pos = __prev, \
|
||||
__prev = LIST_ENTRY(type, __prev->member.prev, member))
|
||||
|
||||
#define list_for_each_entry_from(type, pos, start, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (start), member); \
|
||||
&pos->member != (head); \
|
||||
pos = LIST_ENTRY(type, pos->member.next, member))
|
||||
|
||||
#define list_for_each_entry_from_safe(type, pos, start, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (start), member), \
|
||||
*__next = LIST_ENTRY(type, pos->member.next, member); \
|
||||
&pos->member != (head); \
|
||||
pos = __next, \
|
||||
__next = LIST_ENTRY(type, __next->member.next, member))
|
||||
|
||||
#define list_for_each_entry_from_rev(type, pos, start, head, member) \
|
||||
for (type *pos = LIST_ENTRY(type, (start), member); \
|
||||
&pos->member != (head); \
|
||||
pos = LIST_ENTRY(type, pos->member.prev, member))
|
||||
|
||||
#define list_pair_for_each_entry(type, pos1, pos2, head1, head2, member) \
|
||||
for (type *pos1 = LIST_ENTRY(type, (head1)->next, member), \
|
||||
*pos2 = LIST_ENTRY(type, (head2)->next, member); \
|
||||
&pos1->member != (head1) && &pos2->member != (head2); \
|
||||
pos1 = LIST_ENTRY(type, pos1->member.next, member), \
|
||||
pos2 = LIST_ENTRY(type, pos2->member.next, member))
|
||||
|
||||
#endif /*_UTIL_LIST_H_*/
|
||||
@@ -1,346 +0,0 @@
|
||||
/*
|
||||
* Copyright © 2014 Intel Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice (including the next
|
||||
* paragraph) shall be included in all copies or substantial portions of the
|
||||
* Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef UTIL_MACROS_H
|
||||
#define UTIL_MACROS_H
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
/* Compute the size of an array */
|
||||
#ifndef ARRAY_SIZE
|
||||
# define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
|
||||
#endif
|
||||
|
||||
/* For compatibility with Clang's __has_builtin() */
|
||||
#ifndef __has_builtin
|
||||
# define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
/**
|
||||
* __builtin_expect macros
|
||||
*/
|
||||
#if !defined(HAVE___BUILTIN_EXPECT)
|
||||
# define __builtin_expect(x, y) (x)
|
||||
#endif
|
||||
|
||||
#ifndef likely
|
||||
# ifdef HAVE___BUILTIN_EXPECT
|
||||
# define likely(x) __builtin_expect(!!(x), 1)
|
||||
# define unlikely(x) __builtin_expect(!!(x), 0)
|
||||
# else
|
||||
# define likely(x) (x)
|
||||
# define unlikely(x) (x)
|
||||
# endif
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* Static (compile-time) assertion.
|
||||
* Basically, use COND to dimension an array. If COND is false/zero the
|
||||
* array size will be -1 and we'll get a compilation error.
|
||||
*/
|
||||
#define STATIC_ASSERT(COND) \
|
||||
do { \
|
||||
(void) sizeof(char [1 - 2*!(COND)]); \
|
||||
} while (0)
|
||||
|
||||
|
||||
/**
|
||||
* Unreachable macro. Useful for suppressing "control reaches end of non-void
|
||||
* function" warnings.
|
||||
*/
|
||||
#if defined(HAVE___BUILTIN_UNREACHABLE) || __has_builtin(__builtin_unreachable)
|
||||
#define unreachable(str) \
|
||||
do { \
|
||||
assert(!str); \
|
||||
__builtin_unreachable(); \
|
||||
} while (0)
|
||||
#elif defined (_MSC_VER)
|
||||
#define unreachable(str) \
|
||||
do { \
|
||||
assert(!str); \
|
||||
__assume(0); \
|
||||
} while (0)
|
||||
#else
|
||||
#define unreachable(str) assert(!str)
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Assume macro. Useful for expressing our assumptions to the compiler,
|
||||
* typically for purposes of silencing warnings.
|
||||
*/
|
||||
#if __has_builtin(__builtin_assume)
|
||||
#define assume(expr) \
|
||||
do { \
|
||||
assert(expr); \
|
||||
__builtin_assume(expr); \
|
||||
} while (0)
|
||||
#elif defined HAVE___BUILTIN_UNREACHABLE
|
||||
#define assume(expr) ((expr) ? ((void) 0) \
|
||||
: (assert(!"assumption failed"), \
|
||||
__builtin_unreachable()))
|
||||
#elif defined (_MSC_VER)
|
||||
#define assume(expr) __assume(expr)
|
||||
#else
|
||||
#define assume(expr) assert(expr)
|
||||
#endif
|
||||
|
||||
/* Attribute const is used for functions that have no effects other than their
|
||||
* return value, and only rely on the argument values to compute the return
|
||||
* value. As a result, calls to it can be CSEed. Note that using memory
|
||||
* pointed to by the arguments is not allowed for const functions.
|
||||
*/
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_CONST
|
||||
#define ATTRIBUTE_CONST __attribute__((__const__))
|
||||
#else
|
||||
#define ATTRIBUTE_CONST
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_FLATTEN
|
||||
#define FLATTEN __attribute__((__flatten__))
|
||||
#else
|
||||
#define FLATTEN
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_FORMAT
|
||||
#define PRINTFLIKE(f, a) __attribute__ ((format(__printf__, f, a)))
|
||||
#else
|
||||
#define PRINTFLIKE(f, a)
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_MALLOC
|
||||
#define MALLOCLIKE __attribute__((__malloc__))
|
||||
#else
|
||||
#define MALLOCLIKE
|
||||
#endif
|
||||
|
||||
/* Forced function inlining */
|
||||
/* Note: Clang also sets __GNUC__ (see other cases below) */
|
||||
#ifndef ALWAYS_INLINE
|
||||
# if defined(__GNUC__)
|
||||
# define ALWAYS_INLINE inline __attribute__((always_inline))
|
||||
# elif defined(_MSC_VER)
|
||||
# define ALWAYS_INLINE __forceinline
|
||||
# else
|
||||
# define ALWAYS_INLINE inline
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Used to optionally mark structures with misaligned elements or size as
|
||||
* packed, to trade off performance for space.
|
||||
*/
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_PACKED
|
||||
#define PACKED __attribute__((__packed__))
|
||||
#else
|
||||
#define PACKED
|
||||
#endif
|
||||
|
||||
/* Attribute pure is used for functions that have no effects other than their
|
||||
* return value. As a result, calls to it can be dead code eliminated.
|
||||
*/
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_PURE
|
||||
#define ATTRIBUTE_PURE __attribute__((__pure__))
|
||||
#else
|
||||
#define ATTRIBUTE_PURE
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_RETURNS_NONNULL
|
||||
#define ATTRIBUTE_RETURNS_NONNULL __attribute__((__returns_nonnull__))
|
||||
#else
|
||||
#define ATTRIBUTE_RETURNS_NONNULL
|
||||
#endif
|
||||
|
||||
#ifndef NORETURN
|
||||
# ifdef _MSC_VER
|
||||
# define NORETURN __declspec(noreturn)
|
||||
# elif defined HAVE_FUNC_ATTRIBUTE_NORETURN
|
||||
# define NORETURN __attribute__((__noreturn__))
|
||||
# else
|
||||
# define NORETURN
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
/**
|
||||
* Macro function that evaluates to true if T is a trivially
|
||||
* destructible type -- that is, if its (non-virtual) destructor
|
||||
* performs no action and all member variables and base classes are
|
||||
* trivially destructible themselves.
|
||||
*/
|
||||
# if (defined(__clang__) && defined(__has_feature))
|
||||
# if __has_feature(has_trivial_destructor)
|
||||
# define HAS_TRIVIAL_DESTRUCTOR(T) __has_trivial_destructor(T)
|
||||
# endif
|
||||
# elif defined(__GNUC__)
|
||||
# if ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ >= 3)))
|
||||
# define HAS_TRIVIAL_DESTRUCTOR(T) __has_trivial_destructor(T)
|
||||
# endif
|
||||
# elif defined(_MSC_VER) && !defined(__INTEL_COMPILER)
|
||||
# define HAS_TRIVIAL_DESTRUCTOR(T) __has_trivial_destructor(T)
|
||||
# endif
|
||||
# ifndef HAS_TRIVIAL_DESTRUCTOR
|
||||
/* It's always safe (if inefficient) to assume that a
|
||||
* destructor is non-trivial.
|
||||
*/
|
||||
# define HAS_TRIVIAL_DESTRUCTOR(T) (false)
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
* PUBLIC/USED macros
|
||||
*
|
||||
* If we build the library with gcc's -fvisibility=hidden flag, we'll
|
||||
* use the PUBLIC macro to mark functions that are to be exported.
|
||||
*
|
||||
* We also need to define a USED attribute, so the optimizer doesn't
|
||||
* inline a static function that we later use in an alias. - ajax
|
||||
*/
|
||||
#ifndef PUBLIC
|
||||
# if defined(__GNUC__)
|
||||
# define PUBLIC __attribute__((visibility("default")))
|
||||
# define USED __attribute__((used))
|
||||
# elif defined(_MSC_VER)
|
||||
# define PUBLIC __declspec(dllexport)
|
||||
# define USED
|
||||
# else
|
||||
# define PUBLIC
|
||||
# define USED
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
* UNUSED marks variables (or sometimes functions) that have to be defined,
|
||||
* but are sometimes (or always) unused beyond that. A common case is for
|
||||
* a function parameter to be used in some build configurations but not others.
|
||||
* Another case is fallback vfuncs that don't do anything with their params.
|
||||
*
|
||||
* Note that this should not be used for identifiers used in `assert()`;
|
||||
* see ASSERTED below.
|
||||
*/
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_UNUSED
|
||||
#define UNUSED __attribute__((unused))
|
||||
#else
|
||||
#define UNUSED
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Use ASSERTED to indicate that an identifier is unused outside of an `assert()`,
|
||||
* so that assert-free builds don't get "unused variable" warnings.
|
||||
*/
|
||||
#ifdef NDEBUG
|
||||
#define ASSERTED UNUSED
|
||||
#else
|
||||
#define ASSERTED
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FUNC_ATTRIBUTE_WARN_UNUSED_RESULT
|
||||
#define MUST_CHECK __attribute__((warn_unused_result))
|
||||
#else
|
||||
#define MUST_CHECK
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#define ATTRIBUTE_NOINLINE __attribute__((noinline))
|
||||
#else
|
||||
#define ATTRIBUTE_NOINLINE
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* Check that STRUCT::FIELD can hold MAXVAL. We use a lot of bitfields
|
||||
* in Mesa/gallium. We have to be sure they're of sufficient size to
|
||||
* hold the largest expected value.
|
||||
* Note that with MSVC, enums are signed and enum bitfields need one extra
|
||||
* high bit (always zero) to ensure the max value is handled correctly.
|
||||
* This macro will detect that with MSVC, but not GCC.
|
||||
*/
|
||||
#define ASSERT_BITFIELD_SIZE(STRUCT, FIELD, MAXVAL) \
|
||||
do { \
|
||||
ASSERTED STRUCT s; \
|
||||
s.FIELD = (MAXVAL); \
|
||||
assert((int) s.FIELD == (MAXVAL) && "Insufficient bitfield size!"); \
|
||||
} while (0)
|
||||
|
||||
|
||||
/** Compute ceiling of integer quotient of A divided by B. */
|
||||
#define DIV_ROUND_UP( A, B ) ( ((A) + (B) - 1) / (B) )
|
||||
|
||||
/** Clamp X to [MIN,MAX]. Turn NaN into MIN, arbitrarily. */
|
||||
#define CLAMP( X, MIN, MAX ) ( (X)>(MIN) ? ((X)>(MAX) ? (MAX) : (X)) : (MIN) )
|
||||
|
||||
/** Minimum of two values: */
|
||||
#define MIN2( A, B ) ( (A)<(B) ? (A) : (B) )
|
||||
|
||||
/** Maximum of two values: */
|
||||
#define MAX2( A, B ) ( (A)>(B) ? (A) : (B) )
|
||||
|
||||
/** Minimum and maximum of three values: */
|
||||
#define MIN3( A, B, C ) ((A) < (B) ? MIN2(A, C) : MIN2(B, C))
|
||||
#define MAX3( A, B, C ) ((A) > (B) ? MAX2(A, C) : MAX2(B, C))
|
||||
|
||||
/** Align a value to a power of two */
|
||||
#define ALIGN_POT(x, pot_align) (((x) + (pot_align) - 1) & ~((pot_align) - 1))
|
||||
|
||||
/**
|
||||
* Macro for declaring an explicit conversion operator. Defaults to an
|
||||
* implicit conversion if C++11 is not supported.
|
||||
*/
|
||||
#if __cplusplus >= 201103L
|
||||
#define EXPLICIT_CONVERSION explicit
|
||||
#elif defined(__cplusplus)
|
||||
#define EXPLICIT_CONVERSION
|
||||
#endif
|
||||
|
||||
/** Set a single bit */
|
||||
#define BITFIELD_BIT(b) (1u << (b))
|
||||
/** Set all bits up to excluding bit b */
|
||||
#define BITFIELD_MASK(b) \
|
||||
((b) == 32 ? (~0u) : BITFIELD_BIT((b) % 32) - 1)
|
||||
/** Set count bits starting from bit b */
|
||||
#define BITFIELD_RANGE(b, count) \
|
||||
(BITFIELD_MASK((b) + (count)) & ~BITFIELD_MASK(b))
|
||||
|
||||
/** Set a single bit */
|
||||
#define BITFIELD64_BIT(b) (1ull << (b))
|
||||
/** Set all bits up to excluding bit b */
|
||||
#define BITFIELD64_MASK(b) \
|
||||
((b) == 64 ? (~0ull) : BITFIELD64_BIT(b) - 1)
|
||||
/** Set count bits starting from bit b */
|
||||
#define BITFIELD64_RANGE(b, count) \
|
||||
(BITFIELD64_MASK((b) + (count)) & ~BITFIELD64_MASK(b))
|
||||
|
||||
/* TODO: In future we should try to move this to u_debug.h once header
|
||||
* dependencies are reorganised to allow this.
|
||||
*/
|
||||
enum pipe_debug_type
|
||||
{
|
||||
PIPE_DEBUG_TYPE_OUT_OF_MEMORY = 1,
|
||||
PIPE_DEBUG_TYPE_ERROR,
|
||||
PIPE_DEBUG_TYPE_SHADER_INFO,
|
||||
PIPE_DEBUG_TYPE_PERF_INFO,
|
||||
PIPE_DEBUG_TYPE_INFO,
|
||||
PIPE_DEBUG_TYPE_FALLBACK,
|
||||
PIPE_DEBUG_TYPE_CONFORMANCE,
|
||||
};
|
||||
|
||||
#endif /* UTIL_MACROS_H */
|
||||
@@ -1,13 +0,0 @@
|
||||
from tinygrad.helpers import fetch, Timing
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
|
||||
# run "sudo purge" before testing on OS X to avoid the memory cache
|
||||
|
||||
if __name__ == "__main__":
|
||||
fn = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
model = StableDiffusion()
|
||||
with Timing():
|
||||
load_state_dict(model, torch_load(fn)['state_dict'], strict=False)
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
8
test/external/external_test_embedding.py
vendored
8
test/external/external_test_embedding.py
vendored
@@ -1,8 +0,0 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Embedding
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab_size = 50257
|
||||
dim = 128
|
||||
test = Embedding(vocab_size, dim)
|
||||
ret = test(Tensor([[1,2,3]])).numpy()
|
||||
119
test/external/external_test_hsa_driver.py
vendored
119
test/external/external_test_hsa_driver.py
vendored
@@ -1,119 +0,0 @@
|
||||
import ctypes, unittest
|
||||
from tinygrad.helpers import init_c_struct_t
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.support.hsa import AQLQueue
|
||||
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.engine.realize import BufferXfer
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
def get_hsa_inc_prog(dev, inc=1):
|
||||
prg = f"""
|
||||
extern "C" __attribute__((global)) void test_inc(int* data0) {{
|
||||
data0[0] = (data0[0]+{inc});
|
||||
}}
|
||||
"""
|
||||
return dev.runtime("test_inc", dev.compiler.compile(prg))
|
||||
|
||||
def get_hsa_buffer_and_kernargs(dev):
|
||||
test_buf = Buffer(Device.DEFAULT, 1, dtypes.int)
|
||||
test_buf.copyin(memoryview(bytearray(4))) # zero mem
|
||||
assert test_buf.as_buffer().cast('I')[0] == 0 # check mem is visible + sync to exec
|
||||
|
||||
args_struct_t = init_c_struct_t(tuple([('f0', ctypes.c_void_p)]))
|
||||
kernargs = dev.alloc_kernargs(8)
|
||||
args_st = args_struct_t.from_address(kernargs)
|
||||
args_st.__setattr__('f0', test_buf._buf)
|
||||
dev.flush_hdp()
|
||||
return test_buf, kernargs
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "HSA", "only run on HSA")
|
||||
class TestHSADriver(unittest.TestCase):
|
||||
def test_hsa_simple_enqueue(self):
|
||||
dev = Device[Device.DEFAULT]
|
||||
queue = AQLQueue(dev, sz=256)
|
||||
|
||||
clprg = get_hsa_inc_prog(dev, inc=1)
|
||||
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
|
||||
|
||||
queue.submit_kernel(clprg, [1,1,1], [1,1,1], kernargs)
|
||||
queue.wait()
|
||||
|
||||
assert test_buf.as_buffer().cast('I')[0] == 1, f"{test_buf.as_buffer().cast('I')[0]} != 1, all packets executed?"
|
||||
del queue
|
||||
|
||||
def test_hsa_ring_enqueue(self):
|
||||
dev = Device[Device.DEFAULT]
|
||||
|
||||
queue_size = 256
|
||||
exec_cnt = int(queue_size * 1.5)
|
||||
queue = AQLQueue(dev, sz=queue_size)
|
||||
|
||||
clprg_inc1 = get_hsa_inc_prog(dev, inc=1)
|
||||
clprg_inc2 = get_hsa_inc_prog(dev, inc=2)
|
||||
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
|
||||
|
||||
for _ in range(exec_cnt):
|
||||
queue.submit_kernel(clprg_inc1, [1,1,1], [1,1,1], kernargs)
|
||||
for _ in range(exec_cnt):
|
||||
queue.submit_kernel(clprg_inc2, [1,1,1], [1,1,1], kernargs)
|
||||
queue.wait()
|
||||
|
||||
expected = exec_cnt + exec_cnt * 2
|
||||
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
|
||||
del queue
|
||||
|
||||
def test_hsa_blit_enqueue(self):
|
||||
dev = Device[Device.DEFAULT]
|
||||
|
||||
queue_size = 256
|
||||
exec_cnt = 178
|
||||
queue = AQLQueue(dev, sz=queue_size)
|
||||
|
||||
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
|
||||
|
||||
# Using VirtAQLQueue to blit them
|
||||
virt_queue_packets_cnt = 31
|
||||
virt_queue = VirtAQLQueue(dev, sz=virt_queue_packets_cnt)
|
||||
|
||||
clprogs = []
|
||||
sum_per_blit = 0
|
||||
for i in range(virt_queue_packets_cnt):
|
||||
sum_per_blit += i+1
|
||||
clprogs.append(get_hsa_inc_prog(dev, inc=i+1))
|
||||
|
||||
for i in range(virt_queue_packets_cnt):
|
||||
virt_queue.submit_kernel(clprogs[i], [1,1,1], [1,1,1], kernargs)
|
||||
|
||||
for _ in range(exec_cnt):
|
||||
queue.blit_packets(virt_queue.queue_base, virt_queue.packets_count)
|
||||
queue.wait()
|
||||
|
||||
expected = exec_cnt * sum_per_blit
|
||||
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
|
||||
del queue, clprogs
|
||||
|
||||
def test_hsa_copies_sync(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
test_buf0 = Buffer(d0, 1, dtypes.int)
|
||||
test_buf1 = Buffer(d0, 1, dtypes.int)
|
||||
test_buf2 = Buffer(d1, 1, dtypes.int)
|
||||
test_buf0.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf1.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf2.copyin(memoryview(bytearray(1*4)))
|
||||
|
||||
jit_cache = [ExecItem(UOp(Ops.NOOP), [test_buf0, test_buf2], prg=BufferXfer(test_buf0.nbytes, test_buf0.device, test_buf2.device)),
|
||||
ExecItem(UOp(Ops.NOOP), [test_buf2, test_buf1], prg=BufferXfer(test_buf2.nbytes, test_buf2.device, test_buf1.device))]
|
||||
graph = HSAGraph(jit_cache, [], {})
|
||||
|
||||
for i in range(10000):
|
||||
test_buf0.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf2.copyin(memoryview(bytearray(int.to_bytes(4, length=1*4, byteorder='little'))))
|
||||
graph([], {})
|
||||
assert test_buf0.as_buffer().cast('I')[0] == 4
|
||||
assert test_buf2.as_buffer().cast('I')[0] == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
32
test/external/external_test_yolo.py
vendored
32
test/external/external_test_yolo.py
vendored
@@ -1,32 +0,0 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
from examples.yolov3 import Darknet, infer, show_labels
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
chicken_img = cv2.imread(str(Path(__file__).parent.parent / 'models/efficientnet/Chicken.jpg'))
|
||||
car_img = cv2.imread(str(Path(__file__).parent.parent / 'models/efficientnet/car.jpg'))
|
||||
|
||||
class TestYOLO(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = Darknet(fetch("https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg").read_bytes())
|
||||
print("Loading weights file (237MB). This might take a while…")
|
||||
cls.model.load_weights("https://pjreddie.com/media/files/yolov3.weights")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del cls.model
|
||||
|
||||
def test_chicken(self):
|
||||
labels = show_labels(infer(self.model, chicken_img), confidence=0.56)
|
||||
self.assertEqual(labels, ["bird"])
|
||||
|
||||
def test_car(self):
|
||||
labels = show_labels(infer(self.model, car_img))
|
||||
self.assertEqual(labels, ["car"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
61
test/external/graph_batchnorm.py
vendored
61
test/external/graph_batchnorm.py
vendored
@@ -1,61 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
|
||||
def model_step(lm):
|
||||
with Tensor.train():
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
optimizer.step()
|
||||
|
||||
class TestBatchnorm(unittest.TestCase):
|
||||
def test_conv(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
def forward(self, x):
|
||||
return self.c(x).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
def test_two_conv(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
self.c2 = Conv2d(32, 32, 3, padding=1, bias=False)
|
||||
def forward(self, x):
|
||||
return self.c2(self.c(x)).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
def test_two_conv_bn(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 24, 3, padding=1, bias=False)
|
||||
self.bn = BatchNorm2d(24, track_running_stats=False)
|
||||
self.c2 = Conv2d(24, 32, 3, padding=1, bias=False)
|
||||
self.bn2 = BatchNorm2d(32, track_running_stats=False)
|
||||
def forward(self, x):
|
||||
x = self.bn(self.c(x)).relu()
|
||||
return self.bn2(self.c2(x)).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
def test_conv_bn(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
|
||||
self.bn = BatchNorm2d(32, track_running_stats=False)
|
||||
def forward(self, x):
|
||||
return self.bn(self.c(x)).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user