cleanup stale examples/extra (#13764)

* cleanup stale files

* examples

* move those back

* old

* delete more
This commit is contained in:
George Hotz
2025-12-19 16:27:37 -04:00
committed by GitHub
parent 80b84f5267
commit df6cde8a00
45 changed files with 0 additions and 5039 deletions

View File

@@ -1,93 +0,0 @@
#!/usr/bin/env python3
import os, sys, traceback
sys.path.append(os.getcwd())
from io import StringIO
from contextlib import redirect_stdout
from tinygrad import Tensor, nn
from tinygrad.helpers import Timing, colored, getenv, fetch
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
from sentencepiece import SentencePieceProcessor
def create_fixed_tokenizer(output_file):
print("creating fixed tokenizer")
import extra.junk.sentencepiece_model_pb2 as spb2
mp = spb2.ModelProto()
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
with open(output_file, "wb") as f:
f.write(mp.SerializeToString())
# example:
# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py
if __name__ == "__main__":
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
with Timing("create model: "):
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1))
with Timing("download weights: "):
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
with Timing("weights -> model: "):
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, 32, 32, 8)), strict=False)
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, 32, 32, 8)), strict=False)
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
IM_END = 32000
IM_START = 32001
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
def output(outputted, toks, color):
cur = spp.decode(toks)[len(outputted):]
sys.stdout.write(colored(cur, color))
sys.stdout.flush()
outputted += cur
return outputted
# *** app below this line ***
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
PROMPT = getenv("PROMPT", 1)
temperature = getenv("TEMP", 0.7)
start_pos = 0
outputted = output("", toks, "green")
turn = True
while 1:
if PROMPT:
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
else:
toks += start_prompt("user" if turn else "assistant")
turn = not turn
old_output_len = len(outputted)
while 1:
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
start_pos = len(toks)
toks.append(tok)
outputted = output(outputted, toks, "blue" if not turn else "cyan")
if tok == IM_END: break
if tok == spp.eos_id(): break
new_output = outputted[old_output_len:]
if new_output.endswith("```") and '```python\n' in new_output:
python_code = new_output.split('```python\n')[1].split("```")[0]
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
my_stdout = StringIO()
try:
with redirect_stdout(my_stdout): exec(python_code)
result = my_stdout.getvalue()
except Exception as e:
result = ''.join(traceback.format_exception_only(e))
toks += spp.encode(f"\nOutput:\n```\n{result}```")
outputted = output(outputted, toks, "yellow")
old_output_len = len(outputted)
print("")

View File

@@ -1,89 +0,0 @@
# load weights from
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
# a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
import sys
import ast
import time
import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, fetch, Timing
from tinygrad.engine.jit import TinyJit
from extra.models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)
# TODO: you should be able to put these in the jitted function
bias = Tensor([0.485, 0.456, 0.406])
scale = Tensor([0.229, 0.224, 0.225])
@TinyJit
def _infer(model, img):
img = img.permute((2,0,1))
img = img / 255.0
img = img - bias.reshape((1,-1,1,1))
img = img / scale.reshape((1,-1,1,1))
return model.forward(img).realize()
def infer(model, img):
# preprocess image
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
img = np.array(img)
y0,x0=(np.asarray(img.shape)[:2]-224)//2
retimg = img = img[y0:y0+224, x0:x0+224]
# if you want to look at the image
"""
import matplotlib.pyplot as plt
plt.imshow(img)
plt.show()
"""
# run the net
out = _infer(model, Tensor(img.astype("float32"))).numpy()
# if you want to look at the outputs
"""
import matplotlib.pyplot as plt
plt.plot(out[0])
plt.show()
"""
return out, retimg
if __name__ == "__main__":
# instantiate my net
model = EfficientNet(getenv("NUM", 0))
model.load_from_pretrained()
# category labels
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
# load image and preprocess
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
if url == 'webcam':
import cv2
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
while 1:
_ = cap.grab() # discard one frame to circumvent capture buffering
ret, frame = cap.read()
img = Image.fromarray(frame[:, :, [2,1,0]])
lt = time.monotonic_ns()
out, retimg = infer(model, img)
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
SCALE = 3
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
cv2.imshow('capture', retimg)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
else:
img = Image.open(fetch(url))
for i in range(getenv("CNT", 1)):
with Timing("did inference in "):
out, _ = infer(model, img)
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])

View File

@@ -1,498 +0,0 @@
# pip3 install sentencepiece
# This file incorporates code from the following:
# Github Name | License | Link
# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses
from tinygrad import Tensor, nn, dtypes, TinyJit
from tinygrad.nn.state import safe_load, load_state_dict
from tinygrad.helpers import fetch, tqdm, colored
from sdxl import FirstStage
from extra.models.clip import FrozenClosedClipEmbedder
from extra.models.t5 import T5Embedder
import numpy as np
import math, time, argparse, tempfile
from typing import List, Dict, Optional, Union, Tuple, Callable
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
urls:dict = {
"flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors",
"flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft",
"ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors",
"T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors",
"T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors",
"T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model",
"clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors"
}
def tensor_identity(x:Tensor) -> Tensor: return x
class AutoEncoder:
def __init__(self, scale_factor:float, shift_factor:float):
self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256)
self.scale_factor = scale_factor
self.shift_factor = shift_factor
def decode(self, z:Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
# Conditioner
class ClipEmbedder(FrozenClosedClipEmbedder):
def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)]
# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = Tensor.scaled_dot_product_attention(q, k, v)
return x.rearrange("B H L D -> B L (H D)")
def rope(pos:Tensor, dim:int, theta:int) -> Tensor:
assert dim % 2 == 0
scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation
omega = 1.0 / (theta**scale)
out = Tensor.einsum("...n,d->...nd", pos, omega)
out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1)
out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND:
def __init__(self, dim:int, theta:int, axes_dim:List[int]):
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids:Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class MLPEmbedder:
def __init__(self, in_dim:int, hidden_dim:int):
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x:Tensor) -> Tensor:
return self.out_layer(self.in_layer(x).silu())
class QKNorm:
def __init__(self, dim:int):
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention:
def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False):
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x:Tensor, pe:Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k)
x = attention(q, k, v, pe=pe)
return self.proj(x)
@dataclass
class ModulationOut:
shift:Tensor
scale:Tensor
gate:Tensor
class Modulation:
def __init__(self, dim:int, double:bool):
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1)
return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None
class DoubleStreamBlock:
def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False):
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]
def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
assert img_mod2 is not None and txt_mod2 is not None
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention
q = Tensor.cat(txt_q, img_q, dim=2)
k = Tensor.cat(txt_k, img_k, dim=2)
v = Tensor.cat(txt_v, img_v, dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp)
return img, txt
class SingleStreamBlock:
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None):
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = Tensor.gelu
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2))
return x + mod.gate * output
class LastLayer:
def __init__(self, hidden_size:int, patch_size:int, out_channels:int):
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)]
def __call__(self, x:Tensor, vec:Tensor) -> Tensor:
shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
return self.linear(x)
def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor:
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1)
if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1)
if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype)
return embedding
# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
class Flux:
"""
Transformer model for flow matching on sequences.
"""
def __init__(
self,
guidance_embed:bool,
in_channels:int = 64,
vec_in_dim:int = 768,
context_in_dim:int = 4096,
hidden_size:int = 3072,
mlp_ratio:float = 4.0,
num_heads:int = 24,
depth:int = 19,
depth_single_blocks:int = 38,
axes_dim:Optional[List[int]] = None,
theta:int = 10_000,
qkv_bias:bool = True,
):
axes_dim = axes_dim or [16, 56, 56]
self.guidance_embed = guidance_embed
self.in_channels = in_channels
self.out_channels = self.in_channels
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)]
self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = Tensor.cat(txt_ids, img_ids, dim=1)
pe = self.pe_embedder(ids)
for double_block in self.double_blocks:
img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe)
img = Tensor.cat(txt, img, dim=1)
for single_block in self.single_blocks:
img = single_block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
def load_flow_model(name:str, model_path:str):
# Loading Flux
print("Init model")
model = Flux(guidance_embed=(name != "flux-schnell"))
if not model_path: model_path = fetch(urls[name])
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()}
load_state_dict(model, state_dict)
return model
def load_T5(max_length:int=512):
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
print("Init T5")
T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"]))
pt_1 = fetch(urls["T5_1_of_2"])
pt_2 = fetch(urls["T5_2_of_2"])
load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False)
return T5
def load_clip():
print("Init Clip")
clip = ClipEmbedder()
load_state_dict(clip.transformer, safe_load(fetch(urls["clip"])))
return clip
def load_ae() -> AutoEncoder:
# Loading the autoencoder
print("Init AE")
ae = AutoEncoder(0.3611, 0.1159)
load_state_dict(ae, safe_load(fetch(urls["ae"])))
return ae
# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py
def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]:
bs, _, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = img.expand((bs, *img.shape[1:]))
img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous()
img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :]
img_ids = img_ids.rearrange("h w c -> 1 (h w) c")
img_ids = img_ids.expand((bs, *img_ids.shape[1:]))
if isinstance(prompt, str):
prompt = [prompt]
txt = T5(prompt).realize()
if txt.shape[0] == 1 and bs > 1:
txt = txt.expand((bs, *txt.shape[1:]))
txt_ids = Tensor.zeros(bs, txt.shape[1], 3)
vec = clip(prompt).realize()
if vec.shape[0] == 1 and bs > 1:
vec = vec.expand((bs, *vec.shape[1:]))
return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)}
def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]:
# extra step for zero
step_size = -1.0 / num_steps
timesteps = Tensor.arange(1, 0 + step_size, step_size)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256)
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
return timesteps.tolist()
@TinyJit
def run(model, *args): return model(*args).realize()
def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor:
# this is ignored for schnell
guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"):
t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x:Tensor, height:int, width:int) -> Tensor:
return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py
if __name__ == "__main__":
default_prompt = "bananas and a can of coke"
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
parser.add_argument("--model_path", type=str, default="", help="path of the model file")
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501
parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation")
parser.add_argument("--output_dir", type=str, default="output", help="output directory")
args = parser.parse_args()
if args.name not in ["flux-schnell", "flux-dev"]:
raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev")
if args.num_steps is None:
args.num_steps = 4 if args.name == "flux-schnell" else 50
# allow for packing and conversion to latent space
height = 16 * (args.height // 16)
width = 16 * (args.width // 16)
if args.seed is None: args.seed = Tensor._seed
else: Tensor.manual_seed(args.seed)
print(f"Generating with seed {args.seed}:\n{args.prompt}")
t0 = time.perf_counter()
# prepare input noise
x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16")
# load text embedders
T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512)
clip = load_clip()
# embed text to get inputs for model
inp = prepare(T5, clip, x, prompt=args.prompt)
timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell"))
# done with text embedders
del T5, clip
# load model
model = load_flow_model(args.name, args.model_path)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)
# done with model
del model, run
# load autoencoder
ae = load_ae()
# decode latents to pixel space
x = unpack(x.float(), height, width)
x = ae.decode(x).realize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s. Saving {args.out}")
# bring into PIL format and save
x = x.clamp(-1, 1)
x = x[0].rearrange("c h w -> h w c")
x = (127.5 * (x + 1.0)).cast("uint8")
img = Image.fromarray(x.numpy())
img.save(args.out)
# validation!
if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512:
ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png")))
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
print(colored(f"output validated with {distance=}", "green"))

View File

@@ -1,299 +0,0 @@
from extra.models.mask_rcnn import MaskRCNN
from extra.models.resnet import ResNet
from extra.models.mask_rcnn import BoxList
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import functional as Ft
import random
from tinygrad.tensor import Tensor
from PIL import Image
import numpy as np
import torch
import argparse
import cv2
class Resize:
def __init__(self, min_size, max_size):
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size
# modified from torchvision to add support for max size
def get_size(self, image_size):
w, h = image_size
size = random.choice(self.min_size)
max_size = self.max_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def __call__(self, image):
size = self.get_size(image.size)
image = Ft.resize(image, size)
return image
class Normalize:
def __init__(self, mean, std, to_bgr255=True):
self.mean = mean
self.std = std
self.to_bgr255 = to_bgr255
def __call__(self, image):
if self.to_bgr255:
image = image[[2, 1, 0]] * 255
else:
image = image[[0, 1, 2]] * 255
image = Ft.normalize(image, mean=self.mean, std=self.std)
return image
transforms = lambda size_scale: T.Compose(
[
Resize(int(800*size_scale), int(1333*size_scale)),
T.ToTensor(),
Normalize(
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
),
]
)
def expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = torch.zeros_like(boxes)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
def expand_masks(mask, padding):
N = mask.shape[0]
M = mask.shape[-1]
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
padded_mask[:, :, padding:-padding, padding:-padding] = mask
return padded_mask, scale
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
# TODO: remove torch
mask = torch.tensor(mask.numpy())
box = torch.tensor(box.numpy())
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
box = box.to(dtype=torch.int32)
TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE)
w = max(w, 1)
h = max(h, 1)
mask = mask.expand((1, 1, -1, -1))
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]
if thresh >= 0:
mask = mask > thresh
else:
mask = (mask * 255).to(torch.uint8)
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
x_1 = min(box[2] + 1, im_w)
y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, im_h)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
]
return im_mask
class Masker:
def __init__(self, threshold=0.5, padding=1):
self.threshold = threshold
self.padding = padding
def forward_single_image(self, masks, boxes):
boxes = boxes.convert("xyxy")
im_w, im_h = boxes.size
res = [
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
for mask, box in zip(masks, boxes.bbox)
]
if len(res) > 0:
res = torch.stack(*res, dim=0)[:, None]
else:
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
return Tensor(res.numpy())
def __call__(self, masks, boxes):
if isinstance(boxes, BoxList):
boxes = [boxes]
results = []
for mask, box in zip(masks, boxes):
result = self.forward_single_image(mask, box)
results.append(result)
return results
masker = Masker(threshold=0.5, padding=1)
def select_top_predictions(predictions, confidence_threshold=0.9):
scores = predictions.get_field("scores").numpy()
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
return predictions[keep]
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
image = transforms(size_scale)(original_image).numpy()
image = Tensor(image, requires_grad=False)
predictions = model(image)
prediction = predictions[0]
prediction = select_top_predictions(prediction, confidence_threshold)
width, height = original_image.size
prediction = prediction.resize((width, height))
if prediction.has_field("mask"):
masks = prediction.get_field("mask")
masks = masker([masks], [prediction])[0]
prediction.add_field("mask", masks)
return prediction
def compute_prediction_batched(batch, model, size_scale=1.0):
imgs = []
for img in batch:
imgs.append(transforms(size_scale)(img).numpy())
image = [Tensor(image, requires_grad=False) for image in imgs]
predictions = model(image)
del image
return predictions
palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
def findContours(*args, **kwargs):
if cv2.__version__.startswith('4'):
contours, hierarchy = cv2.findContours(*args, **kwargs)
elif cv2.__version__.startswith('3'):
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
return contours, hierarchy
def compute_colors_for_labels(labels):
l = labels[:, None]
colors = l * palette
colors = (colors % 255).astype("uint8")
return colors
def overlay_mask(image, predictions):
image = np.asarray(image)
masks = predictions.get_field("mask").numpy()
labels = predictions.get_field("labels").numpy()
colors = compute_colors_for_labels(labels).tolist()
for mask, color in zip(masks, colors):
thresh = mask[0, :, :, None]
contours, hierarchy = findContours(
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
image = cv2.drawContours(image, contours, -1, color, 3)
composite = image
return composite
CATEGORIES = [
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
]
def overlay_boxes(image, predictions):
labels = predictions.get_field("labels").numpy()
boxes = predictions.bbox
image = np.asarray(image)
colors = compute_colors_for_labels(labels).tolist()
for box, color in zip(boxes, colors):
box = torch.tensor(box.numpy())
box = box.to(torch.int64)
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
image = cv2.rectangle(
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
)
return image
def overlay_class_names(image, predictions):
scores = predictions.get_field("scores").numpy().tolist()
labels = predictions.get_field("labels").numpy().tolist()
labels = [CATEGORIES[int(i)] for i in labels]
boxes = predictions.bbox.numpy()
image = np.asarray(image)
template = "{}: {:.2f}"
for box, score, label in zip(boxes, scores, labels):
x, y = box[:2]
s = template.format(label, score)
x, y = int(x), int(y)
cv2.putText(
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image', type=str, help="Path of the image to run")
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
args = parser.parse_args()
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
model_tiny = MaskRCNN(resnet)
model_tiny.load_from_pretrained()
img = Image.open(args.image)
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
bbox_image = overlay_boxes(img, top_result_tiny)
mask_image = overlay_mask(bbox_image, top_result_tiny)
final_image = overlay_class_names(mask_image, top_result_tiny)
im = Image.fromarray(final_image)
print(f"saving {args.out}")
im.save(args.out)
im.show()

View File

@@ -1,118 +0,0 @@
import json, pprint
from tinygrad import fetch, nn, Tensor
from tinygrad.helpers import DEBUG
class FeedForward:
def __init__(self, model_dim, intermediate_dim):
self.proj_1 = nn.Linear(model_dim, 2*intermediate_dim, bias=False)
self.proj_2 = nn.Linear(intermediate_dim, model_dim, bias=False)
def __call__(self, x):
y_12 = self.proj_1(x)
y_1, y_2 = y_12.chunk(2, dim=-1)
return self.proj_2(y_1.silu() * y_2)
# NOTE: this RoPE doesn't match LLaMA's?
def _rotate_half(x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return Tensor.cat(-x2, x1, dim=-1)
def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
return (x * pos_cos) + (_rotate_half(x) * pos_sin)
class Attention:
def __init__(self, model_dim, num_query_heads, num_kv_heads, head_dim):
self.qkv_proj = nn.Linear(model_dim, (num_query_heads + num_kv_heads*2) * head_dim, bias=False)
self.num_query_heads, self.num_kv_heads = num_query_heads, num_kv_heads
self.head_dim = head_dim
self.q_norm = nn.RMSNorm(head_dim)
self.k_norm = nn.RMSNorm(head_dim)
self.out_proj = nn.Linear(num_query_heads * head_dim, model_dim, bias=False)
def __call__(self, x:Tensor) -> Tensor:
batch_size, seq_len, embed_dim = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_len, self.num_query_heads+self.num_kv_heads*2, self.head_dim).transpose(1, 2)
xq,xk,xv = qkv.split([self.num_query_heads, self.num_kv_heads, self.num_kv_heads], dim=1)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
# add positional embedding (how many kernels is this?)
freq_constant = 10000
inv_freq = 1.0 / (freq_constant ** (Tensor.arange(0, self.head_dim, 2) / self.head_dim))
pos_index_theta = Tensor.einsum("i,j->ij", Tensor.arange(seq_len), inv_freq)
emb = Tensor.cat(pos_index_theta, pos_index_theta, dim=-1)
cos_emb, sin_emb = emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
xq = _apply_rotary_pos_emb(xq, sin_emb, cos_emb)
xk = _apply_rotary_pos_emb(xk, sin_emb, cos_emb)
# grouped-query attention
num_groups = self.num_query_heads // self.num_kv_heads
xk = xk.repeat_interleave(num_groups, dim=1)
xv = xv.repeat_interleave(num_groups, dim=1)
# masked attention
#start_pos = 0
#mask = Tensor.full((1, 1, seq_len, start_pos+seq_len), float("-inf"), dtype=xq.dtype, device=xq.device).triu(start_pos+1)
#attn_output = xq.scaled_dot_product_attention(xk, xv, mask).transpose(1, 2)
# causal is fine, no mask needed
attn_output = xq.scaled_dot_product_attention(xk, xv, is_causal=True).transpose(1, 2)
return self.out_proj(attn_output.reshape(batch_size, seq_len, self.num_query_heads * self.head_dim))
class Layer:
def __init__(self, model_dim, intermediate_dim, num_query_heads, num_kv_heads, head_dim):
self.ffn = FeedForward(model_dim, intermediate_dim)
self.attn = Attention(model_dim, num_query_heads, num_kv_heads, head_dim)
self.ffn_norm = nn.RMSNorm(model_dim)
self.attn_norm = nn.RMSNorm(model_dim)
def __call__(self, x:Tensor) -> Tensor: # (batch, seq_len, embed_dim)
x = x + self.attn(self.attn_norm(x))
x = x + self.ffn(self.ffn_norm(x))
return x
# stupidly complex
def make_divisible(v, divisor):
new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v: new_v += divisor
return new_v
class Transformer:
def __init__(self, cfg):
if DEBUG >= 3: pprint.pp(cfg)
self.layers = [Layer(cfg['model_dim'], make_divisible(int(cfg["model_dim"] * cfg['ffn_multipliers'][i]), cfg['ffn_dim_divisor']),
cfg['num_query_heads'][i], cfg['num_kv_heads'][i], cfg['head_dim']) for i in range(cfg['num_transformer_layers'])]
self.norm = nn.RMSNorm(cfg['model_dim'])
self.token_embeddings = nn.Embedding(cfg['vocab_size'], cfg['model_dim'])
def __call__(self, tokens:Tensor):
# _bsz, seqlen = tokens.shape
x = self.token_embeddings(tokens)
for l in self.layers: x = l(x)
return self.norm(x) @ self.token_embeddings.weight.T
if __name__ == "__main__":
#model_name = "OpenELM-270M-Instruct"
model_name = "OpenELM-270M" # this is fp32
model = Transformer(json.loads(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/config.json?download=true").read_bytes()))
weights = nn.state.safe_load(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/model.safetensors?download=true"))
if DEBUG >= 3:
for k, v in weights.items(): print(k, v.shape)
nn.state.load_state_dict(model, {k.removeprefix("transformer."):v for k,v in weights.items()})
from sentencepiece import SentencePieceProcessor
tokenizer = SentencePieceProcessor(fetch("https://github.com/karpathy/llama2.c/raw/master/tokenizer.model").as_posix())
toks = [tokenizer.bos_id()] + tokenizer.encode("Some car brands include")
for i in range(100):
ttoks = Tensor([toks])
out = model(ttoks).realize()
t0 = out[0].argmax(axis=-1).tolist()
toks.append(t0[-1])
# hmmm...passthrough still doesn't match (it shouldn't, it outputs the most likely)
print(tokenizer.decode(toks))
#print(toks)
#print(tokenizer.decode(t0))
#print(t0)

View File

@@ -1,55 +0,0 @@
from tinygrad.helpers import trange
from tinygrad.nn.datasets import mnist
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
class Model(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(1, 32, 5)
self.c2 = nn.Conv2d(32, 32, 5)
self.bn1 = nn.BatchNorm(32)
self.m1 = nn.MaxPool2d(2)
self.c3 = nn.Conv2d(32, 64, 3)
self.c4 = nn.Conv2d(64, 64, 3)
self.bn2 = nn.BatchNorm(64)
self.m2 = nn.MaxPool2d(2)
self.lin = nn.Linear(576, 10)
def __call__(self, x):
x = mx.maximum(self.c1(x), 0)
x = mx.maximum(self.c2(x), 0)
x = self.m1(self.bn1(x))
x = mx.maximum(self.c3(x), 0)
x = mx.maximum(self.c4(x), 0)
x = self.m2(self.bn2(x))
return self.lin(mx.flatten(x, 1))
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist()
X_train = mx.array(X_train.float().permute((0,2,3,1)).numpy())
Y_train = mx.array(Y_train.numpy())
X_test = mx.array(X_test.float().permute((0,2,3,1)).numpy())
Y_test = mx.array(Y_test.numpy())
model = Model()
optimizer = optim.Adam(1e-3)
def loss_fn(model, x, y): return nn.losses.cross_entropy(model(x), y).mean()
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(samples):
# Compiled functions will also treat any inputs not in the parameter list as constants.
X,Y = X_train[samples], Y_train[samples]
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, X, Y)
optimizer.update(model, grads)
return loss
test_acc = float('nan')
for i in (t:=trange(70)):
samples = mx.random.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
loss = step(samples)
if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")

View File

@@ -1,45 +0,0 @@
import gymnasium as gym
import numpy as np
from gymnasium.envs.registration import register
# a very simple game
# one of <size> lights will light up
# take the action of the lit up light
# in <hard_mode>, you act differently based on the step number and need to track this
class PressTheLightUpButton(gym.Env):
metadata = {"render_modes": []}
def __init__(self, render_mode=None, size=2, game_length=10, hard_mode=False):
self.size, self.game_length = size, game_length
self.observation_space = gym.spaces.Box(0, 1, shape=(self.size,), dtype=np.float32)
self.action_space = gym.spaces.Discrete(self.size)
self.step_num = 0
self.done = True
self.hard_mode = hard_mode
def _get_obs(self):
obs = [0]*self.size
if self.step_num < len(self.state):
obs[self.state[self.step_num]] = 1
return np.array(obs, dtype=np.float32)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.state = np.random.randint(0, self.size, size=self.game_length)
self.step_num = 0
self.done = False
return self._get_obs(), {}
def step(self, action):
target = ((action + self.step_num) % self.size) if self.hard_mode else action
reward = int(target == self.state[self.step_num])
self.step_num += 1
if not reward:
self.done = True
return self._get_obs(), reward, self.done, self.step_num >= self.game_length, {}
register(
id="PressTheLightUpButton-v0",
entry_point="examples.rl.lightupbutton:PressTheLightUpButton",
max_episode_steps=None,
)

View File

@@ -1,136 +0,0 @@
#!/usr/bin/env python
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
import sys
import numpy as np
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d, optim
from tinygrad.helpers import getenv
from extra.datasets import fetch_mnist
from extra.augment import augment_img
from extra.training import train, evaluate
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")
class SqueezeExciteBlock2D:
def __init__(self, filters):
self.filters = filters
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
self.bias2 = Tensor.scaled_uniform(1, self.filters)
def __call__(self, input):
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
se = se.reshape(shape=(-1, self.filters))
se = se.dot(self.weight1) + self.bias1
se = se.relu()
se = se.dot(self.weight2) + self.bias2
se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
se = input.mul(se)
return se
class ConvBlock:
def __init__(self, h, w, inp, filters=128, conv=3):
self.h, self.w = h, w
self.inp = inp
#init weights
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
#init layers
self._bn = BatchNorm2d(128)
self._seb = SqueezeExciteBlock2D(filters)
def __call__(self, input):
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
for cweight, cbias in zip(self.cweights, self.cbiases):
x = x.pad(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
x = self._bn(x)
x = self._seb(x)
return x
class BigConvNet:
def __init__(self):
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
self.weight1 = Tensor.scaled_uniform(128,10)
self.weight2 = Tensor.scaled_uniform(128,10)
def parameters(self):
if DEBUG: #keeping this for a moment
pars = [par for par in get_parameters(self) if par.requires_grad]
no_pars = 0
for par in pars:
print(par.shape)
no_pars += np.prod(par.shape)
print('no of parameters', no_pars)
return pars
else:
return get_parameters(self)
def save(self, filename):
with open(filename+'.npy', 'wb') as f:
for par in get_parameters(self):
#if par.requires_grad:
np.save(f, par.numpy())
def load(self, filename):
with open(filename+'.npy', 'rb') as f:
for par in get_parameters(self):
#if par.requires_grad:
try:
par.numpy()[:] = np.load(f)
if GPU:
par.gpu()
except:
print('Could not load parameter')
def forward(self, x):
x = self.conv[0](x)
x = self.conv[1](x)
x = x.avg_pool2d(kernel_size=(2,2))
x = self.conv[2](x)
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
return xo
if __name__ == "__main__":
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
BS = 32
lmbd = 0.00025
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
steps = len(X_train)//BS
np.random.seed(1337)
if QUICK:
steps = 1
X_test, Y_test = X_test[:BS], Y_test[:BS]
model = BigConvNet()
if len(sys.argv) > 1:
try:
model.load(sys.argv[1])
print('Loaded weights "'+sys.argv[1]+'", evaluating...')
evaluate(model, X_test, Y_test, BS=BS)
except:
print('could not load weights "'+sys.argv[1]+'".')
if GPU:
params = get_parameters(model)
[x.gpu_() for x in params]
for lr, epochs in zip(lrs, epochss):
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(1,epochs+1):
#first epoch without augmentation
X_aug = X_train if epoch == 1 else augment_img(X_train)
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
accuracy = evaluate(model, X_test, Y_test, BS=BS)
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')

View File

@@ -1,17 +0,0 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d
from tinygrad.nn.state import get_parameters
if __name__ == "__main__":
with Tensor.train():
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
x = Tensor.uniform(BS, C1, H, W)
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
bn = BatchNorm2d(C2, track_running_stats=False)
for t in get_parameters([x, conv, bn]): t.realize()
print("running network")
x.sequential([conv, bn]).numpy()

View File

@@ -1,669 +0,0 @@
# original implementation: https://github.com/svc-develop-team/so-vits-svc
from __future__ import annotations
import sys, logging, time, io, math, argparse, operator, numpy as np
from functools import partial, reduce
from pathlib import Path
from typing import Tuple, Optional, Type
from tinygrad import nn, dtypes, Tensor
from tinygrad.helpers import getenv, fetch
from tinygrad.nn.state import torch_load
from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, get_hparams_from_file, load_checkpoint, weight_norm, HParams
from examples.sovits_helpers import preprocess
import soundfile
DEBUG = getenv("DEBUG")
F0_BIN = 256
F0_MAX = 1100.0
F0_MIN = 50.0
F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
class SpeechEncoder:
def __init__(self, hidden_dim, model:ContentVec): self.hidden_dim, self.model = hidden_dim, model
def encode(self, ): raise NotImplementedError("implement me")
@classmethod
def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url)
return cls(contentvec)
class ContentVec256L9(SpeechEncoder):
def __init__(self, model:ContentVec): super().__init__(hidden_dim=256, model=model)
def encode(self, wav: Tensor):
feats = wav
if len(feats.shape) == 2: # double channels
feats = feats.mean(-1)
assert len(feats.shape) == 1, feats.dim()
feats = feats.reshape(1, -1)
padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=9)
feats = self.model.final_proj(logits[0])
return feats.transpose(1,2)
class ContentVec768L12(SpeechEncoder):
def __init__(self, model:ContentVec): super().__init__(hidden_dim=768, model=model)
def encode(self, wav: Tensor):
feats = wav
if len(feats.shape) == 2: # double channels
feats = feats.mean(-1)
assert len(feats.shape) == 1, feats.dim()
feats = feats.reshape(1, -1)
padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=12)
return logits[0].transpose(1,2)
# original code for contentvec: https://github.com/auspicious3000/contentvec/
class ContentVec:
# self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2.
# This means that the ContentVec only works with the hubert weights used in all SVC models
def __init__(self, cfg: HParams):
self.feature_grad_mult, self.untie_final_proj = cfg.feature_grad_mult, cfg.untie_final_proj
feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0]
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
self.encoder = TransformerEncoder(cfg)
self.layer_norm = nn.LayerNorm(self.embed)
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * 1) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim)
self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32)
self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32)
def forward_features(self, source, padding_mask):
if self.feature_grad_mult > 0:
features = self.feature_extractor(source, padding_mask)
if self.feature_grad_mult != 1.0: pass # training: GradMultiply.forward(features, self.feature_grad_mult)
else:
features = self.feature_extractor(source, padding_mask)
return features
def forward_padding_mask(self, features, padding_mask): # replaces original forward_padding_mask for batch inference
lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure its bool for tilde
lengths = (lengths_org - 400).float().div(320).floor().cast(dtypes.int64) + 1 # intermediate float to divide
padding_mask = lengths_to_padding_mask(lengths)
return padding_mask
def extract_features(self, source: Tensor, spk_emb:Tensor=None, padding_mask=None, ret_conv=False, output_layer=None, tap=False):
features = self.forward_features(source, padding_mask)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
x, _ = self.encoder(features, spk_emb, padding_mask=padding_mask, layer=(None if output_layer is None else output_layer - 1), tap=tap)
res = features if ret_conv else x
return res, padding_mask
@classmethod
def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
fetch(checkpoint_url, checkpoint_path)
cfg = load_fairseq_cfg(checkpoint_path)
enc = cls(cfg.model)
_ = load_checkpoint_enc(checkpoint_path, enc, None)
logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}")
return enc
class TransformerEncoder:
def __init__(self, cfg: HParams):
def make_conv() -> nn.Conv1d:
layer = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=cfg.conv_pos, padding=cfg.conv_pos // 2, groups=cfg.conv_pos_groups)
std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim))
layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), (Tensor.zeros(*layer.bias.shape))
# for training: layer.weights need to be weight_normed
return layer
self.dropout, self.embedding_dim, self.layer_norm_first, self.layerdrop, self.num_layers, self.num_layers_1 = cfg.dropout, cfg.encoder_embed_dim, cfg.layer_norm_first, cfg.encoder_layerdrop, cfg.encoder_layers, cfg.encoder_layers_1
self.pos_conv, self.pos_conv_remove = [make_conv()], (1 if cfg.conv_pos % 2 == 0 else 0)
self.layers = [
TransformerEncoderLayer(self.embedding_dim, cfg.encoder_ffn_embed_dim, cfg.encoder_attention_heads, self.dropout, cfg.attention_dropout, cfg.activation_dropout, cfg.activation_fn, self.layer_norm_first, cond_layer_norm=(i >= cfg.encoder_layers))
for i in range(cfg.encoder_layers + cfg.encoder_layers_1)
]
self.layer_norm = nn.LayerNorm(self.embedding_dim)
self.cond_layer_norm = CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None
# training: apply init_bert_params
def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False):
x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap)
if self.layer_norm_first and layer is None:
x = self.cond_layer_norm(x, spk_emb) if (self.num_layers_1 > 0) else self.layer_norm(x)
return x, layer_results
def extract_features(self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False):
if tgt_layer is not None: # and not self.training
assert tgt_layer >= 0 and tgt_layer < len(self.layers)
if padding_mask is not None:
# x[padding_mask] = 0
assert padding_mask.shape == x.shape[:len(padding_mask.shape)] # first few dims of x must match padding_mask
tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1]))
tmp_mask = tilde(tmp_mask.cast(dtypes.bool))
x = tmp_mask.where(x, 0)
x_conv = self.pos_conv[0](x.transpose(1,2))
if self.pos_conv_remove > 0: x_conv = x_conv[:, :, : -self.pos_conv_remove]
x_conv = x_conv.gelu().transpose(1, 2)
x = (x + x_conv).transpose(0, 1) # B x T x C -> T x B x C
if not self.layer_norm_first: x = self.layer_norm(x)
x = x.dropout(p=self.dropout)
layer_results = []
r = None
for i, layer in enumerate(self.layers):
if i < self.num_layers: # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers):
assert layer.cond_layer_norm == False
x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
if tgt_layer is not None or tap:
layer_results.append(x.transpose(0, 1))
if i>= self.num_layers:
assert layer.cond_layer_norm == True
x = layer(x, emb=spk_emb, self_attn_padding_mask=padding_mask, need_weights=False)
if i == tgt_layer:
r = x
break
if r is not None:
x = r
x = x.transpose(0, 1) # T x B x C -> B x T x C
return x, layer_results
class TransformerEncoderLayer:
def __init__(self, embedding_dim=768.0, ffn_embedding_dim=3072.0, num_attention_heads=8.0, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, activation_fn="relu", layer_norm_first=False, cond_layer_norm=False):
def get_activation_fn(activation):
if activation == "relu": return Tensor.relu
if activation == "gelu": return Tensor.gelu
else: raise RuntimeError(f"activation function={activation} is not forseen")
self.embedding_dim, self.dropout, self.activation_dropout, self.layer_norm_first, self.num_attention_heads, self.cond_layer_norm, self.activation_fn = embedding_dim, dropout, activation_dropout, layer_norm_first, num_attention_heads, cond_layer_norm, get_activation_fn(activation_fn)
self.self_attn = MultiHeadAttention(self.embedding_dim, self.num_attention_heads)
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
self.final_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
def __call__(self, x:Tensor, self_attn_mask:Tensor=None, self_attn_padding_mask:Tensor=None, emb:Tensor=None, need_weights=False):
#self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None
assert self_attn_mask is None and self_attn_padding_mask is not None
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
x = self.self_attn(x=x, mask=self_attn_padding_mask)
x = x.dropout(self.dropout)
x = residual + x
x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
x = self.activation_fn(self.fc1(x))
x = x.dropout(self.activation_dropout)
x = self.fc2(x)
x = x.dropout(self.dropout)
x = residual + x
else:
x = self.self_attn(x=x, mask=self_attn_padding_mask)
x = x.dropout(self.dropout)
x = residual + x
x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
residual = x
x = self.activation_fn(self.fc1(x))
x = x.dropout(self.activation_dropout)
x = self.fc2(x)
x = x.dropout(self.dropout)
x = residual + x
x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
return x
class MultiHeadAttention:
def __init__(self, n_state, n_head):
self.n_state, self.n_head = n_state, n_head
self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(n_state, n_state) for _ in range(4)]
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
x = x.transpose(0,1) # TxBxC -> BxTxC
q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x)
q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)]
wv = Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None).transpose(1, 2).reshape(*x.shape[:2], -1)
ret = self.out_proj(wv).transpose(0,1) # BxTxC -> TxBxC
return ret
class ConvFeatureExtractionModel:
def __init__(self, conv_layers, dropout=.0, mode="default", conv_bias=False):
assert mode in {"default", "group_norm_masked", "layer_norm"}
def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
conv.weight = Tensor.kaiming_normal(*conv.weight.shape)
return conv
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu]
elif is_group_norm and mode == "default":
return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu]
elif is_group_norm and mode == "group_norm_masked":
return [make_conv(), partial(Tensor.dropout, p=dropout), GroupNormMasked(dim, dim, affine=True), Tensor.gelu]
else:
return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu]
in_d, self.conv_layers, self.mode = 1, [], mode
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
if i == 0: self.cl = cl
self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=(mode == "layer_norm"), is_group_norm=((mode == "default" or mode == "group_norm_masked") and i == 0), conv_bias=conv_bias))
in_d = dim
def __call__(self, x:Tensor, padding_mask:Tensor):
x = x.unsqueeze(1) # BxT -> BxCxT
if self.mode == "group_norm_masked":
if padding_mask is not None:
_, k, stride = self.cl
lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure padding_mask is bool for tilde
lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64)
padding_mask = tilde(lengths_to_padding_mask(lengths)).cast(dtypes.int64) # lengths_to_padding_mask returns bool tensor
x = self.conv_layers[0][0](x) # padding_mask is numeric
x = self.conv_layers[0][1](x)
x = self.conv_layers[0][2](x, padding_mask)
x = self.conv_layers[0][3](x)
else:
x = x.sequential(self.conv_layers[0]) # default
for _, conv in enumerate(self.conv_layers[1:], start=1):
conv = reduce(lambda a,b: operator.iconcat(a,b if isinstance(b, list) else [b]), conv, []) # flatten
x = x.sequential(conv)
return x
class CondLayerNorm: # https://github.com/auspicious3000/contentvec/blob/main/contentvec/modules/cond_layer_norm.py#L10
def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True):
self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = dim_last, eps, dim_spk, elementwise_affine
if self.elementwise_affine:
self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
self.weight_ln.weight, self.bias_ln.weight = (Tensor.ones(*self.weight_ln.weight.shape)), (Tensor.zeros(*self.bias_ln.weight.shape))
def __call__(self, x: Tensor, spk_emb: Tensor):
axis = tuple(-1-i for i in range(len(x.shape[1:])))
x = x.layernorm(axis=axis, eps=self.eps)
if not self.elementwise_affine: return x
weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb)
return weights * x + bias
class GroupNormMasked: # https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/modules/fp32_group_norm.py#L16
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
self.weight, self.bias = (Tensor.ones(num_channels)), (Tensor.zeros(num_channels)) if self.affine else (None, None)
def __call__(self, x:Tensor, mask:Tensor):
bsz, n_c, length = x.shape
assert n_c % self.num_groups == 0
x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length)
if mask is None: mask = Tensor.ones_like(x)
else: mask = mask.reshape(bsz, 1, 1, length)
x = x * mask
lengths = mask.sum(axis=3, keepdim=True)
assert x.shape[2] == 1
mean_ = x.mean(dim=3, keepdim=True)
mean = mean_ * length / lengths
var = (((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths - mean**2) + self.eps
return x.add(-mean).div(var.sqrt()).reshape(bsz, n_c, length).mul(self.weight.reshape(1,-1,1)).add(self.bias.reshape(1,-1,1))
class Synthesizer:
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, ssl_dim, n_speakers, sampling_rate=44100, vol_embedding=False, n_flow_layer=4, **kwargs):
self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.vol_embedding = spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, vol_embedding
self.emb_g = nn.Embedding(n_speakers, gin_channels)
if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels)
self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
self.enc_p = TextEncoder(inter_channels, hidden_channels, kernel_size, n_layers, filter_channels=filter_channels, n_heads=n_heads, p_dropout=p_dropout)
self.dec = Generator(sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels)
def infer(self, c:Tensor, f0:Tensor, uv:Tensor, g:Tensor=None, noise_scale=0.35, seed=52468, vol=None) -> Tuple[Tensor, Tensor]:
Tensor.manual_seed(getenv('SEED', seed))
c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device)
if len(g.shape) == 1: g = g.unsqueeze(0)
g = self.emb_g(g).transpose(1, 2)
x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype)
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
x = self.pre(c) * x_mask + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + vol
z_p, _, _, c_mask = self.enc_p.forward(x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale)
z = self.flow.forward(z_p, c_mask, g=g, reverse=True)
o = self.dec.forward(z * c_mask, g=g, f0=f0)
return o,f0
def _f0_to_coarse(self, f0 : Tensor):
f0_mel = 1127 * (1 + f0 / 700).log()
a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN)
b = F0_MEL_MIN * a - 1.
f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel)
f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64)
f0_coarse = f0_coarse * (f0_coarse > 0)
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
f0_coarse = f0_coarse * (f0_coarse < F0_BIN)
f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1))
return f0_coarse
@classmethod
def load_from_pretrained(cls, config_path:str, config_url:str, weights_path:str, weights_url:str) -> Synthesizer:
fetch(config_url, config_path)
hps = get_hparams_from_file(config_path)
fetch(weights_url, weights_path)
net_g = cls(hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
_ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"])
logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}")
return net_g, hps
class TextEncoder:
def __init__(self, out_channels, hidden_channels, kernel_size, n_layers, gin_channels=0, filter_channels=None, n_heads=None, p_dropout=None):
self.out_channels, self.hidden_channels, self.kernel_size, self.n_layers, self.gin_channels = out_channels, hidden_channels, kernel_size, n_layers, gin_channels
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
self.f0_emb = nn.Embedding(256, hidden_channels) # n_vocab = 256
self.enc_ = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
def forward(self, x, x_mask, f0=None, noise_scale=1):
x = x + self.f0_emb(f0).transpose(1, 2)
x = self.enc_.forward(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = split(stats, self.out_channels, dim=1)
z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask
return z, m, logs, x_mask
class Upsample:
def __init__(self, scale_factor):
assert scale_factor % 1 == 0, "Only integer scale factor allowed."
self.scale = int(scale_factor)
def forward(self, x:Tensor):
repeats = tuple([1] * len(x.shape) + [self.scale])
new_shape = (*x.shape[:-1], x.shape[-1] * self.scale)
return x.unsqueeze(-1).repeat(repeats).reshape(new_shape)
class SineGen:
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voice_threshold=0, flag_for_pulse=False):
self.sine_amp, self.noise_std, self.harmonic_num, self.sampling_rate, self.voiced_threshold, self.flag_for_pulse = sine_amp, noise_std, harmonic_num, samp_rate, voice_threshold, flag_for_pulse
self.dim = self.harmonic_num + 1
def _f02uv(self, f0): return (f0 > self.voiced_threshold).float() #generate uv signal
def _f02sine(self, f0_values):
def padDiff(x : Tensor): return (x.pad((0,0,-1,1)) - x).pad((0,0,0,-1))
def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor() # this is what the % operator does in pytorch.
rad_values = mod((f0_values / self.sampling_rate) , 1) # convert to F0 in rad
rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) # initial phase noise
#rand_ini[:, 0] = 0
m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool)
m = tilde(m)
rand_ini = m.where(rand_ini, 0)
#rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp = rad_values[:, 0, :] + rand_ini
m = Tensor.ones(tmp.shape).pad((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool)
m = tilde(m)
tmp = tmp.unsqueeze(1).pad((0,0,0,rad_values.shape[1]-1,0))
rad_values = m.where(rad_values, tmp)
tmp_over_one = mod(rad_values.cumsum(1), 1)
tmp_over_one_idx = padDiff(tmp_over_one) < 0
cumsum_shift = Tensor.zeros_like(rad_values)
#cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad((0,0,1,0))
cumsum_shift = tmp_over_one_idx
sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin()
return sines
def forward(self, f0, upp=None):
fn = f0.mul(Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to(f0.device))
sine_waves = self._f02sine(fn) * self.sine_amp #generate sine waveforms
uv = self._f02uv(f0) # generate uv signal
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceHnNSF:
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
self.sine_amp, self.noise_std = sine_amp, add_noise_std
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
self.l_linear = nn.Linear(harmonic_num + 1, 1)
def forward(self, x, upp=None):
sine_waves, uv, _ = self.l_sin_gen.forward(x, upp)
sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh()
noise = randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
# most of the hifigan in standard vits is reused here, but need to upsample and construct harmonic source from f0
class Generator:
def __init__(self, sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels):
self.sampling_rate, self.inter_channels, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.gin_channels = sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels
self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates)
self.conv_pre = nn.Conv1d(inter_channels, upsample_initial_channel, 7, 1, padding=3)
self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates))
self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8)
resblock = ResBlock1 if resblock == '1' else ResBlock2
self.ups, self.noise_convs, self.resblocks = [], [], []
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel//(2**(i+1))
self.ups.append(nn.ConvTranspose1d(upsample_initial_channel//(2**i), c_cur, k, u, padding=(k-u)//2))
stride_f0 = int(np.prod(upsample_rates[i + 1:]))
self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2) if (i + 1 < len(upsample_rates)) else nn.Conv1d(1, c_cur, kernel_size=1))
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3)
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
self.upp = np.prod(upsample_rates)
def forward(self, x, f0, g=None):
f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, _, _ = self.m_source.forward(f0, self.upp)
har_source = har_source.transpose(1, 2)
x = self.conv_pre(x)
if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples):
x, xs = self.ups[i](x.leaky_relu(LRELU_SLOPE)), None
x_source = self.noise_convs[i](har_source)
x = x + x_source
for j in range(self.num_kernels):
if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
x = xs / self.num_kernels
return self.conv_post(x.leaky_relu()).tanh()
# **** helpers ****
def randn_like(x:Tensor) -> Tensor: return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device)
def tilde(x: Tensor) -> Tensor:
if x.dtype == dtypes.bool: return (1 - x).cast(dtypes.bool)
return (x + 1) * -1 # this seems to be what the ~ operator does in pytorch for non bool
def lengths_to_padding_mask(lens:Tensor) -> Tensor:
bsz, max_lens = lens.shape[0], lens.max().numpy().item()
mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens)
return mask.cast(dtypes.bool)
def repeat_expand_2d_left(content, target_len): # content : [h, t]
src_len = content.shape[-1]
temp = np.arange(src_len+1) * target_len / src_len
current_pos, cols = 0, []
for i in range(target_len):
if i >= temp[current_pos+1]:
current_pos += 1
cols.append(content[:, current_pos])
return Tensor.stack(*cols).transpose(0, 1)
def load_fairseq_cfg(checkpoint_path):
assert Path(checkpoint_path).is_file()
state = torch_load(checkpoint_path)
cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None
if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}")
return HParams(**cfg)
def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]):
assert Path(checkpoint_path).is_file()
start_time = time.time()
checkpoint_dict = torch_load(checkpoint_path)
saved_state_dict = checkpoint_dict['model']
weight_g, weight_v, parent = None, None, None
for key, v in saved_state_dict.items():
if any(layer in key for layer in skip_list): continue
try:
obj, skip = model, False
for k in key.split('.'):
if k.isnumeric(): obj = obj[int(k)]
elif isinstance(obj, dict): obj = obj[k]
else:
if k in ["weight_g", "weight_v"]:
parent, skip = obj, True
if k == "weight_g": weight_g = v
else: weight_v = v
if not skip:
parent = obj
obj = getattr(obj, k)
if weight_g and weight_v:
setattr(obj, "weight_g", weight_g.numpy())
setattr(obj, "weight_v", weight_v.numpy())
obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
weight_g, weight_v, parent, skip = None, None, None, False
if not skip and obj.shape == v.shape:
if "feature_extractor" in key and (isinstance(parent, (nn.GroupNorm, nn.LayerNorm))): # cast
obj.assign(v.to(obj.device).float())
else:
obj.assign(v.to(obj.device))
elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}")
except Exception as e: raise e
logging.info(f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s")
return model, optimizer
def pad_array(arr, target_length):
current_length = arr.shape[0]
if current_length >= target_length: return arr
pad_width = target_length - current_length
pad_left = pad_width // 2
pad_right = pad_width - pad_left
padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
return padded_arr
def split_list_by_n(list_collection, n, pre=0):
for i in range(0, len(list_collection), n):
yield list_collection[i-pre if i-pre>=0 else i: i + n]
def get_sid(spk2id:HParams, speaker:str) -> Tensor:
speaker_id = spk2id[speaker]
if not speaker_id and type(speaker) is int:
if len(spk2id.__dict__) >= speaker: speaker_id = speaker
if speaker_id is None: raise RuntimeError(f"speaker={speaker} not in the speaker list")
return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0)
def get_encoder(ssl_dim) -> Type[SpeechEncoder]:
if ssl_dim == 256: return ContentVec256L9
if ssl_dim == 768: return ContentVec768L12
#########################################################################################
# CODE: https://github.com/svc-develop-team/so-vits-svc
#########################################################################################
# CONTENTVEC:
# CODE: https://github.com/auspicious3000/contentvec
# PAPER: https://arxiv.org/abs/2204.09224
#########################################################################################
# INSTALLATION: dependencies are for preprocessing and loading/saving audio.
# pip3 install soundfile librosa praat-parselmouth
#########################################################################################
# EXAMPLE USAGE:
# python3 examples/so_vits_svc.py --model tf2spy --file ~/recording.wav
#########################################################################################
# DEMO USAGE (uses audio sample from LJ-Speech):
# python3 examples/so_vits_svc.py --model saul_goodman
#########################################################################################
SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC"
VITS_MODELS = { # config_path, weights_path, config_url, weights_url
"saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"),
"drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"),
"cartman" : (SO_VITS_SVC_PATH / "config_cartman.json", SO_VITS_SVC_PATH / "pretrained_cartman.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth"),
"tf2spy" : (SO_VITS_SVC_PATH / "config_tf2spy.json", SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth"),
"tf2heavy" : (SO_VITS_SVC_PATH / "config_tf2heavy.json", SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth"),
"lady_gaga" : (SO_VITS_SVC_PATH / "config_gaga.json", SO_VITS_SVC_PATH / "pretrained_gaga.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth")
}
ENCODER_MODELS = { # weights_path, weights_url
"contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt")
}
ENCODER_MODEL = "contentvec"
DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
if __name__=="__main__":
logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG))
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True)
parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file")
parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.")
parser.add_argument("--noise_scale", default=0.4)
parser.add_argument("--tran", default=0.0, help="Pitch shift, supports positive and negative (semitone) values. Default 0.0")
parser.add_argument("--pad_seconds", default=0.5)
parser.add_argument("--lg_num", default=0.0)
parser.add_argument("--clip_seconds", default=0.0)
parser.add_argument("--slice_db", default=-40)
args = parser.parse_args()
vits_model = args.model
encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model]
Tensor.training = False
# Get Synthesizer and ContentVec
net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3])
Encoder = get_encoder(hps.model.ssl_dim)
encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1])
# model config args
target_sample, spk2id, hop_length, target_sample = hps.data.sampling_rate, hps.spk, hps.data.hop_length, hps.data.sampling_rate
vol_embedding = hps.model.vol_embedding if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None else False
# args
slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = args.slice_db, args.clip_seconds, args.lg_num, args.pad_seconds, args.tran, args.noise_scale, args.file
speaker = args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0]
### Loading audio and slicing ###
if audio_path == DEMO_PATH: fetch(DEMO_URL, DEMO_PATH)
assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav"
chunks = preprocess.cut(audio_path, db_thresh=slice_db)
audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks)
per_size = int(clip_seconds * audio_sr)
lg_size = int(lg_num * audio_sr)
### Infer per slice ###
global_frame = 0
audio = []
for (slice_tag, data) in audio_data:
print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====")
length = int(np.ceil(len(data) / audio_sr * target_sample))
if slice_tag:
print("empty segment")
_audio = np.zeros(length)
audio.extend(list(pad_array(_audio, length)))
global_frame += length // hop_length
continue
datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size)
for k, dat in enumerate(datas):
per_length = int(np.ceil(len(dat) / audio_sr * target_sample)) if clip_seconds!=0 else length
pad_len = int(audio_sr * pad_seconds)
dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
raw_path = io.BytesIO()
soundfile.write(raw_path, dat, audio_sr, format="wav")
raw_path.seek(0)
### Infer START ###
wav, sr = preprocess.load_audiofile(raw_path)
wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0]
wav16k, f0, uv = preprocess.get_unit_f0(wav, tran, hop_length, target_sample)
sid = get_sid(spk2id, speaker)
n_frames = f0.shape[1]
# ContentVec infer
start = time.time()
c = encoder.encode(wav16k)
c = repeat_expand_2d_left(c.squeeze(0).realize(), f0.shape[1]) # interpolate speech encoding to match f0
c = c.unsqueeze(0).realize()
enc_time = time.time() - start
# VITS infer
vits_start = time.time()
out_audio, f0 = net_g.infer(c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None)
out_audio = out_audio[0,0].float().realize()
vits_time = time.time() - vits_start
infer_time = time.time() - start
logging.info("total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format(infer_time, enc_time, vits_time))
### Infer END ###
out_sr, out_frame = out_audio.shape[-1], n_frames
global_frame += out_frame
_audio = out_audio.numpy()
pad_len = int(target_sample * pad_seconds)
_audio = _audio[pad_len:-pad_len]
_audio = pad_array(_audio, per_length)
audio.extend(list(_audio))
audio = np.array(audio)
out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav")
out_path.parent.mkdir(parents=True, exist_ok=True)
soundfile.write(out_path, audio, target_sample, format="flac")
logging.info(f"Saved audio output to {out_path}")

View File

@@ -1,204 +0,0 @@
import math
from typing import Optional, Tuple
from tinygrad import Tensor, dtypes
import librosa
import soundfile
import numpy as np
import parselmouth
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
def interpolate_f0(self,f0):
vuv_vector = np.zeros_like(f0, dtype=np.float32)
vuv_vector[f0 > 0.0] = 1.0
vuv_vector[f0 <= 0.0] = 0.0
nzindex = np.nonzero(f0)[0]
data = f0[nzindex]
nzindex = nzindex.astype(np.float32)
time_org = self.hop_length / self.sampling_rate * nzindex
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
return f0,vuv_vector
def compute_f0(self,wav,p_len=None):
x = wav
if p_len is None: p_len = x.shape[0]//self.hop_length
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
time_step = self.hop_length / self.sampling_rate * 1000
f0 = parselmouth.Sound(x, self.sampling_rate) \
.to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
.selected_array['frequency']
pad_size=(p_len - len(f0) + 1) // 2
if(pad_size>0 or p_len - len(f0) - pad_size>0):
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
f0,uv = self.interpolate_f0(f0)
return f0
def compute_f0_uv(self,wav,p_len=None):
x = wav
if p_len is None: p_len = x.shape[0]//self.hop_length
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
time_step = self.hop_length / self.sampling_rate * 1000
f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
time_step=time_step / 1000, voicing_threshold=0.6,
pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
pad_size=(p_len - len(f0) + 1) // 2
if(pad_size>0 or p_len - len(f0) - pad_size>0):
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
f0,uv = self.interpolate_f0(f0)
return f0,uv
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
if not min_length >= min_interval >= hop_size:
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
if not max_sil_kept >= hop_size:
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
min_interval = sr * min_interval / 1000
self.threshold = 10 ** (threshold / 20.)
self.hop_size = round(sr * hop_size / 1000)
self.win_size = min(round(min_interval), 4 * self.hop_size)
self.min_length = round(sr * min_length / 1000 / self.hop_size)
self.min_interval = round(min_interval / self.hop_size)
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
def _apply_slice(self, waveform, begin, end):
if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
def slice(self, waveform):
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
sil_tags, silence_start, clip_start = [], None, 0
for i, rms in enumerate(rms_list):
if rms < self.threshold: # Keep looping while frame is silent.
if silence_start is None: # Record start of silent frames.
silence_start = i
continue
if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded.
# Clear recorded silence start if interval is not enough or clip is too short
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
if not is_leading_silence and not need_slice_middle:
silence_start = None
continue
if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed.
pos = rms_list[silence_start: i + 1].argmin() + silence_start
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
clip_start = pos
elif i - silence_start <= self.max_sil_kept * 2:
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
pos += i - self.max_sil_kept
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
if silence_start == 0:
sil_tags.append((0, pos_r))
clip_start = pos_r
else:
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
clip_start = max(pos_r, pos)
else:
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
clip_start = pos_r
silence_start = None
total_frames = rms_list.shape[0]
if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence.
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices.
chunks = []
if sil_tags[0][0]:
chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
for i in range(0, len(sil_tags)):
if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
if sil_tags[-1][1] * self.hop_size < len(waveform):
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
chunk_dict = {}
for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
return chunk_dict
# sinc_interp_hann audio resampling
class Resample:
def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
def __call__(self, waveform:Tensor) -> Tensor:
if self.orig_freq == self.new_freq: return waveform
return self._apply_sinc_resample_kernel(waveform)
def _apply_sinc_resample_kernel(self, waveform:Tensor):
if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
shape = waveform.shape
waveform = waveform.reshape(-1, shape[-1]) # pack batch
num_wavs, length = waveform.shape
target_length = int(math.ceil(new_freq * length / orig_freq))
waveform = waveform.pad((self.width, self.width + orig_freq))
resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
resampled = resampled[..., :target_length]
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
return resampled
def _get_sinc_resample_kernel(self, dtype=None):
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
base_freq = min(orig_freq, new_freq)
base_freq *= self.rolloff
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
t *= base_freq
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
t *= math.pi
scale = base_freq / orig_freq
kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
kernels *= window * scale
if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
return kernels, width
def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
return resamp(x)
def cut(audio_path, db_thresh=-30, min_len=5000):
audio, sr = librosa.load(audio_path, sr=None)
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
chunks = slicer.slice(audio)
return chunks
def chunks2audio(audio_path, chunks):
chunks = dict(chunks)
audio, sr = load_audiofile(audio_path)
if len(audio.shape) == 2 and audio.shape[1] >= 2:
audio = audio.mean(0).unsqueeze(0)
audio = audio.numpy()[0]
result = []
for k, v in chunks.items():
tag = v["split_time"].split(",")
if tag[0] != tag[1]:
result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
return result, sr
def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
with soundfile.SoundFile(filepath, "r") as file_:
frames = file_._prepare_read(frame_offset, None, num_frames)
waveform = file_.read(frames, "float32", always_2d=True)
sample_rate = file_.samplerate
waveform = Tensor(waveform)
if channels_first: waveform = waveform.transpose(0, 1)
return waveform, sample_rate
def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
f0 = Tensor(f0.astype(np.float32)).float()
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
wav16k = sinc_interp_resample(wav[None,:], target_sample, 16000)[0]
return wav16k.realize(), f0.realize(), uv.realize()

View File

@@ -1,104 +0,0 @@
import traceback
import time
from multiprocessing import Process, Queue
import numpy as np
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import getenv, trange
from tinygrad.tensor import Tensor
from extra.datasets import fetch_cifar
from extra.models.efficientnet import EfficientNet
class TinyConvNet:
def __init__(self, classes=10):
conv = 3
inter_chan, out_chan = 8, 16 # for speed
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
self.l1 = Tensor.uniform(out_chan*6*6, classes)
def forward(self, x):
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1)
if __name__ == "__main__":
IMAGENET = getenv("IMAGENET")
classes = 1000 if IMAGENET else 10
TINY = getenv("TINY")
TRANSFER = getenv("TRANSFER")
if TINY:
model = TinyConvNet(classes)
elif TRANSFER:
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
model.load_from_pretrained()
else:
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
parameters = get_parameters(model)
print("parameter count", len(parameters))
optimizer = optim.Adam(parameters, lr=0.001)
BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
print(f"training with batch size {BS} for {steps} steps")
if IMAGENET:
from extra.datasets.imagenet import fetch_batch
def loader(q):
while 1:
try:
q.put(fetch_batch(BS))
except Exception:
traceback.print_exc()
q = Queue(16)
for i in range(2):
p = Process(target=loader, args=(q,))
p.daemon = True
p.start()
else:
X_train, Y_train, _, _ = fetch_cifar()
X_train = X_train.reshape((-1, 3, 32, 32))
Y_train = Y_train.reshape((-1,))
with Tensor.train():
for i in (t := trange(steps)):
if IMAGENET:
X, Y = q.get(True)
else:
samp = np.random.randint(0, X_train.shape[0], size=(BS))
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
st = time.time()
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
fp_time = (time.time()-st)*1000.0
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
y = Tensor(y, requires_grad=False)
loss = out.log_softmax().mul(y).mean()
optimizer.zero_grad()
st = time.time()
loss.backward()
bp_time = (time.time()-st)*1000.0
st = time.time()
optimizer.step()
opt_time = (time.time()-st)*1000.0
st = time.time()
loss = loss.numpy()
cat = out.argmax(axis=1).numpy()
accuracy = (cat == Y).mean()
finish_time = (time.time()-st)*1000.0
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
(loss, accuracy,
fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time))
del out, y, loss

View File

@@ -1,46 +0,0 @@
import ast
import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, fetch
from extra.models.vit import ViT
"""
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
import tensorflow as tf
with tf.io.gfile.GFile(fn, "rb") as f:
dat = f.read()
with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
g.write(dat)
"""
Tensor.training = False
if getenv("LARGE", 0) == 1:
m = ViT(embed_dim=768, num_heads=12)
else:
# tiny
m = ViT(embed_dim=192, num_heads=3)
m.load_from_pretrained()
# category labels
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
# junk
img = Image.open(fetch(url))
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
img = np.array(img)
y0,x0=(np.asarray(img.shape)[:2]-224)//2
img = img[y0:y0+224, x0:x0+224]
img = np.moveaxis(img, [2,0,1], [0,1,2])
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
img /= 255.0
img -= 0.5
img /= 0.5
out = m.forward(Tensor(img))
outnp = out.numpy().ravel()
choice = outnp.argmax()
print(out.shape, choice, outnp[choice], lbls[choice])

View File

@@ -1,189 +0,0 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from tinygrad.codegen.opt.kernel import Ops, MemOp, UOp
from tinygrad.uop.ops import BinaryOps, UnaryOps
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import DEBUG
from tinygrad.uop.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
class Register(NamedTuple):
nm:str
dtype:DType
scalar:bool
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes.float.vec(4):
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: Ops
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyLanguage:
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
#TODO: these should be global vars
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
ins: List[AssemblyInstruction] = []
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes.float.vec(4):
for off in range(4):
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
self.cnts[(dtype, scalar)] += 1
return ret
def render_numnode(self, b) -> Register:
key = ("num", b)
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return self.tor[key]
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in self.tor:
#if not isinstance(b, Register): b = render_numnode(b)
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return self.tor[key]
def render_cast(self, a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in self.tor:
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
return self.tor[key]
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def addr_w_offset(self, args):
assert isinstance(args, MemOp)
idx = args.idx*args.memory_dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = cast(int, nums[0])
reg = idx.render(self.render_ops, self)
if self.supports_load3:
if reg.scalar:
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return self.tor[args.name], reg, off
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
return reg, None, off
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
#TODO: Do not use clear()
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
for u in uops:
uop,dtype,vin,args,_ = u
if uop == Ops.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == Ops.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
elif uop == Ops.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
elif args[1] == 'local':
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
elif uop == Ops.CAST:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(u, dtype)
for i,sr in enumerate(out.subregs()):
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
elif uop == Ops.ALU:
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and lang.no_div:
tmp = lang.newreg((u, "rcp"))
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
tmp = lang.newreg((u, "2pi"))
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
else:
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
elif uop == Ops.DEFINE_REG:
reg = lang.newreg(u, dtype=dtype)
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
elif uop == Ops.SPECIAL:
lang.tor[u] = lang.tor[args]
elif uop == Ops.CONST:
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
elif uop == Ops.LOAD:
idx, treg, off = lang.addr_w_offset(args)
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
if args.valid.min == 0:
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == Ops.STORE:
if args is None:
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
else:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if DEBUG >= 4:
for tins in lang.ins: print(tins)
return global_size, local_size

View File

@@ -1,177 +0,0 @@
import struct
from platform import system
from typing import Tuple, Dict, List, Optional
from tinygrad import dtypes
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.opt.kernel import Ops, UOp
from tinygrad.helpers import CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def compute_offsets(total):
quotient, remainder = divmod(total, 4096)
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
#NOTE: Darwin needs names to start with a "_"
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
class ARM64Language(AssemblyLanguage): pass
def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop:Optional[Ops] = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(12))]
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
def mov_imm(value, reg):
# Manually move value into reg if value can't fit
if value.__class__ is not float and abs(value) > abs(65535):
ins.append(f"movz w15, #{value & 0xffff}")
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {reg}, w15")
elif reg[0] == 's':
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
ins.append("str x15, [sp, 16]")
ins.append(f"ldr {reg}, [sp, 16]")
else:
ins.append(f"mov {reg}, #{value}")
# Get variables intervals
live_range:Dict[str, List[int]] = {}
for i, (uop, out, vin, arg) in enumerate(asm):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
mem_vars:Dict[str, int] = {}
rtor:Dict[str, str] = {}
def allocate_regs(mvars):
nonlocal var_size
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if not available_regs:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()
temp_floats = ['s0', 's1', 's2']
temp_ints = ['x12', 'x13', 'x16']
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Assign a registers to the variables using live ranges.
allocate_regs([out] + vin)
# Assign temp regs to vin and load them before direct use
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
if uop == Ops.SPECIAL:
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
else:
ins.append(f"mov {rtor[out.nm]}, #0")
ins.append(f"loop_{arg}:")
elif uop == Ops.CAST:
if arg == BinaryOps.CMPLT:
if rtor[out.nm][0] == 's':
mov_imm(0.0, 's0')
mov_imm(1.0, 's1')
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
if rtor[out.nm][0] == 'x':
mov_imm(0, 'x14')
mov_imm(1, 'x15')
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == Ops.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif arg == TernaryOps.WHERE:
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
else:
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
# Save the registers before they are cleared by func call
for i,k in enumerate(save_regs,1):
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
ins.append("stp x29, x30, [sp, #0]!")
ins.append("mov x29, sp")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
ins.append(alu[arg])
ins.append(f"fmov {rtor[out.nm]}, s0")
ins.append("mov sp, x29")
ins.append("ldp x29, x30, [sp], #0")
for i,k in enumerate(save_regs,1):
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg == BinaryOps.CMPLT:
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
elif arg == BinaryOps.MOD:
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif uop == Ops.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == Ops.STORE:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
elif uop == Ops.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == Ops.LOAD:
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
elif uop == Ops.LABEL:
ins.append(f"{arg[1:]}:")
elif uop == Ops.ENDLOOP:
mov_imm(arg[0], "x15")
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
ins.append(f"b.lt loop_{arg[1]}")
prev_uop = uop
# store regs into memory if needed
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True

View File

@@ -1,105 +0,0 @@
from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.opt.kernel import Ops, UOp
from tinygrad import dtypes
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
def render_cast(ins, inp, out):
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
elif out.dtype == dtypes.bool:
if inp.dtype == dtypes.bool:
ins.append(f"mov.pred {out}, {inp};")
else:
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
else:
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXLanguage(AssemblyLanguage):
supports_constant_folding: bool = True
def specialize_to_ptx(lang, function_name):
param_cnt = 0
ins = []
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
for uop, out, vin, arg in lang.ins:
if uop == Ops.ENDLOOP:
ins.append("bar.sync 0;")
elif uop == Ops.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == Ops.SPECIAL:
if arg.startswith('data'):
param_cnt += 1
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
# ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith('gid'):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith('lid'):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == Ops.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
if arg == TernaryOps.WHERE:
if vin[0].dtype == dtypes.bool:
reg = vin[0]
else:
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
vin = vin[1:] + [reg]
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
elif uop == Ops.LOAD:
if arg.__class__ in (int, float):
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
reg = lang.newreg((out, dt[0]), dtype=dt[1])
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
render_cast(ins, reg, out)
else:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
elif uop == Ops.STORE:
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
if arg[2] == dtypes.bool != vin[1].dtype:
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
render_cast(ins, vin[1], prereg)
else: prereg = vin[1]
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
render_cast(ins, prereg, reg)
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
else:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
elif uop == Ops.CAST:
render_cast(ins, vin[0], out)
elif uop == Ops.LABEL:
ins.append(f"{arg}:")
elif uop == Ops.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
ins = ins_prefix + ins
ins += ["ret;", "}"]
return '\n'.join(ins)
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
lang = PTXLanguage()
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True

View File

@@ -1,203 +0,0 @@
import yaml
from typing import Tuple, Set, Dict
from tinygrad import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.opt.kernel import Ops
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH
# ugh, is this really needed?
from extra.helpers import enable_early_exec
early_exec = enable_early_exec()
boilerplate_start = """
.global _start
_start:
.rodata
.align 0x10
.global code.kd
.type code.kd,STT_OBJECT
.amdhsa_kernel code"""
code_start = """.end_amdhsa_kernel
.text
code:
"""
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
# RDNA3 is actually a SIMD machine!
class RDNACodegen(AssemblyCodegen):
supports_float4: bool = True
supports_float4_alu: bool = True
supports_load3: bool = True
sin_is_sin2pi: bool = True
no_div: bool = True
def specialize(self, asm) -> Tuple[str, str]:
args = []
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
BinaryOps.CMPLT: "cmp_lt"}
pend_regs:Set[Register] = set()
rtor:Dict[Register, str] = {}
def reg_in(x):
nonlocal pend_regs
#print("reg_in", x, rtor[x], pend_regs)
if x in pend_regs:
#print("clear")
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
pend_regs.clear()
return rtor[x]
def reg_out(x):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == Ops.DEFINE_REGISTER:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
for i in range(arg[2]):
# TODO: Re-use gaps created by this to avoid wasting registers
align = int(arg[0][0].itemsize / 4)
if arg[0][1]:
s_cnt += s_cnt % align
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
s_cnt += align
else:
v_cnt += v_cnt % align
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
v_cnt += align
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
if arg[0][0] == dtypes.float.vec(4):
for off in range(4):
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
elif arg[0][0] == dtypes.bool:
for i in range(arg[2]):
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
else:
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
elif uop == Ops.SPECIAL:
if arg.startswith('buf'):
i = int(arg[3:])
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif arg.startswith('gid'):
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
# the docs lied, this is actually y
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
# get local size
offset = len(args)*8
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
pend_regs.clear()
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
elif uop == Ops.CONST:
if arg == float('inf'): arg = "0x7f800000"
elif arg == float('-inf'): arg = "0xff800000"
if out.dtype == dtypes.float.vec(4):
for off in range(4):
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
elif uop == Ops.ALU:
if arg in [BinaryOps.CMPLT]:
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
alu_arg = alu[arg]
if arg == TernaryOps.MULACC and out == vin[2]:
alu_arg = "fmac"
vin = vin[0:2]
if out.dtype == dtypes.float.vec(4):
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
elif uop == Ops.LOAD:
if out.scalar:
# swap arg order
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
else:
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif uop == Ops.STORE:
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
elif uop == Ops.LABEL:
ins.append(f"{arg}:")
elif uop == Ops.COND_BRANCH:
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
elif uop == Ops.CAST:
if vin[0].dtype == dtypes.bool:
if out.dtype == dtypes.float32:
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
else:
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
else:
raise NotImplementedError(uop)
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
# dual alu group
seen = set()
new_ins = []
for i,tins in enumerate(ins):
if tins in seen: continue
if tins.startswith("v_fmac_f32"):
for gins in reversed(ins[i+1:]):
if gins in seen: continue
if gins.startswith("v_fmac_f32"):
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
if r0[0]%2 == r1[0]%2: continue
if r0[1]%2 == r1[1]%2: continue
if r0[2]%2 == r1[2]%2: continue
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
seen.add(tins)
seen.add(gins)
break
if tins not in seen:
new_ins.append(tins)
ins = new_ins
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
def assemble(self, args, ins, v_cnt, s_cnt):
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
metadata = {'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
return asm

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python3
import numpy as np
from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
if __name__ == "__main__":
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
prg = CUDAProgram("test", """
.version 7.8
.target sm_86
.address_size 64
.visible .entry test(.param .u64 x) {
.reg .b32 %r<2>;
.reg .b64 %rd<3>;
ld.param.u64 %rd1, [x];
cvta.to.global.u64 %rd2, %rd1;
mov.u32 %r1, 0x40000000; // 2.0 in float
st.global.u32 [%rd2], %r1;
ret;
}""", binary=True)
prg([1], [1], test)
print(test.toCPU())

View File

@@ -1,42 +0,0 @@
import numpy as np
from PIL import Image
from pathlib import Path
import sys
cwd = Path.cwd()
sys.path.append(cwd.as_posix())
sys.path.append((cwd / 'test').as_posix())
from extra.datasets import fetch_mnist
from tqdm import trange
def augment_img(X, rotate=10, px=3):
Xaug = np.zeros_like(X)
for i in trange(len(X)):
im = Image.fromarray(X[i])
im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC)
w, h = X.shape[1:]
#upper left, lower left, lower right, upper right
quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0])
im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC)
Xaug[i] = im
return Xaug
if __name__ == "__main__":
import matplotlib.pyplot as plt
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10)
fig, a = plt.subplots(2,len(X))
Xaug = augment_img(X)
for i in range(len(X)):
a[0][i].imshow(X[i], cmap='gray')
a[1][i].imshow(Xaug[i],cmap='gray')
a[0][i].axis('off')
a[1][i].axis('off')
plt.show()
#create some nice gifs for doc?!
for i in range(10):
im = Image.fromarray(X_train[7353+i])
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)

View File

@@ -1,39 +0,0 @@
from typing import List, Dict, cast
import ctypes
from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.uop.ops import Variable
from tinygrad.runtime.ops_cpu import ClangProgram
from tinygrad.renderer.cstyle import ClangRenderer
render_dtype = ClangRenderer().render_dtype
class ClangGraph(GraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
args += sorted([f"int {v}" for v in var_vals])
code = ["void batched("+','.join(args)+") {"]
for ji in jit_cache:
args = []
for buf in ji.bufs:
assert buf is not None
if buf in input_rawbuffers:
args.append(f"arg{input_rawbuffers.index(buf)}")
else:
args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
code.append("}")
if DEBUG >= 4: print("\n".join(code))
compiler = Device["CPU"].compiler
assert compiler is not None
self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[str, int], wait=False):
return cpu_time_execution(
lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0])]), enable=wait)

View File

@@ -1,27 +0,0 @@
import ctypes
from typing import Tuple
import tinygrad.runtime.autogen.hip as hip
from tinygrad.helpers import init_c_var, time_execution_cuda_style
from tinygrad.runtime.ops_hip import check, hip_set_device
from tinygrad.runtime.graph.cuda import CUDAGraph
# TODO: this is only used in graph
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
class HIPGraph(CUDAGraph):
def __del__(self):
if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance))
def set_device(self): hip_set_device(self.dev)
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
def graph_instantiate(self, graph):
return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
def graph_add_kernel_node(self, graph, c_deps, c_params):
return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501
def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
def build_kernel_node_params(self, prg, global_size, local_size, c_config):
return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size

View File

@@ -1,143 +0,0 @@
import ctypes, collections
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import init_c_var
def check(status):
if status != 0:
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
# Precalulated AQL info
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
EMPTY_SIGNAL = hsa.hsa_signal_t()
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
class AQLQueue:
def __init__(self, device, sz=-1):
self.device = device
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
self.next_doorbell_index = 0
self.queue_base = self.hw_queue.contents.base_address
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
self.write_addr = self.queue_base
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
self.available_packet_slots = self.hw_queue.contents.size
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
def __del__(self):
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
if self.available_packet_slots == 0: self._wait_queue()
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
packet.workgroup_size_x = local_size[0]
packet.workgroup_size_y = local_size[1]
packet.workgroup_size_z = local_size[2]
packet.reserved0 = 0
packet.grid_size_x = global_size[0] * local_size[0]
packet.grid_size_y = global_size[1] * local_size[1]
packet.grid_size_z = global_size[2] * local_size[2]
packet.private_segment_size = prg.private_segment_size
packet.group_segment_size = prg.group_segment_size
packet.kernel_object = prg.handle
packet.kernarg_address = kernargs
packet.reserved2 = 0
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
packet.setup = DISPATCH_KERNEL_SETUP
packet.header = DISPATCH_KERNEL_HEADER
self._submit_packet()
def submit_barrier(self, wait_signals=None, completion_signal=None):
assert wait_signals is None or len(wait_signals) <= 5
if self.available_packet_slots == 0: self._wait_queue()
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
packet.reserved0 = 0
packet.reserved1 = 0
for i in range(5):
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
packet.reserved2 = 0
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
packet.header = BARRIER_HEADER
self._submit_packet()
def blit_packets(self, packet_addr, packet_cnt):
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
rem_packet_cnt = packet_cnt - tail_blit_packets
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
self._submit_packet(packet_cnt)
def wait(self):
self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
def _wait_queue(self, need_packets=1):
while self.available_packet_slots < need_packets:
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
def _submit_packet(self, cnt=1):
self.available_packet_slots -= cnt
self.next_doorbell_index += cnt
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
self.write_addr += AQL_PACKET_SIZE * cnt
if self.write_addr > self.write_addr_end:
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
def scan_agents():
agents = collections.defaultdict(list)
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
def __scan_agents(agent, data):
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
if status == 0: agents[device_type.value].append(agent)
return hsa.HSA_STATUS_SUCCESS
hsa.hsa_iterate_agents(__scan_agents, None)
return agents
def find_memory_pool(agent, segtyp=-1, location=-1):
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
def __filter_amd_memory_pools(mem_pool, data):
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
ret[0] = mem_pool
return hsa.HSA_STATUS_INFO_BREAK
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
return region

View File

@@ -1,171 +0,0 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Tuple
from tinygrad.helpers import init_c_var, round_up
from tinygrad.device import Buffer, BufferSpec
from tinygrad.device import Compiled, Device
from tinygrad.uop.ops import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner, GraphException
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.runtime.support.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
class VirtAQLQueue(AQLQueue):
def __init__(self, device, sz):
self.device = device
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
self.packets_count = 0
self.available_packet_slots = sz
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
def _submit_packet(self):
self.write_addr += AQL_PACKET_SIZE
self.packets_count += 1
self.available_packet_slots -= 1
class HSAGraph(MultiGraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
compiled_devices = set()
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.dev)
elif isinstance(ji.prg, BufferXfer):
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
else: raise GraphException
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16)
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferSpec()) for dev,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
self.ji_kargs_structs[j] = ji.prg._prg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev])
kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16)
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i].expr])
# Build queues.
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
self.packets = {}
self.transfers = []
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
self.signals_to_reset: List[hsa.hsa_signal_t] = []
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
# Special packet to wait for the world.
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False)
for i in range(0, len(wait_signals), 5):
self.virt_aql_queues[ji.prg.dev].submit_barrier(wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.dev].write_addr)
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
self.virt_aql_queues[ji.prg.dev].submit_kernel(ji.prg._prg, *ji.prg.p.launch_dims(var_vals), #type:ignore
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
if PROFILE: self.profile_info[ji.prg.dev].append((sync_signal, ji.prg._prg.name, False))
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
wait_signals = self.access_resources([dest, src], write=[0], new_dependency=sync_signal, sync_with_aql_packets=True)
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
self.ji_to_transfer[j] = len(self.transfers) - 1
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
# Wait for all active signals to finish the graph
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
for dev in self.signals_to_devices[v.handle]:
wait_signals_to_finish[dev].append(v)
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
for dev in self.devices:
wait_signals = wait_signals_to_finish[dev]
for i in range(0, max(1, len(wait_signals)), 5):
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
# Zero signals to allow graph to start and execute.
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[str, int], wait=False) -> Optional[float]:
# Wait and restore signals
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
if j in self.ji_kargs_structs:
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
else:
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
# Update var_vals
for j in self.jc_idx_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v.expr])
# Update launch dims
for j in self.jc_idx_with_updatable_launch_dims:
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
self.packets[j].workgroup_size_x = lc[0]
self.packets[j].workgroup_size_y = lc[1]
self.packets[j].workgroup_size_z = lc[2]
self.packets[j].grid_size_x = gl[0] * lc[0]
self.packets[j].grid_size_y = gl[1] * lc[1]
self.packets[j].grid_size_z = gl[2] * lc[2]
for dev in self.devices:
dev.flush_hdp()
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
for transfer_data in self.transfers:
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
et = None
if wait:
st = time.perf_counter()
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
et = time.perf_counter() - st
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
return et
def alloc_signal(self, reset_on_start=False, wait_on=None):
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
if reset_on_start: self.signals_to_reset.append(sync_signal)
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
return sync_signal
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
if isinstance(dep, hsa.hsa_signal_t): return dep
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
return packet.completion_signal
return None
def access_resources(self, rawbufs, write, new_dependency, sync_with_aql_packets=False):
rdeps = self._access_resources(rawbufs, write, new_dependency)
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in rawbufs]
return dedup_signals(wait_signals)

View File

@@ -1,275 +0,0 @@
from __future__ import annotations
import ctypes, functools, subprocess, io, atexit, collections, json
from typing import Tuple, TypeVar, List, Dict, Any
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv, PROFILE
from tinygrad.device import Compiled, Compiler, CompileError, BufferSpec, LRUAllocator
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.support.hsa import check, scan_agents, find_memory_pool, AQLQueue
from tinygrad.runtime.support.hip_comgr import compile_hip
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401
class HSAProfiler:
def __init__(self):
self.tracked_signals = collections.defaultdict(list)
self.collected_events: List[Tuple[Any, ...]] = []
self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t()
self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t()
def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy))
def process(self, device):
# Process all tracked signals, should be called before any of tracked signals are reused.
for sig,name,is_copy in self.tracked_signals[device]:
if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings := self.copy_timings)))
else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore
self.collected_events.append((device.device_id, 1 if is_copy else 0, name, timings.start, timings.end))
self.tracked_signals.pop(device)
def save(self, path):
mjson = []
for i in range(len(HSADevice.devices)):
mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}})
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}})
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}})
for dev_id,queue_id,name,st,et in self.collected_events:
mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3})
mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3})
with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson}))
print(f"Saved HSA profile to {path}")
Profiler = HSAProfiler()
class HSACompiler(Compiler):
def __init__(self, arch:str):
self.arch = arch
super().__init__(f"compile_hip_{self.arch}")
def compile(self, src:str) -> bytes:
try: return compile_hip(src, self.arch)
except RuntimeError as e: raise CompileError(e)
class HSAProgram:
def __init__(self, device:HSADevice, name:str, lib:bytes):
self.device, self.name, self.lib = device, name, lib
if DEBUG >= 6:
asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501
self.code_reader = init_c_var(hsa.hsa_code_object_reader_t(),
lambda x: check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(x))))
check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, self.code_reader, None, None))
check(hsa.hsa_executable_freeze(self.exec, None))
self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501
self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501
self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
def __del__(self):
self.device.synchronize()
if hasattr(self, 'code_reader'): check(hsa.hsa_code_object_reader_destroy(self.code_reader))
if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec))
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if not hasattr(self, "args_struct_t"):
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
kernargs = None
if self.kernargs_segment_size > 0:
kernargs = self.device.alloc_kernargs(self.kernargs_segment_size)
args_st = self.args_struct_t.from_address(kernargs)
for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i])
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
self.device.flush_hdp()
signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None
self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal)
if PROFILE: Profiler.track(signal, self.device, self.name)
if wait:
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
return (timings.end - timings.start) * self.device.clocks_to_time
T = TypeVar("T")
CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
class HSAAllocator(LRUAllocator):
def __init__(self, device:HSADevice):
self.device = device
super().__init__()
def _alloc(self, size:int, options:BufferSpec):
if options.host:
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
return mem.value
c_agents = (hsa.hsa_agent_t * len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]))(*HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU])
check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf))
return buf.value
def _free(self, opaque:T, options:BufferSpec):
HSADevice.synchronize_system()
check(hsa.hsa_amd_memory_pool_free(opaque))
def _copyin(self, dest:T, src: memoryview):
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
mem = self._alloc(src.nbytes, BufferSpec(host=True))
ctypes.memmove(mem, from_mv(src), src.nbytes)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
self.device.hw_queue.submit_barrier([copy_signal])
self.device.delayed_free.append(mem)
if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
def copy_from_fd(self, dest, fd, offset, size):
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
if not hasattr(self, 'hb'):
self.hb = [self._alloc(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)]
self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
self.hb_polarity = 0
self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0)
fo = io.FileIO(fd, "a+b", closefd=False)
fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
copies_called = 0
copied_in = 0
for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
copy_size = min(local_size-minor_offset, size-copied_in)
if copy_size == 0: break
hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused
self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False)
fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent,
copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity],
self.sdma[self.hb_polarity], True))
copied_in += copy_size
self.hb_polarity = (self.hb_polarity + 1) % len(self.hb)
minor_offset = 0 # only on the first
copies_called += 1
wait_signals = [self.hb_signals[self.hb_polarity - 1]]
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
self.device.hw_queue.submit_barrier(wait_signals)
def _copyout(self, dest:memoryview, src:T):
HSADevice.synchronize_system()
copy_signal = self.device.alloc_signal(reusable=True)
c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent)
check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True))
dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True))
c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal,
copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True))
src_dev.hw_queue.submit_barrier([copy_signal])
dest_dev.hw_queue.submit_barrier([copy_signal])
if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
class HSADevice(Compiled):
devices: List[HSADevice] = []
agents: Dict[int, List[hsa.hsa_agent_t]] = {}
cpu_agent: hsa.hsa_agent_t
cpu_mempool: hsa.hsa_amd_memory_pool_t
def __init__(self, device:str=""):
if not HSADevice.agents:
check(hsa.hsa_init())
atexit.register(hsa_terminate)
HSADevice.agents = scan_agents()
HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0]
HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1))
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id]
self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
self.hw_queue = AQLQueue(self)
HSADevice.devices.append(self)
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
self.arch = ctypes.string_at(agent_name_buf).decode()
check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64())))
self.clocks_to_time: float = 1 / gpu_freq.value
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AMD_AGENT_INFO_HDP_FLUSH, ctypes.byref(hdp_flush := hsa.hsa_amd_hdp_flush_t())))
self.hdp_flush = hdp_flush
self.delayed_free: List[int] = []
self.reusable_signals: List[hsa.hsa_signal_t] = []
from tinygrad.runtime.graph.hsa import HSAGraph
super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
# Finish init: preallocate some signals + space for kernargs
self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)]
self._new_kernargs_region(16 << 20) # initial region size is 16mb
def synchronize(self):
self.hw_queue.wait()
for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1)
self.signal_pool.extend(self.reusable_signals)
self.reusable_signals.clear()
for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free))
self.delayed_free.clear()
self.kernarg_next_addr = self.kernarg_start_addr
Profiler.process(self)
@staticmethod
def synchronize_system():
for d in HSADevice.devices: d.synchronize()
def alloc_signal(self, reusable=False):
if len(self.signal_pool): signal = self.signal_pool.pop()
else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
# reusable means a signal could be reused after synchronize for the device it's allocated from is called.
if reusable: self.reusable_signals.append(signal)
return signal
def alloc_kernargs(self, sz):
if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2))
result = self.kernarg_next_addr
self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16)
return result
def _new_kernargs_region(self, sz:int):
if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr)
self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferSpec())
self.kernarg_next_addr = self.kernarg_start_addr
self.kernarg_pool_sz: int = sz
def flush_hdp(self): self.hdp_flush.HDP_MEM_FLUSH_CNTL[0] = 1
def hsa_terminate():
# Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs.
for dev in HSADevice.devices:
Profiler.process(dev)
del dev.hw_queue
# hsa_shut_down cleans up all hsa-related resources.
hsa.hsa_shut_down()
HSADevice.synchronize = lambda: None #type:ignore
HSAProgram.__del__ = lambda _: None #type:ignore
if Profiler.collected_events: Profiler.save("/tmp/profile.json")

View File

@@ -1,127 +0,0 @@
from typing import Dict, Set
import yaml
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
from tinygrad.uop.ops import BinaryOps
from tinygrad.dtype import dtypes
def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
# pointer indexing
if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_at=uops.uops.index(u))
ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_at=uops.uops.index(u))
u.vin = (u.vin[0], ptr) + u.vin[2:]
#uops.print()
args = []
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
r: Dict[UOp, str] = {}
for u in uops:
if u.uop == UOps.SPECIAL:
if u.arg.startswith("lidx"):
r[u] = f'v{u.src[0].arg}'
elif u.arg.startswith("gidx"):
r[u] = f's{2+u.src[0].arg}'
else:
raise NotImplementedError
elif u.uop == UOps.CONST:
#r[u] = u.arg
# TODO: sometimes we can use s
#r[u] = f"s{s_cnt}"
#s_cnt += 1
#ins.append(f"s_mov_b32 {r[u]}, {u.arg}")
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"v_mov_b32 {r[u]}, {u.arg}")
elif u.uop == UOps.ALU:
if u.arg == BinaryOps.ADD:
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"v_add_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
elif u.arg == BinaryOps.MUL:
r[u] = f"v{v_cnt}"
v_cnt += 1
if dtypes.is_float(u.dtype):
ins.append(f"v_mul_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
else:
ins.append(f"v_mul_u32_u24 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
else:
raise NotImplementedError
elif u.uop == UOps.LOAD:
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"global_load_b32 {r[u]}, {r[u.vin[1]]}, {r[u.vin[0]]}")
ins.append("s_waitcnt vmcnt(0)")
elif u.uop == UOps.STORE:
ins.append(f"global_store_b32 {r[u.vin[1]]}, {r[u.vin[2]]}, {r[u.vin[0]]}")
elif u.uop == UOps.DEFINE_GLOBAL:
i = u.arg[0]
args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8,
'.type_name': u.dtype.name+"*", '.value_kind': 'global_buffer'})
s_cnt += s_cnt%2 # skip
r[u] = f"s[{s_cnt}:{s_cnt+1}]"
s_cnt += 2
ins.append(f"s_load_b64 {r[u]}, s[0:1], {i*8}")
ins.append("s_waitcnt lgkmcnt(0)")
else:
raise NotImplementedError(f"can't render {u.uop}")
# *** boilerplate rendering ***
metadata = {
'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': function_name, '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': f'{function_name}.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
boilerplate_start = f"""
.rodata
.global {function_name}.kd
.type {function_name}.kd,STT_OBJECT
.align 0x10
.amdhsa_kernel {function_name}"""
kernel_desc = {
'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3,
'.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, '.amdhsa_fp16_overflow': 0,
'.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0,
'.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0,
'.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
code_start = f""".end_amdhsa_kernel
.text
.global {function_name}
.type {function_name},@function
.p2align 8
{function_name}:
"""
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
return ".amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + \
boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + \
'\n'.join(ins) + f"\n.size {function_name}, .-{function_name}"

View File

@@ -1,131 +0,0 @@
from typing import Dict, List, Final, Callable, DefaultDict
from collections import defaultdict
from tinygrad.uop.ops import UnaryOps, BinaryOps, TernaryOps, Op
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.codegen.opt.kernel import UOp, Ops
from triton.compiler import compile as triton_compile
import linecache
import math
import re
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
def next_power_of_2(x):
return 1 << (x - 1).bit_length()
def render_valid(valid):
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
def fill_dims_for_idx(idx, dims):
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
def get_max(var):
if isinstance(var, int): return var
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
def remove_single_scalar_curly_braces(ptx_code):
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
def render_const(args,dtype:DType):
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
def render_cast(x:str, dtype:DType, bitcast=False):
return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
def define_scalar(local_size, dtype, args):
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
return render_const(args,dtype)
def uops_to_triton(function_name:str, uops:List[UOp]):
local_size: List[int] = []
depth = 1
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
nonlocal c, r
c[prefix] += 1
r[u]=f"{prefix}{c[prefix]-1}"
return r[u]
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
def kk(s): kernel.append(" "*depth+s)
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}",
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
}
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop == Ops.LOOP:
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
depth += 1
elif uop == Ops.END: depth -= 1
elif uop == Ops.ALU:
assert dtype is not None
val = code_for_op[args](*[r[x] for x in vin])
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
else: kk(f"{ssa(u, 'alu')} = ({val})")
elif uop == Ops.LOAD:
assert dtype is not None
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
elif uop == Ops.DEFINE_REG: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args)
elif uop == Ops.ASSIGN:
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
r[u] = r[vin[0]]
elif uop == Ops.STORE:
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == Ops.DEFINE_GLOBAL:
bufs.append(args)
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
r[u] = args
elif uop == Ops.SPECIAL:
dims.append(args[1])
valid.append(f"{args[1]}<{get_max(args[2])}")
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
elif args[1].startswith("l"):
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
local_size.append(args[2])
r[u] = args[1]
elif uop == Ops.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
else: raise NotImplementedError(f"unimplemented: {uop}")
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
prg += "\n".join(kernel)
acc_local_size = 1
for x in local_size: acc_local_size *= next_power_of_2(x)
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
if DEBUG >= 4: print(prg)
getlines = linecache.getlines
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}

View File

@@ -1,22 +0,0 @@
import ctypes
import os
import pathlib
import struct
from hexdump import hexdump
fxn = None
def disasm_raw(buf):
global fxn
if fxn is None:
shared = pathlib.Path(__file__).parent / "disasm.so"
if not shared.is_file():
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
fxn = ctypes.CDLL(shared.as_posix())['disasm']
fxn(buf, len(buf))
def disasm(buf):
def _read_lib(off): return struct.unpack("I", buf[off:off+4])[0]
image_offset = _read_lib(0xc0)
image_size = _read_lib(0x100)
disasm_raw(buf[image_offset:image_offset+image_size])

View File

@@ -1,120 +0,0 @@
#!/usr/bin/env python3
import os, ctypes, ctypes.util, io, mmap, pathlib
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import Timing, from_mv
libc = ctypes.CDLL(ctypes.util.find_library("c"))
#from extra.hip_gpu_driver import hip_ioctl
# sudo su -c "echo 3 > /proc/sys/vm/drop_caches"
# sudo su -c 'echo 8 > /proc/sys/kernel/printk'
# sudo su -c "echo 'module amdgpu +p' > /sys/kernel/debug/dynamic_debug/control"
libc.memcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
libc.read.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t]
libc.read.restype = ctypes.c_size_t
libc.malloc.argtypes = [ctypes.c_size_t]
libc.malloc.restype = ctypes.c_void_p
def read_direct(fd, sz):
with Timing("mmap: ", lambda x: f", {sz/x:.2f} GB/s"):
buf = mmap.mmap(-1, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE)
with Timing("read: ", lambda x: f", {sz/x:.2f} GB/s"):
ret = libc.read(fd, from_mv(buf), sz)
assert ret == sz
def read_mmap(fd, sz):
with Timing("mmfd: ", lambda x: f", {sz/x:.2f} GB/s"):
buf = mmap.mmap(fd, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) #|MAP_LOCKED)
t = 0
for i in range(0, sz, 0x1000): t += buf[i]
# def _copyin_async(self, dest:T, src:T, size:int): check(hip.hipMemcpyAsync(dest, src, size, hip.hipMemcpyHostToDevice, None))
def read_to_gpu_mmap(fd, sz, gpubuf):
with Timing("gpu copyin: ", lambda x: f", {sz/x:.2f} GB/s"):
with Timing("mmfd: ", lambda x: f", {sz/x:.2f} GB/s"):
buf = mmap.mmap(fd, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) #|MAP_LOCKED)
dev.allocator._copyin_async(gpubuf, from_mv(buf), sz)
dev.synchronize()
def read_to_gpu_single(fd, sz, gpubuf):
os.lseek(fd, 0, os.SEEK_SET)
with Timing("total: ", lambda x: f", {sz/x:.2f} GB/s"):
with Timing("gpu host alloc: ", lambda x: f", {sz/x:.2f} GB/s"):
hst = dev.allocator._hostalloc(sz)
with Timing("read to host: ", lambda x: f", {sz/x:.2f} GB/s"):
ret = libc.read(fd, hst, sz)
with Timing("gpu host copy: ", lambda x: f", {sz/x:.2f} GB/s"):
dev.allocator._copyin_async(gpubuf, hst, sz)
dev.synchronize()
def read_to_gpu_pingpong(fd, sz, gpubuf):
psz = 256*1024*1024
print(f"piece size {psz/(1024*1024):.2f} MB")
with Timing("gpu host alloc: ", lambda x: f", {sz/x:.2f} GB/s"):
hst1 = dev.allocator._hostalloc(psz)
hst2 = dev.allocator._hostalloc(psz)
os.lseek(fd, 0, os.SEEK_SET)
with Timing("total: ", lambda x: f", {sz/x:.2f} GB/s"):
for i in range(sz//(psz*2)):
with Timing("tfer(0): ", lambda x: f", {psz/x:.2f} GB/s"):
ret = libc.read(fd, hst1, psz)
dev.synchronize()
dev.allocator._copyin_async(gpubuf, hst1, psz)
with Timing("tfer(1): ", lambda x: f", {psz/x:.2f} GB/s"):
ret = libc.read(fd, hst2, psz)
dev.synchronize()
dev.allocator._copyin_async(gpubuf, hst2, psz)
dev.synchronize()
MAP_LOCKED = 0x2000
MAP_HUGETLB = 0x40000
if __name__ == "__main__":
dev = Device[Device.DEFAULT]
warm = (Tensor.ones(1024, device=Device.DEFAULT).contiguous() + Tensor.ones(1024, device=Device.DEFAULT).contiguous()).realize()
#fn = "/home/tiny/tinygrad/weights/rng"
fn = pathlib.Path(__file__).parents[1] / "weights/LLaMA-2/70B/consolidated.00.pth"
sz = os.stat(fn).st_size
t = Tensor.empty(sz, dtype=dtypes.uint8, device=f"disk:{fn}")
with Timing("copy: ", lambda x: f", {sz/x:.2f} GB/s"):
on_dev = t.to(Device.DEFAULT).realize()
exit(0)
# 4GB of random numbers
#fd = os.open("/home/tiny/tinygrad/weights/rng", os.O_RDWR|os.O_DIRECT)
#sz = os.fstat(fd).st_size // 4
fd = os.open("/home/tiny/tinygrad/weights/LLaMA/7B/consolidated.00.pth", os.O_RDWR|os.O_DIRECT)
sz = os.fstat(fd).st_size
print(f"read {sz} from {fd}")
with Timing("gpu alloc: ", lambda x: f", {sz/x:.2f} GB/s"):
gpubuf = dev.allocator._alloc(sz)
# warmup
dev.allocator._copyin_async(gpubuf, from_mv(bytearray(b"\x00\x00\x00\x00"*0x1000)), 0x4000)
print("copying, is warm")
print("****** read to gpu pingpong")
read_to_gpu_pingpong(fd, sz, gpubuf)
exit(0)
print("****** read direct")
read_direct(fd, sz)
print("****** read mmap")
read_mmap(fd, sz)
print("****** read to gpu single")
read_to_gpu_single(fd, sz, gpubuf)
print("****** read to gpu mmap")
read_to_gpu_mmap(fd, sz, gpubuf)
os._exit(0)

View File

@@ -1,21 +0,0 @@
import sys, sqlite3, pickle
from tinygrad.helpers import CACHEDB
if __name__ == "__main__":
fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB
conn = sqlite3.connect(fn)
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
for f in cur.fetchall():
table = f[0]
cur2 = conn.cursor()
cur2.execute(f"SELECT COUNT(*) FROM {table}")
cnt = cur2.fetchone()[0]
print(f"{table:20s} : {cnt}")
cur3 = conn.cursor()
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
for f in cur3.fetchall():
v = pickle.loads(f[-1])
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
#print(f"{len(k):10d}, {sk} -> {v}")

View File

@@ -1,27 +0,0 @@
#!/usr/bin/env python3
import time
import jax
import jax.numpy as jnp
print(jax.devices())
DEVICES = len(jax.devices())
BS = 32
N = 4096
dtype = jnp.float16
A = jnp.zeros((DEVICES, BS, N, N), dtype)
B = jnp.zeros((1, 1, N, N), dtype)
A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
OPS = DEVICES*BS*N*N*N*2
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
pmatmul = jax.pmap(matmul)
MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
for i in range(10):
st = time.perf_counter()
C = pmatmul(A,B).block_until_ready()
et = time.perf_counter()-st
tflops = (OPS*1e-12)/et
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")

View File

@@ -1,10 +0,0 @@
import mlx.core as mx
from tinygrad.helpers import Timing
N = 4096
x = mx.random.normal((N,N))
w = mx.random.normal((N,N))
FLOPS = N*N*N*2
for i in range(10):
with Timing("", lambda x: f" {FLOPS/x:.2f} GFLOPS"):
mx.eval(x@w)

View File

@@ -1,33 +0,0 @@
import time
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
for dtype in [tf.float16, tf.float32]:
for N in [256, 512, 1024, 2048, 4096, 8192]:
FLOPS = N*N*N*2
b = tf.random.uniform((N, N), dtype=dtype)
c = tf.random.uniform((N, N), dtype=dtype)
b = tf.Variable(b)
c = tf.Variable(c)
def tf_prog(b, c):
st = time.perf_counter()
a = tf.matmul(b, c)
tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done.
return time.perf_counter() - st
tm = min([tf_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")

View File

@@ -1,12 +0,0 @@
import ctypes
import tinygrad.runtime.autogen.hip as hip
from tinygrad.runtime.ops_hip import check
from tinygrad.helpers import init_c_var
if __name__ == "__main__":
check(hip.hipSetDevice(0))
evt = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x))))
check(hip.hipSetDevice(1))
check(hip.hipStreamWaitEvent(None, evt, 0))
check(hip.hipSetDevice(0))
check(hip.hipEventRecord(evt, None))

View File

@@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: sentencepiece_model.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'H\003'
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
_globals['_TRAINERSPEC']._serialized_start=45
_globals['_TRAINERSPEC']._serialized_end=1581
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
_globals['_NORMALIZERSPEC']._serialized_start=1584
_globals['_NORMALIZERSPEC']._serialized_end=1793
_globals['_SELFTESTDATA']._serialized_start=1795
_globals['_SELFTESTDATA']._serialized_end=1916
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
_globals['_MODELPROTO']._serialized_start=1919
_globals['_MODELPROTO']._serialized_end=2429
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
# @@protoc_insertion_point(module_scope)

View File

@@ -1,176 +0,0 @@
from __future__ import annotations
from typing import List, Optional, Dict, cast
import numpy as np
np.set_printoptions(suppress=True)
import math, functools, time, random, statistics
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
from tinygrad.codegen.opt.kernel import Kernel
from tinygrad.device import Buffer, Device, CompileError
from tinygrad.codegen.opt.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
from tinygrad.engine.realize import get_program
class MCTSNode:
def __init__(self, kernel:Kernel, parent=None):
self.kernel:Kernel = kernel
self.t = math.inf
self.n = 0
self.tm = math.inf
self.i = -1
self.parents: List[MCTSNode] = [parent] if parent is not None else []
self.children: Optional[List[MCTSNode]] = None
self.removed_children: List[MCTSNode] = []
def expand_node(node:MCTSNode):
assert node.children is None
node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
def remove_node(node:MCTSNode):
for parent in node.parents:
assert parent.children is not None
parent.children.remove(node)
parent.removed_children.append(node)
C = math.sqrt(2)
TEMP = 0.5
def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
if node.children is None or len(node.children) == 0: return node
unexplored_children = []
explored_children = []
ucb_explored_children: List[float] = []
for child in node.children:
if child.n == 0: unexplored_children.append(child)
else:
ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
if not math.isinf(ucb):
explored_children.append(child)
ucb_explored_children.append(ucb)
if len(unexplored_children): return random.choice(unexplored_children)
if not len(explored_children): return node
# safe softmax
ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP)
return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
# this will expand/remove sometimes
def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]:
if root.children is None: expand_node(root)
while root.children:
# tree traversal
node = _sample_tree(root, best_tm)
if node.children is not None and len(node.children) == 0:
remove_node(node)
continue
# node expansion
if node.n != 0:
if node.children is None: expand_node(node)
assert node.children is not None
if len(node.children) == 0:
remove_node(node)
continue
node = random.choice(node.children)
return node
return None
def backprop(bnode:MCTSNode, tm, strength=1.0):
if bnode.t > tm: bnode.t = tm
bnode.n += strength
for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents))
graph_mcts_cnt = 0
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
global graph_mcts_cnt
# TODO: copied from BEAM
key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
dev = Device[lin.opts.device]
root = MCTSNode(lin)
st = time.perf_counter()
best, best_idx, best_tm = lin, 0, math.inf
seen_libs: Dict[bytes, MCTSNode] = {}
seen_asts: Dict[bytes, MCTSNode] = {}
compile_time, runtime_time = 0.0, 0.0
for i in range(amt):
node = sample_tree(root, best_tm) # sample and expand
if node is None: break # finished the whole tree
node.i = i # when was node explored
opt_ast = node.kernel.get_optimized_ast()
if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
# early check for same optimized AST hit
remove_node(node)
tm = sibling_node.t
else:
seen_asts[opt_ast.key] = node
# lowering (50% of the time)
p = get_program(node.kernel.get_optimized_ast(name_override="test"), node.kernel.opts)
# rollout
tm1 = time.perf_counter()
try:
lib = dev.compiler.compile(p.src)
except CompileError:
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
lib = None
tm2 = time.perf_counter()
if lib is None:
tm = math.inf
else:
if (sibling_node:=seen_libs.get(lib, None)) is not None:
# NOTE: these should all be caught by the AST check, need to canonicalize
# remove this node, it's a duplicate
remove_node(node)
tm = sibling_node.t
else:
seen_libs[lib] = node
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
except RuntimeError: tm = math.inf
node.tm = tm
tm3 = time.perf_counter()
compile_time += tm2-tm1
runtime_time += tm3-tm2
# mock rollout
#node.tm = tm = random.random() + 0.1
if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
et = time.perf_counter() - st
if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
# backprop
backprop(node, tm)
if DEBUG>=2: print()
if getenv("MCTSGRAPH"):
import networkx as nx
import os
GRAPHPATH = "/tmp/net"
def save_graph(G, fn, opt=""):
print("saving", G, f"to {fn}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
G = nx.DiGraph()
def add_node(node:MCTSNode):
if node.n == 0: return
for parent in node.parents: G.add_edge(parent, node)
gopts = node.kernel.applied_opts
edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT"
G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
if node.children is not None:
for child in node.children+node.removed_children: add_node(child)
add_node(root)
save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR')
graph_mcts_cnt += 1
if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
return best

View File

@@ -1,75 +0,0 @@
import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context, Tensor, GlobalCounters
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item, get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
import numpy as np
def move_jit_captured_to_dev(captured, device="DSP"):
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
assign = {}
def move_buffer(b):
if b in assign: return assign[b]
if b._base is not None:
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
else:
newbuf = Buffer(device, b.size, b.dtype)
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
assign[b] = newbuf
return assign[b]
for item in captured.jit_cache:
for b in item.bufs:
if b is not None: move_buffer(b)
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
return captured
if __name__ == "__main__":
with Context(DEBUG=0):
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
# Move all buffers to DSP device.
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
new_jit = []
knum = 1
for ei in fxn.captured.jit_cache:
# skip the copy and the first kernel
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
if getenv("VALIDATE"):
with Context(NOOPT=1):
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
correct = ei.bufs[0].numpy()
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
GlobalCounters.kernel_count -= 1
if not getenv("NOOPT"): k.apply_opts(hand_coded_optimizations(k))
p2 = get_program(k.ast, k.opts, k.applied_opts)
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei.run()
new_jit.append(new_ei)
test = ei.bufs[0].numpy()
if getenv("VALIDATE"):
import numpy as np
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
knum += 1
if getenv("RUN_JIT", 0):
fxn.captured.free_intermediates()
fxn.captured.jit_cache = new_jit
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))

View File

@@ -1,114 +0,0 @@
# code from https://x.com/awnihannun/status/1832511021602500796
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
import time
class Block(nn.Module):
def __init__(self, in_dims, dims, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm(dims)
self.conv2 = nn.Conv2d(
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm(dims)
self.downsample = []
if stride != 1:
self.downsample = [
nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm(dims)
]
def __call__(self, x):
out = nn.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
for l in self.downsample:
x = l(x)
out += x
out = nn.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(in_dims, dims, stride))
in_dims = dims
return layers
def __call__(self, x):
x = nn.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
for l in self.layer1 + self.layer2 + self.layer3 + self.layer4:
x = l(x)
x = mx.mean(x, axis=[1, 2])
x = self.fc(x)
return x
def load():
model = ResNet(Block, [2, 2, 2, 2], num_classes=1000)
file = "model.safetensors"
model_path = snapshot_download(
repo_id="awni/resnet18-mlx",
allow_patterns=[file],
)
model.load_weights(model_path + "/" + file)
model.eval()
mx.eval(model)
return model
if __name__ == "__main__":
resnet18 = load()
@mx.compile
def forward(im):
return resnet18(im)
batch_sizes = [1, 2, 4, 8, 16, 32, 64]
#its = 200
#batch_sizes = [64]
its = 20
print(f"Batch Size | Images-per-second | Milliseconds-per-image")
print(f"---- | ---- | ---- ")
for N in batch_sizes:
image = mx.random.uniform(shape=(N, 288, 288, 3))
# Warmup
for _ in range(5):
output = forward(image)
mx.eval(output)
tic = time.time()
for _ in range(its):
output = forward(image)
mx.async_eval(output)
mx.eval(output)
toc = time.time()
ims_per_sec = N * its / (toc - tic)
ms_per_im = 1e3 / ims_per_sec
print(f"{N} | {ims_per_sec:.3f} | {ms_per_im:.3f}")

View File

@@ -1,109 +0,0 @@
from huggingface_hub import snapshot_download
from tinygrad import nn, Tensor, TinyJit, Device, GlobalCounters, Context
import time
class Block:
def __init__(self, in_dims, dims, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm(dims)
self.conv2 = nn.Conv2d(
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm(dims)
self.downsample = []
if stride != 1:
self.downsample = [
nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm(dims)
]
def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.bn2(self.conv2(out))
for l in self.downsample:
x = l(x)
out += x
return out.relu()
class ResNet:
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm(64)
self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(in_dims, dims, stride))
in_dims = dims
return layers
def __call__(self, x:Tensor):
x = self.bn1(self.conv1(x)).relu().max_pool2d()
x = x.sequential(self.layer1)
with Context(WINO=1): x = x.sequential(self.layer2 + self.layer3 + self.layer4)
x = x.mean([2, 3])
x = self.fc(x)
return x
def load():
model = ResNet(Block, [2, 2, 2, 2], num_classes=1000)
file = "model.safetensors"
model_path = snapshot_download(
repo_id="awni/resnet18-mlx",
allow_patterns=[file],
)
state = nn.state.safe_load(model_path + "/" + file)
# mlx is NHWC, tinygrad is NCHW
nn.state.load_state_dict(model, {k:v if len(v.shape) != 4 else v.to(None).permute(0,3,1,2).contiguous() for k,v in state.items()}, strict=False)
return model
if __name__ == "__main__":
resnet18 = load()
def _forward(im): return resnet18(im)
forward = TinyJit(_forward, prune=True)
batch_sizes = [1, 2, 4, 8, 16, 32, 64]
#its = 200
#batch_sizes = [64]
its = 20
print(f"Batch Size | Images-per-second | Milliseconds-per-image")
print(f"---- | ---- | ---- ")
for N in batch_sizes:
forward.reset() # reset the JIT for a new batch size (could make automatic)
image = Tensor.uniform(N, 3, 288, 288)
# Warmup
for _ in range(5):
GlobalCounters.reset()
output = forward(image)
Device.default.synchronize()
tic = time.time()
for _ in range(its):
GlobalCounters.reset()
output = forward(image)
Device.default.synchronize()
toc = time.time()
ims_per_sec = N * its / (toc - tic)
ms_per_im = 1e3 / ims_per_sec
print(f"{N} | {ims_per_sec:.3f} | {ms_per_im:.3f}")

View File

@@ -1,15 +0,0 @@
from tinygrad import Tensor, Device, GlobalCounters
from tinygrad.helpers import Timing
N = 512
GPUS = 5
ds = tuple([f"{Device.DEFAULT}:{i+1}" for i in range(GPUS)])
t = [Tensor.ones(N, N, N, device=d).contiguous().realize() for d in ds]
for _ in range(10):
GlobalCounters.reset()
with Timing():
for ti in t:
ti.to_(ds[(ds.index(ti.device)+1+len(ds))%len(ds)])
# ti.to_(ds[(ds.index(ti.device)-1+len(ds))%len(ds)]) # reversed order
ti.realize()

View File

@@ -1,47 +0,0 @@
import os, pathlib, argparse
from examples.llama3 import Tokenizer
from tabulate import tabulate
from tinygrad import fetch
from tinygrad.helpers import flatten, getenv
from sz import NONCORE_DIRS
# llama 3 tokenizer
tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model").as_posix())
def read_code(base_path, full=False):
ret = []
for path, _, files in os.walk(os.path.join(base_path, "tinygrad")):
if not full and any(path.split("./")[1].startswith(x) for x in NONCORE_DIRS): continue
for name in files:
if not name.endswith(".py"): continue
if 'tinygrad/runtime/autogen' in path.replace('\\', '/'): continue
fullpath = os.path.join(path, name)
code = pathlib.Path(fullpath).read_text()
ret.append((fullpath.split("tinygrad/", 1)[1], code))
return ret
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze and optionally save tinygrad code.")
parser.add_argument("--output", help="Output file to write the combined code to.")
parser.add_argument("--full", action="store_true", help="All directories")
args = parser.parse_args()
ret = read_code(".", args.full)
table = []
for name,code in ret:
table.append([name, len(tokenizer.encode(code))])
print(tabulate([["name", "llm tokens"]]+sorted(table, key=lambda x: -x[1]), headers="firstrow"))
banner = "#"*40
code_str = ''.join([f"{banner}\n# {name}\n{banner}\n\n{code}\n" for name,code in ret])
print(f"code has {len(code_str)} chars")
newline_count = code_str.count('\n')
print(f"code has {newline_count} newlines")
encoded = tokenizer.encode(code_str)
print(f"code has {len(encoded)} tokens")
if args.output:
with open(args.output, 'w') as f: f.write(code_str)
print(f"Combined code written to {args.output}")

View File

@@ -1,16 +0,0 @@
if __name__ == "__main__":
import os
if "DEBUG" not in os.environ: os.environ["DEBUG"] = "2"
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import getenv
if (seed := getenv("SEED", 0)) != 0:
Tensor.manual_seed(seed)
print(f"using seed {Tensor._seed}")
for N in [10_000_000, 100_000_000, 1_000_000_000]:
GlobalCounters.reset()
t = Tensor.rand(N)
t.realize()
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")

View File

@@ -1,154 +0,0 @@
import itertools
from enum import Enum, auto
from collections import defaultdict
from typing import List, Tuple, DefaultDict
from tinygrad.helpers import prod, tqdm
from tinygrad.uop.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.uop.ops import sym_infer
from tinygrad.tensor import Tensor
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
mop, arg = mop_arg
if mop == MovementOps.RESHAPE:
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
if arg == (-1,): return st.reshape((prod(st.shape),))
return st.reshape(arg)
if mop == MovementOps.PERMUTE: return st.permute(arg)
if mop == MovementOps.EXPAND:
if len(arg) != len(st.shape): st = st.reshape((1,*st.shape))
return st.expand(arg)
if mop == MovementOps.PAD: return st.pad(arg)
if mop == MovementOps.SHRINK: return st.shrink(arg)
if mop == MovementOps.STRIDE:
assert all(x in [-1, 1] for x in arg)
return st.flip(tuple(i for i,x in enumerate(arg) if x == -1))
raise ValueError("invalid mop")
def make_scratch_st(st: ShapeTracker) -> ShapeTracker:
return ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),))
# ShapeTracker to an equivalent series of MovementOps (https://github.com/tinygrad/tinygrad/pull/2216)
def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for i, v in enumerate(st.views):
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
offset = (v.offset or 0) + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
if i: buffer_size = prod(st.views[i-1].shape) - real_offset if real_shape else 1
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
ordered_shape_strides, order = sort_by_strides(real_real_shape, strides)
to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))])
if strides:
if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),)))
for i, shape_stride in enumerate(ordered_shape_strides):
if i<len(ordered_shape_strides)-1 and shape_stride[1] < ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
remaining_buffer = ordered_shape_strides[i-1][1] if i>0 else buffer_size
to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer)))
to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1)))
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer)))
to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1]))))
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1])))
ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1])
else:
to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1]))))
to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1])))
to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))])
if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides)))))
to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides))))
if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
scratch_st = make_scratch_st(st)
ret = []
seen = {} # {shapetracker: list of mops to generate that shapetracker}
for mop_arg in to_apply:
scratch_st = apply_mop(scratch_st, mop_arg)
if scratch_st in seen:
ret = seen[scratch_st][:]
else:
if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE:
ret[-1] = mop_arg
else:
if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),))
ret.append(mop_arg)
seen[scratch_st] = ret[:]
return ret
def get_real_view(shape, strides, offset, mask):
real_shape = tuple(y-x for x,y in mask) if mask else shape
offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0)
real_real_shape = [s for s,st in zip(real_shape, strides) if st]
strides = [abs(st) if isinstance(st,int) else st for st in strides if st]
return real_real_shape, strides, real_offset
def get_buffer_size(shape, strides, offset, mask):
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True
idx1, valid1 = idxs1
idx2, valid2 = idxs2
# always invalid
if valid1 == 0 and valid2 == 0: return True
var1 = idx1.vars() | valid1.vars()
var2 = idx2.vars() | valid2.vars()
# Maybe there are cases that vars are different yet the sts are the same?
if var1 != var2: return False
# brute force over the vars range
vs = list(var1)
for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])):
if i > 1000:
print("WARNING: did not search all possible combinations")
break
var_vals = {k.expr:v for k,v in zip(vs, ranges)}
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
if r1 != r2: return False
return True
c: DefaultDict[int,int] = defaultdict(int)
def test_rebuild(st: ShapeTracker):
rebuilt_st = make_scratch_st(st)
mops = to_movement_ops(st)
c[len(mops)] += 1
for mop_arg in mops: rebuilt_st = apply_mop(rebuilt_st, mop_arg)
rebuilt_st = rebuilt_st.simplify()
# why is the "all(x == 0 for x in rebuilt_st.views[-1].strides)" hack needed?
assert st_equivalent(st, rebuilt_st) or all(x == 0 for x in rebuilt_st.views[-1].strides), f"mismatch {st} {rebuilt_st}"
last_v1 = st.views[-1]
last_v2 = rebuilt_st.views[-1]
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
def test_rebuild_bufferop_st(ast:UOp):
if ast.op is Ops.SHAPETRACKER:
test_rebuild(ast.arg)
for src in ast.src: test_rebuild_bufferop_st(src)
if __name__ == "__main__":
from extra.optimization.helpers import load_worlds, ast_str_to_ast
ast_strs = load_worlds(False, False, True)[:2000]
for ast_str in tqdm(ast_strs):
test_rebuild_bufferop_st(ast_str_to_ast(ast_str))
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")

View File

@@ -1,18 +0,0 @@
from tinygrad import Tensor, Device
#N = 1024
N = 32
t = Tensor.rand(N, N, N, device="CPU").realize()
d1 = Device.DEFAULT + ":1"
d2 = Device.DEFAULT + ":2"
d3 = Device.DEFAULT + ":3"
for i in range(3):
t.to_(d1)
t.realize()
# t.to_("CPU")
# t.realize()
t.to_(d2)
t.realize()
t.to_(d3)
t.realize()