mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
cleanup stale examples/extra (#13764)
* cleanup stale files * examples * move those back * old * delete more
This commit is contained in:
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, sys, traceback
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from io import StringIO
|
||||
from contextlib import redirect_stdout
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import Timing, colored, getenv, fetch
|
||||
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
def create_fixed_tokenizer(output_file):
|
||||
print("creating fixed tokenizer")
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(mp.SerializeToString())
|
||||
|
||||
# example:
|
||||
# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py
|
||||
|
||||
if __name__ == "__main__":
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1))
|
||||
|
||||
with Timing("download weights: "):
|
||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
||||
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, 32, 32, 8)), strict=False)
|
||||
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, 32, 32, 8)), strict=False)
|
||||
|
||||
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
IM_END = 32000
|
||||
IM_START = 32001
|
||||
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
def output(outputted, toks, color):
|
||||
cur = spp.decode(toks)[len(outputted):]
|
||||
sys.stdout.write(colored(cur, color))
|
||||
sys.stdout.flush()
|
||||
outputted += cur
|
||||
return outputted
|
||||
|
||||
# *** app below this line ***
|
||||
|
||||
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
|
||||
|
||||
PROMPT = getenv("PROMPT", 1)
|
||||
temperature = getenv("TEMP", 0.7)
|
||||
|
||||
start_pos = 0
|
||||
outputted = output("", toks, "green")
|
||||
turn = True
|
||||
while 1:
|
||||
if PROMPT:
|
||||
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
|
||||
else:
|
||||
toks += start_prompt("user" if turn else "assistant")
|
||||
turn = not turn
|
||||
old_output_len = len(outputted)
|
||||
while 1:
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||
if tok == IM_END: break
|
||||
if tok == spp.eos_id(): break
|
||||
new_output = outputted[old_output_len:]
|
||||
|
||||
if new_output.endswith("```") and '```python\n' in new_output:
|
||||
python_code = new_output.split('```python\n')[1].split("```")[0]
|
||||
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
||||
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
|
||||
my_stdout = StringIO()
|
||||
try:
|
||||
with redirect_stdout(my_stdout): exec(python_code)
|
||||
result = my_stdout.getvalue()
|
||||
except Exception as e:
|
||||
result = ''.join(traceback.format_exception_only(e))
|
||||
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
||||
outputted = output(outputted, toks, "yellow")
|
||||
old_output_len = len(outputted)
|
||||
print("")
|
||||
@@ -1,89 +0,0 @@
|
||||
# load weights from
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
import sys
|
||||
import ast
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, fetch, Timing
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
# TODO: you should be able to put these in the jitted function
|
||||
bias = Tensor([0.485, 0.456, 0.406])
|
||||
scale = Tensor([0.229, 0.224, 0.225])
|
||||
|
||||
@TinyJit
|
||||
def _infer(model, img):
|
||||
img = img.permute((2,0,1))
|
||||
img = img / 255.0
|
||||
img = img - bias.reshape((1,-1,1,1))
|
||||
img = img / scale.reshape((1,-1,1,1))
|
||||
return model.forward(img).realize()
|
||||
|
||||
def infer(model, img):
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
|
||||
img = np.array(img)
|
||||
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
||||
retimg = img = img[y0:y0+224, x0:x0+224]
|
||||
|
||||
# if you want to look at the image
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
"""
|
||||
|
||||
# run the net
|
||||
out = _infer(model, Tensor(img.astype("float32"))).numpy()
|
||||
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(out[0])
|
||||
plt.show()
|
||||
"""
|
||||
return out, retimg
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet(getenv("NUM", 0))
|
||||
model.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
|
||||
# load image and preprocess
|
||||
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
||||
if url == 'webcam':
|
||||
import cv2
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
_ = cap.grab() # discard one frame to circumvent capture buffering
|
||||
ret, frame = cap.read()
|
||||
img = Image.fromarray(frame[:, :, [2,1,0]])
|
||||
lt = time.monotonic_ns()
|
||||
out, retimg = infer(model, img)
|
||||
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
SCALE = 3
|
||||
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
|
||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow('capture', retimg)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
img = Image.open(fetch(url))
|
||||
for i in range(getenv("CNT", 1)):
|
||||
with Timing("did inference in "):
|
||||
out, _ = infer(model, img)
|
||||
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
@@ -1,498 +0,0 @@
|
||||
# pip3 install sentencepiece
|
||||
|
||||
# This file incorporates code from the following:
|
||||
# Github Name | License | Link
|
||||
# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses
|
||||
|
||||
from tinygrad import Tensor, nn, dtypes, TinyJit
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
from tinygrad.helpers import fetch, tqdm, colored
|
||||
from sdxl import FirstStage
|
||||
from extra.models.clip import FrozenClosedClipEmbedder
|
||||
from extra.models.t5 import T5Embedder
|
||||
import numpy as np
|
||||
|
||||
import math, time, argparse, tempfile
|
||||
from typing import List, Dict, Optional, Union, Tuple, Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
urls:dict = {
|
||||
"flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors",
|
||||
"flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft",
|
||||
"ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors",
|
||||
"T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors",
|
||||
"T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors",
|
||||
"T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model",
|
||||
"clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors"
|
||||
}
|
||||
|
||||
def tensor_identity(x:Tensor) -> Tensor: return x
|
||||
|
||||
class AutoEncoder:
|
||||
def __init__(self, scale_factor:float, shift_factor:float):
|
||||
self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256)
|
||||
self.scale_factor = scale_factor
|
||||
self.shift_factor = shift_factor
|
||||
|
||||
def decode(self, z:Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
# Conditioner
|
||||
class ClipEmbedder(FrozenClosedClipEmbedder):
|
||||
def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
|
||||
return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)]
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||
def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
x = Tensor.scaled_dot_product_attention(q, k, v)
|
||||
return x.rearrange("B H L D -> B L (H D)")
|
||||
|
||||
def rope(pos:Tensor, dim:int, theta:int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = Tensor.einsum("...n,d->...nd", pos, omega)
|
||||
out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1)
|
||||
out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype)
|
||||
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class EmbedND:
|
||||
def __init__(self, dim:int, theta:int, axes_dim:List[int]):
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def __call__(self, ids:Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
class MLPEmbedder:
|
||||
def __init__(self, in_dim:int, hidden_dim:int):
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.out_layer(self.in_layer(x).silu())
|
||||
|
||||
class QKNorm:
|
||||
def __init__(self, dim:int):
|
||||
self.query_norm = nn.RMSNorm(dim)
|
||||
self.key_norm = nn.RMSNorm(dim)
|
||||
|
||||
def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
return self.query_norm(q), self.key_norm(k)
|
||||
|
||||
class SelfAttention:
|
||||
def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False):
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def __call__(self, x:Tensor, pe:Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
return self.proj(x)
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift:Tensor
|
||||
scale:Tensor
|
||||
gate:Tensor
|
||||
|
||||
class Modulation:
|
||||
def __init__(self, dim:int, double:bool):
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||
out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None
|
||||
|
||||
class DoubleStreamBlock:
|
||||
def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False):
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
|
||||
|
||||
def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
assert img_mod2 is not None and txt_mod2 is not None
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
||||
|
||||
# run actual attention
|
||||
q = Tensor.cat(txt_q, img_q, dim=2)
|
||||
k = Tensor.cat(txt_k, img_k, dim=2)
|
||||
v = Tensor.cat(txt_v, img_v, dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock:
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None):
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = Tensor.gelu
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer:
|
||||
def __init__(self, hidden_size:int, patch_size:int, out_channels:int):
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)]
|
||||
|
||||
def __call__(self, x:Tensor, vec:Tensor) -> Tensor:
|
||||
shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
return self.linear(x)
|
||||
|
||||
def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor:
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1)
|
||||
if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype)
|
||||
return embedding
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
|
||||
class Flux:
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_embed:bool,
|
||||
in_channels:int = 64,
|
||||
vec_in_dim:int = 768,
|
||||
context_in_dim:int = 4096,
|
||||
hidden_size:int = 3072,
|
||||
mlp_ratio:float = 4.0,
|
||||
num_heads:int = 24,
|
||||
depth:int = 19,
|
||||
depth_single_blocks:int = 38,
|
||||
axes_dim:Optional[List[int]] = None,
|
||||
theta:int = 10_000,
|
||||
qkv_bias:bool = True,
|
||||
):
|
||||
|
||||
axes_dim = axes_dim or [16, 56, 56]
|
||||
self.guidance_embed = guidance_embed
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
|
||||
pe_dim = hidden_size // num_heads
|
||||
if sum(axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
|
||||
self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)]
|
||||
self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)]
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
ids = Tensor.cat(txt_ids, img_ids, dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
for double_block in self.double_blocks:
|
||||
img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = Tensor.cat(txt, img, dim=1)
|
||||
for single_block in self.single_blocks:
|
||||
img = single_block(img, vec=vec, pe=pe)
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
|
||||
def load_flow_model(name:str, model_path:str):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
model = Flux(guidance_embed=(name != "flux-schnell"))
|
||||
if not model_path: model_path = fetch(urls[name])
|
||||
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()}
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
def load_T5(max_length:int=512):
|
||||
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
||||
print("Init T5")
|
||||
T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"]))
|
||||
pt_1 = fetch(urls["T5_1_of_2"])
|
||||
pt_2 = fetch(urls["T5_2_of_2"])
|
||||
load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False)
|
||||
return T5
|
||||
|
||||
def load_clip():
|
||||
print("Init Clip")
|
||||
clip = ClipEmbedder()
|
||||
load_state_dict(clip.transformer, safe_load(fetch(urls["clip"])))
|
||||
return clip
|
||||
|
||||
def load_ae() -> AutoEncoder:
|
||||
# Loading the autoencoder
|
||||
print("Init AE")
|
||||
ae = AutoEncoder(0.3611, 0.1159)
|
||||
load_state_dict(ae, safe_load(fetch(urls["ae"])))
|
||||
return ae
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py
|
||||
def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]:
|
||||
bs, _, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = img.expand((bs, *img.shape[1:]))
|
||||
|
||||
img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous()
|
||||
img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :]
|
||||
img_ids = img_ids.rearrange("h w c -> 1 (h w) c")
|
||||
img_ids = img_ids.expand((bs, *img_ids.shape[1:]))
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = T5(prompt).realize()
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = txt.expand((bs, *txt.shape[1:]))
|
||||
txt_ids = Tensor.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt).realize()
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = vec.expand((bs, *vec.shape[1:]))
|
||||
|
||||
return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)}
|
||||
|
||||
|
||||
def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]:
|
||||
# extra step for zero
|
||||
step_size = -1.0 / num_steps
|
||||
timesteps = Tensor.arange(1, 0 + step_size, step_size)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256)
|
||||
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
|
||||
return timesteps.tolist()
|
||||
|
||||
@TinyJit
|
||||
def run(model, *args): return model(*args).realize()
|
||||
|
||||
def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor:
|
||||
# this is ignored for schnell
|
||||
guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"):
|
||||
t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
|
||||
pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec)
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
return img
|
||||
|
||||
def unpack(x:Tensor, height:int, width:int) -> Tensor:
|
||||
return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2)
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py
|
||||
if __name__ == "__main__":
|
||||
default_prompt = "bananas and a can of coke"
|
||||
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
|
||||
parser.add_argument("--model_path", type=str, default="", help="path of the model file")
|
||||
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
|
||||
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
|
||||
parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling")
|
||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
||||
parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501
|
||||
parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation")
|
||||
parser.add_argument("--output_dir", type=str, default="output", help="output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.name not in ["flux-schnell", "flux-dev"]:
|
||||
raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev")
|
||||
|
||||
if args.num_steps is None:
|
||||
args.num_steps = 4 if args.name == "flux-schnell" else 50
|
||||
|
||||
# allow for packing and conversion to latent space
|
||||
height = 16 * (args.height // 16)
|
||||
width = 16 * (args.width // 16)
|
||||
|
||||
if args.seed is None: args.seed = Tensor._seed
|
||||
else: Tensor.manual_seed(args.seed)
|
||||
|
||||
print(f"Generating with seed {args.seed}:\n{args.prompt}")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# prepare input noise
|
||||
x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16")
|
||||
|
||||
# load text embedders
|
||||
T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512)
|
||||
clip = load_clip()
|
||||
|
||||
# embed text to get inputs for model
|
||||
inp = prepare(T5, clip, x, prompt=args.prompt)
|
||||
timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell"))
|
||||
|
||||
# done with text embedders
|
||||
del T5, clip
|
||||
|
||||
# load model
|
||||
model = load_flow_model(args.name, args.model_path)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)
|
||||
|
||||
# done with model
|
||||
del model, run
|
||||
|
||||
# load autoencoder
|
||||
ae = load_ae()
|
||||
|
||||
# decode latents to pixel space
|
||||
x = unpack(x.float(), height, width)
|
||||
x = ae.decode(x).realize()
|
||||
|
||||
t1 = time.perf_counter()
|
||||
print(f"Done in {t1 - t0:.1f}s. Saving {args.out}")
|
||||
|
||||
# bring into PIL format and save
|
||||
x = x.clamp(-1, 1)
|
||||
x = x[0].rearrange("c h w -> h w c")
|
||||
x = (127.5 * (x + 1.0)).cast("uint8")
|
||||
|
||||
img = Image.fromarray(x.numpy())
|
||||
|
||||
img.save(args.out)
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512:
|
||||
ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png")))
|
||||
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
||||
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
|
||||
print(colored(f"output validated with {distance=}", "green"))
|
||||
@@ -1,299 +0,0 @@
|
||||
from extra.models.mask_rcnn import MaskRCNN
|
||||
from extra.models.resnet import ResNet
|
||||
from extra.models.mask_rcnn import BoxList
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms as T
|
||||
from torchvision.transforms import functional as Ft
|
||||
import random
|
||||
from tinygrad.tensor import Tensor
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, min_size, max_size):
|
||||
if not isinstance(min_size, (list, tuple)):
|
||||
min_size = (min_size,)
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
# modified from torchvision to add support for max size
|
||||
def get_size(self, image_size):
|
||||
w, h = image_size
|
||||
size = random.choice(self.min_size)
|
||||
max_size = self.max_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def __call__(self, image):
|
||||
size = self.get_size(image.size)
|
||||
image = Ft.resize(image, size)
|
||||
return image
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std, to_bgr255=True):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.to_bgr255 = to_bgr255
|
||||
|
||||
def __call__(self, image):
|
||||
if self.to_bgr255:
|
||||
image = image[[2, 1, 0]] * 255
|
||||
else:
|
||||
image = image[[0, 1, 2]] * 255
|
||||
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
||||
return image
|
||||
|
||||
transforms = lambda size_scale: T.Compose(
|
||||
[
|
||||
Resize(int(800*size_scale), int(1333*size_scale)),
|
||||
T.ToTensor(),
|
||||
Normalize(
|
||||
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def expand_boxes(boxes, scale):
|
||||
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
|
||||
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
|
||||
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
|
||||
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
|
||||
|
||||
w_half *= scale
|
||||
h_half *= scale
|
||||
|
||||
boxes_exp = torch.zeros_like(boxes)
|
||||
boxes_exp[:, 0] = x_c - w_half
|
||||
boxes_exp[:, 2] = x_c + w_half
|
||||
boxes_exp[:, 1] = y_c - h_half
|
||||
boxes_exp[:, 3] = y_c + h_half
|
||||
return boxes_exp
|
||||
|
||||
|
||||
def expand_masks(mask, padding):
|
||||
N = mask.shape[0]
|
||||
M = mask.shape[-1]
|
||||
pad2 = 2 * padding
|
||||
scale = float(M + pad2) / M
|
||||
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
|
||||
padded_mask[:, :, padding:-padding, padding:-padding] = mask
|
||||
return padded_mask, scale
|
||||
|
||||
|
||||
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
|
||||
# TODO: remove torch
|
||||
mask = torch.tensor(mask.numpy())
|
||||
box = torch.tensor(box.numpy())
|
||||
padded_mask, scale = expand_masks(mask[None], padding=padding)
|
||||
mask = padded_mask[0, 0]
|
||||
box = expand_boxes(box[None], scale)[0]
|
||||
box = box.to(dtype=torch.int32)
|
||||
|
||||
TO_REMOVE = 1
|
||||
w = int(box[2] - box[0] + TO_REMOVE)
|
||||
h = int(box[3] - box[1] + TO_REMOVE)
|
||||
w = max(w, 1)
|
||||
h = max(h, 1)
|
||||
|
||||
mask = mask.expand((1, 1, -1, -1))
|
||||
|
||||
mask = mask.to(torch.float32)
|
||||
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
|
||||
mask = mask[0][0]
|
||||
|
||||
if thresh >= 0:
|
||||
mask = mask > thresh
|
||||
else:
|
||||
mask = (mask * 255).to(torch.uint8)
|
||||
|
||||
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
|
||||
x_0 = max(box[0], 0)
|
||||
x_1 = min(box[2] + 1, im_w)
|
||||
y_0 = max(box[1], 0)
|
||||
y_1 = min(box[3] + 1, im_h)
|
||||
|
||||
im_mask[y_0:y_1, x_0:x_1] = mask[
|
||||
(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
|
||||
]
|
||||
return im_mask
|
||||
|
||||
|
||||
class Masker:
|
||||
def __init__(self, threshold=0.5, padding=1):
|
||||
self.threshold = threshold
|
||||
self.padding = padding
|
||||
|
||||
def forward_single_image(self, masks, boxes):
|
||||
boxes = boxes.convert("xyxy")
|
||||
im_w, im_h = boxes.size
|
||||
res = [
|
||||
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
|
||||
for mask, box in zip(masks, boxes.bbox)
|
||||
]
|
||||
if len(res) > 0:
|
||||
res = torch.stack(*res, dim=0)[:, None]
|
||||
else:
|
||||
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
|
||||
return Tensor(res.numpy())
|
||||
|
||||
def __call__(self, masks, boxes):
|
||||
if isinstance(boxes, BoxList):
|
||||
boxes = [boxes]
|
||||
|
||||
results = []
|
||||
for mask, box in zip(masks, boxes):
|
||||
result = self.forward_single_image(mask, box)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
def select_top_predictions(predictions, confidence_threshold=0.9):
|
||||
scores = predictions.get_field("scores").numpy()
|
||||
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
||||
return predictions[keep]
|
||||
|
||||
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
|
||||
image = transforms(size_scale)(original_image).numpy()
|
||||
image = Tensor(image, requires_grad=False)
|
||||
predictions = model(image)
|
||||
prediction = predictions[0]
|
||||
prediction = select_top_predictions(prediction, confidence_threshold)
|
||||
width, height = original_image.size
|
||||
prediction = prediction.resize((width, height))
|
||||
|
||||
if prediction.has_field("mask"):
|
||||
masks = prediction.get_field("mask")
|
||||
masks = masker([masks], [prediction])[0]
|
||||
prediction.add_field("mask", masks)
|
||||
return prediction
|
||||
|
||||
def compute_prediction_batched(batch, model, size_scale=1.0):
|
||||
imgs = []
|
||||
for img in batch:
|
||||
imgs.append(transforms(size_scale)(img).numpy())
|
||||
image = [Tensor(image, requires_grad=False) for image in imgs]
|
||||
predictions = model(image)
|
||||
del image
|
||||
return predictions
|
||||
|
||||
palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
|
||||
|
||||
def findContours(*args, **kwargs):
|
||||
if cv2.__version__.startswith('4'):
|
||||
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
elif cv2.__version__.startswith('3'):
|
||||
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
return contours, hierarchy
|
||||
|
||||
def compute_colors_for_labels(labels):
|
||||
l = labels[:, None]
|
||||
colors = l * palette
|
||||
colors = (colors % 255).astype("uint8")
|
||||
return colors
|
||||
|
||||
def overlay_mask(image, predictions):
|
||||
image = np.asarray(image)
|
||||
masks = predictions.get_field("mask").numpy()
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
|
||||
for mask, color in zip(masks, colors):
|
||||
thresh = mask[0, :, :, None]
|
||||
contours, hierarchy = findContours(
|
||||
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
image = cv2.drawContours(image, contours, -1, color, 3)
|
||||
|
||||
composite = image
|
||||
|
||||
return composite
|
||||
|
||||
CATEGORIES = [
|
||||
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
||||
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
|
||||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
|
||||
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
|
||||
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
|
||||
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
||||
]
|
||||
|
||||
def overlay_boxes(image, predictions):
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
boxes = predictions.bbox
|
||||
image = np.asarray(image)
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
|
||||
for box, color in zip(boxes, colors):
|
||||
box = torch.tensor(box.numpy())
|
||||
box = box.to(torch.int64)
|
||||
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
||||
image = cv2.rectangle(
|
||||
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def overlay_class_names(image, predictions):
|
||||
scores = predictions.get_field("scores").numpy().tolist()
|
||||
labels = predictions.get_field("labels").numpy().tolist()
|
||||
labels = [CATEGORIES[int(i)] for i in labels]
|
||||
boxes = predictions.bbox.numpy()
|
||||
image = np.asarray(image)
|
||||
template = "{}: {:.2f}"
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
x, y = box[:2]
|
||||
s = template.format(label, score)
|
||||
x, y = int(x), int(y)
|
||||
cv2.putText(
|
||||
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--image', type=str, help="Path of the image to run")
|
||||
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
|
||||
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
|
||||
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
||||
args = parser.parse_args()
|
||||
|
||||
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||
model_tiny = MaskRCNN(resnet)
|
||||
model_tiny.load_from_pretrained()
|
||||
img = Image.open(args.image)
|
||||
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
|
||||
bbox_image = overlay_boxes(img, top_result_tiny)
|
||||
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
||||
final_image = overlay_class_names(mask_image, top_result_tiny)
|
||||
|
||||
im = Image.fromarray(final_image)
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
im.show()
|
||||
@@ -1,118 +0,0 @@
|
||||
import json, pprint
|
||||
from tinygrad import fetch, nn, Tensor
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, model_dim, intermediate_dim):
|
||||
self.proj_1 = nn.Linear(model_dim, 2*intermediate_dim, bias=False)
|
||||
self.proj_2 = nn.Linear(intermediate_dim, model_dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
y_12 = self.proj_1(x)
|
||||
y_1, y_2 = y_12.chunk(2, dim=-1)
|
||||
return self.proj_2(y_1.silu() * y_2)
|
||||
|
||||
# NOTE: this RoPE doesn't match LLaMA's?
|
||||
def _rotate_half(x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return Tensor.cat(-x2, x1, dim=-1)
|
||||
|
||||
def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
|
||||
return (x * pos_cos) + (_rotate_half(x) * pos_sin)
|
||||
|
||||
class Attention:
|
||||
def __init__(self, model_dim, num_query_heads, num_kv_heads, head_dim):
|
||||
self.qkv_proj = nn.Linear(model_dim, (num_query_heads + num_kv_heads*2) * head_dim, bias=False)
|
||||
self.num_query_heads, self.num_kv_heads = num_query_heads, num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_norm = nn.RMSNorm(head_dim)
|
||||
self.k_norm = nn.RMSNorm(head_dim)
|
||||
self.out_proj = nn.Linear(num_query_heads * head_dim, model_dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
batch_size, seq_len, embed_dim = x.shape
|
||||
qkv = self.qkv_proj(x)
|
||||
qkv = qkv.reshape(batch_size, seq_len, self.num_query_heads+self.num_kv_heads*2, self.head_dim).transpose(1, 2)
|
||||
xq,xk,xv = qkv.split([self.num_query_heads, self.num_kv_heads, self.num_kv_heads], dim=1)
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
# add positional embedding (how many kernels is this?)
|
||||
freq_constant = 10000
|
||||
inv_freq = 1.0 / (freq_constant ** (Tensor.arange(0, self.head_dim, 2) / self.head_dim))
|
||||
pos_index_theta = Tensor.einsum("i,j->ij", Tensor.arange(seq_len), inv_freq)
|
||||
emb = Tensor.cat(pos_index_theta, pos_index_theta, dim=-1)
|
||||
cos_emb, sin_emb = emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
||||
xq = _apply_rotary_pos_emb(xq, sin_emb, cos_emb)
|
||||
xk = _apply_rotary_pos_emb(xk, sin_emb, cos_emb)
|
||||
|
||||
# grouped-query attention
|
||||
num_groups = self.num_query_heads // self.num_kv_heads
|
||||
xk = xk.repeat_interleave(num_groups, dim=1)
|
||||
xv = xv.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# masked attention
|
||||
#start_pos = 0
|
||||
#mask = Tensor.full((1, 1, seq_len, start_pos+seq_len), float("-inf"), dtype=xq.dtype, device=xq.device).triu(start_pos+1)
|
||||
#attn_output = xq.scaled_dot_product_attention(xk, xv, mask).transpose(1, 2)
|
||||
|
||||
# causal is fine, no mask needed
|
||||
attn_output = xq.scaled_dot_product_attention(xk, xv, is_causal=True).transpose(1, 2)
|
||||
return self.out_proj(attn_output.reshape(batch_size, seq_len, self.num_query_heads * self.head_dim))
|
||||
|
||||
class Layer:
|
||||
def __init__(self, model_dim, intermediate_dim, num_query_heads, num_kv_heads, head_dim):
|
||||
self.ffn = FeedForward(model_dim, intermediate_dim)
|
||||
self.attn = Attention(model_dim, num_query_heads, num_kv_heads, head_dim)
|
||||
self.ffn_norm = nn.RMSNorm(model_dim)
|
||||
self.attn_norm = nn.RMSNorm(model_dim)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: # (batch, seq_len, embed_dim)
|
||||
x = x + self.attn(self.attn_norm(x))
|
||||
x = x + self.ffn(self.ffn_norm(x))
|
||||
return x
|
||||
|
||||
# stupidly complex
|
||||
def make_divisible(v, divisor):
|
||||
new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v: new_v += divisor
|
||||
return new_v
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, cfg):
|
||||
if DEBUG >= 3: pprint.pp(cfg)
|
||||
self.layers = [Layer(cfg['model_dim'], make_divisible(int(cfg["model_dim"] * cfg['ffn_multipliers'][i]), cfg['ffn_dim_divisor']),
|
||||
cfg['num_query_heads'][i], cfg['num_kv_heads'][i], cfg['head_dim']) for i in range(cfg['num_transformer_layers'])]
|
||||
self.norm = nn.RMSNorm(cfg['model_dim'])
|
||||
self.token_embeddings = nn.Embedding(cfg['vocab_size'], cfg['model_dim'])
|
||||
|
||||
def __call__(self, tokens:Tensor):
|
||||
# _bsz, seqlen = tokens.shape
|
||||
x = self.token_embeddings(tokens)
|
||||
for l in self.layers: x = l(x)
|
||||
return self.norm(x) @ self.token_embeddings.weight.T
|
||||
|
||||
if __name__ == "__main__":
|
||||
#model_name = "OpenELM-270M-Instruct"
|
||||
model_name = "OpenELM-270M" # this is fp32
|
||||
model = Transformer(json.loads(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/config.json?download=true").read_bytes()))
|
||||
weights = nn.state.safe_load(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/model.safetensors?download=true"))
|
||||
if DEBUG >= 3:
|
||||
for k, v in weights.items(): print(k, v.shape)
|
||||
nn.state.load_state_dict(model, {k.removeprefix("transformer."):v for k,v in weights.items()})
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
tokenizer = SentencePieceProcessor(fetch("https://github.com/karpathy/llama2.c/raw/master/tokenizer.model").as_posix())
|
||||
toks = [tokenizer.bos_id()] + tokenizer.encode("Some car brands include")
|
||||
for i in range(100):
|
||||
ttoks = Tensor([toks])
|
||||
out = model(ttoks).realize()
|
||||
t0 = out[0].argmax(axis=-1).tolist()
|
||||
toks.append(t0[-1])
|
||||
# hmmm...passthrough still doesn't match (it shouldn't, it outputs the most likely)
|
||||
print(tokenizer.decode(toks))
|
||||
#print(toks)
|
||||
#print(tokenizer.decode(t0))
|
||||
#print(t0)
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
from tinygrad.helpers import trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from functools import partial
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c1 = nn.Conv2d(1, 32, 5)
|
||||
self.c2 = nn.Conv2d(32, 32, 5)
|
||||
self.bn1 = nn.BatchNorm(32)
|
||||
self.m1 = nn.MaxPool2d(2)
|
||||
self.c3 = nn.Conv2d(32, 64, 3)
|
||||
self.c4 = nn.Conv2d(64, 64, 3)
|
||||
self.bn2 = nn.BatchNorm(64)
|
||||
self.m2 = nn.MaxPool2d(2)
|
||||
self.lin = nn.Linear(576, 10)
|
||||
def __call__(self, x):
|
||||
x = mx.maximum(self.c1(x), 0)
|
||||
x = mx.maximum(self.c2(x), 0)
|
||||
x = self.m1(self.bn1(x))
|
||||
x = mx.maximum(self.c3(x), 0)
|
||||
x = mx.maximum(self.c4(x), 0)
|
||||
x = self.m2(self.bn2(x))
|
||||
return self.lin(mx.flatten(x, 1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = mnist()
|
||||
X_train = mx.array(X_train.float().permute((0,2,3,1)).numpy())
|
||||
Y_train = mx.array(Y_train.numpy())
|
||||
X_test = mx.array(X_test.float().permute((0,2,3,1)).numpy())
|
||||
Y_test = mx.array(Y_test.numpy())
|
||||
|
||||
model = Model()
|
||||
optimizer = optim.Adam(1e-3)
|
||||
def loss_fn(model, x, y): return nn.losses.cross_entropy(model(x), y).mean()
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(samples):
|
||||
# Compiled functions will also treat any inputs not in the parameter list as constants.
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, X, Y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
samples = mx.random.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
|
||||
loss = step(samples)
|
||||
if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
@@ -1,45 +0,0 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
# a very simple game
|
||||
# one of <size> lights will light up
|
||||
# take the action of the lit up light
|
||||
# in <hard_mode>, you act differently based on the step number and need to track this
|
||||
|
||||
class PressTheLightUpButton(gym.Env):
|
||||
metadata = {"render_modes": []}
|
||||
def __init__(self, render_mode=None, size=2, game_length=10, hard_mode=False):
|
||||
self.size, self.game_length = size, game_length
|
||||
self.observation_space = gym.spaces.Box(0, 1, shape=(self.size,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Discrete(self.size)
|
||||
self.step_num = 0
|
||||
self.done = True
|
||||
self.hard_mode = hard_mode
|
||||
|
||||
def _get_obs(self):
|
||||
obs = [0]*self.size
|
||||
if self.step_num < len(self.state):
|
||||
obs[self.state[self.step_num]] = 1
|
||||
return np.array(obs, dtype=np.float32)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self.state = np.random.randint(0, self.size, size=self.game_length)
|
||||
self.step_num = 0
|
||||
self.done = False
|
||||
return self._get_obs(), {}
|
||||
|
||||
def step(self, action):
|
||||
target = ((action + self.step_num) % self.size) if self.hard_mode else action
|
||||
reward = int(target == self.state[self.step_num])
|
||||
self.step_num += 1
|
||||
if not reward:
|
||||
self.done = True
|
||||
return self._get_obs(), reward, self.done, self.step_num >= self.game_length, {}
|
||||
|
||||
register(
|
||||
id="PressTheLightUpButton-v0",
|
||||
entry_point="examples.rl.lightupbutton:PressTheLightUpButton",
|
||||
max_episode_steps=None,
|
||||
)
|
||||
@@ -1,136 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
import sys
|
||||
import numpy as np
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d, optim
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.datasets import fetch_mnist
|
||||
from extra.augment import augment_img
|
||||
from extra.training import train, evaluate
|
||||
GPU = getenv("GPU")
|
||||
QUICK = getenv("QUICK")
|
||||
DEBUG = getenv("DEBUG")
|
||||
|
||||
class SqueezeExciteBlock2D:
|
||||
def __init__(self, filters):
|
||||
self.filters = filters
|
||||
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
|
||||
self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
|
||||
self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
|
||||
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
||||
|
||||
def __call__(self, input):
|
||||
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
|
||||
se = se.reshape(shape=(-1, self.filters))
|
||||
se = se.dot(self.weight1) + self.bias1
|
||||
se = se.relu()
|
||||
se = se.dot(self.weight2) + self.bias2
|
||||
se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
|
||||
se = input.mul(se)
|
||||
return se
|
||||
|
||||
class ConvBlock:
|
||||
def __init__(self, h, w, inp, filters=128, conv=3):
|
||||
self.h, self.w = h, w
|
||||
self.inp = inp
|
||||
#init weights
|
||||
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
|
||||
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
||||
#init layers
|
||||
self._bn = BatchNorm2d(128)
|
||||
self._seb = SqueezeExciteBlock2D(filters)
|
||||
|
||||
def __call__(self, input):
|
||||
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
|
||||
for cweight, cbias in zip(self.cweights, self.cbiases):
|
||||
x = x.pad(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
|
||||
x = self._bn(x)
|
||||
x = self._seb(x)
|
||||
return x
|
||||
|
||||
class BigConvNet:
|
||||
def __init__(self):
|
||||
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
|
||||
self.weight1 = Tensor.scaled_uniform(128,10)
|
||||
self.weight2 = Tensor.scaled_uniform(128,10)
|
||||
|
||||
def parameters(self):
|
||||
if DEBUG: #keeping this for a moment
|
||||
pars = [par for par in get_parameters(self) if par.requires_grad]
|
||||
no_pars = 0
|
||||
for par in pars:
|
||||
print(par.shape)
|
||||
no_pars += np.prod(par.shape)
|
||||
print('no of parameters', no_pars)
|
||||
return pars
|
||||
else:
|
||||
return get_parameters(self)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename+'.npy', 'wb') as f:
|
||||
for par in get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
np.save(f, par.numpy())
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename+'.npy', 'rb') as f:
|
||||
for par in get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
try:
|
||||
par.numpy()[:] = np.load(f)
|
||||
if GPU:
|
||||
par.gpu()
|
||||
except:
|
||||
print('Could not load parameter')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv[0](x)
|
||||
x = self.conv[1](x)
|
||||
x = x.avg_pool2d(kernel_size=(2,2))
|
||||
x = self.conv[2](x)
|
||||
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
||||
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
||||
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
|
||||
return xo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
|
||||
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
|
||||
BS = 32
|
||||
|
||||
lmbd = 0.00025
|
||||
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
steps = len(X_train)//BS
|
||||
np.random.seed(1337)
|
||||
if QUICK:
|
||||
steps = 1
|
||||
X_test, Y_test = X_test[:BS], Y_test[:BS]
|
||||
|
||||
model = BigConvNet()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
model.load(sys.argv[1])
|
||||
print('Loaded weights "'+sys.argv[1]+'", evaluating...')
|
||||
evaluate(model, X_test, Y_test, BS=BS)
|
||||
except:
|
||||
print('could not load weights "'+sys.argv[1]+'".')
|
||||
|
||||
if GPU:
|
||||
params = get_parameters(model)
|
||||
[x.gpu_() for x in params]
|
||||
|
||||
for lr, epochs in zip(lrs, epochss):
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
for epoch in range(1,epochs+1):
|
||||
#first epoch without augmentation
|
||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
|
||||
@@ -1,17 +0,0 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Tensor.train():
|
||||
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
@@ -1,669 +0,0 @@
|
||||
# original implementation: https://github.com/svc-develop-team/so-vits-svc
|
||||
from __future__ import annotations
|
||||
import sys, logging, time, io, math, argparse, operator, numpy as np
|
||||
from functools import partial, reduce
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional, Type
|
||||
from tinygrad import nn, dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, get_hparams_from_file, load_checkpoint, weight_norm, HParams
|
||||
from examples.sovits_helpers import preprocess
|
||||
import soundfile
|
||||
|
||||
DEBUG = getenv("DEBUG")
|
||||
|
||||
F0_BIN = 256
|
||||
F0_MAX = 1100.0
|
||||
F0_MIN = 50.0
|
||||
F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
|
||||
F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
|
||||
|
||||
class SpeechEncoder:
|
||||
def __init__(self, hidden_dim, model:ContentVec): self.hidden_dim, self.model = hidden_dim, model
|
||||
def encode(self, ): raise NotImplementedError("implement me")
|
||||
@classmethod
|
||||
def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
|
||||
contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url)
|
||||
return cls(contentvec)
|
||||
|
||||
class ContentVec256L9(SpeechEncoder):
|
||||
def __init__(self, model:ContentVec): super().__init__(hidden_dim=256, model=model)
|
||||
def encode(self, wav: Tensor):
|
||||
feats = wav
|
||||
if len(feats.shape) == 2: # double channels
|
||||
feats = feats.mean(-1)
|
||||
assert len(feats.shape) == 1, feats.dim()
|
||||
feats = feats.reshape(1, -1)
|
||||
padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
|
||||
logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=9)
|
||||
feats = self.model.final_proj(logits[0])
|
||||
return feats.transpose(1,2)
|
||||
|
||||
class ContentVec768L12(SpeechEncoder):
|
||||
def __init__(self, model:ContentVec): super().__init__(hidden_dim=768, model=model)
|
||||
def encode(self, wav: Tensor):
|
||||
feats = wav
|
||||
if len(feats.shape) == 2: # double channels
|
||||
feats = feats.mean(-1)
|
||||
assert len(feats.shape) == 1, feats.dim()
|
||||
feats = feats.reshape(1, -1)
|
||||
padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
|
||||
logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=12)
|
||||
return logits[0].transpose(1,2)
|
||||
|
||||
# original code for contentvec: https://github.com/auspicious3000/contentvec/
|
||||
class ContentVec:
|
||||
# self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2.
|
||||
# This means that the ContentVec only works with the hubert weights used in all SVC models
|
||||
def __init__(self, cfg: HParams):
|
||||
self.feature_grad_mult, self.untie_final_proj = cfg.feature_grad_mult, cfg.untie_final_proj
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||||
self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
|
||||
self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = nn.LayerNorm(self.embed)
|
||||
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * 1) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim)
|
||||
self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32)
|
||||
self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32)
|
||||
def forward_features(self, source, padding_mask):
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source, padding_mask)
|
||||
if self.feature_grad_mult != 1.0: pass # training: GradMultiply.forward(features, self.feature_grad_mult)
|
||||
else:
|
||||
features = self.feature_extractor(source, padding_mask)
|
||||
return features
|
||||
def forward_padding_mask(self, features, padding_mask): # replaces original forward_padding_mask for batch inference
|
||||
lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure its bool for tilde
|
||||
lengths = (lengths_org - 400).float().div(320).floor().cast(dtypes.int64) + 1 # intermediate float to divide
|
||||
padding_mask = lengths_to_padding_mask(lengths)
|
||||
return padding_mask
|
||||
def extract_features(self, source: Tensor, spk_emb:Tensor=None, padding_mask=None, ret_conv=False, output_layer=None, tap=False):
|
||||
features = self.forward_features(source, padding_mask)
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
x, _ = self.encoder(features, spk_emb, padding_mask=padding_mask, layer=(None if output_layer is None else output_layer - 1), tap=tap)
|
||||
res = features if ret_conv else x
|
||||
return res, padding_mask
|
||||
@classmethod
|
||||
def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
|
||||
fetch(checkpoint_url, checkpoint_path)
|
||||
cfg = load_fairseq_cfg(checkpoint_path)
|
||||
enc = cls(cfg.model)
|
||||
_ = load_checkpoint_enc(checkpoint_path, enc, None)
|
||||
logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}")
|
||||
return enc
|
||||
|
||||
class TransformerEncoder:
|
||||
def __init__(self, cfg: HParams):
|
||||
def make_conv() -> nn.Conv1d:
|
||||
layer = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=cfg.conv_pos, padding=cfg.conv_pos // 2, groups=cfg.conv_pos_groups)
|
||||
std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim))
|
||||
layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), (Tensor.zeros(*layer.bias.shape))
|
||||
# for training: layer.weights need to be weight_normed
|
||||
return layer
|
||||
self.dropout, self.embedding_dim, self.layer_norm_first, self.layerdrop, self.num_layers, self.num_layers_1 = cfg.dropout, cfg.encoder_embed_dim, cfg.layer_norm_first, cfg.encoder_layerdrop, cfg.encoder_layers, cfg.encoder_layers_1
|
||||
self.pos_conv, self.pos_conv_remove = [make_conv()], (1 if cfg.conv_pos % 2 == 0 else 0)
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(self.embedding_dim, cfg.encoder_ffn_embed_dim, cfg.encoder_attention_heads, self.dropout, cfg.attention_dropout, cfg.activation_dropout, cfg.activation_fn, self.layer_norm_first, cond_layer_norm=(i >= cfg.encoder_layers))
|
||||
for i in range(cfg.encoder_layers + cfg.encoder_layers_1)
|
||||
]
|
||||
self.layer_norm = nn.LayerNorm(self.embedding_dim)
|
||||
self.cond_layer_norm = CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None
|
||||
# training: apply init_bert_params
|
||||
def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False):
|
||||
x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap)
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.cond_layer_norm(x, spk_emb) if (self.num_layers_1 > 0) else self.layer_norm(x)
|
||||
return x, layer_results
|
||||
def extract_features(self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False):
|
||||
if tgt_layer is not None: # and not self.training
|
||||
assert tgt_layer >= 0 and tgt_layer < len(self.layers)
|
||||
if padding_mask is not None:
|
||||
# x[padding_mask] = 0
|
||||
assert padding_mask.shape == x.shape[:len(padding_mask.shape)] # first few dims of x must match padding_mask
|
||||
tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1]))
|
||||
tmp_mask = tilde(tmp_mask.cast(dtypes.bool))
|
||||
x = tmp_mask.where(x, 0)
|
||||
x_conv = self.pos_conv[0](x.transpose(1,2))
|
||||
if self.pos_conv_remove > 0: x_conv = x_conv[:, :, : -self.pos_conv_remove]
|
||||
x_conv = x_conv.gelu().transpose(1, 2)
|
||||
x = (x + x_conv).transpose(0, 1) # B x T x C -> T x B x C
|
||||
if not self.layer_norm_first: x = self.layer_norm(x)
|
||||
x = x.dropout(p=self.dropout)
|
||||
layer_results = []
|
||||
r = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i < self.num_layers: # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers):
|
||||
assert layer.cond_layer_norm == False
|
||||
x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
|
||||
if tgt_layer is not None or tap:
|
||||
layer_results.append(x.transpose(0, 1))
|
||||
if i>= self.num_layers:
|
||||
assert layer.cond_layer_norm == True
|
||||
x = layer(x, emb=spk_emb, self_attn_padding_mask=padding_mask, need_weights=False)
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
if r is not None:
|
||||
x = r
|
||||
x = x.transpose(0, 1) # T x B x C -> B x T x C
|
||||
return x, layer_results
|
||||
|
||||
class TransformerEncoderLayer:
|
||||
def __init__(self, embedding_dim=768.0, ffn_embedding_dim=3072.0, num_attention_heads=8.0, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, activation_fn="relu", layer_norm_first=False, cond_layer_norm=False):
|
||||
def get_activation_fn(activation):
|
||||
if activation == "relu": return Tensor.relu
|
||||
if activation == "gelu": return Tensor.gelu
|
||||
else: raise RuntimeError(f"activation function={activation} is not forseen")
|
||||
self.embedding_dim, self.dropout, self.activation_dropout, self.layer_norm_first, self.num_attention_heads, self.cond_layer_norm, self.activation_fn = embedding_dim, dropout, activation_dropout, layer_norm_first, num_attention_heads, cond_layer_norm, get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiHeadAttention(self.embedding_dim, self.num_attention_heads)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
|
||||
def __call__(self, x:Tensor, self_attn_mask:Tensor=None, self_attn_padding_mask:Tensor=None, emb:Tensor=None, need_weights=False):
|
||||
#self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None
|
||||
assert self_attn_mask is None and self_attn_padding_mask is not None
|
||||
residual = x
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
|
||||
x = self.self_attn(x=x, mask=self_attn_padding_mask)
|
||||
x = x.dropout(self.dropout)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = x.dropout(self.activation_dropout)
|
||||
x = self.fc2(x)
|
||||
x = x.dropout(self.dropout)
|
||||
x = residual + x
|
||||
else:
|
||||
x = self.self_attn(x=x, mask=self_attn_padding_mask)
|
||||
x = x.dropout(self.dropout)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
|
||||
residual = x
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = x.dropout(self.activation_dropout)
|
||||
x = self.fc2(x)
|
||||
x = x.dropout(self.dropout)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
|
||||
return x
|
||||
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, n_state, n_head):
|
||||
self.n_state, self.n_head = n_state, n_head
|
||||
self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(n_state, n_state) for _ in range(4)]
|
||||
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
|
||||
x = x.transpose(0,1) # TxBxC -> BxTxC
|
||||
q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x)
|
||||
q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)]
|
||||
wv = Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None).transpose(1, 2).reshape(*x.shape[:2], -1)
|
||||
ret = self.out_proj(wv).transpose(0,1) # BxTxC -> TxBxC
|
||||
return ret
|
||||
|
||||
class ConvFeatureExtractionModel:
|
||||
def __init__(self, conv_layers, dropout=.0, mode="default", conv_bias=False):
|
||||
assert mode in {"default", "group_norm_masked", "layer_norm"}
|
||||
def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
|
||||
def make_conv():
|
||||
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||
conv.weight = Tensor.kaiming_normal(*conv.weight.shape)
|
||||
return conv
|
||||
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
||||
if is_layer_norm:
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu]
|
||||
elif is_group_norm and mode == "default":
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu]
|
||||
elif is_group_norm and mode == "group_norm_masked":
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout), GroupNormMasked(dim, dim, affine=True), Tensor.gelu]
|
||||
else:
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu]
|
||||
in_d, self.conv_layers, self.mode = 1, [], mode
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||
(dim, k, stride) = cl
|
||||
if i == 0: self.cl = cl
|
||||
self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=(mode == "layer_norm"), is_group_norm=((mode == "default" or mode == "group_norm_masked") and i == 0), conv_bias=conv_bias))
|
||||
in_d = dim
|
||||
def __call__(self, x:Tensor, padding_mask:Tensor):
|
||||
x = x.unsqueeze(1) # BxT -> BxCxT
|
||||
if self.mode == "group_norm_masked":
|
||||
if padding_mask is not None:
|
||||
_, k, stride = self.cl
|
||||
lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure padding_mask is bool for tilde
|
||||
lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64)
|
||||
padding_mask = tilde(lengths_to_padding_mask(lengths)).cast(dtypes.int64) # lengths_to_padding_mask returns bool tensor
|
||||
x = self.conv_layers[0][0](x) # padding_mask is numeric
|
||||
x = self.conv_layers[0][1](x)
|
||||
x = self.conv_layers[0][2](x, padding_mask)
|
||||
x = self.conv_layers[0][3](x)
|
||||
else:
|
||||
x = x.sequential(self.conv_layers[0]) # default
|
||||
for _, conv in enumerate(self.conv_layers[1:], start=1):
|
||||
conv = reduce(lambda a,b: operator.iconcat(a,b if isinstance(b, list) else [b]), conv, []) # flatten
|
||||
x = x.sequential(conv)
|
||||
return x
|
||||
|
||||
class CondLayerNorm: # https://github.com/auspicious3000/contentvec/blob/main/contentvec/modules/cond_layer_norm.py#L10
|
||||
def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True):
|
||||
self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = dim_last, eps, dim_spk, elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
|
||||
self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
|
||||
self.weight_ln.weight, self.bias_ln.weight = (Tensor.ones(*self.weight_ln.weight.shape)), (Tensor.zeros(*self.bias_ln.weight.shape))
|
||||
def __call__(self, x: Tensor, spk_emb: Tensor):
|
||||
axis = tuple(-1-i for i in range(len(x.shape[1:])))
|
||||
x = x.layernorm(axis=axis, eps=self.eps)
|
||||
if not self.elementwise_affine: return x
|
||||
weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb)
|
||||
return weights * x + bias
|
||||
|
||||
class GroupNormMasked: # https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/modules/fp32_group_norm.py#L16
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
||||
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
|
||||
self.weight, self.bias = (Tensor.ones(num_channels)), (Tensor.zeros(num_channels)) if self.affine else (None, None)
|
||||
def __call__(self, x:Tensor, mask:Tensor):
|
||||
bsz, n_c, length = x.shape
|
||||
assert n_c % self.num_groups == 0
|
||||
x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length)
|
||||
if mask is None: mask = Tensor.ones_like(x)
|
||||
else: mask = mask.reshape(bsz, 1, 1, length)
|
||||
x = x * mask
|
||||
lengths = mask.sum(axis=3, keepdim=True)
|
||||
assert x.shape[2] == 1
|
||||
mean_ = x.mean(dim=3, keepdim=True)
|
||||
mean = mean_ * length / lengths
|
||||
var = (((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths - mean**2) + self.eps
|
||||
return x.add(-mean).div(var.sqrt()).reshape(bsz, n_c, length).mul(self.weight.reshape(1,-1,1)).add(self.bias.reshape(1,-1,1))
|
||||
|
||||
class Synthesizer:
|
||||
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, ssl_dim, n_speakers, sampling_rate=44100, vol_embedding=False, n_flow_layer=4, **kwargs):
|
||||
self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.vol_embedding = spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, vol_embedding
|
||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels)
|
||||
self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
|
||||
self.enc_p = TextEncoder(inter_channels, hidden_channels, kernel_size, n_layers, filter_channels=filter_channels, n_heads=n_heads, p_dropout=p_dropout)
|
||||
self.dec = Generator(sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
|
||||
self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels)
|
||||
def infer(self, c:Tensor, f0:Tensor, uv:Tensor, g:Tensor=None, noise_scale=0.35, seed=52468, vol=None) -> Tuple[Tensor, Tensor]:
|
||||
Tensor.manual_seed(getenv('SEED', seed))
|
||||
c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device)
|
||||
if len(g.shape) == 1: g = g.unsqueeze(0)
|
||||
g = self.emb_g(g).transpose(1, 2)
|
||||
x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype)
|
||||
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
|
||||
x = self.pre(c) * x_mask + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + vol
|
||||
z_p, _, _, c_mask = self.enc_p.forward(x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale)
|
||||
z = self.flow.forward(z_p, c_mask, g=g, reverse=True)
|
||||
o = self.dec.forward(z * c_mask, g=g, f0=f0)
|
||||
return o,f0
|
||||
def _f0_to_coarse(self, f0 : Tensor):
|
||||
f0_mel = 1127 * (1 + f0 / 700).log()
|
||||
a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN)
|
||||
b = F0_MEL_MIN * a - 1.
|
||||
f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel)
|
||||
f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64)
|
||||
f0_coarse = f0_coarse * (f0_coarse > 0)
|
||||
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
|
||||
f0_coarse = f0_coarse * (f0_coarse < F0_BIN)
|
||||
f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1))
|
||||
return f0_coarse
|
||||
@classmethod
|
||||
def load_from_pretrained(cls, config_path:str, config_url:str, weights_path:str, weights_url:str) -> Synthesizer:
|
||||
fetch(config_url, config_path)
|
||||
hps = get_hparams_from_file(config_path)
|
||||
fetch(weights_url, weights_path)
|
||||
net_g = cls(hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
|
||||
_ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"])
|
||||
logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}")
|
||||
return net_g, hps
|
||||
|
||||
class TextEncoder:
|
||||
def __init__(self, out_channels, hidden_channels, kernel_size, n_layers, gin_channels=0, filter_channels=None, n_heads=None, p_dropout=None):
|
||||
self.out_channels, self.hidden_channels, self.kernel_size, self.n_layers, self.gin_channels = out_channels, hidden_channels, kernel_size, n_layers, gin_channels
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
self.f0_emb = nn.Embedding(256, hidden_channels) # n_vocab = 256
|
||||
self.enc_ = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
def forward(self, x, x_mask, f0=None, noise_scale=1):
|
||||
x = x + self.f0_emb(f0).transpose(1, 2)
|
||||
x = self.enc_.forward(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = split(stats, self.out_channels, dim=1)
|
||||
z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, scale_factor):
|
||||
assert scale_factor % 1 == 0, "Only integer scale factor allowed."
|
||||
self.scale = int(scale_factor)
|
||||
def forward(self, x:Tensor):
|
||||
repeats = tuple([1] * len(x.shape) + [self.scale])
|
||||
new_shape = (*x.shape[:-1], x.shape[-1] * self.scale)
|
||||
return x.unsqueeze(-1).repeat(repeats).reshape(new_shape)
|
||||
|
||||
class SineGen:
|
||||
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voice_threshold=0, flag_for_pulse=False):
|
||||
self.sine_amp, self.noise_std, self.harmonic_num, self.sampling_rate, self.voiced_threshold, self.flag_for_pulse = sine_amp, noise_std, harmonic_num, samp_rate, voice_threshold, flag_for_pulse
|
||||
self.dim = self.harmonic_num + 1
|
||||
def _f02uv(self, f0): return (f0 > self.voiced_threshold).float() #generate uv signal
|
||||
def _f02sine(self, f0_values):
|
||||
def padDiff(x : Tensor): return (x.pad((0,0,-1,1)) - x).pad((0,0,0,-1))
|
||||
def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor() # this is what the % operator does in pytorch.
|
||||
rad_values = mod((f0_values / self.sampling_rate) , 1) # convert to F0 in rad
|
||||
rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) # initial phase noise
|
||||
|
||||
#rand_ini[:, 0] = 0
|
||||
m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool)
|
||||
m = tilde(m)
|
||||
rand_ini = m.where(rand_ini, 0)
|
||||
|
||||
#rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
tmp = rad_values[:, 0, :] + rand_ini
|
||||
m = Tensor.ones(tmp.shape).pad((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool)
|
||||
m = tilde(m)
|
||||
tmp = tmp.unsqueeze(1).pad((0,0,0,rad_values.shape[1]-1,0))
|
||||
rad_values = m.where(rad_values, tmp)
|
||||
|
||||
tmp_over_one = mod(rad_values.cumsum(1), 1)
|
||||
tmp_over_one_idx = padDiff(tmp_over_one) < 0
|
||||
cumsum_shift = Tensor.zeros_like(rad_values)
|
||||
|
||||
#cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad((0,0,1,0))
|
||||
cumsum_shift = tmp_over_one_idx
|
||||
|
||||
sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin()
|
||||
return sines
|
||||
def forward(self, f0, upp=None):
|
||||
fn = f0.mul(Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to(f0.device))
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp #generate sine waveforms
|
||||
uv = self._f02uv(f0) # generate uv signal
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * randn_like(sine_waves)
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
class SourceHnNSF:
|
||||
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
|
||||
self.sine_amp, self.noise_std = sine_amp, add_noise_std
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
|
||||
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
||||
def forward(self, x, upp=None):
|
||||
sine_waves, uv, _ = self.l_sin_gen.forward(x, upp)
|
||||
sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh()
|
||||
noise = randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
# most of the hifigan in standard vits is reused here, but need to upsample and construct harmonic source from f0
|
||||
class Generator:
|
||||
def __init__(self, sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels):
|
||||
self.sampling_rate, self.inter_channels, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.gin_channels = sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels
|
||||
self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates)
|
||||
self.conv_pre = nn.Conv1d(inter_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates))
|
||||
self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8)
|
||||
resblock = ResBlock1 if resblock == '1' else ResBlock2
|
||||
self.ups, self.noise_convs, self.resblocks = [], [], []
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel//(2**(i+1))
|
||||
self.ups.append(nn.ConvTranspose1d(upsample_initial_channel//(2**i), c_cur, k, u, padding=(k-u)//2))
|
||||
stride_f0 = int(np.prod(upsample_rates[i + 1:]))
|
||||
self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2) if (i + 1 < len(upsample_rates)) else nn.Conv1d(1, c_cur, kernel_size=1))
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3)
|
||||
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
self.upp = np.prod(upsample_rates)
|
||||
def forward(self, x, f0, g=None):
|
||||
f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
har_source, _, _ = self.m_source.forward(f0, self.upp)
|
||||
har_source = har_source.transpose(1, 2)
|
||||
x = self.conv_pre(x)
|
||||
if g is not None: x = x + self.cond(g)
|
||||
for i in range(self.num_upsamples):
|
||||
x, xs = self.ups[i](x.leaky_relu(LRELU_SLOPE)), None
|
||||
x_source = self.noise_convs[i](har_source)
|
||||
x = x + x_source
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
|
||||
else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
|
||||
x = xs / self.num_kernels
|
||||
return self.conv_post(x.leaky_relu()).tanh()
|
||||
|
||||
# **** helpers ****
|
||||
|
||||
def randn_like(x:Tensor) -> Tensor: return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device)
|
||||
|
||||
def tilde(x: Tensor) -> Tensor:
|
||||
if x.dtype == dtypes.bool: return (1 - x).cast(dtypes.bool)
|
||||
return (x + 1) * -1 # this seems to be what the ~ operator does in pytorch for non bool
|
||||
|
||||
def lengths_to_padding_mask(lens:Tensor) -> Tensor:
|
||||
bsz, max_lens = lens.shape[0], lens.max().numpy().item()
|
||||
mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens)
|
||||
mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens)
|
||||
return mask.cast(dtypes.bool)
|
||||
|
||||
def repeat_expand_2d_left(content, target_len): # content : [h, t]
|
||||
src_len = content.shape[-1]
|
||||
temp = np.arange(src_len+1) * target_len / src_len
|
||||
current_pos, cols = 0, []
|
||||
for i in range(target_len):
|
||||
if i >= temp[current_pos+1]:
|
||||
current_pos += 1
|
||||
cols.append(content[:, current_pos])
|
||||
return Tensor.stack(*cols).transpose(0, 1)
|
||||
|
||||
def load_fairseq_cfg(checkpoint_path):
|
||||
assert Path(checkpoint_path).is_file()
|
||||
state = torch_load(checkpoint_path)
|
||||
cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None
|
||||
if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}")
|
||||
return HParams(**cfg)
|
||||
|
||||
def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]):
|
||||
assert Path(checkpoint_path).is_file()
|
||||
start_time = time.time()
|
||||
checkpoint_dict = torch_load(checkpoint_path)
|
||||
saved_state_dict = checkpoint_dict['model']
|
||||
weight_g, weight_v, parent = None, None, None
|
||||
for key, v in saved_state_dict.items():
|
||||
if any(layer in key for layer in skip_list): continue
|
||||
try:
|
||||
obj, skip = model, False
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric(): obj = obj[int(k)]
|
||||
elif isinstance(obj, dict): obj = obj[k]
|
||||
else:
|
||||
if k in ["weight_g", "weight_v"]:
|
||||
parent, skip = obj, True
|
||||
if k == "weight_g": weight_g = v
|
||||
else: weight_v = v
|
||||
if not skip:
|
||||
parent = obj
|
||||
obj = getattr(obj, k)
|
||||
if weight_g and weight_v:
|
||||
setattr(obj, "weight_g", weight_g.numpy())
|
||||
setattr(obj, "weight_v", weight_v.numpy())
|
||||
obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
|
||||
weight_g, weight_v, parent, skip = None, None, None, False
|
||||
if not skip and obj.shape == v.shape:
|
||||
if "feature_extractor" in key and (isinstance(parent, (nn.GroupNorm, nn.LayerNorm))): # cast
|
||||
obj.assign(v.to(obj.device).float())
|
||||
else:
|
||||
obj.assign(v.to(obj.device))
|
||||
elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}")
|
||||
except Exception as e: raise e
|
||||
logging.info(f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s")
|
||||
return model, optimizer
|
||||
|
||||
def pad_array(arr, target_length):
|
||||
current_length = arr.shape[0]
|
||||
if current_length >= target_length: return arr
|
||||
pad_width = target_length - current_length
|
||||
pad_left = pad_width // 2
|
||||
pad_right = pad_width - pad_left
|
||||
padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
|
||||
return padded_arr
|
||||
|
||||
def split_list_by_n(list_collection, n, pre=0):
|
||||
for i in range(0, len(list_collection), n):
|
||||
yield list_collection[i-pre if i-pre>=0 else i: i + n]
|
||||
|
||||
def get_sid(spk2id:HParams, speaker:str) -> Tensor:
|
||||
speaker_id = spk2id[speaker]
|
||||
if not speaker_id and type(speaker) is int:
|
||||
if len(spk2id.__dict__) >= speaker: speaker_id = speaker
|
||||
if speaker_id is None: raise RuntimeError(f"speaker={speaker} not in the speaker list")
|
||||
return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0)
|
||||
|
||||
def get_encoder(ssl_dim) -> Type[SpeechEncoder]:
|
||||
if ssl_dim == 256: return ContentVec256L9
|
||||
if ssl_dim == 768: return ContentVec768L12
|
||||
|
||||
#########################################################################################
|
||||
# CODE: https://github.com/svc-develop-team/so-vits-svc
|
||||
#########################################################################################
|
||||
# CONTENTVEC:
|
||||
# CODE: https://github.com/auspicious3000/contentvec
|
||||
# PAPER: https://arxiv.org/abs/2204.09224
|
||||
#########################################################################################
|
||||
# INSTALLATION: dependencies are for preprocessing and loading/saving audio.
|
||||
# pip3 install soundfile librosa praat-parselmouth
|
||||
#########################################################################################
|
||||
# EXAMPLE USAGE:
|
||||
# python3 examples/so_vits_svc.py --model tf2spy --file ~/recording.wav
|
||||
#########################################################################################
|
||||
# DEMO USAGE (uses audio sample from LJ-Speech):
|
||||
# python3 examples/so_vits_svc.py --model saul_goodman
|
||||
#########################################################################################
|
||||
SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC"
|
||||
VITS_MODELS = { # config_path, weights_path, config_url, weights_url
|
||||
"saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"),
|
||||
"drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"),
|
||||
"cartman" : (SO_VITS_SVC_PATH / "config_cartman.json", SO_VITS_SVC_PATH / "pretrained_cartman.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth"),
|
||||
"tf2spy" : (SO_VITS_SVC_PATH / "config_tf2spy.json", SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth"),
|
||||
"tf2heavy" : (SO_VITS_SVC_PATH / "config_tf2heavy.json", SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth"),
|
||||
"lady_gaga" : (SO_VITS_SVC_PATH / "config_gaga.json", SO_VITS_SVC_PATH / "pretrained_gaga.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth")
|
||||
}
|
||||
ENCODER_MODELS = { # weights_path, weights_url
|
||||
"contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt")
|
||||
}
|
||||
ENCODER_MODEL = "contentvec"
|
||||
DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
|
||||
if __name__=="__main__":
|
||||
logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG))
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True)
|
||||
parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file")
|
||||
parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
|
||||
parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
|
||||
parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
|
||||
parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.")
|
||||
parser.add_argument("--noise_scale", default=0.4)
|
||||
parser.add_argument("--tran", default=0.0, help="Pitch shift, supports positive and negative (semitone) values. Default 0.0")
|
||||
parser.add_argument("--pad_seconds", default=0.5)
|
||||
parser.add_argument("--lg_num", default=0.0)
|
||||
parser.add_argument("--clip_seconds", default=0.0)
|
||||
parser.add_argument("--slice_db", default=-40)
|
||||
args = parser.parse_args()
|
||||
|
||||
vits_model = args.model
|
||||
encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model]
|
||||
|
||||
Tensor.training = False
|
||||
# Get Synthesizer and ContentVec
|
||||
net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3])
|
||||
Encoder = get_encoder(hps.model.ssl_dim)
|
||||
encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1])
|
||||
|
||||
# model config args
|
||||
target_sample, spk2id, hop_length, target_sample = hps.data.sampling_rate, hps.spk, hps.data.hop_length, hps.data.sampling_rate
|
||||
vol_embedding = hps.model.vol_embedding if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None else False
|
||||
|
||||
# args
|
||||
slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = args.slice_db, args.clip_seconds, args.lg_num, args.pad_seconds, args.tran, args.noise_scale, args.file
|
||||
speaker = args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0]
|
||||
|
||||
### Loading audio and slicing ###
|
||||
if audio_path == DEMO_PATH: fetch(DEMO_URL, DEMO_PATH)
|
||||
assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav"
|
||||
chunks = preprocess.cut(audio_path, db_thresh=slice_db)
|
||||
audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks)
|
||||
|
||||
per_size = int(clip_seconds * audio_sr)
|
||||
lg_size = int(lg_num * audio_sr)
|
||||
|
||||
### Infer per slice ###
|
||||
global_frame = 0
|
||||
audio = []
|
||||
for (slice_tag, data) in audio_data:
|
||||
print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====")
|
||||
length = int(np.ceil(len(data) / audio_sr * target_sample))
|
||||
|
||||
if slice_tag:
|
||||
print("empty segment")
|
||||
_audio = np.zeros(length)
|
||||
audio.extend(list(pad_array(_audio, length)))
|
||||
global_frame += length // hop_length
|
||||
continue
|
||||
|
||||
datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size)
|
||||
|
||||
for k, dat in enumerate(datas):
|
||||
per_length = int(np.ceil(len(dat) / audio_sr * target_sample)) if clip_seconds!=0 else length
|
||||
pad_len = int(audio_sr * pad_seconds)
|
||||
dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
|
||||
raw_path = io.BytesIO()
|
||||
soundfile.write(raw_path, dat, audio_sr, format="wav")
|
||||
raw_path.seek(0)
|
||||
|
||||
### Infer START ###
|
||||
wav, sr = preprocess.load_audiofile(raw_path)
|
||||
wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0]
|
||||
wav16k, f0, uv = preprocess.get_unit_f0(wav, tran, hop_length, target_sample)
|
||||
sid = get_sid(spk2id, speaker)
|
||||
n_frames = f0.shape[1]
|
||||
|
||||
# ContentVec infer
|
||||
start = time.time()
|
||||
c = encoder.encode(wav16k)
|
||||
c = repeat_expand_2d_left(c.squeeze(0).realize(), f0.shape[1]) # interpolate speech encoding to match f0
|
||||
c = c.unsqueeze(0).realize()
|
||||
enc_time = time.time() - start
|
||||
|
||||
# VITS infer
|
||||
vits_start = time.time()
|
||||
out_audio, f0 = net_g.infer(c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None)
|
||||
out_audio = out_audio[0,0].float().realize()
|
||||
vits_time = time.time() - vits_start
|
||||
|
||||
infer_time = time.time() - start
|
||||
logging.info("total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format(infer_time, enc_time, vits_time))
|
||||
### Infer END ###
|
||||
|
||||
out_sr, out_frame = out_audio.shape[-1], n_frames
|
||||
global_frame += out_frame
|
||||
_audio = out_audio.numpy()
|
||||
pad_len = int(target_sample * pad_seconds)
|
||||
_audio = _audio[pad_len:-pad_len]
|
||||
_audio = pad_array(_audio, per_length)
|
||||
audio.extend(list(_audio))
|
||||
|
||||
audio = np.array(audio)
|
||||
out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav")
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
soundfile.write(out_path, audio, target_sample, format="flac")
|
||||
logging.info(f"Saved audio output to {out_path}")
|
||||
@@ -1,204 +0,0 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from tinygrad import Tensor, dtypes
|
||||
import librosa
|
||||
import soundfile
|
||||
import numpy as np
|
||||
import parselmouth
|
||||
|
||||
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||
def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
|
||||
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
|
||||
def interpolate_f0(self,f0):
|
||||
vuv_vector = np.zeros_like(f0, dtype=np.float32)
|
||||
vuv_vector[f0 > 0.0] = 1.0
|
||||
vuv_vector[f0 <= 0.0] = 0.0
|
||||
nzindex = np.nonzero(f0)[0]
|
||||
data = f0[nzindex]
|
||||
nzindex = nzindex.astype(np.float32)
|
||||
time_org = self.hop_length / self.sampling_rate * nzindex
|
||||
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
|
||||
if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
|
||||
if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
|
||||
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
|
||||
return f0,vuv_vector
|
||||
def compute_f0(self,wav,p_len=None):
|
||||
x = wav
|
||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = parselmouth.Sound(x, self.sampling_rate) \
|
||||
.to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
|
||||
.selected_array['frequency']
|
||||
pad_size=(p_len - len(f0) + 1) // 2
|
||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
||||
f0,uv = self.interpolate_f0(f0)
|
||||
return f0
|
||||
def compute_f0_uv(self,wav,p_len=None):
|
||||
x = wav
|
||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
|
||||
time_step=time_step / 1000, voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
|
||||
pad_size=(p_len - len(f0) + 1) // 2
|
||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
||||
f0,uv = self.interpolate_f0(f0)
|
||||
return f0,uv
|
||||
|
||||
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||
def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
|
||||
if not max_sil_kept >= hop_size:
|
||||
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
|
||||
min_interval = sr * min_interval / 1000
|
||||
self.threshold = 10 ** (threshold / 20.)
|
||||
self.hop_size = round(sr * hop_size / 1000)
|
||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||
self.min_interval = round(min_interval / self.hop_size)
|
||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||
def _apply_slice(self, waveform, begin, end):
|
||||
if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
|
||||
else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
|
||||
def slice(self, waveform):
|
||||
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
|
||||
if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
||||
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
sil_tags, silence_start, clip_start = [], None, 0
|
||||
for i, rms in enumerate(rms_list):
|
||||
if rms < self.threshold: # Keep looping while frame is silent.
|
||||
if silence_start is None: # Record start of silent frames.
|
||||
silence_start = i
|
||||
continue
|
||||
if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded.
|
||||
# Clear recorded silence start if interval is not enough or clip is too short
|
||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
||||
if not is_leading_silence and not need_slice_middle:
|
||||
silence_start = None
|
||||
continue
|
||||
if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed.
|
||||
pos = rms_list[silence_start: i + 1].argmin() + silence_start
|
||||
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
||||
clip_start = pos
|
||||
elif i - silence_start <= self.max_sil_kept * 2:
|
||||
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
|
||||
pos += i - self.max_sil_kept
|
||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
clip_start = pos_r
|
||||
else:
|
||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||
clip_start = max(pos_r, pos)
|
||||
else:
|
||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
||||
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
|
||||
clip_start = pos_r
|
||||
silence_start = None
|
||||
total_frames = rms_list.shape[0]
|
||||
if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence.
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices.
|
||||
chunks = []
|
||||
if sil_tags[0][0]:
|
||||
chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
|
||||
for i in range(0, len(sil_tags)):
|
||||
if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
|
||||
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
|
||||
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
||||
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
|
||||
chunk_dict = {}
|
||||
for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
|
||||
return chunk_dict
|
||||
|
||||
# sinc_interp_hann audio resampling
|
||||
class Resample:
|
||||
def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
|
||||
self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
|
||||
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
|
||||
self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
|
||||
def __call__(self, waveform:Tensor) -> Tensor:
|
||||
if self.orig_freq == self.new_freq: return waveform
|
||||
return self._apply_sinc_resample_kernel(waveform)
|
||||
def _apply_sinc_resample_kernel(self, waveform:Tensor):
|
||||
if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
||||
shape = waveform.shape
|
||||
waveform = waveform.reshape(-1, shape[-1]) # pack batch
|
||||
num_wavs, length = waveform.shape
|
||||
target_length = int(math.ceil(new_freq * length / orig_freq))
|
||||
waveform = waveform.pad((self.width, self.width + orig_freq))
|
||||
resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq)
|
||||
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
|
||||
resampled = resampled[..., :target_length]
|
||||
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
|
||||
return resampled
|
||||
def _get_sinc_resample_kernel(self, dtype=None):
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
||||
if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
|
||||
base_freq = min(orig_freq, new_freq)
|
||||
base_freq *= self.rolloff
|
||||
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
|
||||
idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
|
||||
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
|
||||
t *= base_freq
|
||||
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
|
||||
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
|
||||
t *= math.pi
|
||||
scale = base_freq / orig_freq
|
||||
kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
|
||||
kernels *= window * scale
|
||||
if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
|
||||
return kernels, width
|
||||
|
||||
def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
|
||||
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
|
||||
return resamp(x)
|
||||
|
||||
def cut(audio_path, db_thresh=-30, min_len=5000):
|
||||
audio, sr = librosa.load(audio_path, sr=None)
|
||||
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
|
||||
chunks = slicer.slice(audio)
|
||||
return chunks
|
||||
|
||||
def chunks2audio(audio_path, chunks):
|
||||
chunks = dict(chunks)
|
||||
audio, sr = load_audiofile(audio_path)
|
||||
if len(audio.shape) == 2 and audio.shape[1] >= 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = audio.numpy()[0]
|
||||
result = []
|
||||
for k, v in chunks.items():
|
||||
tag = v["split_time"].split(",")
|
||||
if tag[0] != tag[1]:
|
||||
result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
|
||||
return result, sr
|
||||
|
||||
def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
|
||||
with soundfile.SoundFile(filepath, "r") as file_:
|
||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||
waveform = file_.read(frames, "float32", always_2d=True)
|
||||
sample_rate = file_.samplerate
|
||||
waveform = Tensor(waveform)
|
||||
if channels_first: waveform = waveform.transpose(0, 1)
|
||||
return waveform, sample_rate
|
||||
|
||||
def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
|
||||
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
|
||||
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
|
||||
if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
|
||||
f0 = Tensor(f0.astype(np.float32)).float()
|
||||
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
|
||||
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
|
||||
wav16k = sinc_interp_resample(wav[None,:], target_sample, 16000)[0]
|
||||
return wav16k.realize(), f0.realize(), uv.realize()
|
||||
@@ -1,104 +0,0 @@
|
||||
import traceback
|
||||
import time
|
||||
from multiprocessing import Process, Queue
|
||||
import numpy as np
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv, trange
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.datasets import fetch_cifar
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
||||
class TinyConvNet:
|
||||
def __init__(self, classes=10):
|
||||
conv = 3
|
||||
inter_chan, out_chan = 8, 16 # for speed
|
||||
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
|
||||
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
|
||||
self.l1 = Tensor.uniform(out_chan*6*6, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
IMAGENET = getenv("IMAGENET")
|
||||
classes = 1000 if IMAGENET else 10
|
||||
|
||||
TINY = getenv("TINY")
|
||||
TRANSFER = getenv("TRANSFER")
|
||||
if TINY:
|
||||
model = TinyConvNet(classes)
|
||||
elif TRANSFER:
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
|
||||
model.load_from_pretrained()
|
||||
else:
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
|
||||
|
||||
parameters = get_parameters(model)
|
||||
print("parameter count", len(parameters))
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
|
||||
BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
|
||||
print(f"training with batch size {BS} for {steps} steps")
|
||||
|
||||
if IMAGENET:
|
||||
from extra.datasets.imagenet import fetch_batch
|
||||
def loader(q):
|
||||
while 1:
|
||||
try:
|
||||
q.put(fetch_batch(BS))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
q = Queue(16)
|
||||
for i in range(2):
|
||||
p = Process(target=loader, args=(q,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
else:
|
||||
X_train, Y_train, _, _ = fetch_cifar()
|
||||
X_train = X_train.reshape((-1, 3, 32, 32))
|
||||
Y_train = Y_train.reshape((-1,))
|
||||
|
||||
with Tensor.train():
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
|
||||
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
fp_time = (time.time()-st)*1000.0
|
||||
|
||||
y = np.zeros((BS,classes), np.float32)
|
||||
y[range(y.shape[0]),Y] = -classes
|
||||
y = Tensor(y, requires_grad=False)
|
||||
loss = out.log_softmax().mul(y).mean()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
bp_time = (time.time()-st)*1000.0
|
||||
|
||||
st = time.time()
|
||||
optimizer.step()
|
||||
opt_time = (time.time()-st)*1000.0
|
||||
|
||||
st = time.time()
|
||||
loss = loss.numpy()
|
||||
cat = out.argmax(axis=1).numpy()
|
||||
accuracy = (cat == Y).mean()
|
||||
finish_time = (time.time()-st)*1000.0
|
||||
|
||||
# printing
|
||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
||||
(loss, accuracy,
|
||||
fp_time, bp_time, opt_time, finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time))
|
||||
|
||||
del out, y, loss
|
||||
@@ -1,46 +0,0 @@
|
||||
import ast
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
from extra.models.vit import ViT
|
||||
"""
|
||||
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
||||
import tensorflow as tf
|
||||
with tf.io.gfile.GFile(fn, "rb") as f:
|
||||
dat = f.read()
|
||||
with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
|
||||
g.write(dat)
|
||||
"""
|
||||
|
||||
Tensor.training = False
|
||||
if getenv("LARGE", 0) == 1:
|
||||
m = ViT(embed_dim=768, num_heads=12)
|
||||
else:
|
||||
# tiny
|
||||
m = ViT(embed_dim=192, num_heads=3)
|
||||
m.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
|
||||
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
||||
|
||||
# junk
|
||||
img = Image.open(fetch(url))
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
img = np.array(img)
|
||||
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
||||
img = img[y0:y0+224, x0:x0+224]
|
||||
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
|
||||
img /= 255.0
|
||||
img -= 0.5
|
||||
img /= 0.5
|
||||
|
||||
out = m.forward(Tensor(img))
|
||||
outnp = out.numpy().ravel()
|
||||
choice = outnp.argmax()
|
||||
print(out.shape, choice, outnp[choice], lbls[choice])
|
||||
@@ -1,189 +0,0 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from tinygrad.codegen.opt.kernel import Ops, MemOp, UOp
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.uop.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
scalar:bool
|
||||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes.float.vec(4):
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: Ops
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyLanguage:
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
#TODO: these should be global vars
|
||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
|
||||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
self.cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return self.tor[key]
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def addr_w_offset(self, args):
|
||||
assert isinstance(args, MemOp)
|
||||
idx = args.idx*args.memory_dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = cast(int, nums[0])
|
||||
reg = idx.render(self.render_ops, self)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
|
||||
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
#TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == Ops.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == Ops.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == Ops.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
elif args[1] == 'local':
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == Ops.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == Ops.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == Ops.DEFINE_REG:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
|
||||
elif uop == Ops.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == Ops.CONST:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == Ops.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == Ops.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
return global_size, local_size
|
||||
@@ -1,177 +0,0 @@
|
||||
import struct
|
||||
from platform import system
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.codegen.opt.kernel import Ops, UOp
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
def compute_offsets(total):
|
||||
quotient, remainder = divmod(total, 4096)
|
||||
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
||||
|
||||
#NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
||||
|
||||
class ARM64Language(AssemblyLanguage): pass
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop:Optional[Ops] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
||||
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
||||
|
||||
def mov_imm(value, reg):
|
||||
# Manually move value into reg if value can't fit
|
||||
if value.__class__ is not float and abs(value) > abs(65535):
|
||||
ins.append(f"movz w15, #{value & 0xffff}")
|
||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {reg}, w15")
|
||||
elif reg[0] == 's':
|
||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append("str x15, [sp, 16]")
|
||||
ins.append(f"ldr {reg}, [sp, 16]")
|
||||
else:
|
||||
ins.append(f"mov {reg}, #{value}")
|
||||
|
||||
# Get variables intervals
|
||||
live_range:Dict[str, List[int]] = {}
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
|
||||
mem_vars:Dict[str, int] = {}
|
||||
rtor:Dict[str, str] = {}
|
||||
def allocate_regs(mvars):
|
||||
nonlocal var_size
|
||||
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if not available_regs:
|
||||
# ARM needs the stack 16-byte aligned
|
||||
var_size += 16
|
||||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
||||
mem_vars[v.nm] = var_size
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
||||
temp_floats = ['s0', 's1', 's2']
|
||||
temp_ints = ['x12', 'x13', 'x16']
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
# Clear regs out of interval
|
||||
for var, reg in list(rtor.items()):
|
||||
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs([out] + vin)
|
||||
# Assign temp regs to vin and load them before direct use
|
||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
||||
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == Ops.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == 's':
|
||||
mov_imm(0.0, 's0')
|
||||
mov_imm(1.0, 's1')
|
||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||
if rtor[out.nm][0] == 'x':
|
||||
mov_imm(0, 'x14')
|
||||
mov_imm(1, 'x15')
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == Ops.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif arg == TernaryOps.WHERE:
|
||||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
else:
|
||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||
# Save the registers before they are cleared by func call
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append("stp x29, x30, [sp, #0]!")
|
||||
ins.append("mov x29, sp")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
ins.append(alu[arg])
|
||||
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||
ins.append("mov sp, x29")
|
||||
ins.append("ldp x29, x30, [sp], #0")
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg == BinaryOps.CMPLT:
|
||||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||
elif arg == BinaryOps.MOD:
|
||||
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == Ops.STORE:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == Ops.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == Ops.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"b.lt loop_{arg[1]}")
|
||||
prev_uop = uop
|
||||
# store regs into memory if needed
|
||||
if out is not None and out.nm in mem_vars:
|
||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
||||
|
||||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
||||
lang = ARM64Language()
|
||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
||||
@@ -1,105 +0,0 @@
|
||||
from typing import List
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
from tinygrad.codegen.opt.kernel import Ops, UOp
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
||||
|
||||
def render_cast(ins, inp, out):
|
||||
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
||||
elif out.dtype == dtypes.bool:
|
||||
if inp.dtype == dtypes.bool:
|
||||
ins.append(f"mov.pred {out}, {inp};")
|
||||
else:
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
||||
else:
|
||||
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
||||
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
supports_constant_folding: bool = True
|
||||
|
||||
def specialize_to_ptx(lang, function_name):
|
||||
param_cnt = 0
|
||||
ins = []
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == Ops.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == Ops.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == Ops.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
|
||||
if arg == TernaryOps.WHERE:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
reg = vin[0]
|
||||
else:
|
||||
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == Ops.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
render_cast(ins, vin[1], prereg)
|
||||
else: prereg = vin[1]
|
||||
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == Ops.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
||||
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
ins = ins_prefix + ins
|
||||
ins += ["ret;", "}"]
|
||||
return '\n'.join(ins)
|
||||
|
||||
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
|
||||
lang = PTXLanguage()
|
||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
||||
@@ -1,203 +0,0 @@
|
||||
import yaml
|
||||
from typing import Tuple, Set, Dict
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.opt.kernel import Ops
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH
|
||||
|
||||
# ugh, is this really needed?
|
||||
from extra.helpers import enable_early_exec
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
boilerplate_start = """
|
||||
.global _start
|
||||
_start:
|
||||
.rodata
|
||||
.align 0x10
|
||||
.global code.kd
|
||||
.type code.kd,STT_OBJECT
|
||||
.amdhsa_kernel code"""
|
||||
|
||||
code_start = """.end_amdhsa_kernel
|
||||
.text
|
||||
code:
|
||||
"""
|
||||
|
||||
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
||||
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
||||
# RDNA3 is actually a SIMD machine!
|
||||
class RDNACodegen(AssemblyCodegen):
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
supports_load3: bool = True
|
||||
sin_is_sin2pi: bool = True
|
||||
no_div: bool = True
|
||||
|
||||
def specialize(self, asm) -> Tuple[str, str]:
|
||||
args = []
|
||||
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
||||
ins = []
|
||||
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
|
||||
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPLT: "cmp_lt"}
|
||||
|
||||
pend_regs:Set[Register] = set()
|
||||
rtor:Dict[Register, str] = {}
|
||||
def reg_in(x):
|
||||
nonlocal pend_regs
|
||||
#print("reg_in", x, rtor[x], pend_regs)
|
||||
if x in pend_regs:
|
||||
#print("clear")
|
||||
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
||||
pend_regs.clear()
|
||||
return rtor[x]
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == Ops.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
if arg[0][1]:
|
||||
s_cnt += s_cnt % align
|
||||
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
||||
s_cnt += align
|
||||
else:
|
||||
v_cnt += v_cnt % align
|
||||
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
||||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
|
||||
if arg[0][0] == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
||||
elif arg[0][0] == dtypes.bool:
|
||||
for i in range(arg[2]):
|
||||
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
||||
# the docs lied, this is actually y
|
||||
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
# get local size
|
||||
offset = len(args)*8
|
||||
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
||||
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
||||
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
elif uop == Ops.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == Ops.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == Ops.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
else:
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == Ops.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == Ops.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
else:
|
||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||
else:
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
|
||||
# dual alu group
|
||||
seen = set()
|
||||
new_ins = []
|
||||
for i,tins in enumerate(ins):
|
||||
if tins in seen: continue
|
||||
if tins.startswith("v_fmac_f32"):
|
||||
for gins in reversed(ins[i+1:]):
|
||||
if gins in seen: continue
|
||||
if gins.startswith("v_fmac_f32"):
|
||||
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
||||
if r0[0]%2 == r1[0]%2: continue
|
||||
if r0[1]%2 == r1[1]%2: continue
|
||||
if r0[2]%2 == r1[2]%2: continue
|
||||
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
||||
seen.add(tins)
|
||||
seen.add(gins)
|
||||
break
|
||||
if tins not in seen:
|
||||
new_ins.append(tins)
|
||||
ins = new_ins
|
||||
|
||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
||||
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
||||
'.amdhsa_next_free_sgpr': s_cnt,
|
||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
||||
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
||||
|
||||
metadata = {'amdhsa.kernels': [{'.args': args,
|
||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
||||
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
||||
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
||||
'.wavefront_size': 32}],
|
||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
||||
|
||||
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
||||
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
||||
return asm
|
||||
@@ -1,23 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg = CUDAProgram("test", """
|
||||
.version 7.8
|
||||
.target sm_86
|
||||
.address_size 64
|
||||
.visible .entry test(.param .u64 x) {
|
||||
.reg .b32 %r<2>;
|
||||
.reg .b64 %rd<3>;
|
||||
|
||||
ld.param.u64 %rd1, [x];
|
||||
cvta.to.global.u64 %rd2, %rd1;
|
||||
mov.u32 %r1, 0x40000000; // 2.0 in float
|
||||
st.global.u32 [%rd2], %r1;
|
||||
ret;
|
||||
}""", binary=True)
|
||||
prg([1], [1], test)
|
||||
print(test.toCPU())
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import sys
|
||||
cwd = Path.cwd()
|
||||
sys.path.append(cwd.as_posix())
|
||||
sys.path.append((cwd / 'test').as_posix())
|
||||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
def augment_img(X, rotate=10, px=3):
|
||||
Xaug = np.zeros_like(X)
|
||||
for i in trange(len(X)):
|
||||
im = Image.fromarray(X[i])
|
||||
im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC)
|
||||
w, h = X.shape[1:]
|
||||
#upper left, lower left, lower right, upper right
|
||||
quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0])
|
||||
im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC)
|
||||
Xaug[i] = im
|
||||
return Xaug
|
||||
|
||||
if __name__ == "__main__":
|
||||
import matplotlib.pyplot as plt
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10)
|
||||
fig, a = plt.subplots(2,len(X))
|
||||
Xaug = augment_img(X)
|
||||
for i in range(len(X)):
|
||||
a[0][i].imshow(X[i], cmap='gray')
|
||||
a[1][i].imshow(Xaug[i],cmap='gray')
|
||||
a[0][i].axis('off')
|
||||
a[1][i].axis('off')
|
||||
plt.show()
|
||||
|
||||
#create some nice gifs for doc?!
|
||||
for i in range(10):
|
||||
im = Image.fromarray(X_train[7353+i])
|
||||
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
|
||||
im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)
|
||||
@@ -1,39 +0,0 @@
|
||||
from typing import List, Dict, cast
|
||||
import ctypes
|
||||
from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.runtime.ops_cpu import ClangProgram
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
render_dtype = ClangRenderer().render_dtype
|
||||
|
||||
class ClangGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
||||
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
||||
args += sorted([f"int {v}" for v in var_vals])
|
||||
code = ["void batched("+','.join(args)+") {"]
|
||||
for ji in jit_cache:
|
||||
args = []
|
||||
for buf in ji.bufs:
|
||||
assert buf is not None
|
||||
if buf in input_rawbuffers:
|
||||
args.append(f"arg{input_rawbuffers.index(buf)}")
|
||||
else:
|
||||
args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
|
||||
args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
||||
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
|
||||
code.append("}")
|
||||
if DEBUG >= 4: print("\n".join(code))
|
||||
compiler = Device["CPU"].compiler
|
||||
assert compiler is not None
|
||||
self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
||||
|
||||
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[str, int], wait=False):
|
||||
return cpu_time_execution(
|
||||
lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0])]), enable=wait)
|
||||
@@ -1,27 +0,0 @@
|
||||
import ctypes
|
||||
from typing import Tuple
|
||||
import tinygrad.runtime.autogen.hip as hip
|
||||
from tinygrad.helpers import init_c_var, time_execution_cuda_style
|
||||
from tinygrad.runtime.ops_hip import check, hip_set_device
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
|
||||
# TODO: this is only used in graph
|
||||
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
|
||||
|
||||
class HIPGraph(CUDAGraph):
|
||||
def __del__(self):
|
||||
if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
|
||||
if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance))
|
||||
def set_device(self): hip_set_device(self.dev)
|
||||
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
|
||||
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph):
|
||||
return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
|
||||
def graph_add_kernel_node(self, graph, c_deps, c_params):
|
||||
return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501
|
||||
def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
|
||||
def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
|
||||
def build_kernel_node_params(self, prg, global_size, local_size, c_config):
|
||||
return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
||||
node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size
|
||||
@@ -1,143 +0,0 @@
|
||||
import ctypes, collections
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import init_c_var
|
||||
|
||||
def check(status):
|
||||
if status != 0:
|
||||
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
||||
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
|
||||
|
||||
# Precalulated AQL info
|
||||
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
|
||||
EMPTY_SIGNAL = hsa.hsa_signal_t()
|
||||
|
||||
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
|
||||
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
||||
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
||||
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
||||
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
|
||||
|
||||
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
||||
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
||||
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
||||
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
|
||||
|
||||
class AQLQueue:
|
||||
def __init__(self, device, sz=-1):
|
||||
self.device = device
|
||||
|
||||
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
|
||||
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
|
||||
|
||||
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
|
||||
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
|
||||
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
|
||||
|
||||
self.next_doorbell_index = 0
|
||||
self.queue_base = self.hw_queue.contents.base_address
|
||||
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
|
||||
self.write_addr = self.queue_base
|
||||
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
|
||||
self.available_packet_slots = self.hw_queue.contents.size
|
||||
|
||||
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
|
||||
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
|
||||
|
||||
def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
|
||||
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
||||
packet.workgroup_size_x = local_size[0]
|
||||
packet.workgroup_size_y = local_size[1]
|
||||
packet.workgroup_size_z = local_size[2]
|
||||
packet.reserved0 = 0
|
||||
packet.grid_size_x = global_size[0] * local_size[0]
|
||||
packet.grid_size_y = global_size[1] * local_size[1]
|
||||
packet.grid_size_z = global_size[2] * local_size[2]
|
||||
packet.private_segment_size = prg.private_segment_size
|
||||
packet.group_segment_size = prg.group_segment_size
|
||||
packet.kernel_object = prg.handle
|
||||
packet.kernarg_address = kernargs
|
||||
packet.reserved2 = 0
|
||||
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
||||
packet.setup = DISPATCH_KERNEL_SETUP
|
||||
packet.header = DISPATCH_KERNEL_HEADER
|
||||
self._submit_packet()
|
||||
|
||||
def submit_barrier(self, wait_signals=None, completion_signal=None):
|
||||
assert wait_signals is None or len(wait_signals) <= 5
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
|
||||
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
||||
packet.reserved0 = 0
|
||||
packet.reserved1 = 0
|
||||
for i in range(5):
|
||||
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
|
||||
packet.reserved2 = 0
|
||||
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
||||
packet.header = BARRIER_HEADER
|
||||
self._submit_packet()
|
||||
|
||||
def blit_packets(self, packet_addr, packet_cnt):
|
||||
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
||||
|
||||
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
|
||||
rem_packet_cnt = packet_cnt - tail_blit_packets
|
||||
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
|
||||
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
||||
|
||||
self._submit_packet(packet_cnt)
|
||||
|
||||
def wait(self):
|
||||
self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
|
||||
hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
|
||||
|
||||
def _wait_queue(self, need_packets=1):
|
||||
while self.available_packet_slots < need_packets:
|
||||
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
||||
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
|
||||
|
||||
def _submit_packet(self, cnt=1):
|
||||
self.available_packet_slots -= cnt
|
||||
self.next_doorbell_index += cnt
|
||||
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
|
||||
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
|
||||
|
||||
self.write_addr += AQL_PACKET_SIZE * cnt
|
||||
if self.write_addr > self.write_addr_end:
|
||||
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
|
||||
|
||||
def scan_agents():
|
||||
agents = collections.defaultdict(list)
|
||||
|
||||
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
|
||||
def __scan_agents(agent, data):
|
||||
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
|
||||
if status == 0: agents[device_type.value].append(agent)
|
||||
return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
hsa.hsa_iterate_agents(__scan_agents, None)
|
||||
return agents
|
||||
|
||||
def find_memory_pool(agent, segtyp=-1, location=-1):
|
||||
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
|
||||
def __filter_amd_memory_pools(mem_pool, data):
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
|
||||
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
|
||||
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
|
||||
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
|
||||
ret[0] = mem_pool
|
||||
return hsa.HSA_STATUS_INFO_BREAK
|
||||
|
||||
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
|
||||
return region
|
||||
@@ -1,171 +0,0 @@
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple
|
||||
from tinygrad.helpers import init_c_var, round_up
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.device import Compiled, Device
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.runtime.support.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
||||
|
||||
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
|
||||
|
||||
class VirtAQLQueue(AQLQueue):
|
||||
def __init__(self, device, sz):
|
||||
self.device = device
|
||||
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
|
||||
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
|
||||
self.packets_count = 0
|
||||
self.available_packet_slots = sz
|
||||
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
|
||||
def _submit_packet(self):
|
||||
self.write_addr += AQL_PACKET_SIZE
|
||||
self.packets_count += 1
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
class HSAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
compiled_devices = set()
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.dev)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
||||
else: raise GraphException
|
||||
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
||||
|
||||
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16)
|
||||
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferSpec()) for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
self.ji_kargs_structs[j] = ji.prg._prg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev])
|
||||
kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16)
|
||||
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
||||
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i].expr])
|
||||
|
||||
# Build queues.
|
||||
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
||||
self.packets = {}
|
||||
self.transfers = []
|
||||
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
||||
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
||||
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
||||
|
||||
# Special packet to wait for the world.
|
||||
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
|
||||
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.dev].submit_barrier(wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.dev].write_addr)
|
||||
|
||||
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
||||
self.virt_aql_queues[ji.prg.dev].submit_kernel(ji.prg._prg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
||||
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
||||
if PROFILE: self.profile_info[ji.prg.dev].append((sync_signal, ji.prg._prg.name, False))
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
||||
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
||||
|
||||
wait_signals = self.access_resources([dest, src], write=[0], new_dependency=sync_signal, sync_with_aql_packets=True)
|
||||
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
||||
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
|
||||
self.ji_to_transfer[j] = len(self.transfers) - 1
|
||||
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
||||
|
||||
# Wait for all active signals to finish the graph
|
||||
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
||||
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
||||
for dev in self.signals_to_devices[v.handle]:
|
||||
wait_signals_to_finish[dev].append(v)
|
||||
|
||||
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
for dev in self.devices:
|
||||
wait_signals = wait_signals_to_finish[dev]
|
||||
for i in range(0, max(1, len(wait_signals)), 5):
|
||||
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
|
||||
|
||||
# Zero signals to allow graph to start and execute.
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[str, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if j in self.ji_kargs_structs:
|
||||
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
else:
|
||||
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
|
||||
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
||||
|
||||
# Update var_vals
|
||||
for j in self.jc_idx_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v.expr])
|
||||
|
||||
# Update launch dims
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
||||
self.packets[j].workgroup_size_x = lc[0]
|
||||
self.packets[j].workgroup_size_y = lc[1]
|
||||
self.packets[j].workgroup_size_z = lc[2]
|
||||
self.packets[j].grid_size_x = gl[0] * lc[0]
|
||||
self.packets[j].grid_size_y = gl[1] * lc[1]
|
||||
self.packets[j].grid_size_z = gl[2] * lc[2]
|
||||
|
||||
for dev in self.devices:
|
||||
dev.flush_hdp()
|
||||
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
|
||||
|
||||
for transfer_data in self.transfers:
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
|
||||
|
||||
et = None
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
et = time.perf_counter() - st
|
||||
|
||||
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
||||
return et
|
||||
|
||||
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
||||
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
if reset_on_start: self.signals_to_reset.append(sync_signal)
|
||||
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
|
||||
return sync_signal
|
||||
|
||||
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
||||
if isinstance(dep, hsa.hsa_signal_t): return dep
|
||||
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
||||
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
|
||||
return packet.completion_signal
|
||||
return None
|
||||
|
||||
def access_resources(self, rawbufs, write, new_dependency, sync_with_aql_packets=False):
|
||||
rdeps = self._access_resources(rawbufs, write, new_dependency)
|
||||
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
|
||||
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in rawbufs]
|
||||
return dedup_signals(wait_signals)
|
||||
@@ -1,275 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, functools, subprocess, io, atexit, collections, json
|
||||
from typing import Tuple, TypeVar, List, Dict, Any
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv, PROFILE
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, BufferSpec, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.support.hsa import check, scan_agents, find_memory_pool, AQLQueue
|
||||
from tinygrad.runtime.support.hip_comgr import compile_hip
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401
|
||||
|
||||
class HSAProfiler:
|
||||
def __init__(self):
|
||||
self.tracked_signals = collections.defaultdict(list)
|
||||
self.collected_events: List[Tuple[Any, ...]] = []
|
||||
self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t()
|
||||
self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t()
|
||||
|
||||
def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy))
|
||||
def process(self, device):
|
||||
# Process all tracked signals, should be called before any of tracked signals are reused.
|
||||
for sig,name,is_copy in self.tracked_signals[device]:
|
||||
if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings := self.copy_timings)))
|
||||
else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore
|
||||
self.collected_events.append((device.device_id, 1 if is_copy else 0, name, timings.start, timings.end))
|
||||
self.tracked_signals.pop(device)
|
||||
|
||||
def save(self, path):
|
||||
mjson = []
|
||||
for i in range(len(HSADevice.devices)):
|
||||
mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}})
|
||||
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}})
|
||||
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}})
|
||||
|
||||
for dev_id,queue_id,name,st,et in self.collected_events:
|
||||
mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3})
|
||||
mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3})
|
||||
with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson}))
|
||||
print(f"Saved HSA profile to {path}")
|
||||
Profiler = HSAProfiler()
|
||||
|
||||
class HSACompiler(Compiler):
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
super().__init__(f"compile_hip_{self.arch}")
|
||||
def compile(self, src:str) -> bytes:
|
||||
try: return compile_hip(src, self.arch)
|
||||
except RuntimeError as e: raise CompileError(e)
|
||||
|
||||
class HSAProgram:
|
||||
def __init__(self, device:HSADevice, name:str, lib:bytes):
|
||||
self.device, self.name, self.lib = device, name, lib
|
||||
|
||||
if DEBUG >= 6:
|
||||
asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
|
||||
self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501
|
||||
self.code_reader = init_c_var(hsa.hsa_code_object_reader_t(),
|
||||
lambda x: check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(x))))
|
||||
check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, self.code_reader, None, None))
|
||||
check(hsa.hsa_executable_freeze(self.exec, None))
|
||||
|
||||
self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501
|
||||
self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501
|
||||
self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
|
||||
self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
|
||||
self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
|
||||
|
||||
def __del__(self):
|
||||
self.device.synchronize()
|
||||
if hasattr(self, 'code_reader'): check(hsa.hsa_code_object_reader_destroy(self.code_reader))
|
||||
if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec))
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if not hasattr(self, "args_struct_t"):
|
||||
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
|
||||
if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
|
||||
raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
|
||||
|
||||
kernargs = None
|
||||
if self.kernargs_segment_size > 0:
|
||||
kernargs = self.device.alloc_kernargs(self.kernargs_segment_size)
|
||||
args_st = self.args_struct_t.from_address(kernargs)
|
||||
for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i])
|
||||
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
|
||||
self.device.flush_hdp()
|
||||
|
||||
signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None
|
||||
self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal)
|
||||
if PROFILE: Profiler.track(signal, self.device, self.name)
|
||||
if wait:
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
|
||||
return (timings.end - timings.start) * self.device.clocks_to_time
|
||||
|
||||
T = TypeVar("T")
|
||||
CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
|
||||
class HSAAllocator(LRUAllocator):
|
||||
def __init__(self, device:HSADevice):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
|
||||
def _alloc(self, size:int, options:BufferSpec):
|
||||
if options.host:
|
||||
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
|
||||
return mem.value
|
||||
c_agents = (hsa.hsa_agent_t * len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]))(*HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU])
|
||||
check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf))
|
||||
return buf.value
|
||||
|
||||
def _free(self, opaque:T, options:BufferSpec):
|
||||
HSADevice.synchronize_system()
|
||||
check(hsa.hsa_amd_memory_pool_free(opaque))
|
||||
|
||||
def _copyin(self, dest:T, src: memoryview):
|
||||
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
mem = self._alloc(src.nbytes, BufferSpec(host=True))
|
||||
ctypes.memmove(mem, from_mv(src), src.nbytes)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
|
||||
copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
self.device.hw_queue.submit_barrier([copy_signal])
|
||||
self.device.delayed_free.append(mem)
|
||||
if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
|
||||
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
|
||||
if not hasattr(self, 'hb'):
|
||||
self.hb = [self._alloc(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)]
|
||||
self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
|
||||
self.hb_polarity = 0
|
||||
self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
|
||||
for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0)
|
||||
|
||||
fo = io.FileIO(fd, "a+b", closefd=False)
|
||||
fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
|
||||
|
||||
copies_called = 0
|
||||
copied_in = 0
|
||||
for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
|
||||
local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
|
||||
copy_size = min(local_size-minor_offset, size-copied_in)
|
||||
if copy_size == 0: break
|
||||
|
||||
hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused
|
||||
self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False)
|
||||
|
||||
fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent,
|
||||
copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity],
|
||||
self.sdma[self.hb_polarity], True))
|
||||
copied_in += copy_size
|
||||
self.hb_polarity = (self.hb_polarity + 1) % len(self.hb)
|
||||
minor_offset = 0 # only on the first
|
||||
copies_called += 1
|
||||
|
||||
wait_signals = [self.hb_signals[self.hb_polarity - 1]]
|
||||
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
|
||||
self.device.hw_queue.submit_barrier(wait_signals)
|
||||
|
||||
def _copyout(self, dest:memoryview, src:T):
|
||||
HSADevice.synchronize_system()
|
||||
copy_signal = self.device.alloc_signal(reusable=True)
|
||||
c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent)
|
||||
check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
|
||||
hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
|
||||
if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
|
||||
|
||||
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
|
||||
src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True))
|
||||
dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True))
|
||||
c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal,
|
||||
copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
src_dev.hw_queue.submit_barrier([copy_signal])
|
||||
dest_dev.hw_queue.submit_barrier([copy_signal])
|
||||
if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
|
||||
|
||||
class HSADevice(Compiled):
|
||||
devices: List[HSADevice] = []
|
||||
agents: Dict[int, List[hsa.hsa_agent_t]] = {}
|
||||
cpu_agent: hsa.hsa_agent_t
|
||||
cpu_mempool: hsa.hsa_amd_memory_pool_t
|
||||
def __init__(self, device:str=""):
|
||||
if not HSADevice.agents:
|
||||
check(hsa.hsa_init())
|
||||
atexit.register(hsa_terminate)
|
||||
HSADevice.agents = scan_agents()
|
||||
HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0]
|
||||
HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
|
||||
if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1))
|
||||
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id]
|
||||
self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
|
||||
self.hw_queue = AQLQueue(self)
|
||||
HSADevice.devices.append(self)
|
||||
|
||||
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
|
||||
self.arch = ctypes.string_at(agent_name_buf).decode()
|
||||
|
||||
check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64())))
|
||||
self.clocks_to_time: float = 1 / gpu_freq.value
|
||||
|
||||
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AMD_AGENT_INFO_HDP_FLUSH, ctypes.byref(hdp_flush := hsa.hsa_amd_hdp_flush_t())))
|
||||
self.hdp_flush = hdp_flush
|
||||
|
||||
self.delayed_free: List[int] = []
|
||||
self.reusable_signals: List[hsa.hsa_signal_t] = []
|
||||
|
||||
from tinygrad.runtime.graph.hsa import HSAGraph
|
||||
super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
|
||||
|
||||
# Finish init: preallocate some signals + space for kernargs
|
||||
self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)]
|
||||
self._new_kernargs_region(16 << 20) # initial region size is 16mb
|
||||
|
||||
def synchronize(self):
|
||||
self.hw_queue.wait()
|
||||
|
||||
for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
||||
self.signal_pool.extend(self.reusable_signals)
|
||||
self.reusable_signals.clear()
|
||||
|
||||
for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free))
|
||||
self.delayed_free.clear()
|
||||
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
Profiler.process(self)
|
||||
|
||||
@staticmethod
|
||||
def synchronize_system():
|
||||
for d in HSADevice.devices: d.synchronize()
|
||||
|
||||
def alloc_signal(self, reusable=False):
|
||||
if len(self.signal_pool): signal = self.signal_pool.pop()
|
||||
else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
|
||||
|
||||
# reusable means a signal could be reused after synchronize for the device it's allocated from is called.
|
||||
if reusable: self.reusable_signals.append(signal)
|
||||
return signal
|
||||
|
||||
def alloc_kernargs(self, sz):
|
||||
if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2))
|
||||
result = self.kernarg_next_addr
|
||||
self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16)
|
||||
return result
|
||||
|
||||
def _new_kernargs_region(self, sz:int):
|
||||
if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr)
|
||||
self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferSpec())
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
self.kernarg_pool_sz: int = sz
|
||||
|
||||
def flush_hdp(self): self.hdp_flush.HDP_MEM_FLUSH_CNTL[0] = 1
|
||||
|
||||
def hsa_terminate():
|
||||
# Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs.
|
||||
for dev in HSADevice.devices:
|
||||
Profiler.process(dev)
|
||||
del dev.hw_queue
|
||||
|
||||
# hsa_shut_down cleans up all hsa-related resources.
|
||||
hsa.hsa_shut_down()
|
||||
HSADevice.synchronize = lambda: None #type:ignore
|
||||
HSAProgram.__del__ = lambda _: None #type:ignore
|
||||
if Profiler.collected_events: Profiler.save("/tmp/profile.json")
|
||||
@@ -1,127 +0,0 @@
|
||||
from typing import Dict, Set
|
||||
import yaml
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
|
||||
from tinygrad.uop.ops import BinaryOps
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
|
||||
replace: Dict[UOp, UOp] = {}
|
||||
seen: Set[UOp] = set()
|
||||
for u in uops:
|
||||
if u in seen: continue
|
||||
seen.add(u)
|
||||
for o,n in replace.items():
|
||||
if o in u.vin and u is not n:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
# pointer indexing
|
||||
if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
|
||||
val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_at=uops.uops.index(u))
|
||||
ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_at=uops.uops.index(u))
|
||||
u.vin = (u.vin[0], ptr) + u.vin[2:]
|
||||
#uops.print()
|
||||
|
||||
args = []
|
||||
ins = []
|
||||
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
|
||||
r: Dict[UOp, str] = {}
|
||||
for u in uops:
|
||||
if u.uop == UOps.SPECIAL:
|
||||
if u.arg.startswith("lidx"):
|
||||
r[u] = f'v{u.src[0].arg}'
|
||||
elif u.arg.startswith("gidx"):
|
||||
r[u] = f's{2+u.src[0].arg}'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif u.uop == UOps.CONST:
|
||||
#r[u] = u.arg
|
||||
|
||||
# TODO: sometimes we can use s
|
||||
#r[u] = f"s{s_cnt}"
|
||||
#s_cnt += 1
|
||||
#ins.append(f"s_mov_b32 {r[u]}, {u.arg}")
|
||||
|
||||
r[u] = f"v{v_cnt}"
|
||||
v_cnt += 1
|
||||
ins.append(f"v_mov_b32 {r[u]}, {u.arg}")
|
||||
elif u.uop == UOps.ALU:
|
||||
if u.arg == BinaryOps.ADD:
|
||||
r[u] = f"v{v_cnt}"
|
||||
v_cnt += 1
|
||||
ins.append(f"v_add_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
|
||||
elif u.arg == BinaryOps.MUL:
|
||||
r[u] = f"v{v_cnt}"
|
||||
v_cnt += 1
|
||||
if dtypes.is_float(u.dtype):
|
||||
ins.append(f"v_mul_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
|
||||
else:
|
||||
ins.append(f"v_mul_u32_u24 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif u.uop == UOps.LOAD:
|
||||
r[u] = f"v{v_cnt}"
|
||||
v_cnt += 1
|
||||
ins.append(f"global_load_b32 {r[u]}, {r[u.vin[1]]}, {r[u.vin[0]]}")
|
||||
ins.append("s_waitcnt vmcnt(0)")
|
||||
elif u.uop == UOps.STORE:
|
||||
ins.append(f"global_store_b32 {r[u.vin[1]]}, {r[u.vin[2]]}, {r[u.vin[0]]}")
|
||||
elif u.uop == UOps.DEFINE_GLOBAL:
|
||||
i = u.arg[0]
|
||||
args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8,
|
||||
'.type_name': u.dtype.name+"*", '.value_kind': 'global_buffer'})
|
||||
s_cnt += s_cnt%2 # skip
|
||||
r[u] = f"s[{s_cnt}:{s_cnt+1}]"
|
||||
s_cnt += 2
|
||||
ins.append(f"s_load_b64 {r[u]}, s[0:1], {i*8}")
|
||||
ins.append("s_waitcnt lgkmcnt(0)")
|
||||
else:
|
||||
raise NotImplementedError(f"can't render {u.uop}")
|
||||
|
||||
# *** boilerplate rendering ***
|
||||
|
||||
metadata = {
|
||||
'amdhsa.kernels': [{'.args': args,
|
||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
||||
'.name': function_name, '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
||||
'.symbol': f'{function_name}.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
||||
'.wavefront_size': 32}],
|
||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
||||
|
||||
boilerplate_start = f"""
|
||||
.rodata
|
||||
.global {function_name}.kd
|
||||
.type {function_name}.kd,STT_OBJECT
|
||||
.align 0x10
|
||||
.amdhsa_kernel {function_name}"""
|
||||
|
||||
kernel_desc = {
|
||||
'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
||||
'.amdhsa_next_free_sgpr': s_cnt,
|
||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3,
|
||||
'.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, '.amdhsa_fp16_overflow': 0,
|
||||
'.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0,
|
||||
'.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0,
|
||||
'.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
||||
|
||||
code_start = f""".end_amdhsa_kernel
|
||||
.text
|
||||
.global {function_name}
|
||||
.type {function_name},@function
|
||||
.p2align 8
|
||||
{function_name}:
|
||||
"""
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
return ".amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + \
|
||||
boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + \
|
||||
'\n'.join(ins) + f"\n.size {function_name}, .-{function_name}"
|
||||
@@ -1,131 +0,0 @@
|
||||
from typing import Dict, List, Final, Callable, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
||||
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
|
||||
from tinygrad.codegen.opt.kernel import UOp, Ops
|
||||
from triton.compiler import compile as triton_compile
|
||||
import linecache
|
||||
import math
|
||||
import re
|
||||
|
||||
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
|
||||
signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
|
||||
|
||||
def next_power_of_2(x):
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
def render_valid(valid):
|
||||
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
|
||||
|
||||
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
|
||||
def fill_dims_for_idx(idx, dims):
|
||||
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
|
||||
|
||||
def get_max(var):
|
||||
if isinstance(var, int): return var
|
||||
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
|
||||
|
||||
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
|
||||
def remove_single_scalar_curly_braces(ptx_code):
|
||||
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
|
||||
|
||||
def render_const(args,dtype:DType):
|
||||
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
|
||||
|
||||
def render_cast(x:str, dtype:DType, bitcast=False):
|
||||
return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
|
||||
|
||||
def define_scalar(local_size, dtype, args):
|
||||
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
|
||||
return render_const(args,dtype)
|
||||
|
||||
def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
local_size: List[int] = []
|
||||
depth = 1
|
||||
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
nonlocal c, r
|
||||
c[prefix] += 1
|
||||
r[u]=f"{prefix}{c[prefix]-1}"
|
||||
return r[u]
|
||||
|
||||
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for ru in uops:
|
||||
for v in ru.vin:
|
||||
child_count[v] += 1
|
||||
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}",
|
||||
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
||||
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
||||
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
|
||||
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
|
||||
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
|
||||
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
|
||||
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
|
||||
}
|
||||
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop == Ops.LOOP:
|
||||
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
|
||||
depth += 1
|
||||
elif uop == Ops.END: depth -= 1
|
||||
elif uop == Ops.ALU:
|
||||
assert dtype is not None
|
||||
val = code_for_op[args](*[r[x] for x in vin])
|
||||
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
|
||||
else: kk(f"{ssa(u, 'alu')} = ({val})")
|
||||
elif uop == Ops.LOAD:
|
||||
assert dtype is not None
|
||||
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
|
||||
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
|
||||
elif uop == Ops.DEFINE_REG: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
|
||||
elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args)
|
||||
elif uop == Ops.ASSIGN:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop == Ops.STORE:
|
||||
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
|
||||
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
||||
elif uop == Ops.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
|
||||
r[u] = args
|
||||
elif uop == Ops.SPECIAL:
|
||||
dims.append(args[1])
|
||||
valid.append(f"{args[1]}<{get_max(args[2])}")
|
||||
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
|
||||
elif args[1].startswith("l"):
|
||||
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
|
||||
local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == Ops.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
|
||||
else: raise NotImplementedError(f"unimplemented: {uop}")
|
||||
|
||||
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
|
||||
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
|
||||
prg += "\n".join(kernel)
|
||||
|
||||
acc_local_size = 1
|
||||
for x in local_size: acc_local_size *= next_power_of_2(x)
|
||||
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
|
||||
|
||||
if DEBUG >= 4: print(prg)
|
||||
getlines = linecache.getlines
|
||||
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
|
||||
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
|
||||
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
|
||||
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
|
||||
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
|
||||
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
|
||||
|
||||
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}
|
||||
@@ -1,22 +0,0 @@
|
||||
import ctypes
|
||||
import os
|
||||
import pathlib
|
||||
import struct
|
||||
from hexdump import hexdump
|
||||
|
||||
fxn = None
|
||||
def disasm_raw(buf):
|
||||
global fxn
|
||||
if fxn is None:
|
||||
shared = pathlib.Path(__file__).parent / "disasm.so"
|
||||
if not shared.is_file():
|
||||
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
|
||||
fxn = ctypes.CDLL(shared.as_posix())['disasm']
|
||||
fxn(buf, len(buf))
|
||||
|
||||
def disasm(buf):
|
||||
def _read_lib(off): return struct.unpack("I", buf[off:off+4])[0]
|
||||
|
||||
image_offset = _read_lib(0xc0)
|
||||
image_size = _read_lib(0x100)
|
||||
disasm_raw(buf[image_offset:image_offset+image_size])
|
||||
@@ -1,120 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, ctypes, ctypes.util, io, mmap, pathlib
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import Timing, from_mv
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
|
||||
#from extra.hip_gpu_driver import hip_ioctl
|
||||
|
||||
# sudo su -c "echo 3 > /proc/sys/vm/drop_caches"
|
||||
|
||||
# sudo su -c 'echo 8 > /proc/sys/kernel/printk'
|
||||
# sudo su -c "echo 'module amdgpu +p' > /sys/kernel/debug/dynamic_debug/control"
|
||||
|
||||
libc.memcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
|
||||
|
||||
libc.read.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t]
|
||||
libc.read.restype = ctypes.c_size_t
|
||||
|
||||
libc.malloc.argtypes = [ctypes.c_size_t]
|
||||
libc.malloc.restype = ctypes.c_void_p
|
||||
|
||||
def read_direct(fd, sz):
|
||||
with Timing("mmap: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
buf = mmap.mmap(-1, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE)
|
||||
with Timing("read: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
ret = libc.read(fd, from_mv(buf), sz)
|
||||
assert ret == sz
|
||||
|
||||
def read_mmap(fd, sz):
|
||||
with Timing("mmfd: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
buf = mmap.mmap(fd, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) #|MAP_LOCKED)
|
||||
t = 0
|
||||
for i in range(0, sz, 0x1000): t += buf[i]
|
||||
|
||||
# def _copyin_async(self, dest:T, src:T, size:int): check(hip.hipMemcpyAsync(dest, src, size, hip.hipMemcpyHostToDevice, None))
|
||||
|
||||
def read_to_gpu_mmap(fd, sz, gpubuf):
|
||||
with Timing("gpu copyin: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
with Timing("mmfd: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
buf = mmap.mmap(fd, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) #|MAP_LOCKED)
|
||||
dev.allocator._copyin_async(gpubuf, from_mv(buf), sz)
|
||||
dev.synchronize()
|
||||
|
||||
def read_to_gpu_single(fd, sz, gpubuf):
|
||||
os.lseek(fd, 0, os.SEEK_SET)
|
||||
with Timing("total: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
with Timing("gpu host alloc: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
hst = dev.allocator._hostalloc(sz)
|
||||
with Timing("read to host: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
ret = libc.read(fd, hst, sz)
|
||||
with Timing("gpu host copy: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
dev.allocator._copyin_async(gpubuf, hst, sz)
|
||||
dev.synchronize()
|
||||
|
||||
def read_to_gpu_pingpong(fd, sz, gpubuf):
|
||||
psz = 256*1024*1024
|
||||
print(f"piece size {psz/(1024*1024):.2f} MB")
|
||||
with Timing("gpu host alloc: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
hst1 = dev.allocator._hostalloc(psz)
|
||||
hst2 = dev.allocator._hostalloc(psz)
|
||||
|
||||
os.lseek(fd, 0, os.SEEK_SET)
|
||||
with Timing("total: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
for i in range(sz//(psz*2)):
|
||||
with Timing("tfer(0): ", lambda x: f", {psz/x:.2f} GB/s"):
|
||||
ret = libc.read(fd, hst1, psz)
|
||||
dev.synchronize()
|
||||
dev.allocator._copyin_async(gpubuf, hst1, psz)
|
||||
with Timing("tfer(1): ", lambda x: f", {psz/x:.2f} GB/s"):
|
||||
ret = libc.read(fd, hst2, psz)
|
||||
dev.synchronize()
|
||||
dev.allocator._copyin_async(gpubuf, hst2, psz)
|
||||
dev.synchronize()
|
||||
|
||||
MAP_LOCKED = 0x2000
|
||||
MAP_HUGETLB = 0x40000
|
||||
|
||||
if __name__ == "__main__":
|
||||
dev = Device[Device.DEFAULT]
|
||||
|
||||
warm = (Tensor.ones(1024, device=Device.DEFAULT).contiguous() + Tensor.ones(1024, device=Device.DEFAULT).contiguous()).realize()
|
||||
#fn = "/home/tiny/tinygrad/weights/rng"
|
||||
fn = pathlib.Path(__file__).parents[1] / "weights/LLaMA-2/70B/consolidated.00.pth"
|
||||
sz = os.stat(fn).st_size
|
||||
t = Tensor.empty(sz, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
with Timing("copy: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
|
||||
exit(0)
|
||||
|
||||
# 4GB of random numbers
|
||||
#fd = os.open("/home/tiny/tinygrad/weights/rng", os.O_RDWR|os.O_DIRECT)
|
||||
#sz = os.fstat(fd).st_size // 4
|
||||
fd = os.open("/home/tiny/tinygrad/weights/LLaMA/7B/consolidated.00.pth", os.O_RDWR|os.O_DIRECT)
|
||||
sz = os.fstat(fd).st_size
|
||||
print(f"read {sz} from {fd}")
|
||||
|
||||
with Timing("gpu alloc: ", lambda x: f", {sz/x:.2f} GB/s"):
|
||||
gpubuf = dev.allocator._alloc(sz)
|
||||
# warmup
|
||||
dev.allocator._copyin_async(gpubuf, from_mv(bytearray(b"\x00\x00\x00\x00"*0x1000)), 0x4000)
|
||||
print("copying, is warm")
|
||||
|
||||
print("****** read to gpu pingpong")
|
||||
read_to_gpu_pingpong(fd, sz, gpubuf)
|
||||
exit(0)
|
||||
|
||||
print("****** read direct")
|
||||
read_direct(fd, sz)
|
||||
|
||||
print("****** read mmap")
|
||||
read_mmap(fd, sz)
|
||||
|
||||
print("****** read to gpu single")
|
||||
read_to_gpu_single(fd, sz, gpubuf)
|
||||
|
||||
print("****** read to gpu mmap")
|
||||
read_to_gpu_mmap(fd, sz, gpubuf)
|
||||
|
||||
os._exit(0)
|
||||
@@ -1,21 +0,0 @@
|
||||
import sys, sqlite3, pickle
|
||||
from tinygrad.helpers import CACHEDB
|
||||
|
||||
if __name__ == "__main__":
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB
|
||||
conn = sqlite3.connect(fn)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
for f in cur.fetchall():
|
||||
table = f[0]
|
||||
cur2 = conn.cursor()
|
||||
cur2.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
cnt = cur2.fetchone()[0]
|
||||
print(f"{table:20s} : {cnt}")
|
||||
|
||||
cur3 = conn.cursor()
|
||||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||
for f in cur3.fetchall():
|
||||
v = pickle.loads(f[-1])
|
||||
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
|
||||
#print(f"{len(k):10d}, {sk} -> {v}")
|
||||
@@ -1,27 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
print(jax.devices())
|
||||
DEVICES = len(jax.devices())
|
||||
BS = 32
|
||||
N = 4096
|
||||
dtype = jnp.float16
|
||||
A = jnp.zeros((DEVICES, BS, N, N), dtype)
|
||||
B = jnp.zeros((1, 1, N, N), dtype)
|
||||
A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
||||
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
||||
|
||||
OPS = DEVICES*BS*N*N*N*2
|
||||
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
||||
pmatmul = jax.pmap(matmul)
|
||||
|
||||
MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||||
for i in range(10):
|
||||
st = time.perf_counter()
|
||||
C = pmatmul(A,B).block_until_ready()
|
||||
et = time.perf_counter()-st
|
||||
tflops = (OPS*1e-12)/et
|
||||
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
import mlx.core as mx
|
||||
from tinygrad.helpers import Timing
|
||||
N = 4096
|
||||
x = mx.random.normal((N,N))
|
||||
w = mx.random.normal((N,N))
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
for i in range(10):
|
||||
with Timing("", lambda x: f" {FLOPS/x:.2f} GFLOPS"):
|
||||
mx.eval(x@w)
|
||||
@@ -1,33 +0,0 @@
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
if gpus:
|
||||
try:
|
||||
# Currently, memory growth needs to be the same across GPUs
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
# Memory growth must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
for dtype in [tf.float16, tf.float32]:
|
||||
for N in [256, 512, 1024, 2048, 4096, 8192]:
|
||||
FLOPS = N*N*N*2
|
||||
|
||||
b = tf.random.uniform((N, N), dtype=dtype)
|
||||
c = tf.random.uniform((N, N), dtype=dtype)
|
||||
|
||||
b = tf.Variable(b)
|
||||
c = tf.Variable(c)
|
||||
|
||||
def tf_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tf.matmul(b, c)
|
||||
tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done.
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([tf_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
||||
@@ -1,12 +0,0 @@
|
||||
import ctypes
|
||||
import tinygrad.runtime.autogen.hip as hip
|
||||
from tinygrad.runtime.ops_hip import check
|
||||
from tinygrad.helpers import init_c_var
|
||||
|
||||
if __name__ == "__main__":
|
||||
check(hip.hipSetDevice(0))
|
||||
evt = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x))))
|
||||
check(hip.hipSetDevice(1))
|
||||
check(hip.hipStreamWaitEvent(None, evt, 0))
|
||||
check(hip.hipSetDevice(0))
|
||||
check(hip.hipEventRecord(evt, None))
|
||||
@@ -1,45 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: sentencepiece_model.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'H\003'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC']._serialized_start=45
|
||||
_globals['_TRAINERSPEC']._serialized_end=1581
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
|
||||
_globals['_NORMALIZERSPEC']._serialized_start=1584
|
||||
_globals['_NORMALIZERSPEC']._serialized_end=1793
|
||||
_globals['_SELFTESTDATA']._serialized_start=1795
|
||||
_globals['_SELFTESTDATA']._serialized_end=1916
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
|
||||
_globals['_MODELPROTO']._serialized_start=1919
|
||||
_globals['_MODELPROTO']._serialized_end=2429
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,176 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict, cast
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
import math, functools, time, random, statistics
|
||||
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.device import Buffer, Device, CompileError
|
||||
from tinygrad.codegen.opt.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
class MCTSNode:
|
||||
def __init__(self, kernel:Kernel, parent=None):
|
||||
self.kernel:Kernel = kernel
|
||||
self.t = math.inf
|
||||
self.n = 0
|
||||
self.tm = math.inf
|
||||
self.i = -1
|
||||
self.parents: List[MCTSNode] = [parent] if parent is not None else []
|
||||
self.children: Optional[List[MCTSNode]] = None
|
||||
self.removed_children: List[MCTSNode] = []
|
||||
|
||||
def expand_node(node:MCTSNode):
|
||||
assert node.children is None
|
||||
node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
|
||||
|
||||
def remove_node(node:MCTSNode):
|
||||
for parent in node.parents:
|
||||
assert parent.children is not None
|
||||
parent.children.remove(node)
|
||||
parent.removed_children.append(node)
|
||||
|
||||
C = math.sqrt(2)
|
||||
TEMP = 0.5
|
||||
def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
|
||||
if node.children is None or len(node.children) == 0: return node
|
||||
unexplored_children = []
|
||||
explored_children = []
|
||||
ucb_explored_children: List[float] = []
|
||||
for child in node.children:
|
||||
if child.n == 0: unexplored_children.append(child)
|
||||
else:
|
||||
ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
|
||||
if not math.isinf(ucb):
|
||||
explored_children.append(child)
|
||||
ucb_explored_children.append(ucb)
|
||||
if len(unexplored_children): return random.choice(unexplored_children)
|
||||
if not len(explored_children): return node
|
||||
# safe softmax
|
||||
ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP)
|
||||
return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
|
||||
|
||||
# this will expand/remove sometimes
|
||||
def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]:
|
||||
if root.children is None: expand_node(root)
|
||||
while root.children:
|
||||
# tree traversal
|
||||
node = _sample_tree(root, best_tm)
|
||||
|
||||
if node.children is not None and len(node.children) == 0:
|
||||
remove_node(node)
|
||||
continue
|
||||
|
||||
# node expansion
|
||||
if node.n != 0:
|
||||
if node.children is None: expand_node(node)
|
||||
assert node.children is not None
|
||||
if len(node.children) == 0:
|
||||
remove_node(node)
|
||||
continue
|
||||
node = random.choice(node.children)
|
||||
return node
|
||||
return None
|
||||
|
||||
def backprop(bnode:MCTSNode, tm, strength=1.0):
|
||||
if bnode.t > tm: bnode.t = tm
|
||||
bnode.n += strength
|
||||
for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents))
|
||||
|
||||
graph_mcts_cnt = 0
|
||||
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
||||
global graph_mcts_cnt
|
||||
# TODO: copied from BEAM
|
||||
key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
dev = Device[lin.opts.device]
|
||||
root = MCTSNode(lin)
|
||||
|
||||
st = time.perf_counter()
|
||||
best, best_idx, best_tm = lin, 0, math.inf
|
||||
seen_libs: Dict[bytes, MCTSNode] = {}
|
||||
seen_asts: Dict[bytes, MCTSNode] = {}
|
||||
compile_time, runtime_time = 0.0, 0.0
|
||||
for i in range(amt):
|
||||
node = sample_tree(root, best_tm) # sample and expand
|
||||
if node is None: break # finished the whole tree
|
||||
node.i = i # when was node explored
|
||||
|
||||
opt_ast = node.kernel.get_optimized_ast()
|
||||
if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
|
||||
# early check for same optimized AST hit
|
||||
remove_node(node)
|
||||
tm = sibling_node.t
|
||||
else:
|
||||
seen_asts[opt_ast.key] = node
|
||||
|
||||
# lowering (50% of the time)
|
||||
p = get_program(node.kernel.get_optimized_ast(name_override="test"), node.kernel.opts)
|
||||
|
||||
# rollout
|
||||
tm1 = time.perf_counter()
|
||||
try:
|
||||
lib = dev.compiler.compile(p.src)
|
||||
except CompileError:
|
||||
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
|
||||
lib = None
|
||||
tm2 = time.perf_counter()
|
||||
if lib is None:
|
||||
tm = math.inf
|
||||
else:
|
||||
if (sibling_node:=seen_libs.get(lib, None)) is not None:
|
||||
# NOTE: these should all be caught by the AST check, need to canonicalize
|
||||
# remove this node, it's a duplicate
|
||||
remove_node(node)
|
||||
tm = sibling_node.t
|
||||
else:
|
||||
seen_libs[lib] = node
|
||||
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
|
||||
except RuntimeError: tm = math.inf
|
||||
node.tm = tm
|
||||
tm3 = time.perf_counter()
|
||||
compile_time += tm2-tm1
|
||||
runtime_time += tm3-tm2
|
||||
|
||||
# mock rollout
|
||||
#node.tm = tm = random.random() + 0.1
|
||||
|
||||
if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
|
||||
et = time.perf_counter() - st
|
||||
if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# backprop
|
||||
backprop(node, tm)
|
||||
if DEBUG>=2: print()
|
||||
|
||||
if getenv("MCTSGRAPH"):
|
||||
import networkx as nx
|
||||
import os
|
||||
GRAPHPATH = "/tmp/net"
|
||||
def save_graph(G, fn, opt=""):
|
||||
print("saving", G, f"to {fn}.svg")
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
|
||||
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
|
||||
|
||||
G = nx.DiGraph()
|
||||
def add_node(node:MCTSNode):
|
||||
if node.n == 0: return
|
||||
for parent in node.parents: G.add_edge(parent, node)
|
||||
gopts = node.kernel.applied_opts
|
||||
edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT"
|
||||
G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
|
||||
fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
|
||||
if node.children is not None:
|
||||
for child in node.children+node.removed_children: add_node(child)
|
||||
add_node(root)
|
||||
save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR')
|
||||
graph_mcts_cnt += 1
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
|
||||
return best
|
||||
@@ -1,75 +0,0 @@
|
||||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device, Context, Tensor, GlobalCounters
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv, BEAM
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item, get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
import numpy as np
|
||||
|
||||
def move_jit_captured_to_dev(captured, device="DSP"):
|
||||
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
|
||||
|
||||
assign = {}
|
||||
def move_buffer(b):
|
||||
if b in assign: return assign[b]
|
||||
|
||||
if b._base is not None:
|
||||
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
newbuf = Buffer(device, b.size, b.dtype)
|
||||
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
|
||||
assign[b] = newbuf
|
||||
return assign[b]
|
||||
|
||||
for item in captured.jit_cache:
|
||||
for b in item.bufs:
|
||||
if b is not None: move_buffer(b)
|
||||
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
|
||||
return captured
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Context(DEBUG=0):
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
fxn: TinyJit = pickle.load(f)
|
||||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
|
||||
# Move all buffers to DSP device.
|
||||
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
|
||||
new_jit = []
|
||||
|
||||
knum = 1
|
||||
for ei in fxn.captured.jit_cache:
|
||||
# skip the copy and the first kernel
|
||||
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
|
||||
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
with Context(NOOPT=1):
|
||||
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
|
||||
correct = ei.bufs[0].numpy()
|
||||
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
|
||||
GlobalCounters.kernel_count -= 1
|
||||
|
||||
if not getenv("NOOPT"): k.apply_opts(hand_coded_optimizations(k))
|
||||
p2 = get_program(k.ast, k.opts, k.applied_opts)
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
new_ei.run()
|
||||
new_jit.append(new_ei)
|
||||
test = ei.bufs[0].numpy()
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
|
||||
knum += 1
|
||||
|
||||
if getenv("RUN_JIT", 0):
|
||||
fxn.captured.free_intermediates()
|
||||
fxn.captured.jit_cache = new_jit
|
||||
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))
|
||||
@@ -1,114 +0,0 @@
|
||||
# code from https://x.com/awnihannun/status/1832511021602500796
|
||||
from huggingface_hub import snapshot_download
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import time
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, in_dims, dims, stride=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm(dims)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm(dims)
|
||||
|
||||
self.downsample = []
|
||||
if stride != 1:
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm(dims)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
out = nn.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
for l in self.downsample:
|
||||
x = l(x)
|
||||
out += x
|
||||
out = nn.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm(64)
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)
|
||||
|
||||
self.fc = nn.Linear(512, num_classes)
|
||||
|
||||
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(in_dims, dims, stride))
|
||||
in_dims = dims
|
||||
return layers
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.relu(self.bn1(self.conv1(x)))
|
||||
x = self.maxpool(x)
|
||||
for l in self.layer1 + self.layer2 + self.layer3 + self.layer4:
|
||||
x = l(x)
|
||||
x = mx.mean(x, axis=[1, 2])
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def load():
|
||||
model = ResNet(Block, [2, 2, 2, 2], num_classes=1000)
|
||||
file = "model.safetensors"
|
||||
model_path = snapshot_download(
|
||||
repo_id="awni/resnet18-mlx",
|
||||
allow_patterns=[file],
|
||||
)
|
||||
model.load_weights(model_path + "/" + file)
|
||||
model.eval()
|
||||
mx.eval(model)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
resnet18 = load()
|
||||
|
||||
@mx.compile
|
||||
def forward(im):
|
||||
return resnet18(im)
|
||||
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
#its = 200
|
||||
#batch_sizes = [64]
|
||||
its = 20
|
||||
print(f"Batch Size | Images-per-second | Milliseconds-per-image")
|
||||
print(f"---- | ---- | ---- ")
|
||||
for N in batch_sizes:
|
||||
image = mx.random.uniform(shape=(N, 288, 288, 3))
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
output = forward(image)
|
||||
mx.eval(output)
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(its):
|
||||
output = forward(image)
|
||||
mx.async_eval(output)
|
||||
mx.eval(output)
|
||||
toc = time.time()
|
||||
ims_per_sec = N * its / (toc - tic)
|
||||
ms_per_im = 1e3 / ims_per_sec
|
||||
print(f"{N} | {ims_per_sec:.3f} | {ms_per_im:.3f}")
|
||||
@@ -1,109 +0,0 @@
|
||||
from huggingface_hub import snapshot_download
|
||||
from tinygrad import nn, Tensor, TinyJit, Device, GlobalCounters, Context
|
||||
import time
|
||||
|
||||
class Block:
|
||||
def __init__(self, in_dims, dims, stride=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm(dims)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm(dims)
|
||||
|
||||
self.downsample = []
|
||||
if stride != 1:
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm(dims)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out))
|
||||
for l in self.downsample:
|
||||
x = l(x)
|
||||
out += x
|
||||
return out.relu()
|
||||
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm(64)
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)
|
||||
|
||||
self.fc = nn.Linear(512, num_classes)
|
||||
|
||||
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(in_dims, dims, stride))
|
||||
in_dims = dims
|
||||
return layers
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
x = self.bn1(self.conv1(x)).relu().max_pool2d()
|
||||
x = x.sequential(self.layer1)
|
||||
with Context(WINO=1): x = x.sequential(self.layer2 + self.layer3 + self.layer4)
|
||||
x = x.mean([2, 3])
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def load():
|
||||
model = ResNet(Block, [2, 2, 2, 2], num_classes=1000)
|
||||
file = "model.safetensors"
|
||||
model_path = snapshot_download(
|
||||
repo_id="awni/resnet18-mlx",
|
||||
allow_patterns=[file],
|
||||
)
|
||||
state = nn.state.safe_load(model_path + "/" + file)
|
||||
# mlx is NHWC, tinygrad is NCHW
|
||||
nn.state.load_state_dict(model, {k:v if len(v.shape) != 4 else v.to(None).permute(0,3,1,2).contiguous() for k,v in state.items()}, strict=False)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
resnet18 = load()
|
||||
|
||||
def _forward(im): return resnet18(im)
|
||||
forward = TinyJit(_forward, prune=True)
|
||||
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
#its = 200
|
||||
#batch_sizes = [64]
|
||||
its = 20
|
||||
print(f"Batch Size | Images-per-second | Milliseconds-per-image")
|
||||
print(f"---- | ---- | ---- ")
|
||||
for N in batch_sizes:
|
||||
forward.reset() # reset the JIT for a new batch size (could make automatic)
|
||||
image = Tensor.uniform(N, 3, 288, 288)
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
GlobalCounters.reset()
|
||||
output = forward(image)
|
||||
Device.default.synchronize()
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(its):
|
||||
GlobalCounters.reset()
|
||||
output = forward(image)
|
||||
Device.default.synchronize()
|
||||
toc = time.time()
|
||||
ims_per_sec = N * its / (toc - tic)
|
||||
ms_per_im = 1e3 / ims_per_sec
|
||||
print(f"{N} | {ims_per_sec:.3f} | {ms_per_im:.3f}")
|
||||
@@ -1,15 +0,0 @@
|
||||
from tinygrad import Tensor, Device, GlobalCounters
|
||||
from tinygrad.helpers import Timing
|
||||
|
||||
N = 512
|
||||
GPUS = 5
|
||||
ds = tuple([f"{Device.DEFAULT}:{i+1}" for i in range(GPUS)])
|
||||
t = [Tensor.ones(N, N, N, device=d).contiguous().realize() for d in ds]
|
||||
|
||||
for _ in range(10):
|
||||
GlobalCounters.reset()
|
||||
with Timing():
|
||||
for ti in t:
|
||||
ti.to_(ds[(ds.index(ti.device)+1+len(ds))%len(ds)])
|
||||
# ti.to_(ds[(ds.index(ti.device)-1+len(ds))%len(ds)]) # reversed order
|
||||
ti.realize()
|
||||
@@ -1,47 +0,0 @@
|
||||
import os, pathlib, argparse
|
||||
from examples.llama3 import Tokenizer
|
||||
from tabulate import tabulate
|
||||
from tinygrad import fetch
|
||||
from tinygrad.helpers import flatten, getenv
|
||||
from sz import NONCORE_DIRS
|
||||
|
||||
# llama 3 tokenizer
|
||||
tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model").as_posix())
|
||||
|
||||
def read_code(base_path, full=False):
|
||||
ret = []
|
||||
for path, _, files in os.walk(os.path.join(base_path, "tinygrad")):
|
||||
if not full and any(path.split("./")[1].startswith(x) for x in NONCORE_DIRS): continue
|
||||
for name in files:
|
||||
if not name.endswith(".py"): continue
|
||||
if 'tinygrad/runtime/autogen' in path.replace('\\', '/'): continue
|
||||
fullpath = os.path.join(path, name)
|
||||
code = pathlib.Path(fullpath).read_text()
|
||||
ret.append((fullpath.split("tinygrad/", 1)[1], code))
|
||||
return ret
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze and optionally save tinygrad code.")
|
||||
parser.add_argument("--output", help="Output file to write the combined code to.")
|
||||
parser.add_argument("--full", action="store_true", help="All directories")
|
||||
args = parser.parse_args()
|
||||
|
||||
ret = read_code(".", args.full)
|
||||
|
||||
table = []
|
||||
for name,code in ret:
|
||||
table.append([name, len(tokenizer.encode(code))])
|
||||
print(tabulate([["name", "llm tokens"]]+sorted(table, key=lambda x: -x[1]), headers="firstrow"))
|
||||
|
||||
banner = "#"*40
|
||||
code_str = ''.join([f"{banner}\n# {name}\n{banner}\n\n{code}\n" for name,code in ret])
|
||||
print(f"code has {len(code_str)} chars")
|
||||
newline_count = code_str.count('\n')
|
||||
print(f"code has {newline_count} newlines")
|
||||
|
||||
encoded = tokenizer.encode(code_str)
|
||||
print(f"code has {len(encoded)} tokens")
|
||||
|
||||
if args.output:
|
||||
with open(args.output, 'w') as f: f.write(code_str)
|
||||
print(f"Combined code written to {args.output}")
|
||||
@@ -1,16 +0,0 @@
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
if "DEBUG" not in os.environ: os.environ["DEBUG"] = "2"
|
||||
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
if (seed := getenv("SEED", 0)) != 0:
|
||||
Tensor.manual_seed(seed)
|
||||
print(f"using seed {Tensor._seed}")
|
||||
|
||||
for N in [10_000_000, 100_000_000, 1_000_000_000]:
|
||||
GlobalCounters.reset()
|
||||
t = Tensor.rand(N)
|
||||
t.realize()
|
||||
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")
|
||||
@@ -1,154 +0,0 @@
|
||||
import itertools
|
||||
from enum import Enum, auto
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, DefaultDict
|
||||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.uop.ops import sym_infer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
|
||||
def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
|
||||
mop, arg = mop_arg
|
||||
if mop == MovementOps.RESHAPE:
|
||||
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
|
||||
if arg == (-1,): return st.reshape((prod(st.shape),))
|
||||
return st.reshape(arg)
|
||||
if mop == MovementOps.PERMUTE: return st.permute(arg)
|
||||
if mop == MovementOps.EXPAND:
|
||||
if len(arg) != len(st.shape): st = st.reshape((1,*st.shape))
|
||||
return st.expand(arg)
|
||||
if mop == MovementOps.PAD: return st.pad(arg)
|
||||
if mop == MovementOps.SHRINK: return st.shrink(arg)
|
||||
if mop == MovementOps.STRIDE:
|
||||
assert all(x in [-1, 1] for x in arg)
|
||||
return st.flip(tuple(i for i,x in enumerate(arg) if x == -1))
|
||||
raise ValueError("invalid mop")
|
||||
|
||||
def make_scratch_st(st: ShapeTracker) -> ShapeTracker:
|
||||
return ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),))
|
||||
|
||||
# ShapeTracker to an equivalent series of MovementOps (https://github.com/tinygrad/tinygrad/pull/2216)
|
||||
def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
|
||||
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
||||
for i, v in enumerate(st.views):
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
offset = (v.offset or 0) + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
|
||||
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
|
||||
strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
|
||||
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
|
||||
if i: buffer_size = prod(st.views[i-1].shape) - real_offset if real_shape else 1
|
||||
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
|
||||
ordered_shape_strides, order = sort_by_strides(real_real_shape, strides)
|
||||
to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))])
|
||||
if strides:
|
||||
if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),)))
|
||||
for i, shape_stride in enumerate(ordered_shape_strides):
|
||||
if i<len(ordered_shape_strides)-1 and shape_stride[1] < ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
|
||||
remaining_buffer = ordered_shape_strides[i-1][1] if i>0 else buffer_size
|
||||
to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer)))
|
||||
to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1)))
|
||||
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer)))
|
||||
to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1]))))
|
||||
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1])))
|
||||
ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1])
|
||||
else:
|
||||
to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1]))))
|
||||
to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1])))
|
||||
to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))])
|
||||
if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides)))))
|
||||
to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides))))
|
||||
if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides)))
|
||||
# then, we apply pre expand pads
|
||||
if v.mask is not None:
|
||||
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
if any(x != (0,0) for x in pre_expand_pads):
|
||||
to_apply.append((MovementOps.PAD, pre_expand_pads))
|
||||
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
|
||||
# then, we do any expands
|
||||
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
|
||||
# lastly, we apply post expand pads
|
||||
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
|
||||
|
||||
scratch_st = make_scratch_st(st)
|
||||
ret = []
|
||||
seen = {} # {shapetracker: list of mops to generate that shapetracker}
|
||||
for mop_arg in to_apply:
|
||||
scratch_st = apply_mop(scratch_st, mop_arg)
|
||||
if scratch_st in seen:
|
||||
ret = seen[scratch_st][:]
|
||||
else:
|
||||
if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE:
|
||||
ret[-1] = mop_arg
|
||||
else:
|
||||
if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),))
|
||||
ret.append(mop_arg)
|
||||
seen[scratch_st] = ret[:]
|
||||
return ret
|
||||
|
||||
def get_real_view(shape, strides, offset, mask):
|
||||
real_shape = tuple(y-x for x,y in mask) if mask else shape
|
||||
offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0)
|
||||
real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0)
|
||||
real_real_shape = [s for s,st in zip(real_shape, strides) if st]
|
||||
strides = [abs(st) if isinstance(st,int) else st for st in strides if st]
|
||||
return real_real_shape, strides, real_offset
|
||||
|
||||
def get_buffer_size(shape, strides, offset, mask):
|
||||
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
|
||||
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
|
||||
|
||||
def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
|
||||
if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True
|
||||
idx1, valid1 = idxs1
|
||||
idx2, valid2 = idxs2
|
||||
# always invalid
|
||||
if valid1 == 0 and valid2 == 0: return True
|
||||
|
||||
var1 = idx1.vars() | valid1.vars()
|
||||
var2 = idx2.vars() | valid2.vars()
|
||||
# Maybe there are cases that vars are different yet the sts are the same?
|
||||
if var1 != var2: return False
|
||||
|
||||
# brute force over the vars range
|
||||
vs = list(var1)
|
||||
for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])):
|
||||
if i > 1000:
|
||||
print("WARNING: did not search all possible combinations")
|
||||
break
|
||||
var_vals = {k.expr:v for k,v in zip(vs, ranges)}
|
||||
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
|
||||
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
|
||||
if r1 != r2: return False
|
||||
|
||||
return True
|
||||
|
||||
c: DefaultDict[int,int] = defaultdict(int)
|
||||
def test_rebuild(st: ShapeTracker):
|
||||
rebuilt_st = make_scratch_st(st)
|
||||
mops = to_movement_ops(st)
|
||||
c[len(mops)] += 1
|
||||
for mop_arg in mops: rebuilt_st = apply_mop(rebuilt_st, mop_arg)
|
||||
rebuilt_st = rebuilt_st.simplify()
|
||||
# why is the "all(x == 0 for x in rebuilt_st.views[-1].strides)" hack needed?
|
||||
assert st_equivalent(st, rebuilt_st) or all(x == 0 for x in rebuilt_st.views[-1].strides), f"mismatch {st} {rebuilt_st}"
|
||||
last_v1 = st.views[-1]
|
||||
last_v2 = rebuilt_st.views[-1]
|
||||
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
||||
|
||||
def test_rebuild_bufferop_st(ast:UOp):
|
||||
if ast.op is Ops.SHAPETRACKER:
|
||||
test_rebuild(ast.arg)
|
||||
for src in ast.src: test_rebuild_bufferop_st(src)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
ast_strs = load_worlds(False, False, True)[:2000]
|
||||
for ast_str in tqdm(ast_strs):
|
||||
test_rebuild_bufferop_st(ast_str_to_ast(ast_str))
|
||||
|
||||
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")
|
||||
@@ -1,18 +0,0 @@
|
||||
from tinygrad import Tensor, Device
|
||||
|
||||
#N = 1024
|
||||
N = 32
|
||||
t = Tensor.rand(N, N, N, device="CPU").realize()
|
||||
d1 = Device.DEFAULT + ":1"
|
||||
d2 = Device.DEFAULT + ":2"
|
||||
d3 = Device.DEFAULT + ":3"
|
||||
|
||||
for i in range(3):
|
||||
t.to_(d1)
|
||||
t.realize()
|
||||
# t.to_("CPU")
|
||||
# t.realize()
|
||||
t.to_(d2)
|
||||
t.realize()
|
||||
t.to_(d3)
|
||||
t.realize()
|
||||
Reference in New Issue
Block a user