From df6cde8a00e110eafead5dea5b5acdc1a0e9065c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 19 Dec 2025 16:27:37 -0400 Subject: [PATCH] cleanup stale examples/extra (#13764) * cleanup stale files * examples * move those back * old * delete more --- examples/coder.py | 93 --- examples/efficientnet.py | 89 --- examples/flux1.py | 498 --------------- examples/mask_rcnn.py | 299 --------- examples/openelm.py | 118 ---- examples/other_mnist/beautiful_mnist_mlx.py | 55 -- examples/rl/lightupbutton.py | 45 -- examples/serious_mnist.py | 136 ---- examples/simple_conv_bn.py | 17 - examples/so_vits_svc.py | 669 -------------------- examples/sovits_helpers/preprocess.py | 204 ------ {extra => examples/tools}/bandwidth_test.py | 0 {extra => examples/tools}/gpuburn.py | 0 examples/train_efficientnet.py | 104 --- examples/vit.py | 46 -- extra/assembly/assembly.py | 189 ------ extra/assembly/assembly_arm64.py | 177 ------ extra/assembly/assembly_ptx.py | 105 --- extra/assembly/assembly_rdna.py | 203 ------ extra/assembly/ptx/test.py | 23 - extra/augment.py | 42 -- extra/backends/clang_graph.py | 39 -- extra/backends/graph_hip.py | 27 - extra/backends/hsa_driver.py | 143 ----- extra/backends/hsa_graph.py | 171 ----- extra/backends/ops_hsa.py | 275 -------- extra/backends/rdna.py | 127 ---- extra/backends/triton.py | 131 ---- extra/disassemblers/adreno/__init__.py | 22 - extra/disk_read_speed.py | 120 ---- extra/dump_cache.py | 21 - extra/gemm/jax_pmatmul.py | 27 - extra/gemm/mlx_matmul.py | 10 - extra/gemm/tf_gemm.py | 33 - extra/hip_events.py | 12 - extra/junk/sentencepiece_model_pb2.py | 45 -- extra/mcts_search.py | 176 ----- extra/replay_pkl.py | 75 --- extra/resnet18/resnet_mlx.py | 114 ---- extra/resnet18/resnet_tinygrad.py | 109 ---- extra/ring_copy.py | 15 - extra/self_tokenize.py | 47 -- extra/threefry.py | 16 - extra/to_movement_ops.py | 154 ----- extra/transfer_speed.py | 18 - 45 files changed, 5039 deletions(-) delete mode 100644 examples/coder.py delete mode 100644 examples/efficientnet.py delete mode 100644 examples/flux1.py delete mode 100644 examples/mask_rcnn.py delete mode 100644 examples/openelm.py delete mode 100644 examples/other_mnist/beautiful_mnist_mlx.py delete mode 100644 examples/rl/lightupbutton.py delete mode 100644 examples/serious_mnist.py delete mode 100644 examples/simple_conv_bn.py delete mode 100644 examples/so_vits_svc.py delete mode 100644 examples/sovits_helpers/preprocess.py rename {extra => examples/tools}/bandwidth_test.py (100%) rename {extra => examples/tools}/gpuburn.py (100%) delete mode 100644 examples/train_efficientnet.py delete mode 100644 examples/vit.py delete mode 100644 extra/assembly/assembly.py delete mode 100644 extra/assembly/assembly_arm64.py delete mode 100644 extra/assembly/assembly_ptx.py delete mode 100644 extra/assembly/assembly_rdna.py delete mode 100644 extra/assembly/ptx/test.py delete mode 100644 extra/augment.py delete mode 100644 extra/backends/clang_graph.py delete mode 100644 extra/backends/graph_hip.py delete mode 100644 extra/backends/hsa_driver.py delete mode 100644 extra/backends/hsa_graph.py delete mode 100644 extra/backends/ops_hsa.py delete mode 100644 extra/backends/rdna.py delete mode 100644 extra/backends/triton.py delete mode 100644 extra/disassemblers/adreno/__init__.py delete mode 100644 extra/disk_read_speed.py delete mode 100644 extra/dump_cache.py delete mode 100755 extra/gemm/jax_pmatmul.py delete mode 100644 extra/gemm/mlx_matmul.py delete mode 100644 extra/gemm/tf_gemm.py delete mode 100644 extra/hip_events.py delete mode 100644 extra/junk/sentencepiece_model_pb2.py delete mode 100644 extra/mcts_search.py delete mode 100644 extra/replay_pkl.py delete mode 100644 extra/resnet18/resnet_mlx.py delete mode 100644 extra/resnet18/resnet_tinygrad.py delete mode 100644 extra/ring_copy.py delete mode 100644 extra/self_tokenize.py delete mode 100644 extra/threefry.py delete mode 100644 extra/to_movement_ops.py delete mode 100644 extra/transfer_speed.py diff --git a/examples/coder.py b/examples/coder.py deleted file mode 100644 index 8a8d3c5b78..0000000000 --- a/examples/coder.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -import os, sys, traceback -sys.path.append(os.getcwd()) - -from io import StringIO -from contextlib import redirect_stdout -from tinygrad import Tensor, nn -from tinygrad.helpers import Timing, colored, getenv, fetch -from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 -from sentencepiece import SentencePieceProcessor - -def create_fixed_tokenizer(output_file): - print("creating fixed tokenizer") - import extra.junk.sentencepiece_model_pb2 as spb2 - mp = spb2.ModelProto() - mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes()) - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0)) - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0)) - with open(output_file, "wb") as f: - f.write(mp.SerializeToString()) - -# example: -# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py - -if __name__ == "__main__": - # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json - with Timing("create model: "): - model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1)) - - with Timing("download weights: "): - part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true")) - part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true")) - - with Timing("weights -> model: "): - nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, 32, 32, 8)), strict=False) - nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, 32, 32, 8)), strict=False) - - if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model") - spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model") - - # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json - # "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - IM_END = 32000 - IM_START = 32001 - def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n") - def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n") - def output(outputted, toks, color): - cur = spp.decode(toks)[len(outputted):] - sys.stdout.write(colored(cur, color)) - sys.stdout.flush() - outputted += cur - return outputted - - # *** app below this line *** - - toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input") - - PROMPT = getenv("PROMPT", 1) - temperature = getenv("TEMP", 0.7) - - start_pos = 0 - outputted = output("", toks, "green") - turn = True - while 1: - if PROMPT: - toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant") - else: - toks += start_prompt("user" if turn else "assistant") - turn = not turn - old_output_len = len(outputted) - while 1: - tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item() - start_pos = len(toks) - toks.append(tok) - outputted = output(outputted, toks, "blue" if not turn else "cyan") - if tok == IM_END: break - if tok == spp.eos_id(): break - new_output = outputted[old_output_len:] - - if new_output.endswith("```") and '```python\n' in new_output: - python_code = new_output.split('```python\n')[1].split("```")[0] - # AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things. - if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y': - my_stdout = StringIO() - try: - with redirect_stdout(my_stdout): exec(python_code) - result = my_stdout.getvalue() - except Exception as e: - result = ''.join(traceback.format_exception_only(e)) - toks += spp.encode(f"\nOutput:\n```\n{result}```") - outputted = output(outputted, toks, "yellow") - old_output_len = len(outputted) - print("") \ No newline at end of file diff --git a/examples/efficientnet.py b/examples/efficientnet.py deleted file mode 100644 index e8e8bd916b..0000000000 --- a/examples/efficientnet.py +++ /dev/null @@ -1,89 +0,0 @@ -# load weights from -# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth -# a rough copy of -# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py -import sys -import ast -import time -import numpy as np -from PIL import Image -from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv, fetch, Timing -from tinygrad.engine.jit import TinyJit -from extra.models.efficientnet import EfficientNet -np.set_printoptions(suppress=True) - -# TODO: you should be able to put these in the jitted function -bias = Tensor([0.485, 0.456, 0.406]) -scale = Tensor([0.229, 0.224, 0.225]) - -@TinyJit -def _infer(model, img): - img = img.permute((2,0,1)) - img = img / 255.0 - img = img - bias.reshape((1,-1,1,1)) - img = img / scale.reshape((1,-1,1,1)) - return model.forward(img).realize() - -def infer(model, img): - # preprocess image - aspect_ratio = img.size[0] / img.size[1] - img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) - - img = np.array(img) - y0,x0=(np.asarray(img.shape)[:2]-224)//2 - retimg = img = img[y0:y0+224, x0:x0+224] - - # if you want to look at the image - """ - import matplotlib.pyplot as plt - plt.imshow(img) - plt.show() - """ - - # run the net - out = _infer(model, Tensor(img.astype("float32"))).numpy() - - # if you want to look at the outputs - """ - import matplotlib.pyplot as plt - plt.plot(out[0]) - plt.show() - """ - return out, retimg - -if __name__ == "__main__": - # instantiate my net - model = EfficientNet(getenv("NUM", 0)) - model.load_from_pretrained() - - # category labels - lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text()) - - # load image and preprocess - url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg" - if url == 'webcam': - import cv2 - cap = cv2.VideoCapture(0) - cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) - while 1: - _ = cap.grab() # discard one frame to circumvent capture buffering - ret, frame = cap.read() - img = Image.fromarray(frame[:, :, [2,1,0]]) - lt = time.monotonic_ns() - out, retimg = infer(model, img) - print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)]) - SCALE = 3 - simg = cv2.resize(retimg, (224*SCALE, 224*SCALE)) - retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR) - cv2.imshow('capture', retimg) - if cv2.waitKey(1) & 0xFF == ord('q'): - break - cap.release() - cv2.destroyAllWindows() - else: - img = Image.open(fetch(url)) - for i in range(getenv("CNT", 1)): - with Timing("did inference in "): - out, _ = infer(model, img) - print(np.argmax(out), np.max(out), lbls[np.argmax(out)]) diff --git a/examples/flux1.py b/examples/flux1.py deleted file mode 100644 index b0fcff13b1..0000000000 --- a/examples/flux1.py +++ /dev/null @@ -1,498 +0,0 @@ -# pip3 install sentencepiece - -# This file incorporates code from the following: -# Github Name | License | Link -# black-forest-labs/flux | Apache | https://github.com/black-forest-labs/flux/tree/main/model_licenses - -from tinygrad import Tensor, nn, dtypes, TinyJit -from tinygrad.nn.state import safe_load, load_state_dict -from tinygrad.helpers import fetch, tqdm, colored -from sdxl import FirstStage -from extra.models.clip import FrozenClosedClipEmbedder -from extra.models.t5 import T5Embedder -import numpy as np - -import math, time, argparse, tempfile -from typing import List, Dict, Optional, Union, Tuple, Callable -from dataclasses import dataclass -from pathlib import Path -from PIL import Image - -urls:dict = { - "flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors", - "flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft", - "ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors", - "T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors", - "T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors", - "T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model", - "clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors" -} - -def tensor_identity(x:Tensor) -> Tensor: return x - -class AutoEncoder: - def __init__(self, scale_factor:float, shift_factor:float): - self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256) - self.scale_factor = scale_factor - self.shift_factor = shift_factor - - def decode(self, z:Tensor) -> Tensor: - z = z / self.scale_factor + self.shift_factor - return self.decoder(z) - -# Conditioner -class ClipEmbedder(FrozenClosedClipEmbedder): - def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor: - if isinstance(texts, str): texts = [texts] - assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" - tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0) - return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)] - -# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py -def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor: - q, k = apply_rope(q, k, pe) - x = Tensor.scaled_dot_product_attention(q, k, v) - return x.rearrange("B H L D -> B L (H D)") - -def rope(pos:Tensor, dim:int, theta:int) -> Tensor: - assert dim % 2 == 0 - scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation - omega = 1.0 / (theta**scale) - out = Tensor.einsum("...n,d->...nd", pos, omega) - out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1) - out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2) - return out.float() - -def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]: - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype) - - -# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class EmbedND: - def __init__(self, dim:int, theta:int, axes_dim:List[int]): - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - - def __call__(self, ids:Tensor) -> Tensor: - n_axes = ids.shape[-1] - emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) - return emb.unsqueeze(1) - -class MLPEmbedder: - def __init__(self, in_dim:int, hidden_dim:int): - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) - - def __call__(self, x:Tensor) -> Tensor: - return self.out_layer(self.in_layer(x).silu()) - -class QKNorm: - def __init__(self, dim:int): - self.query_norm = nn.RMSNorm(dim) - self.key_norm = nn.RMSNorm(dim) - - def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]: - return self.query_norm(q), self.key_norm(k) - -class SelfAttention: - def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False): - self.num_heads = num_heads - head_dim = dim // num_heads - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.norm = QKNorm(head_dim) - self.proj = nn.Linear(dim, dim) - - def __call__(self, x:Tensor, pe:Tensor) -> Tensor: - qkv = self.qkv(x) - q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k = self.norm(q, k) - x = attention(q, k, v, pe=pe) - return self.proj(x) - -@dataclass -class ModulationOut: - shift:Tensor - scale:Tensor - gate:Tensor - -class Modulation: - def __init__(self, dim:int, double:bool): - self.is_double = double - self.multiplier = 6 if double else 3 - self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) - - def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]: - out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1) - return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None - -class DoubleStreamBlock: - def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False): - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.num_heads = num_heads - self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True) - self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) - - self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)] - - self.txt_mod = Modulation(hidden_size, double=True) - self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) - - self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)] - - def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]: - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) - assert img_mod2 is not None and txt_mod2 is not None - # prepare image for attention - img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift - img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) - img_q, img_k = self.img_attn.norm(img_q, img_k) - - # prepare txt for attention - txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift - txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k) - - # run actual attention - q = Tensor.cat(txt_q, img_q, dim=2) - k = Tensor.cat(txt_k, img_k, dim=2) - v = Tensor.cat(txt_v, img_v, dim=2) - - attn = attention(q, k, v, pe=pe) - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp) - - # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp) - return img, txt - - -class SingleStreamBlock: - """ - A DiT block with parallel linear layers as described in - https://arxiv.org/abs/2302.05442 and adapted modulation interface. - """ - - def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None): - self.hidden_dim = hidden_size - self.num_heads = num_heads - head_dim = hidden_size // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) - # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) - # proj and mlp_out - self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) - - self.norm = QKNorm(head_dim) - - self.hidden_size = hidden_size - self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - - self.mlp_act = Tensor.gelu - self.modulation = Modulation(hidden_size, double=False) - - def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor: - mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift - qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k = self.norm(q, k) - - # compute attention - attn = attention(q, k, v, pe=pe) - # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2)) - return x + mod.gate * output - - -class LastLayer: - def __init__(self, hidden_size:int, patch_size:int, out_channels:int): - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)] - - def __call__(self, x:Tensor, vec:Tensor) -> Tensor: - shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] - return self.linear(x) - -def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor: - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - t = time_factor * t - half = dim // 2 - freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device) - - args = t[:, None].float() * freqs[None] - embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1) - if dim % 2: embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1) - if Tensor.is_floating_point(t): embedding = embedding.cast(t.dtype) - return embedding - -# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py -class Flux: - """ - Transformer model for flow matching on sequences. - """ - - def __init__( - self, - guidance_embed:bool, - in_channels:int = 64, - vec_in_dim:int = 768, - context_in_dim:int = 4096, - hidden_size:int = 3072, - mlp_ratio:float = 4.0, - num_heads:int = 24, - depth:int = 19, - depth_single_blocks:int = 38, - axes_dim:Optional[List[int]] = None, - theta:int = 10_000, - qkv_bias:bool = True, - ): - - axes_dim = axes_dim or [16, 56, 56] - self.guidance_embed = guidance_embed - self.in_channels = in_channels - self.out_channels = self.in_channels - if hidden_size % num_heads != 0: - raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") - pe_dim = hidden_size // num_heads - if sum(axes_dim) != pe_dim: - raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = hidden_size - self.num_heads = num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size) - self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity - self.txt_in = nn.Linear(context_in_dim, self.hidden_size) - - self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)] - self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)] - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) - - def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor: - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - # running on sequences img - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) - if self.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) - vec = vec + self.vector_in(y) - txt = self.txt_in(txt) - ids = Tensor.cat(txt_ids, img_ids, dim=1) - pe = self.pe_embedder(ids) - for double_block in self.double_blocks: - img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe) - - img = Tensor.cat(txt, img, dim=1) - for single_block in self.single_blocks: - img = single_block(img, vec=vec, pe=pe) - - img = img[:, txt.shape[1] :, ...] - - return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - -# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py -def load_flow_model(name:str, model_path:str): - # Loading Flux - print("Init model") - model = Flux(guidance_embed=(name != "flux-schnell")) - if not model_path: model_path = fetch(urls[name]) - state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()} - load_state_dict(model, state_dict) - return model - -def load_T5(max_length:int=512): - # max length 64, 128, 256 and 512 should work (if your sequence is short enough) - print("Init T5") - T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"])) - pt_1 = fetch(urls["T5_1_of_2"]) - pt_2 = fetch(urls["T5_2_of_2"]) - load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False) - return T5 - -def load_clip(): - print("Init Clip") - clip = ClipEmbedder() - load_state_dict(clip.transformer, safe_load(fetch(urls["clip"]))) - return clip - -def load_ae() -> AutoEncoder: - # Loading the autoencoder - print("Init AE") - ae = AutoEncoder(0.3611, 0.1159) - load_state_dict(ae, safe_load(fetch(urls["ae"]))) - return ae - -# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py -def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]: - bs, _, h, w = img.shape - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img.shape[0] == 1 and bs > 1: - img = img.expand((bs, *img.shape[1:])) - - img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous() - img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :] - img_ids = img_ids.rearrange("h w c -> 1 (h w) c") - img_ids = img_ids.expand((bs, *img_ids.shape[1:])) - - if isinstance(prompt, str): - prompt = [prompt] - txt = T5(prompt).realize() - if txt.shape[0] == 1 and bs > 1: - txt = txt.expand((bs, *txt.shape[1:])) - txt_ids = Tensor.zeros(bs, txt.shape[1], 3) - - vec = clip(prompt).realize() - if vec.shape[0] == 1 and bs > 1: - vec = vec.expand((bs, *vec.shape[1:])) - - return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)} - - -def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]: - # extra step for zero - step_size = -1.0 / num_steps - timesteps = Tensor.arange(1, 0 + step_size, step_size) - - # shifting the schedule to favor high timesteps for higher signal images - if shift: - # estimate mu based on linear estimation between two points - mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256) - timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1)) - return timesteps.tolist() - -@TinyJit -def run(model, *args): return model(*args).realize() - -def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor: - # this is ignored for schnell - guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],)) - for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"): - t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],)) - pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec) - img = img + (t_prev - t_curr) * pred - - return img - -def unpack(x:Tensor, height:int, width:int) -> Tensor: - return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2) - -# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py -if __name__ == "__main__": - default_prompt = "bananas and a can of coke" - parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load") - parser.add_argument("--model_path", type=str, default="", help="path of the model file") - parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)") - parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)") - parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling") - parser.add_argument("--prompt", type=str, default=default_prompt, help="Prompt used for sampling") - parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") - parser.add_argument("--num_steps", type=int, default=None, help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501 - parser.add_argument("--guidance", type=float, default=3.5, help="guidance value used for guidance distillation") - parser.add_argument("--output_dir", type=str, default="output", help="output directory") - args = parser.parse_args() - - if args.name not in ["flux-schnell", "flux-dev"]: - raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev") - - if args.num_steps is None: - args.num_steps = 4 if args.name == "flux-schnell" else 50 - - # allow for packing and conversion to latent space - height = 16 * (args.height // 16) - width = 16 * (args.width // 16) - - if args.seed is None: args.seed = Tensor._seed - else: Tensor.manual_seed(args.seed) - - print(f"Generating with seed {args.seed}:\n{args.prompt}") - t0 = time.perf_counter() - - # prepare input noise - x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16") - - # load text embedders - T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512) - clip = load_clip() - - # embed text to get inputs for model - inp = prepare(T5, clip, x, prompt=args.prompt) - timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell")) - - # done with text embedders - del T5, clip - - # load model - model = load_flow_model(args.name, args.model_path) - - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance) - - # done with model - del model, run - - # load autoencoder - ae = load_ae() - - # decode latents to pixel space - x = unpack(x.float(), height, width) - x = ae.decode(x).realize() - - t1 = time.perf_counter() - print(f"Done in {t1 - t0:.1f}s. Saving {args.out}") - - # bring into PIL format and save - x = x.clamp(-1, 1) - x = x[0].rearrange("c h w -> h w c") - x = (127.5 * (x + 1.0)).cast("uint8") - - img = Image.fromarray(x.numpy()) - - img.save(args.out) - - # validation! - if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512: - ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png"))) - distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() - assert distance < 4e-3, colored(f"validation failed with {distance=}", "red") - print(colored(f"output validated with {distance=}", "green")) \ No newline at end of file diff --git a/examples/mask_rcnn.py b/examples/mask_rcnn.py deleted file mode 100644 index 00d4c240c3..0000000000 --- a/examples/mask_rcnn.py +++ /dev/null @@ -1,299 +0,0 @@ -from extra.models.mask_rcnn import MaskRCNN -from extra.models.resnet import ResNet -from extra.models.mask_rcnn import BoxList -from torch.nn import functional as F -from torchvision import transforms as T -from torchvision.transforms import functional as Ft -import random -from tinygrad.tensor import Tensor -from PIL import Image -import numpy as np -import torch -import argparse -import cv2 - - -class Resize: - def __init__(self, min_size, max_size): - if not isinstance(min_size, (list, tuple)): - min_size = (min_size,) - self.min_size = min_size - self.max_size = max_size - - # modified from torchvision to add support for max size - def get_size(self, image_size): - w, h = image_size - size = random.choice(self.min_size) - max_size = self.max_size - if max_size is not None: - min_original_size = float(min((w, h))) - max_original_size = float(max((w, h))) - if max_original_size / min_original_size * size > max_size: - size = int(round(max_size * min_original_size / max_original_size)) - - if (w <= h and w == size) or (h <= w and h == size): - return (h, w) - - if w < h: - ow = size - oh = int(size * h / w) - else: - oh = size - ow = int(size * w / h) - - return (oh, ow) - - def __call__(self, image): - size = self.get_size(image.size) - image = Ft.resize(image, size) - return image - - -class Normalize: - def __init__(self, mean, std, to_bgr255=True): - self.mean = mean - self.std = std - self.to_bgr255 = to_bgr255 - - def __call__(self, image): - if self.to_bgr255: - image = image[[2, 1, 0]] * 255 - else: - image = image[[0, 1, 2]] * 255 - image = Ft.normalize(image, mean=self.mean, std=self.std) - return image - -transforms = lambda size_scale: T.Compose( - [ - Resize(int(800*size_scale), int(1333*size_scale)), - T.ToTensor(), - Normalize( - mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True - ), - ] -) - -def expand_boxes(boxes, scale): - w_half = (boxes[:, 2] - boxes[:, 0]) * .5 - h_half = (boxes[:, 3] - boxes[:, 1]) * .5 - x_c = (boxes[:, 2] + boxes[:, 0]) * .5 - y_c = (boxes[:, 3] + boxes[:, 1]) * .5 - - w_half *= scale - h_half *= scale - - boxes_exp = torch.zeros_like(boxes) - boxes_exp[:, 0] = x_c - w_half - boxes_exp[:, 2] = x_c + w_half - boxes_exp[:, 1] = y_c - h_half - boxes_exp[:, 3] = y_c + h_half - return boxes_exp - - -def expand_masks(mask, padding): - N = mask.shape[0] - M = mask.shape[-1] - pad2 = 2 * padding - scale = float(M + pad2) / M - padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2)) - padded_mask[:, :, padding:-padding, padding:-padding] = mask - return padded_mask, scale - - -def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): - # TODO: remove torch - mask = torch.tensor(mask.numpy()) - box = torch.tensor(box.numpy()) - padded_mask, scale = expand_masks(mask[None], padding=padding) - mask = padded_mask[0, 0] - box = expand_boxes(box[None], scale)[0] - box = box.to(dtype=torch.int32) - - TO_REMOVE = 1 - w = int(box[2] - box[0] + TO_REMOVE) - h = int(box[3] - box[1] + TO_REMOVE) - w = max(w, 1) - h = max(h, 1) - - mask = mask.expand((1, 1, -1, -1)) - - mask = mask.to(torch.float32) - mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) - mask = mask[0][0] - - if thresh >= 0: - mask = mask > thresh - else: - mask = (mask * 255).to(torch.uint8) - - im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8) - x_0 = max(box[0], 0) - x_1 = min(box[2] + 1, im_w) - y_0 = max(box[1], 0) - y_1 = min(box[3] + 1, im_h) - - im_mask[y_0:y_1, x_0:x_1] = mask[ - (y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0]) - ] - return im_mask - - -class Masker: - def __init__(self, threshold=0.5, padding=1): - self.threshold = threshold - self.padding = padding - - def forward_single_image(self, masks, boxes): - boxes = boxes.convert("xyxy") - im_w, im_h = boxes.size - res = [ - paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) - for mask, box in zip(masks, boxes.bbox) - ] - if len(res) > 0: - res = torch.stack(*res, dim=0)[:, None] - else: - res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) - return Tensor(res.numpy()) - - def __call__(self, masks, boxes): - if isinstance(boxes, BoxList): - boxes = [boxes] - - results = [] - for mask, box in zip(masks, boxes): - result = self.forward_single_image(mask, box) - results.append(result) - return results - - -masker = Masker(threshold=0.5, padding=1) - -def select_top_predictions(predictions, confidence_threshold=0.9): - scores = predictions.get_field("scores").numpy() - keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold] - return predictions[keep] - -def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0): - image = transforms(size_scale)(original_image).numpy() - image = Tensor(image, requires_grad=False) - predictions = model(image) - prediction = predictions[0] - prediction = select_top_predictions(prediction, confidence_threshold) - width, height = original_image.size - prediction = prediction.resize((width, height)) - - if prediction.has_field("mask"): - masks = prediction.get_field("mask") - masks = masker([masks], [prediction])[0] - prediction.add_field("mask", masks) - return prediction - -def compute_prediction_batched(batch, model, size_scale=1.0): - imgs = [] - for img in batch: - imgs.append(transforms(size_scale)(img).numpy()) - image = [Tensor(image, requires_grad=False) for image in imgs] - predictions = model(image) - del image - return predictions - -palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - -def findContours(*args, **kwargs): - if cv2.__version__.startswith('4'): - contours, hierarchy = cv2.findContours(*args, **kwargs) - elif cv2.__version__.startswith('3'): - _, contours, hierarchy = cv2.findContours(*args, **kwargs) - return contours, hierarchy - -def compute_colors_for_labels(labels): - l = labels[:, None] - colors = l * palette - colors = (colors % 255).astype("uint8") - return colors - -def overlay_mask(image, predictions): - image = np.asarray(image) - masks = predictions.get_field("mask").numpy() - labels = predictions.get_field("labels").numpy() - - colors = compute_colors_for_labels(labels).tolist() - - for mask, color in zip(masks, colors): - thresh = mask[0, :, :, None] - contours, hierarchy = findContours( - thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE - ) - image = cv2.drawContours(image, contours, -1, color, 3) - - composite = image - - return composite - -CATEGORIES = [ - "__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", - "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", - "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", - "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", - "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", - "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", - "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", - "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", -] - -def overlay_boxes(image, predictions): - labels = predictions.get_field("labels").numpy() - boxes = predictions.bbox - image = np.asarray(image) - colors = compute_colors_for_labels(labels).tolist() - - for box, color in zip(boxes, colors): - box = torch.tensor(box.numpy()) - box = box.to(torch.int64) - top_left, bottom_right = box[:2].tolist(), box[2:].tolist() - image = cv2.rectangle( - image, tuple(top_left), tuple(bottom_right), tuple(color), 1 - ) - - return image - -def overlay_class_names(image, predictions): - scores = predictions.get_field("scores").numpy().tolist() - labels = predictions.get_field("labels").numpy().tolist() - labels = [CATEGORIES[int(i)] for i in labels] - boxes = predictions.bbox.numpy() - image = np.asarray(image) - template = "{}: {:.2f}" - for box, score, label in zip(boxes, scores, labels): - x, y = box[:2] - s = template.format(label, score) - x, y = int(x), int(y) - cv2.putText( - image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1 - ) - - return image - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--image', type=str, help="Path of the image to run") - parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold") - parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier") - parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename") - args = parser.parse_args() - - resnet = ResNet(50, num_classes=None, stride_in_1x1=True) - model_tiny = MaskRCNN(resnet) - model_tiny.load_from_pretrained() - img = Image.open(args.image) - top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale) - bbox_image = overlay_boxes(img, top_result_tiny) - mask_image = overlay_mask(bbox_image, top_result_tiny) - final_image = overlay_class_names(mask_image, top_result_tiny) - - im = Image.fromarray(final_image) - print(f"saving {args.out}") - im.save(args.out) - im.show() diff --git a/examples/openelm.py b/examples/openelm.py deleted file mode 100644 index 71e9cd89df..0000000000 --- a/examples/openelm.py +++ /dev/null @@ -1,118 +0,0 @@ -import json, pprint -from tinygrad import fetch, nn, Tensor -from tinygrad.helpers import DEBUG - -class FeedForward: - def __init__(self, model_dim, intermediate_dim): - self.proj_1 = nn.Linear(model_dim, 2*intermediate_dim, bias=False) - self.proj_2 = nn.Linear(intermediate_dim, model_dim, bias=False) - - def __call__(self, x): - y_12 = self.proj_1(x) - y_1, y_2 = y_12.chunk(2, dim=-1) - return self.proj_2(y_1.silu() * y_2) - -# NOTE: this RoPE doesn't match LLaMA's? -def _rotate_half(x: Tensor) -> Tensor: - x1, x2 = x.chunk(2, dim=-1) - return Tensor.cat(-x2, x1, dim=-1) - -def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor: - return (x * pos_cos) + (_rotate_half(x) * pos_sin) - -class Attention: - def __init__(self, model_dim, num_query_heads, num_kv_heads, head_dim): - self.qkv_proj = nn.Linear(model_dim, (num_query_heads + num_kv_heads*2) * head_dim, bias=False) - self.num_query_heads, self.num_kv_heads = num_query_heads, num_kv_heads - self.head_dim = head_dim - self.q_norm = nn.RMSNorm(head_dim) - self.k_norm = nn.RMSNorm(head_dim) - self.out_proj = nn.Linear(num_query_heads * head_dim, model_dim, bias=False) - - def __call__(self, x:Tensor) -> Tensor: - batch_size, seq_len, embed_dim = x.shape - qkv = self.qkv_proj(x) - qkv = qkv.reshape(batch_size, seq_len, self.num_query_heads+self.num_kv_heads*2, self.head_dim).transpose(1, 2) - xq,xk,xv = qkv.split([self.num_query_heads, self.num_kv_heads, self.num_kv_heads], dim=1) - xq = self.q_norm(xq) - xk = self.k_norm(xk) - - # add positional embedding (how many kernels is this?) - freq_constant = 10000 - inv_freq = 1.0 / (freq_constant ** (Tensor.arange(0, self.head_dim, 2) / self.head_dim)) - pos_index_theta = Tensor.einsum("i,j->ij", Tensor.arange(seq_len), inv_freq) - emb = Tensor.cat(pos_index_theta, pos_index_theta, dim=-1) - cos_emb, sin_emb = emb.cos()[None, None, :, :], emb.sin()[None, None, :, :] - xq = _apply_rotary_pos_emb(xq, sin_emb, cos_emb) - xk = _apply_rotary_pos_emb(xk, sin_emb, cos_emb) - - # grouped-query attention - num_groups = self.num_query_heads // self.num_kv_heads - xk = xk.repeat_interleave(num_groups, dim=1) - xv = xv.repeat_interleave(num_groups, dim=1) - - # masked attention - #start_pos = 0 - #mask = Tensor.full((1, 1, seq_len, start_pos+seq_len), float("-inf"), dtype=xq.dtype, device=xq.device).triu(start_pos+1) - #attn_output = xq.scaled_dot_product_attention(xk, xv, mask).transpose(1, 2) - - # causal is fine, no mask needed - attn_output = xq.scaled_dot_product_attention(xk, xv, is_causal=True).transpose(1, 2) - return self.out_proj(attn_output.reshape(batch_size, seq_len, self.num_query_heads * self.head_dim)) - -class Layer: - def __init__(self, model_dim, intermediate_dim, num_query_heads, num_kv_heads, head_dim): - self.ffn = FeedForward(model_dim, intermediate_dim) - self.attn = Attention(model_dim, num_query_heads, num_kv_heads, head_dim) - self.ffn_norm = nn.RMSNorm(model_dim) - self.attn_norm = nn.RMSNorm(model_dim) - - def __call__(self, x:Tensor) -> Tensor: # (batch, seq_len, embed_dim) - x = x + self.attn(self.attn_norm(x)) - x = x + self.ffn(self.ffn_norm(x)) - return x - -# stupidly complex -def make_divisible(v, divisor): - new_v = max(divisor, int(v + divisor / 2) // divisor * divisor) - if new_v < 0.9 * v: new_v += divisor - return new_v - -class Transformer: - def __init__(self, cfg): - if DEBUG >= 3: pprint.pp(cfg) - self.layers = [Layer(cfg['model_dim'], make_divisible(int(cfg["model_dim"] * cfg['ffn_multipliers'][i]), cfg['ffn_dim_divisor']), - cfg['num_query_heads'][i], cfg['num_kv_heads'][i], cfg['head_dim']) for i in range(cfg['num_transformer_layers'])] - self.norm = nn.RMSNorm(cfg['model_dim']) - self.token_embeddings = nn.Embedding(cfg['vocab_size'], cfg['model_dim']) - - def __call__(self, tokens:Tensor): - # _bsz, seqlen = tokens.shape - x = self.token_embeddings(tokens) - for l in self.layers: x = l(x) - return self.norm(x) @ self.token_embeddings.weight.T - -if __name__ == "__main__": - #model_name = "OpenELM-270M-Instruct" - model_name = "OpenELM-270M" # this is fp32 - model = Transformer(json.loads(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/config.json?download=true").read_bytes())) - weights = nn.state.safe_load(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/model.safetensors?download=true")) - if DEBUG >= 3: - for k, v in weights.items(): print(k, v.shape) - nn.state.load_state_dict(model, {k.removeprefix("transformer."):v for k,v in weights.items()}) - - from sentencepiece import SentencePieceProcessor - tokenizer = SentencePieceProcessor(fetch("https://github.com/karpathy/llama2.c/raw/master/tokenizer.model").as_posix()) - toks = [tokenizer.bos_id()] + tokenizer.encode("Some car brands include") - for i in range(100): - ttoks = Tensor([toks]) - out = model(ttoks).realize() - t0 = out[0].argmax(axis=-1).tolist() - toks.append(t0[-1]) - # hmmm...passthrough still doesn't match (it shouldn't, it outputs the most likely) - print(tokenizer.decode(toks)) - #print(toks) - #print(tokenizer.decode(t0)) - #print(t0) - - diff --git a/examples/other_mnist/beautiful_mnist_mlx.py b/examples/other_mnist/beautiful_mnist_mlx.py deleted file mode 100644 index 8261ed4472..0000000000 --- a/examples/other_mnist/beautiful_mnist_mlx.py +++ /dev/null @@ -1,55 +0,0 @@ -from tinygrad.helpers import trange -from tinygrad.nn.datasets import mnist -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -from functools import partial - -class Model(nn.Module): - def __init__(self): - super().__init__() - self.c1 = nn.Conv2d(1, 32, 5) - self.c2 = nn.Conv2d(32, 32, 5) - self.bn1 = nn.BatchNorm(32) - self.m1 = nn.MaxPool2d(2) - self.c3 = nn.Conv2d(32, 64, 3) - self.c4 = nn.Conv2d(64, 64, 3) - self.bn2 = nn.BatchNorm(64) - self.m2 = nn.MaxPool2d(2) - self.lin = nn.Linear(576, 10) - def __call__(self, x): - x = mx.maximum(self.c1(x), 0) - x = mx.maximum(self.c2(x), 0) - x = self.m1(self.bn1(x)) - x = mx.maximum(self.c3(x), 0) - x = mx.maximum(self.c4(x), 0) - x = self.m2(self.bn2(x)) - return self.lin(mx.flatten(x, 1)) - -if __name__ == "__main__": - X_train, Y_train, X_test, Y_test = mnist() - X_train = mx.array(X_train.float().permute((0,2,3,1)).numpy()) - Y_train = mx.array(Y_train.numpy()) - X_test = mx.array(X_test.float().permute((0,2,3,1)).numpy()) - Y_test = mx.array(Y_test.numpy()) - - model = Model() - optimizer = optim.Adam(1e-3) - def loss_fn(model, x, y): return nn.losses.cross_entropy(model(x), y).mean() - - state = [model.state, optimizer.state] - @partial(mx.compile, inputs=state, outputs=state) - def step(samples): - # Compiled functions will also treat any inputs not in the parameter list as constants. - X,Y = X_train[samples], Y_train[samples] - loss_and_grad_fn = nn.value_and_grad(model, loss_fn) - loss, grads = loss_and_grad_fn(model, X, Y) - optimizer.update(model, grads) - return loss - - test_acc = float('nan') - for i in (t:=trange(70)): - samples = mx.random.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well - loss = step(samples) - if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item() - t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") diff --git a/examples/rl/lightupbutton.py b/examples/rl/lightupbutton.py deleted file mode 100644 index 65eeff27fd..0000000000 --- a/examples/rl/lightupbutton.py +++ /dev/null @@ -1,45 +0,0 @@ -import gymnasium as gym -import numpy as np -from gymnasium.envs.registration import register - -# a very simple game -# one of lights will light up -# take the action of the lit up light -# in , you act differently based on the step number and need to track this - -class PressTheLightUpButton(gym.Env): - metadata = {"render_modes": []} - def __init__(self, render_mode=None, size=2, game_length=10, hard_mode=False): - self.size, self.game_length = size, game_length - self.observation_space = gym.spaces.Box(0, 1, shape=(self.size,), dtype=np.float32) - self.action_space = gym.spaces.Discrete(self.size) - self.step_num = 0 - self.done = True - self.hard_mode = hard_mode - - def _get_obs(self): - obs = [0]*self.size - if self.step_num < len(self.state): - obs[self.state[self.step_num]] = 1 - return np.array(obs, dtype=np.float32) - - def reset(self, seed=None, options=None): - super().reset(seed=seed) - self.state = np.random.randint(0, self.size, size=self.game_length) - self.step_num = 0 - self.done = False - return self._get_obs(), {} - - def step(self, action): - target = ((action + self.step_num) % self.size) if self.hard_mode else action - reward = int(target == self.state[self.step_num]) - self.step_num += 1 - if not reward: - self.done = True - return self._get_obs(), reward, self.done, self.step_num >= self.game_length, {} - -register( - id="PressTheLightUpButton-v0", - entry_point="examples.rl.lightupbutton:PressTheLightUpButton", - max_episode_steps=None, -) \ No newline at end of file diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py deleted file mode 100644 index 752f5dab41..0000000000 --- a/examples/serious_mnist.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python -#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb -import sys -import numpy as np -from tinygrad.nn.state import get_parameters -from tinygrad.tensor import Tensor -from tinygrad.nn import BatchNorm2d, optim -from tinygrad.helpers import getenv -from extra.datasets import fetch_mnist -from extra.augment import augment_img -from extra.training import train, evaluate -GPU = getenv("GPU") -QUICK = getenv("QUICK") -DEBUG = getenv("DEBUG") - -class SqueezeExciteBlock2D: - def __init__(self, filters): - self.filters = filters - self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32) - self.bias1 = Tensor.scaled_uniform(1,self.filters//32) - self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters) - self.bias2 = Tensor.scaled_uniform(1, self.filters) - - def __call__(self, input): - se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D - se = se.reshape(shape=(-1, self.filters)) - se = se.dot(self.weight1) + self.bias1 - se = se.relu() - se = se.dot(self.weight2) + self.bias2 - se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting - se = input.mul(se) - return se - -class ConvBlock: - def __init__(self, h, w, inp, filters=128, conv=3): - self.h, self.w = h, w - self.inp = inp - #init weights - self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)] - self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)] - #init layers - self._bn = BatchNorm2d(128) - self._seb = SqueezeExciteBlock2D(filters) - - def __call__(self, input): - x = input.reshape(shape=(-1, self.inp, self.w, self.h)) - for cweight, cbias in zip(self.cweights, self.cbiases): - x = x.pad(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu() - x = self._bn(x) - x = self._seb(x) - return x - -class BigConvNet: - def __init__(self): - self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)] - self.weight1 = Tensor.scaled_uniform(128,10) - self.weight2 = Tensor.scaled_uniform(128,10) - - def parameters(self): - if DEBUG: #keeping this for a moment - pars = [par for par in get_parameters(self) if par.requires_grad] - no_pars = 0 - for par in pars: - print(par.shape) - no_pars += np.prod(par.shape) - print('no of parameters', no_pars) - return pars - else: - return get_parameters(self) - - def save(self, filename): - with open(filename+'.npy', 'wb') as f: - for par in get_parameters(self): - #if par.requires_grad: - np.save(f, par.numpy()) - - def load(self, filename): - with open(filename+'.npy', 'rb') as f: - for par in get_parameters(self): - #if par.requires_grad: - try: - par.numpy()[:] = np.load(f) - if GPU: - par.gpu() - except: - print('Could not load parameter') - - def forward(self, x): - x = self.conv[0](x) - x = self.conv[1](x) - x = x.avg_pool2d(kernel_size=(2,2)) - x = self.conv[2](x) - x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global - x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global - xo = x1.dot(self.weight1) + x2.dot(self.weight2) - return xo - - -if __name__ == "__main__": - lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5] - epochss = [2, 1] if QUICK else [13, 3, 3, 1] - BS = 32 - - lmbd = 0.00025 - lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum() - X_train, Y_train, X_test, Y_test = fetch_mnist() - X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) - X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) - steps = len(X_train)//BS - np.random.seed(1337) - if QUICK: - steps = 1 - X_test, Y_test = X_test[:BS], Y_test[:BS] - - model = BigConvNet() - - if len(sys.argv) > 1: - try: - model.load(sys.argv[1]) - print('Loaded weights "'+sys.argv[1]+'", evaluating...') - evaluate(model, X_test, Y_test, BS=BS) - except: - print('could not load weights "'+sys.argv[1]+'".') - - if GPU: - params = get_parameters(model) - [x.gpu_() for x in params] - - for lr, epochs in zip(lrs, epochss): - optimizer = optim.Adam(model.parameters(), lr=lr) - for epoch in range(1,epochs+1): - #first epoch without augmentation - X_aug = X_train if epoch == 1 else augment_img(X_train) - train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS) - accuracy = evaluate(model, X_test, Y_test, BS=BS) - model.save(f'examples/checkpoint{accuracy * 1e6:.0f}') diff --git a/examples/simple_conv_bn.py b/examples/simple_conv_bn.py deleted file mode 100644 index 287691016f..0000000000 --- a/examples/simple_conv_bn.py +++ /dev/null @@ -1,17 +0,0 @@ -from tinygrad.tensor import Tensor -from tinygrad.nn import Conv2d, BatchNorm2d -from tinygrad.nn.state import get_parameters - -if __name__ == "__main__": - with Tensor.train(): - - BS, C1, H, W = 4, 16, 224, 224 - C2, K, S, P = 64, 7, 2, 1 - - x = Tensor.uniform(BS, C1, H, W) - conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) - bn = BatchNorm2d(C2, track_running_stats=False) - for t in get_parameters([x, conv, bn]): t.realize() - - print("running network") - x.sequential([conv, bn]).numpy() diff --git a/examples/so_vits_svc.py b/examples/so_vits_svc.py deleted file mode 100644 index 6b6eeab7ef..0000000000 --- a/examples/so_vits_svc.py +++ /dev/null @@ -1,669 +0,0 @@ -# original implementation: https://github.com/svc-develop-team/so-vits-svc -from __future__ import annotations -import sys, logging, time, io, math, argparse, operator, numpy as np -from functools import partial, reduce -from pathlib import Path -from typing import Tuple, Optional, Type -from tinygrad import nn, dtypes, Tensor -from tinygrad.helpers import getenv, fetch -from tinygrad.nn.state import torch_load -from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, get_hparams_from_file, load_checkpoint, weight_norm, HParams -from examples.sovits_helpers import preprocess -import soundfile - -DEBUG = getenv("DEBUG") - -F0_BIN = 256 -F0_MAX = 1100.0 -F0_MIN = 50.0 -F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700) -F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700) - -class SpeechEncoder: - def __init__(self, hidden_dim, model:ContentVec): self.hidden_dim, self.model = hidden_dim, model - def encode(self, ): raise NotImplementedError("implement me") - @classmethod - def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec: - contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url) - return cls(contentvec) - -class ContentVec256L9(SpeechEncoder): - def __init__(self, model:ContentVec): super().__init__(hidden_dim=256, model=model) - def encode(self, wav: Tensor): - feats = wav - if len(feats.shape) == 2: # double channels - feats = feats.mean(-1) - assert len(feats.shape) == 1, feats.dim() - feats = feats.reshape(1, -1) - padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool) - logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=9) - feats = self.model.final_proj(logits[0]) - return feats.transpose(1,2) - -class ContentVec768L12(SpeechEncoder): - def __init__(self, model:ContentVec): super().__init__(hidden_dim=768, model=model) - def encode(self, wav: Tensor): - feats = wav - if len(feats.shape) == 2: # double channels - feats = feats.mean(-1) - assert len(feats.shape) == 1, feats.dim() - feats = feats.reshape(1, -1) - padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool) - logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=12) - return logits[0].transpose(1,2) - -# original code for contentvec: https://github.com/auspicious3000/contentvec/ -class ContentVec: - # self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2. - # This means that the ContentVec only works with the hubert weights used in all SVC models - def __init__(self, cfg: HParams): - self.feature_grad_mult, self.untie_final_proj = cfg.feature_grad_mult, cfg.untie_final_proj - feature_enc_layers = eval(cfg.conv_feature_layers) - self.embed = feature_enc_layers[-1][0] - final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim - self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias) - self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None - self.encoder = TransformerEncoder(cfg) - self.layer_norm = nn.LayerNorm(self.embed) - self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * 1) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim) - self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32) - self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32) - def forward_features(self, source, padding_mask): - if self.feature_grad_mult > 0: - features = self.feature_extractor(source, padding_mask) - if self.feature_grad_mult != 1.0: pass # training: GradMultiply.forward(features, self.feature_grad_mult) - else: - features = self.feature_extractor(source, padding_mask) - return features - def forward_padding_mask(self, features, padding_mask): # replaces original forward_padding_mask for batch inference - lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure its bool for tilde - lengths = (lengths_org - 400).float().div(320).floor().cast(dtypes.int64) + 1 # intermediate float to divide - padding_mask = lengths_to_padding_mask(lengths) - return padding_mask - def extract_features(self, source: Tensor, spk_emb:Tensor=None, padding_mask=None, ret_conv=False, output_layer=None, tap=False): - features = self.forward_features(source, padding_mask) - if padding_mask is not None: - padding_mask = self.forward_padding_mask(features, padding_mask) - features = features.transpose(1, 2) - features = self.layer_norm(features) - if self.post_extract_proj is not None: - features = self.post_extract_proj(features) - x, _ = self.encoder(features, spk_emb, padding_mask=padding_mask, layer=(None if output_layer is None else output_layer - 1), tap=tap) - res = features if ret_conv else x - return res, padding_mask - @classmethod - def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec: - fetch(checkpoint_url, checkpoint_path) - cfg = load_fairseq_cfg(checkpoint_path) - enc = cls(cfg.model) - _ = load_checkpoint_enc(checkpoint_path, enc, None) - logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}") - return enc - -class TransformerEncoder: - def __init__(self, cfg: HParams): - def make_conv() -> nn.Conv1d: - layer = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=cfg.conv_pos, padding=cfg.conv_pos // 2, groups=cfg.conv_pos_groups) - std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim)) - layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), (Tensor.zeros(*layer.bias.shape)) - # for training: layer.weights need to be weight_normed - return layer - self.dropout, self.embedding_dim, self.layer_norm_first, self.layerdrop, self.num_layers, self.num_layers_1 = cfg.dropout, cfg.encoder_embed_dim, cfg.layer_norm_first, cfg.encoder_layerdrop, cfg.encoder_layers, cfg.encoder_layers_1 - self.pos_conv, self.pos_conv_remove = [make_conv()], (1 if cfg.conv_pos % 2 == 0 else 0) - self.layers = [ - TransformerEncoderLayer(self.embedding_dim, cfg.encoder_ffn_embed_dim, cfg.encoder_attention_heads, self.dropout, cfg.attention_dropout, cfg.activation_dropout, cfg.activation_fn, self.layer_norm_first, cond_layer_norm=(i >= cfg.encoder_layers)) - for i in range(cfg.encoder_layers + cfg.encoder_layers_1) - ] - self.layer_norm = nn.LayerNorm(self.embedding_dim) - self.cond_layer_norm = CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None - # training: apply init_bert_params - def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False): - x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap) - if self.layer_norm_first and layer is None: - x = self.cond_layer_norm(x, spk_emb) if (self.num_layers_1 > 0) else self.layer_norm(x) - return x, layer_results - def extract_features(self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False): - if tgt_layer is not None: # and not self.training - assert tgt_layer >= 0 and tgt_layer < len(self.layers) - if padding_mask is not None: - # x[padding_mask] = 0 - assert padding_mask.shape == x.shape[:len(padding_mask.shape)] # first few dims of x must match padding_mask - tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1])) - tmp_mask = tilde(tmp_mask.cast(dtypes.bool)) - x = tmp_mask.where(x, 0) - x_conv = self.pos_conv[0](x.transpose(1,2)) - if self.pos_conv_remove > 0: x_conv = x_conv[:, :, : -self.pos_conv_remove] - x_conv = x_conv.gelu().transpose(1, 2) - x = (x + x_conv).transpose(0, 1) # B x T x C -> T x B x C - if not self.layer_norm_first: x = self.layer_norm(x) - x = x.dropout(p=self.dropout) - layer_results = [] - r = None - for i, layer in enumerate(self.layers): - if i < self.num_layers: # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers): - assert layer.cond_layer_norm == False - x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) - if tgt_layer is not None or tap: - layer_results.append(x.transpose(0, 1)) - if i>= self.num_layers: - assert layer.cond_layer_norm == True - x = layer(x, emb=spk_emb, self_attn_padding_mask=padding_mask, need_weights=False) - if i == tgt_layer: - r = x - break - if r is not None: - x = r - x = x.transpose(0, 1) # T x B x C -> B x T x C - return x, layer_results - -class TransformerEncoderLayer: - def __init__(self, embedding_dim=768.0, ffn_embedding_dim=3072.0, num_attention_heads=8.0, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, activation_fn="relu", layer_norm_first=False, cond_layer_norm=False): - def get_activation_fn(activation): - if activation == "relu": return Tensor.relu - if activation == "gelu": return Tensor.gelu - else: raise RuntimeError(f"activation function={activation} is not forseen") - self.embedding_dim, self.dropout, self.activation_dropout, self.layer_norm_first, self.num_attention_heads, self.cond_layer_norm, self.activation_fn = embedding_dim, dropout, activation_dropout, layer_norm_first, num_attention_heads, cond_layer_norm, get_activation_fn(activation_fn) - self.self_attn = MultiHeadAttention(self.embedding_dim, self.num_attention_heads) - self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim) - self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) - self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) - self.final_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim) - def __call__(self, x:Tensor, self_attn_mask:Tensor=None, self_attn_padding_mask:Tensor=None, emb:Tensor=None, need_weights=False): - #self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None - assert self_attn_mask is None and self_attn_padding_mask is not None - residual = x - if self.layer_norm_first: - x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb) - x = self.self_attn(x=x, mask=self_attn_padding_mask) - x = x.dropout(self.dropout) - x = residual + x - x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb) - x = self.activation_fn(self.fc1(x)) - x = x.dropout(self.activation_dropout) - x = self.fc2(x) - x = x.dropout(self.dropout) - x = residual + x - else: - x = self.self_attn(x=x, mask=self_attn_padding_mask) - x = x.dropout(self.dropout) - x = residual + x - x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb) - residual = x - x = self.activation_fn(self.fc1(x)) - x = x.dropout(self.activation_dropout) - x = self.fc2(x) - x = x.dropout(self.dropout) - x = residual + x - x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb) - return x - -class MultiHeadAttention: - def __init__(self, n_state, n_head): - self.n_state, self.n_head = n_state, n_head - self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(n_state, n_state) for _ in range(4)] - def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None): - x = x.transpose(0,1) # TxBxC -> BxTxC - q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x) - q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)] - wv = Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None).transpose(1, 2).reshape(*x.shape[:2], -1) - ret = self.out_proj(wv).transpose(0,1) # BxTxC -> TxBxC - return ret - -class ConvFeatureExtractionModel: - def __init__(self, conv_layers, dropout=.0, mode="default", conv_bias=False): - assert mode in {"default", "group_norm_masked", "layer_norm"} - def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False): - def make_conv(): - conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) - conv.weight = Tensor.kaiming_normal(*conv.weight.shape) - return conv - assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" - if is_layer_norm: - return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu] - elif is_group_norm and mode == "default": - return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu] - elif is_group_norm and mode == "group_norm_masked": - return [make_conv(), partial(Tensor.dropout, p=dropout), GroupNormMasked(dim, dim, affine=True), Tensor.gelu] - else: - return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu] - in_d, self.conv_layers, self.mode = 1, [], mode - for i, cl in enumerate(conv_layers): - assert len(cl) == 3, "invalid conv definition: " + str(cl) - (dim, k, stride) = cl - if i == 0: self.cl = cl - self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=(mode == "layer_norm"), is_group_norm=((mode == "default" or mode == "group_norm_masked") and i == 0), conv_bias=conv_bias)) - in_d = dim - def __call__(self, x:Tensor, padding_mask:Tensor): - x = x.unsqueeze(1) # BxT -> BxCxT - if self.mode == "group_norm_masked": - if padding_mask is not None: - _, k, stride = self.cl - lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure padding_mask is bool for tilde - lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64) - padding_mask = tilde(lengths_to_padding_mask(lengths)).cast(dtypes.int64) # lengths_to_padding_mask returns bool tensor - x = self.conv_layers[0][0](x) # padding_mask is numeric - x = self.conv_layers[0][1](x) - x = self.conv_layers[0][2](x, padding_mask) - x = self.conv_layers[0][3](x) - else: - x = x.sequential(self.conv_layers[0]) # default - for _, conv in enumerate(self.conv_layers[1:], start=1): - conv = reduce(lambda a,b: operator.iconcat(a,b if isinstance(b, list) else [b]), conv, []) # flatten - x = x.sequential(conv) - return x - -class CondLayerNorm: # https://github.com/auspicious3000/contentvec/blob/main/contentvec/modules/cond_layer_norm.py#L10 - def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True): - self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = dim_last, eps, dim_spk, elementwise_affine - if self.elementwise_affine: - self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False) - self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False) - self.weight_ln.weight, self.bias_ln.weight = (Tensor.ones(*self.weight_ln.weight.shape)), (Tensor.zeros(*self.bias_ln.weight.shape)) - def __call__(self, x: Tensor, spk_emb: Tensor): - axis = tuple(-1-i for i in range(len(x.shape[1:]))) - x = x.layernorm(axis=axis, eps=self.eps) - if not self.elementwise_affine: return x - weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb) - return weights * x + bias - -class GroupNormMasked: # https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/modules/fp32_group_norm.py#L16 - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine - self.weight, self.bias = (Tensor.ones(num_channels)), (Tensor.zeros(num_channels)) if self.affine else (None, None) - def __call__(self, x:Tensor, mask:Tensor): - bsz, n_c, length = x.shape - assert n_c % self.num_groups == 0 - x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length) - if mask is None: mask = Tensor.ones_like(x) - else: mask = mask.reshape(bsz, 1, 1, length) - x = x * mask - lengths = mask.sum(axis=3, keepdim=True) - assert x.shape[2] == 1 - mean_ = x.mean(dim=3, keepdim=True) - mean = mean_ * length / lengths - var = (((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths - mean**2) + self.eps - return x.add(-mean).div(var.sqrt()).reshape(bsz, n_c, length).mul(self.weight.reshape(1,-1,1)).add(self.bias.reshape(1,-1,1)) - -class Synthesizer: - def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, ssl_dim, n_speakers, sampling_rate=44100, vol_embedding=False, n_flow_layer=4, **kwargs): - self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.vol_embedding = spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, vol_embedding - self.emb_g = nn.Embedding(n_speakers, gin_channels) - if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels) - self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) - self.enc_p = TextEncoder(inter_channels, hidden_channels, kernel_size, n_layers, filter_channels=filter_channels, n_heads=n_heads, p_dropout=p_dropout) - self.dec = Generator(sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels) - self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels) - def infer(self, c:Tensor, f0:Tensor, uv:Tensor, g:Tensor=None, noise_scale=0.35, seed=52468, vol=None) -> Tuple[Tensor, Tensor]: - Tensor.manual_seed(getenv('SEED', seed)) - c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device) - if len(g.shape) == 1: g = g.unsqueeze(0) - g = self.emb_g(g).transpose(1, 2) - x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype) - vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 - x = self.pre(c) * x_mask + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + vol - z_p, _, _, c_mask = self.enc_p.forward(x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale) - z = self.flow.forward(z_p, c_mask, g=g, reverse=True) - o = self.dec.forward(z * c_mask, g=g, f0=f0) - return o,f0 - def _f0_to_coarse(self, f0 : Tensor): - f0_mel = 1127 * (1 + f0 / 700).log() - a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) - b = F0_MEL_MIN * a - 1. - f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel) - f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64) - f0_coarse = f0_coarse * (f0_coarse > 0) - f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) - f0_coarse = f0_coarse * (f0_coarse < F0_BIN) - f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1)) - return f0_coarse - @classmethod - def load_from_pretrained(cls, config_path:str, config_url:str, weights_path:str, weights_url:str) -> Synthesizer: - fetch(config_url, config_path) - hps = get_hparams_from_file(config_path) - fetch(weights_url, weights_path) - net_g = cls(hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model) - _ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"]) - logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}") - return net_g, hps - -class TextEncoder: - def __init__(self, out_channels, hidden_channels, kernel_size, n_layers, gin_channels=0, filter_channels=None, n_heads=None, p_dropout=None): - self.out_channels, self.hidden_channels, self.kernel_size, self.n_layers, self.gin_channels = out_channels, hidden_channels, kernel_size, n_layers, gin_channels - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - self.f0_emb = nn.Embedding(256, hidden_channels) # n_vocab = 256 - self.enc_ = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) - def forward(self, x, x_mask, f0=None, noise_scale=1): - x = x + self.f0_emb(f0).transpose(1, 2) - x = self.enc_.forward(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = split(stats, self.out_channels, dim=1) - z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask - return z, m, logs, x_mask - -class Upsample: - def __init__(self, scale_factor): - assert scale_factor % 1 == 0, "Only integer scale factor allowed." - self.scale = int(scale_factor) - def forward(self, x:Tensor): - repeats = tuple([1] * len(x.shape) + [self.scale]) - new_shape = (*x.shape[:-1], x.shape[-1] * self.scale) - return x.unsqueeze(-1).repeat(repeats).reshape(new_shape) - -class SineGen: - def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voice_threshold=0, flag_for_pulse=False): - self.sine_amp, self.noise_std, self.harmonic_num, self.sampling_rate, self.voiced_threshold, self.flag_for_pulse = sine_amp, noise_std, harmonic_num, samp_rate, voice_threshold, flag_for_pulse - self.dim = self.harmonic_num + 1 - def _f02uv(self, f0): return (f0 > self.voiced_threshold).float() #generate uv signal - def _f02sine(self, f0_values): - def padDiff(x : Tensor): return (x.pad((0,0,-1,1)) - x).pad((0,0,0,-1)) - def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor() # this is what the % operator does in pytorch. - rad_values = mod((f0_values / self.sampling_rate) , 1) # convert to F0 in rad - rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) # initial phase noise - - #rand_ini[:, 0] = 0 - m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool) - m = tilde(m) - rand_ini = m.where(rand_ini, 0) - - #rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini - tmp = rad_values[:, 0, :] + rand_ini - m = Tensor.ones(tmp.shape).pad((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool) - m = tilde(m) - tmp = tmp.unsqueeze(1).pad((0,0,0,rad_values.shape[1]-1,0)) - rad_values = m.where(rad_values, tmp) - - tmp_over_one = mod(rad_values.cumsum(1), 1) - tmp_over_one_idx = padDiff(tmp_over_one) < 0 - cumsum_shift = Tensor.zeros_like(rad_values) - - #cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad((0,0,1,0)) - cumsum_shift = tmp_over_one_idx - - sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin() - return sines - def forward(self, f0, upp=None): - fn = f0.mul(Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to(f0.device)) - sine_waves = self._f02sine(fn) * self.sine_amp #generate sine waveforms - uv = self._f02uv(f0) # generate uv signal - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * randn_like(sine_waves) - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise - -class SourceHnNSF: - def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0): - self.sine_amp, self.noise_std = sine_amp, add_noise_std - self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold) - self.l_linear = nn.Linear(harmonic_num + 1, 1) - def forward(self, x, upp=None): - sine_waves, uv, _ = self.l_sin_gen.forward(x, upp) - sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh() - noise = randn_like(uv) * self.sine_amp / 3 - return sine_merge, noise, uv - -# most of the hifigan in standard vits is reused here, but need to upsample and construct harmonic source from f0 -class Generator: - def __init__(self, sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels): - self.sampling_rate, self.inter_channels, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.gin_channels = sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels - self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates) - self.conv_pre = nn.Conv1d(inter_channels, upsample_initial_channel, 7, 1, padding=3) - self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates)) - self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8) - resblock = ResBlock1 if resblock == '1' else ResBlock2 - self.ups, self.noise_convs, self.resblocks = [], [], [] - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - c_cur = upsample_initial_channel//(2**(i+1)) - self.ups.append(nn.ConvTranspose1d(upsample_initial_channel//(2**i), c_cur, k, u, padding=(k-u)//2)) - stride_f0 = int(np.prod(upsample_rates[i + 1:])) - self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2) if (i + 1 < len(upsample_rates)) else nn.Conv1d(1, c_cur, kernel_size=1)) - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3) - if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - self.upp = np.prod(upsample_rates) - def forward(self, x, f0, g=None): - f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2) # bs,n,t - har_source, _, _ = self.m_source.forward(f0, self.upp) - har_source = har_source.transpose(1, 2) - x = self.conv_pre(x) - if g is not None: x = x + self.cond(g) - for i in range(self.num_upsamples): - x, xs = self.ups[i](x.leaky_relu(LRELU_SLOPE)), None - x_source = self.noise_convs[i](har_source) - x = x + x_source - for j in range(self.num_kernels): - if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x) - else: xs += self.resblocks[i * self.num_kernels + j].forward(x) - x = xs / self.num_kernels - return self.conv_post(x.leaky_relu()).tanh() - -# **** helpers **** - -def randn_like(x:Tensor) -> Tensor: return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device) - -def tilde(x: Tensor) -> Tensor: - if x.dtype == dtypes.bool: return (1 - x).cast(dtypes.bool) - return (x + 1) * -1 # this seems to be what the ~ operator does in pytorch for non bool - -def lengths_to_padding_mask(lens:Tensor) -> Tensor: - bsz, max_lens = lens.shape[0], lens.max().numpy().item() - mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens) - mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens) - return mask.cast(dtypes.bool) - -def repeat_expand_2d_left(content, target_len): # content : [h, t] - src_len = content.shape[-1] - temp = np.arange(src_len+1) * target_len / src_len - current_pos, cols = 0, [] - for i in range(target_len): - if i >= temp[current_pos+1]: - current_pos += 1 - cols.append(content[:, current_pos]) - return Tensor.stack(*cols).transpose(0, 1) - -def load_fairseq_cfg(checkpoint_path): - assert Path(checkpoint_path).is_file() - state = torch_load(checkpoint_path) - cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None - if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}") - return HParams(**cfg) - -def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]): - assert Path(checkpoint_path).is_file() - start_time = time.time() - checkpoint_dict = torch_load(checkpoint_path) - saved_state_dict = checkpoint_dict['model'] - weight_g, weight_v, parent = None, None, None - for key, v in saved_state_dict.items(): - if any(layer in key for layer in skip_list): continue - try: - obj, skip = model, False - for k in key.split('.'): - if k.isnumeric(): obj = obj[int(k)] - elif isinstance(obj, dict): obj = obj[k] - else: - if k in ["weight_g", "weight_v"]: - parent, skip = obj, True - if k == "weight_g": weight_g = v - else: weight_v = v - if not skip: - parent = obj - obj = getattr(obj, k) - if weight_g and weight_v: - setattr(obj, "weight_g", weight_g.numpy()) - setattr(obj, "weight_v", weight_v.numpy()) - obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0) - weight_g, weight_v, parent, skip = None, None, None, False - if not skip and obj.shape == v.shape: - if "feature_extractor" in key and (isinstance(parent, (nn.GroupNorm, nn.LayerNorm))): # cast - obj.assign(v.to(obj.device).float()) - else: - obj.assign(v.to(obj.device)) - elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}") - except Exception as e: raise e - logging.info(f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s") - return model, optimizer - -def pad_array(arr, target_length): - current_length = arr.shape[0] - if current_length >= target_length: return arr - pad_width = target_length - current_length - pad_left = pad_width // 2 - pad_right = pad_width - pad_left - padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0)) - return padded_arr - -def split_list_by_n(list_collection, n, pre=0): - for i in range(0, len(list_collection), n): - yield list_collection[i-pre if i-pre>=0 else i: i + n] - -def get_sid(spk2id:HParams, speaker:str) -> Tensor: - speaker_id = spk2id[speaker] - if not speaker_id and type(speaker) is int: - if len(spk2id.__dict__) >= speaker: speaker_id = speaker - if speaker_id is None: raise RuntimeError(f"speaker={speaker} not in the speaker list") - return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0) - -def get_encoder(ssl_dim) -> Type[SpeechEncoder]: - if ssl_dim == 256: return ContentVec256L9 - if ssl_dim == 768: return ContentVec768L12 - -######################################################################################### -# CODE: https://github.com/svc-develop-team/so-vits-svc -######################################################################################### -# CONTENTVEC: -# CODE: https://github.com/auspicious3000/contentvec -# PAPER: https://arxiv.org/abs/2204.09224 -######################################################################################### -# INSTALLATION: dependencies are for preprocessing and loading/saving audio. -# pip3 install soundfile librosa praat-parselmouth -######################################################################################### -# EXAMPLE USAGE: -# python3 examples/so_vits_svc.py --model tf2spy --file ~/recording.wav -######################################################################################### -# DEMO USAGE (uses audio sample from LJ-Speech): -# python3 examples/so_vits_svc.py --model saul_goodman -######################################################################################### -SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC" -VITS_MODELS = { # config_path, weights_path, config_url, weights_url - "saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"), - "drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"), - "cartman" : (SO_VITS_SVC_PATH / "config_cartman.json", SO_VITS_SVC_PATH / "pretrained_cartman.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth"), - "tf2spy" : (SO_VITS_SVC_PATH / "config_tf2spy.json", SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth"), - "tf2heavy" : (SO_VITS_SVC_PATH / "config_tf2heavy.json", SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth"), - "lady_gaga" : (SO_VITS_SVC_PATH / "config_gaga.json", SO_VITS_SVC_PATH / "pretrained_gaga.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth") -} -ENCODER_MODELS = { # weights_path, weights_url - "contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt") -} -ENCODER_MODEL = "contentvec" -DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav" -if __name__=="__main__": - logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG)) - parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True) - parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file") - parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.") - parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.") - parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.") - parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.") - parser.add_argument("--noise_scale", default=0.4) - parser.add_argument("--tran", default=0.0, help="Pitch shift, supports positive and negative (semitone) values. Default 0.0") - parser.add_argument("--pad_seconds", default=0.5) - parser.add_argument("--lg_num", default=0.0) - parser.add_argument("--clip_seconds", default=0.0) - parser.add_argument("--slice_db", default=-40) - args = parser.parse_args() - - vits_model = args.model - encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model] - - Tensor.training = False - # Get Synthesizer and ContentVec - net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3]) - Encoder = get_encoder(hps.model.ssl_dim) - encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1]) - - # model config args - target_sample, spk2id, hop_length, target_sample = hps.data.sampling_rate, hps.spk, hps.data.hop_length, hps.data.sampling_rate - vol_embedding = hps.model.vol_embedding if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None else False - - # args - slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = args.slice_db, args.clip_seconds, args.lg_num, args.pad_seconds, args.tran, args.noise_scale, args.file - speaker = args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0] - - ### Loading audio and slicing ### - if audio_path == DEMO_PATH: fetch(DEMO_URL, DEMO_PATH) - assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav" - chunks = preprocess.cut(audio_path, db_thresh=slice_db) - audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks) - - per_size = int(clip_seconds * audio_sr) - lg_size = int(lg_num * audio_sr) - - ### Infer per slice ### - global_frame = 0 - audio = [] - for (slice_tag, data) in audio_data: - print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====") - length = int(np.ceil(len(data) / audio_sr * target_sample)) - - if slice_tag: - print("empty segment") - _audio = np.zeros(length) - audio.extend(list(pad_array(_audio, length))) - global_frame += length // hop_length - continue - - datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size) - - for k, dat in enumerate(datas): - per_length = int(np.ceil(len(dat) / audio_sr * target_sample)) if clip_seconds!=0 else length - pad_len = int(audio_sr * pad_seconds) - dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) - raw_path = io.BytesIO() - soundfile.write(raw_path, dat, audio_sr, format="wav") - raw_path.seek(0) - - ### Infer START ### - wav, sr = preprocess.load_audiofile(raw_path) - wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0] - wav16k, f0, uv = preprocess.get_unit_f0(wav, tran, hop_length, target_sample) - sid = get_sid(spk2id, speaker) - n_frames = f0.shape[1] - - # ContentVec infer - start = time.time() - c = encoder.encode(wav16k) - c = repeat_expand_2d_left(c.squeeze(0).realize(), f0.shape[1]) # interpolate speech encoding to match f0 - c = c.unsqueeze(0).realize() - enc_time = time.time() - start - - # VITS infer - vits_start = time.time() - out_audio, f0 = net_g.infer(c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None) - out_audio = out_audio[0,0].float().realize() - vits_time = time.time() - vits_start - - infer_time = time.time() - start - logging.info("total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format(infer_time, enc_time, vits_time)) - ### Infer END ### - - out_sr, out_frame = out_audio.shape[-1], n_frames - global_frame += out_frame - _audio = out_audio.numpy() - pad_len = int(target_sample * pad_seconds) - _audio = _audio[pad_len:-pad_len] - _audio = pad_array(_audio, per_length) - audio.extend(list(_audio)) - - audio = np.array(audio) - out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav") - out_path.parent.mkdir(parents=True, exist_ok=True) - soundfile.write(out_path, audio, target_sample, format="flac") - logging.info(f"Saved audio output to {out_path}") diff --git a/examples/sovits_helpers/preprocess.py b/examples/sovits_helpers/preprocess.py deleted file mode 100644 index 17a265040a..0000000000 --- a/examples/sovits_helpers/preprocess.py +++ /dev/null @@ -1,204 +0,0 @@ -import math -from typing import Optional, Tuple -from tinygrad import Tensor, dtypes -import librosa -import soundfile -import numpy as np -import parselmouth - -class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/ - def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): - self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm" - def interpolate_f0(self,f0): - vuv_vector = np.zeros_like(f0, dtype=np.float32) - vuv_vector[f0 > 0.0] = 1.0 - vuv_vector[f0 <= 0.0] = 0.0 - nzindex = np.nonzero(f0)[0] - data = f0[nzindex] - nzindex = nzindex.astype(np.float32) - time_org = self.hop_length / self.sampling_rate * nzindex - time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate - if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector - if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector - f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) - return f0,vuv_vector - def compute_f0(self,wav,p_len=None): - x = wav - if p_len is None: p_len = x.shape[0]//self.hop_length - else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" - time_step = self.hop_length / self.sampling_rate * 1000 - f0 = parselmouth.Sound(x, self.sampling_rate) \ - .to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \ - .selected_array['frequency'] - pad_size=(p_len - len(f0) + 1) // 2 - if(pad_size>0 or p_len - len(f0) - pad_size>0): - f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') - f0,uv = self.interpolate_f0(f0) - return f0 - def compute_f0_uv(self,wav,p_len=None): - x = wav - if p_len is None: p_len = x.shape[0]//self.hop_length - else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" - time_step = self.hop_length / self.sampling_rate * 1000 - f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac( - time_step=time_step / 1000, voicing_threshold=0.6, - pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency'] - pad_size=(p_len - len(f0) + 1) // 2 - if(pad_size>0 or p_len - len(f0) - pad_size>0): - f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') - f0,uv = self.interpolate_f0(f0) - return f0,uv - -class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/ - def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000): - if not min_length >= min_interval >= hop_size: - raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') - if not max_sil_kept >= hop_size: - raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') - min_interval = sr * min_interval / 1000 - self.threshold = 10 ** (threshold / 20.) - self.hop_size = round(sr * hop_size / 1000) - self.win_size = min(round(min_interval), 4 * self.hop_size) - self.min_length = round(sr * min_length / 1000 / self.hop_size) - self.min_interval = round(min_interval / self.hop_size) - self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) - def _apply_slice(self, waveform, begin, end): - if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] - else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] - def slice(self, waveform): - samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform - if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} - rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) - sil_tags, silence_start, clip_start = [], None, 0 - for i, rms in enumerate(rms_list): - if rms < self.threshold: # Keep looping while frame is silent. - if silence_start is None: # Record start of silent frames. - silence_start = i - continue - if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded. - # Clear recorded silence start if interval is not enough or clip is too short - is_leading_silence = silence_start == 0 and i > self.max_sil_kept - need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length - if not is_leading_silence and not need_slice_middle: - silence_start = None - continue - if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed. - pos = rms_list[silence_start: i + 1].argmin() + silence_start - sil_tags.append((0, pos) if silence_start == 0 else (pos, pos)) - clip_start = pos - elif i - silence_start <= self.max_sil_kept * 2: - pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() - pos += i - self.max_sil_kept - pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start - pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept - if silence_start == 0: - sil_tags.append((0, pos_r)) - clip_start = pos_r - else: - sil_tags.append((min(pos_l, pos), max(pos_r, pos))) - clip_start = max(pos_r, pos) - else: - pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start - pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept - sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r)) - clip_start = pos_r - silence_start = None - total_frames = rms_list.shape[0] - if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence. - silence_end = min(total_frames, silence_start + self.max_sil_kept) - pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start - sil_tags.append((pos, total_frames + 1)) - if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices. - chunks = [] - if sil_tags[0][0]: - chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"}) - for i in range(0, len(sil_tags)): - if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"}) - chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"}) - if sil_tags[-1][1] * self.hop_size < len(waveform): - chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"}) - chunk_dict = {} - for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i] - return chunk_dict - -# sinc_interp_hann audio resampling -class Resample: - def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None): - self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta - self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq)) - self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None) - def __call__(self, waveform:Tensor) -> Tensor: - if self.orig_freq == self.new_freq: return waveform - return self._apply_sinc_resample_kernel(waveform) - def _apply_sinc_resample_kernel(self, waveform:Tensor): - if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.") - orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd) - shape = waveform.shape - waveform = waveform.reshape(-1, shape[-1]) # pack batch - num_wavs, length = waveform.shape - target_length = int(math.ceil(new_freq * length / orig_freq)) - waveform = waveform.pad((self.width, self.width + orig_freq)) - resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq) - resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) - resampled = resampled[..., :target_length] - resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch - return resampled - def _get_sinc_resample_kernel(self, dtype=None): - orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd) - if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.") - base_freq = min(orig_freq, new_freq) - base_freq *= self.rolloff - width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq) - idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq - t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx - t *= base_freq - t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width) - window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2 - t *= math.pi - scale = base_freq / orig_freq - kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t) - kernels *= window * scale - if dtype is None: kernels = kernels.cast(dtype=dtypes.float32) - return kernels, width - -def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None): - resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype) - return resamp(x) - -def cut(audio_path, db_thresh=-30, min_len=5000): - audio, sr = librosa.load(audio_path, sr=None) - slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len) - chunks = slicer.slice(audio) - return chunks - -def chunks2audio(audio_path, chunks): - chunks = dict(chunks) - audio, sr = load_audiofile(audio_path) - if len(audio.shape) == 2 and audio.shape[1] >= 2: - audio = audio.mean(0).unsqueeze(0) - audio = audio.numpy()[0] - result = [] - for k, v in chunks.items(): - tag = v["split_time"].split(",") - if tag[0] != tag[1]: - result.append((v["slice"], audio[int(tag[0]):int(tag[1])])) - return result, sr - -def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True): - with soundfile.SoundFile(filepath, "r") as file_: - frames = file_._prepare_read(frame_offset, None, num_frames) - waveform = file_.read(frames, "float32", always_2d=True) - sample_rate = file_.samplerate - waveform = Tensor(waveform) - if channels_first: waveform = waveform.transpose(0, 1) - return waveform, sample_rate - -def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]: - f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample) - f0, uv = f0_predictor.compute_f0_uv(wav.numpy()) - if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected") - f0 = Tensor(f0.astype(np.float32)).float() - f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0) - uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0) - wav16k = sinc_interp_resample(wav[None,:], target_sample, 16000)[0] - return wav16k.realize(), f0.realize(), uv.realize() diff --git a/extra/bandwidth_test.py b/examples/tools/bandwidth_test.py similarity index 100% rename from extra/bandwidth_test.py rename to examples/tools/bandwidth_test.py diff --git a/extra/gpuburn.py b/examples/tools/gpuburn.py similarity index 100% rename from extra/gpuburn.py rename to examples/tools/gpuburn.py diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py deleted file mode 100644 index 521c981182..0000000000 --- a/examples/train_efficientnet.py +++ /dev/null @@ -1,104 +0,0 @@ -import traceback -import time -from multiprocessing import Process, Queue -import numpy as np -from tinygrad.nn.state import get_parameters -from tinygrad.nn import optim -from tinygrad.helpers import getenv, trange -from tinygrad.tensor import Tensor -from extra.datasets import fetch_cifar -from extra.models.efficientnet import EfficientNet - -class TinyConvNet: - def __init__(self, classes=10): - conv = 3 - inter_chan, out_chan = 8, 16 # for speed - self.c1 = Tensor.uniform(inter_chan,3,conv,conv) - self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) - self.l1 = Tensor.uniform(out_chan*6*6, classes) - - def forward(self, x): - x = x.conv2d(self.c1).relu().max_pool2d() - x = x.conv2d(self.c2).relu().max_pool2d() - x = x.reshape(shape=[x.shape[0], -1]) - return x.dot(self.l1) - -if __name__ == "__main__": - IMAGENET = getenv("IMAGENET") - classes = 1000 if IMAGENET else 10 - - TINY = getenv("TINY") - TRANSFER = getenv("TRANSFER") - if TINY: - model = TinyConvNet(classes) - elif TRANSFER: - model = EfficientNet(getenv("NUM", 0), classes, has_se=True) - model.load_from_pretrained() - else: - model = EfficientNet(getenv("NUM", 0), classes, has_se=False) - - parameters = get_parameters(model) - print("parameter count", len(parameters)) - optimizer = optim.Adam(parameters, lr=0.001) - - BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048) - print(f"training with batch size {BS} for {steps} steps") - - if IMAGENET: - from extra.datasets.imagenet import fetch_batch - def loader(q): - while 1: - try: - q.put(fetch_batch(BS)) - except Exception: - traceback.print_exc() - q = Queue(16) - for i in range(2): - p = Process(target=loader, args=(q,)) - p.daemon = True - p.start() - else: - X_train, Y_train, _, _ = fetch_cifar() - X_train = X_train.reshape((-1, 3, 32, 32)) - Y_train = Y_train.reshape((-1,)) - - with Tensor.train(): - for i in (t := trange(steps)): - if IMAGENET: - X, Y = q.get(True) - else: - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - X, Y = X_train.numpy()[samp], Y_train.numpy()[samp] - - st = time.time() - out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) - fp_time = (time.time()-st)*1000.0 - - y = np.zeros((BS,classes), np.float32) - y[range(y.shape[0]),Y] = -classes - y = Tensor(y, requires_grad=False) - loss = out.log_softmax().mul(y).mean() - - optimizer.zero_grad() - - st = time.time() - loss.backward() - bp_time = (time.time()-st)*1000.0 - - st = time.time() - optimizer.step() - opt_time = (time.time()-st)*1000.0 - - st = time.time() - loss = loss.numpy() - cat = out.argmax(axis=1).numpy() - accuracy = (cat == Y).mean() - finish_time = (time.time()-st)*1000.0 - - # printing - t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % - (loss, accuracy, - fp_time, bp_time, opt_time, finish_time, - fp_time + bp_time + opt_time + finish_time)) - - del out, y, loss diff --git a/examples/vit.py b/examples/vit.py deleted file mode 100644 index bf9a8f5d31..0000000000 --- a/examples/vit.py +++ /dev/null @@ -1,46 +0,0 @@ -import ast -import numpy as np -from PIL import Image -from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv, fetch -from extra.models.vit import ViT -""" -fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz" -import tensorflow as tf -with tf.io.gfile.GFile(fn, "rb") as f: - dat = f.read() - with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g: - g.write(dat) -""" - -Tensor.training = False -if getenv("LARGE", 0) == 1: - m = ViT(embed_dim=768, num_heads=12) -else: - # tiny - m = ViT(embed_dim=192, num_heads=3) -m.load_from_pretrained() - -# category labels -lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text()) - -#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg" -url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0" - -# junk -img = Image.open(fetch(url)) -aspect_ratio = img.size[0] / img.size[1] -img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) -img = np.array(img) -y0,x0=(np.asarray(img.shape)[:2]-224)//2 -img = img[y0:y0+224, x0:x0+224] -img = np.moveaxis(img, [2,0,1], [0,1,2]) -img = img.astype(np.float32)[:3].reshape(1,3,224,224) -img /= 255.0 -img -= 0.5 -img /= 0.5 - -out = m.forward(Tensor(img)) -outnp = out.numpy().ravel() -choice = outnp.argmax() -print(out.shape, choice, outnp[choice], lbls[choice]) diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py deleted file mode 100644 index ca19c4ff2d..0000000000 --- a/extra/assembly/assembly.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast -from tinygrad.codegen.opt.kernel import Ops, MemOp, UOp -from tinygrad.uop.ops import BinaryOps, UnaryOps -from tinygrad.dtype import DType, dtypes -from tinygrad.helpers import DEBUG -from tinygrad.uop.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode -import functools -import math -from collections import defaultdict - -_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h', - dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'} - -class Register(NamedTuple): - nm:str - dtype:DType - scalar:bool - off:Optional[int] = None - def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}" - def subregs(self): - if self.dtype == dtypes.float.vec(4): - return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)] - return [] - -class AssemblyInstruction(NamedTuple): - op: Ops - out: Optional[Register] - vin: List[Union[Register, int, float]] - arg: Any = None - -# warp size of 32, s registers are shared across the warp, v are 32-wide vectors -class AssemblyLanguage: - supports_load3: bool = False - sin_is_sin2pi: bool = False - no_div: bool = False - #TODO: these should be global vars - cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) - tor: Dict[Any, Register] = {} - ins: List[AssemblyInstruction] = [] - - def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]] - def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register: - self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar) - if dtype == dtypes.float.vec(4): - for off in range(4): - self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off) - self.cnts[(dtype, scalar)] += 1 - return ret - - def render_numnode(self, b) -> Register: - key = ("num", b) - if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b)) - return self.tor[key] - - def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: - key = (op, a, b) - if key not in self.tor: - #if not isinstance(b, Register): b = render_numnode(b) - self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) - return self.tor[key] - - def render_cast(self, a:Register, new_dtype:DType) -> Register: - if a.dtype == new_dtype: return a - key = (a, new_dtype) - if key not in self.tor: - self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a])) - return self.tor[key] - - render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), - MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), - DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), - ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), - LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } - - def addr_w_offset(self, args): - assert isinstance(args, MemOp) - idx = args.idx*args.memory_dtype.itemsize - off = 0 # TODO: should this be None? - if isinstance(idx, SumNode): - nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] - if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? - idx -= nums[0] - off = cast(int, nums[0]) - reg = idx.render(self.render_ops, self) - if self.supports_load3: - if reg.scalar: - new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype) - self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP)) - reg = new_reg - return self.tor[args.name], reg, off - reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64) - return reg, None, off - -def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): - #TODO: Do not use clear() - lang.ins.clear() - lang.tor.clear() - lang.cnts.clear() - buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL} - global_size, local_size = [], [] - skipload_branch = 0 - lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] - for u in uops: - uop,dtype,vin,args,_ = u - if uop == Ops.DEFINE_LOCAL: - lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args)) - lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) - elif uop == Ops.LOOP: - if args[1] == "global": - for i,var in enumerate(args[0]): - global_size.append(var.max+1) - lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) - elif args[1] == "local": - for i,var in enumerate(args[0]): - local_size.append(var.max+1) - lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) - else: - for var in args[0]: - if not isinstance(var, NumNode): # TODO: why is this coming through? - lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) - lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr)) - elif uop == Ops.ENDLOOP: - if args[1] not in ["global", "local", "global+local"]: - for var in reversed(args[0]): - if not isinstance(var, NumNode): # TODO: why is this coming through? - lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD)) - pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool) - lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) - elif args[1] == "global+local": - for i, var in enumerate(reversed(args[0])): - lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}"))) - elif args[1] == 'local': - for i, var in enumerate(reversed(args[0])): - lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}"))) - elif uop == Ops.CAST: - # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies - out = lang.newreg(u, dtype) - for i,sr in enumerate(out.subregs()): - lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)) - elif uop == Ops.ALU: - out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u] - # this is the only thing that can violate SSA - if args in [BinaryOps.CMPLT]: - pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool) - lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args)) - lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args)) - elif args == BinaryOps.DIV and lang.no_div: - tmp = lang.newreg((u, "rcp")) - lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP)) - lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL)) - elif args == UnaryOps.SIN and lang.sin_is_sin2pi: - tmp = lang.newreg((u, "2pi")) - lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) - lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args)) - else: - lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args)) - elif uop == Ops.DEFINE_REG: - reg = lang.newreg(u, dtype=dtype) - lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args)) - elif uop == Ops.SPECIAL: - lang.tor[u] = lang.tor[args] - elif uop == Ops.CONST: - lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args)) - elif uop == Ops.LOAD: - idx, treg, off = lang.addr_w_offset(args) - reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) - if args.valid.min == 0: - lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0)) - if args.valid.max == 1: - pred = args.valid.render(lang.render_ops, lang) - lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) - if args.valid.max == 1: - # NOTE: you can't compute the index in here, because it assumes it's all available later - lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) - if args.valid.min == 0 and args.valid.max == 1: - lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}")) - skipload_branch += 1 - elif uop == Ops.STORE: - if args is None: - lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP)) - else: - idx, treg, off = lang.addr_w_offset(args) - lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) - - if DEBUG >= 4: - for tins in lang.ins: print(tins) - return global_size, local_size diff --git a/extra/assembly/assembly_arm64.py b/extra/assembly/assembly_arm64.py deleted file mode 100644 index c5a3ad49b8..0000000000 --- a/extra/assembly/assembly_arm64.py +++ /dev/null @@ -1,177 +0,0 @@ -import struct -from platform import system -from typing import Tuple, Dict, List, Optional -from tinygrad import dtypes -from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.codegen.opt.kernel import Ops, UOp -from tinygrad.helpers import CI -from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage - -def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) -def compute_offsets(total): - quotient, remainder = divmod(total, 4096) - return [4096]*quotient + [remainder] if remainder else [4096]*quotient - -#NOTE: Darwin needs names to start with a "_" -def get_name(name): return ('_' if system() == 'Darwin' else '') + name - -class ARM64Language(AssemblyLanguage): pass - -def specialize_to_arm64(fn_nm, asm): - var_size = 16 - prev_uop:Optional[Ops] = None - ins = [] - x_regs = ['x' + str(i) for i in reversed(range(12))] - s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16] - type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", - BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", - UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", - UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), - TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"} - - def mov_imm(value, reg): - # Manually move value into reg if value can't fit - if value.__class__ is not float and abs(value) > abs(65535): - ins.append(f"movz w15, #{value & 0xffff}") - ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16") - ins.append(f"sxtw {reg}, w15") - elif reg[0] == 's': - ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") - ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") - ins.append("str x15, [sp, 16]") - ins.append(f"ldr {reg}, [sp, 16]") - else: - ins.append(f"mov {reg}, #{value}") - - # Get variables intervals - live_range:Dict[str, List[int]] = {} - for i, (uop, out, vin, arg) in enumerate(asm): - for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]): - live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i] - - mem_vars:Dict[str, int] = {} - rtor:Dict[str, str] = {} - def allocate_regs(mvars): - nonlocal var_size - for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]: - available_regs = s_regs if dtypes.is_float(v[1]) else x_regs - #NOTE: Very simple spill, everything that don't fit in regs goes to mem - if not available_regs: - # ARM needs the stack 16-byte aligned - var_size += 16 - available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12') - mem_vars[v.nm] = var_size - rtor[v.nm] = available_regs.pop() - - temp_floats = ['s0', 's1', 's2'] - temp_ints = ['x12', 'x13', 'x16'] - for i, (uop, out, vin, arg) in enumerate(asm): - # Clear regs out of interval - for var, reg in list(rtor.items()): - available_regs = s_regs if reg[0] == 's' else x_regs - if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]: - available_regs.append(rtor.pop(var)) - # Assign a registers to the variables using live ranges. - allocate_regs([out] + vin) - # Assign temp regs to vin and load them before direct use - for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]): - rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] - # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 - ins.append(f"mov x15, {mem_vars[v.nm]}") - ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") - - if uop == Ops.SPECIAL: - if arg.startswith('data'): - # data 8 to n into the stack - if int(arg[4:]) >= 8: - ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]") - ins.append(f"mov {rtor[out.nm]}, x15") - else: - ins.append(f"mov {rtor[out.nm]}, #0") - ins.append(f"loop_{arg}:") - elif uop == Ops.CAST: - if arg == BinaryOps.CMPLT: - if rtor[out.nm][0] == 's': - mov_imm(0.0, 's0') - mov_imm(1.0, 's1') - ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") - if rtor[out.nm][0] == 'x': - mov_imm(0, 'x14') - mov_imm(1, 'x15') - ins.append(f"csel {rtor[out.nm]}, x15, x14, lt") - else: - ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") - elif uop == Ops.ALU: - if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15') - if arg == BinaryOps.MUL and out.dtype == dtypes.bool: - ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") - elif arg == TernaryOps.WHERE: - ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0") - ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne") - elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: - #NOTE: Not a real instruction, use to emulate a ext call in unicorn - if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}") - else: - save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars] - ins.append(f"sub sp, sp, #{(len(save_regs))*16}") - # Save the registers before they are cleared by func call - for i,k in enumerate(save_regs,1): - ins.append(f"str {rtor[k]}, [sp, #{16*i}]") - ins.append("stp x29, x30, [sp, #0]!") - ins.append("mov x29, sp") - ins.append(f"fmov s0, {rtor[vin[0].nm]}") - ins.append(alu[arg]) - ins.append(f"fmov {rtor[out.nm]}, s0") - ins.append("mov sp, x29") - ins.append("ldp x29, x30, [sp], #0") - for i,k in enumerate(save_regs,1): - ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]") - ins.append(f"add sp, sp, #{len(save_regs)*16}") - elif arg == BinaryOps.CMPLT: - ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}") - elif arg == BinaryOps.MOD: - rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm] - ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}") - ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}") - else: - ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") - elif uop == Ops.LOAD: - if arg.__class__ in (int, float): - mov_imm(arg, rtor[out.nm]) - else: - #NOTE: if need casting load var in s/h0 or x/w12 temp regs - reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm] - mov_imm(arg[0], "x15") - ins.append(f"add x15, {rtor[vin[0].nm]}, x15") - ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") - if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}") - elif uop == Ops.STORE: - #NOTE: if need casting load var in s/h0 or x/w12 temp regs - reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) - if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}") - ins.append(f"mov x15, #{arg[0]}") - ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]") - elif uop == Ops.COND_BRANCH: - #TODO: this is a hack it shouldn't always be a cmp before a cond branch? - if prev_uop == Ops.LOAD: - ins.append(f"cmp {rtor[vin[0].nm]}, #0") - ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") - elif uop == Ops.LABEL: - ins.append(f"{arg[1:]}:") - elif uop == Ops.ENDLOOP: - mov_imm(arg[0], "x15") - ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") - ins.append(f"cmp {rtor[vin[0].nm]}, x15") - ins.append(f"b.lt loop_{arg[1]}") - prev_uop = uop - # store regs into memory if needed - if out is not None and out.nm in mem_vars: - ins.append(f"mov x15, {mem_vars[out.nm]}") - ins.append(f"str {rtor[out.nm]}, [sp, x15]") - return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) - -def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]: - lang = ARM64Language() - global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops) - return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True \ No newline at end of file diff --git a/extra/assembly/assembly_ptx.py b/extra/assembly/assembly_ptx.py deleted file mode 100644 index 9a9593eb20..0000000000 --- a/extra/assembly/assembly_ptx.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import List -import struct -from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage -from tinygrad.codegen.opt.kernel import Ops, UOp -from tinygrad import dtypes -from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.runtime.ops_cuda import arch - -dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"} -def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) - -def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize) - -def render_cast(ins, inp, out): - if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)): - ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};") - elif out.dtype == dtypes.bool: - if inp.dtype == dtypes.bool: - ins.append(f"mov.pred {out}, {inp};") - else: - ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};") - else: - round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else '' - ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};") - -# https://docs.nvidia.com/cuda/parallel-thread-execution/# - -class PTXLanguage(AssemblyLanguage): - supports_constant_folding: bool = True - -def specialize_to_ptx(lang, function_name): - param_cnt = 0 - ins = [] - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", - BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx", - UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", - UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", - TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"} - for uop, out, vin, arg in lang.ins: - if uop == Ops.ENDLOOP: - ins.append("bar.sync 0;") - elif uop == Ops.DEFINE_LOCAL: - ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") - elif uop == Ops.SPECIAL: - if arg.startswith('data'): - param_cnt += 1 - ins.append(f"ld.param.u64 {out}, [{arg}];") - # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to? - # ins.append(f"cvta.to.global.u64 {out}, {out};") - elif arg.startswith('gid'): - ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") - elif arg.startswith('lid'): - ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") - elif uop == Ops.ALU: - if arg == BinaryOps.MUL and out.dtype == dtypes.bool: - ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") - else: - otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype - if arg == TernaryOps.WHERE: - if vin[0].dtype == dtypes.bool: - reg = vin[0] - else: - reg = lang.newreg((vin[0], 'bool'), dtypes.bool) - ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") - vin = vin[1:] + [reg] - ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") - elif uop == Ops.LOAD: - if arg.__class__ in (int, float): - ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};") - elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype): - dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2]) - reg = lang.newreg((out, dt[0]), dtype=dt[1]) - ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") - render_cast(ins, reg, out) - else: - ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") - elif uop == Ops.STORE: - if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool: - if arg[2] == dtypes.bool != vin[1].dtype: - prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool) - render_cast(ins, vin[1], prereg) - else: prereg = vin[1] - reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]) - render_cast(ins, prereg, reg) - ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};") - else: - ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") - elif uop == Ops.CAST: - render_cast(ins, vin[0], out) - elif uop == Ops.LABEL: - ins.append(f"{arg}:") - elif uop == Ops.COND_BRANCH: - ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") - - ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64", - f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"] - for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",) - ins = ins_prefix + ins - ins += ["ret;", "}"] - return '\n'.join(ins) - -def uops_to_ptx_asm(function_name:str, uops:List[UOp]): - lang = PTXLanguage() - global_size, local_size = uops_to_asmstyle(lang, function_name, uops) - return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True diff --git a/extra/assembly/assembly_rdna.py b/extra/assembly/assembly_rdna.py deleted file mode 100644 index 297639d676..0000000000 --- a/extra/assembly/assembly_rdna.py +++ /dev/null @@ -1,203 +0,0 @@ -import yaml -from typing import Tuple, Set, Dict -from tinygrad import dtypes -from tinygrad.codegen.assembly import AssemblyCodegen, Register -from tinygrad.codegen.opt.kernel import Ops -from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH - -# ugh, is this really needed? -from extra.helpers import enable_early_exec -early_exec = enable_early_exec() - -boilerplate_start = """ -.global _start -_start: -.rodata -.align 0x10 -.global code.kd -.type code.kd,STT_OBJECT -.amdhsa_kernel code""" - -code_start = """.end_amdhsa_kernel -.text -code: -""" - -# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst -# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state -# RDNA3 is actually a SIMD machine! -class RDNACodegen(AssemblyCodegen): - supports_float4: bool = True - supports_float4_alu: bool = True - supports_load3: bool = True - sin_is_sin2pi: bool = True - no_div: bool = True - - def specialize(self, asm) -> Tuple[str, str]: - args = [] - for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'}) - ins = [] - - v_cnt = 3 # v[0:2] is local_xyz - s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz - - dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"} - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma", - BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp", - UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp", - BinaryOps.CMPLT: "cmp_lt"} - - pend_regs:Set[Register] = set() - rtor:Dict[Register, str] = {} - def reg_in(x): - nonlocal pend_regs - #print("reg_in", x, rtor[x], pend_regs) - if x in pend_regs: - #print("clear") - ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)') - pend_regs.clear() - return rtor[x] - def reg_out(x): - return rtor[x] - for uop, out, vin, arg in asm: - if uop == Ops.DEFINE_REGISTER: - if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]: - for i in range(arg[2]): - # TODO: Re-use gaps created by this to avoid wasting registers - align = int(arg[0][0].itemsize / 4) - if arg[0][1]: - s_cnt += s_cnt % align - reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}" - s_cnt += align - else: - v_cnt += v_cnt % align - reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}" - v_cnt += align - rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name - - if arg[0][0] == dtypes.float.vec(4): - for off in range(4): - reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}" - rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name - elif arg[0][0] == dtypes.bool: - for i in range(arg[2]): - reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32 - rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name - else: - raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg) - elif uop == Ops.SPECIAL: - if arg.startswith('buf'): - i = int(arg[3:]) - ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}') - pend_regs.add(out) - for r in out.subregs(): pend_regs.add(r) - elif arg.startswith('gid'): - ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}') - # the docs lied, this is actually y - if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested - if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10") - elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0") - # get local size - offset = len(args)*8 - args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8}) - ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}') - ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)') - pend_regs.clear() - ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}') - ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}') - elif uop == Ops.CONST: - if arg == float('inf'): arg = "0x7f800000" - elif arg == float('-inf'): arg = "0xff800000" - if out.dtype == dtypes.float.vec(4): - for off in range(4): - ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}") - else: - ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}") - elif uop == Ops.ALU: - if arg in [BinaryOps.CMPLT]: - ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") - else: - alu_arg = alu[arg] - if arg == TernaryOps.MULACC and out == vin[2]: - alu_arg = "fmac" - vin = vin[0:2] - if out.dtype == dtypes.float.vec(4): - for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]): - ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}") - else: - ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") - elif uop == Ops.LOAD: - if out.scalar: - # swap arg order - ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}') - else: - ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') - pend_regs.add(out) - for r in out.subregs(): pend_regs.add(r) - elif uop == Ops.STORE: - ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') - elif uop == Ops.LABEL: - ins.append(f"{arg}:") - elif uop == Ops.COND_BRANCH: - ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}") - elif uop == Ops.CAST: - if vin[0].dtype == dtypes.bool: - if out.dtype == dtypes.float32: - ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}") - else: - raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}") - else: - raise NotImplementedError(uop) - - ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end'] - - # dual alu group - seen = set() - new_ins = [] - for i,tins in enumerate(ins): - if tins in seen: continue - if tins.startswith("v_fmac_f32"): - for gins in reversed(ins[i+1:]): - if gins in seen: continue - if gins.startswith("v_fmac_f32"): - r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]] - r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]] - if r0[0]%2 == r1[0]%2: continue - if r0[1]%2 == r1[1]%2: continue - if r0[2]%2 == r1[2]%2: continue - new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_")) - seen.add(tins) - seen.add(gins) - break - if tins not in seen: - new_ins.append(tins) - ins = new_ins - - return 'code', self.assemble(args, ins, v_cnt, s_cnt) - - def assemble(self, args, ins, v_cnt, s_cnt): - kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0, - '.amdhsa_next_free_vgpr': v_cnt, # this matters! - '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0, - '.amdhsa_next_free_sgpr': s_cnt, - '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, - '.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0, - '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1, - '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real? - '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0, - '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1, - '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0} - - metadata = {'amdhsa.kernels': [{'.args': args, - '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"], - '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256, - '.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0, - '.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0, - '.wavefront_size': 32}], - 'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]} - - code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" - obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8"))) - asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj)) - return asm diff --git a/extra/assembly/ptx/test.py b/extra/assembly/ptx/test.py deleted file mode 100644 index f30348b8c4..0000000000 --- a/extra/assembly/ptx/test.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env python3 -import numpy as np -from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer - -if __name__ == "__main__": - test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32)) - prg = CUDAProgram("test", """ - .version 7.8 - .target sm_86 - .address_size 64 - .visible .entry test(.param .u64 x) { - .reg .b32 %r<2>; - .reg .b64 %rd<3>; - - ld.param.u64 %rd1, [x]; - cvta.to.global.u64 %rd2, %rd1; - mov.u32 %r1, 0x40000000; // 2.0 in float - st.global.u32 [%rd2], %r1; - ret; - }""", binary=True) - prg([1], [1], test) - print(test.toCPU()) - diff --git a/extra/augment.py b/extra/augment.py deleted file mode 100644 index 06e7906c7d..0000000000 --- a/extra/augment.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np -from PIL import Image -from pathlib import Path -import sys -cwd = Path.cwd() -sys.path.append(cwd.as_posix()) -sys.path.append((cwd / 'test').as_posix()) -from extra.datasets import fetch_mnist -from tqdm import trange - -def augment_img(X, rotate=10, px=3): - Xaug = np.zeros_like(X) - for i in trange(len(X)): - im = Image.fromarray(X[i]) - im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC) - w, h = X.shape[1:] - #upper left, lower left, lower right, upper right - quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0]) - im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC) - Xaug[i] = im - return Xaug - -if __name__ == "__main__": - import matplotlib.pyplot as plt - X_train, Y_train, X_test, Y_test = fetch_mnist() - X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) - X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) - X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10) - fig, a = plt.subplots(2,len(X)) - Xaug = augment_img(X) - for i in range(len(X)): - a[0][i].imshow(X[i], cmap='gray') - a[1][i].imshow(Xaug[i],cmap='gray') - a[0][i].axis('off') - a[1][i].axis('off') - plt.show() - - #create some nice gifs for doc?! - for i in range(10): - im = Image.fromarray(X_train[7353+i]) - im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))] - im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0) diff --git a/extra/backends/clang_graph.py b/extra/backends/clang_graph.py deleted file mode 100644 index 2e946d54c4..0000000000 --- a/extra/backends/clang_graph.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import List, Dict, cast -import ctypes -from tinygrad.helpers import dedup, cpu_time_execution, DEBUG -from tinygrad.engine.jit import GraphRunner, GraphException -from tinygrad.device import Buffer, Device -from tinygrad.engine.realize import ExecItem, CompiledRunner -from tinygrad.uop.ops import Variable -from tinygrad.runtime.ops_cpu import ClangProgram -from tinygrad.renderer.cstyle import ClangRenderer -render_dtype = ClangRenderer().render_dtype - -class ClangGraph(GraphRunner): - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]): - super().__init__(jit_cache, input_rawbuffers, var_vals) - if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException - - prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache])) - args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)] - args += sorted([f"int {v}" for v in var_vals]) - code = ["void batched("+','.join(args)+") {"] - for ji in jit_cache: - args = [] - for buf in ji.bufs: - assert buf is not None - if buf in input_rawbuffers: - args.append(f"arg{input_rawbuffers.index(buf)}") - else: - args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}") - args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars] - code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});") - code.append("}") - if DEBUG >= 4: print("\n".join(code)) - compiler = Device["CPU"].compiler - assert compiler is not None - self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers - - def __call__(self, rawbufs: List[Buffer], var_vals: Dict[str, int], wait=False): - return cpu_time_execution( - lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0])]), enable=wait) diff --git a/extra/backends/graph_hip.py b/extra/backends/graph_hip.py deleted file mode 100644 index ddcb3d58b1..0000000000 --- a/extra/backends/graph_hip.py +++ /dev/null @@ -1,27 +0,0 @@ -import ctypes -from typing import Tuple -import tinygrad.runtime.autogen.hip as hip -from tinygrad.helpers import init_c_var, time_execution_cuda_style -from tinygrad.runtime.ops_hip import check, hip_set_device -from tinygrad.runtime.graph.cuda import CUDAGraph - -# TODO: this is only used in graph -def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501 - -class HIPGraph(CUDAGraph): - def __del__(self): - if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph)) - if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance)) - def set_device(self): hip_set_device(self.dev) - def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3)) - def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0))) - def graph_instantiate(self, graph): - return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0))) - def graph_add_kernel_node(self, graph, c_deps, c_params): - return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501 - def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait) - def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args)) - def build_kernel_node_params(self, prg, global_size, local_size, c_config): - return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0) - def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): - node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size diff --git a/extra/backends/hsa_driver.py b/extra/backends/hsa_driver.py deleted file mode 100644 index 3091e0c0e2..0000000000 --- a/extra/backends/hsa_driver.py +++ /dev/null @@ -1,143 +0,0 @@ -import ctypes, collections -import tinygrad.runtime.autogen.hsa as hsa -from tinygrad.helpers import init_c_var - -def check(status): - if status != 0: - hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)())) - raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}") - -# Precalulated AQL info -AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t) -EMPTY_SIGNAL = hsa.hsa_signal_t() - -DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS -DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER -DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE -DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE -DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE - -BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER -BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE -BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE -BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE - -class AQLQueue: - def __init__(self, device, sz=-1): - self.device = device - - check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32()))) - queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value - - null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)() - self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check( - hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x)))) - - self.next_doorbell_index = 0 - self.queue_base = self.hw_queue.contents.base_address - self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes - self.write_addr = self.queue_base - self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time - self.available_packet_slots = self.hw_queue.contents.size - - check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH)) - check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1)) - - def __del__(self): - if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue)) - - def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None): - if self.available_packet_slots == 0: self._wait_queue() - - packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr) - packet.workgroup_size_x = local_size[0] - packet.workgroup_size_y = local_size[1] - packet.workgroup_size_z = local_size[2] - packet.reserved0 = 0 - packet.grid_size_x = global_size[0] * local_size[0] - packet.grid_size_y = global_size[1] * local_size[1] - packet.grid_size_z = global_size[2] * local_size[2] - packet.private_segment_size = prg.private_segment_size - packet.group_segment_size = prg.group_segment_size - packet.kernel_object = prg.handle - packet.kernarg_address = kernargs - packet.reserved2 = 0 - packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL - packet.setup = DISPATCH_KERNEL_SETUP - packet.header = DISPATCH_KERNEL_HEADER - self._submit_packet() - - def submit_barrier(self, wait_signals=None, completion_signal=None): - assert wait_signals is None or len(wait_signals) <= 5 - if self.available_packet_slots == 0: self._wait_queue() - - packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr) - packet.reserved0 = 0 - packet.reserved1 = 0 - for i in range(5): - packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL - packet.reserved2 = 0 - packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL - packet.header = BARRIER_HEADER - self._submit_packet() - - def blit_packets(self, packet_addr, packet_cnt): - if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt) - - tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt) - rem_packet_cnt = packet_cnt - tail_blit_packets - ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets) - if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt) - - self._submit_packet(packet_cnt) - - def wait(self): - self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True)) - hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - - def _wait_queue(self, need_packets=1): - while self.available_packet_slots < need_packets: - rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue) - self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex) - - def _submit_packet(self, cnt=1): - self.available_packet_slots -= cnt - self.next_doorbell_index += cnt - hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index) - hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1) - - self.write_addr += AQL_PACKET_SIZE * cnt - if self.write_addr > self.write_addr_end: - self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size - -def scan_agents(): - agents = collections.defaultdict(list) - - @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p) - def __scan_agents(agent, data): - status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t())) - if status == 0: agents[device_type.value].append(agent) - return hsa.HSA_STATUS_SUCCESS - - hsa.hsa_iterate_agents(__scan_agents, None) - return agents - -def find_memory_pool(agent, segtyp=-1, location=-1): - @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p) - def __filter_amd_memory_pools(mem_pool, data): - check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t()))) - if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS - - check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t()))) - if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS - - check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t()))) - if sz.value == 0: return hsa.HSA_STATUS_SUCCESS - - ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t)) - ret[0] = mem_pool - return hsa.HSA_STATUS_INFO_BREAK - - hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t())) - return region diff --git a/extra/backends/hsa_graph.py b/extra/backends/hsa_graph.py deleted file mode 100644 index b8df58857b..0000000000 --- a/extra/backends/hsa_graph.py +++ /dev/null @@ -1,171 +0,0 @@ -import ctypes, collections, time, itertools -from typing import List, Any, Dict, cast, Optional, Tuple -from tinygrad.helpers import init_c_var, round_up -from tinygrad.device import Buffer, BufferSpec -from tinygrad.device import Compiled, Device -from tinygrad.uop.ops import Variable -from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler -from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner -from tinygrad.engine.jit import MultiGraphRunner, GraphException -import tinygrad.runtime.autogen.hsa as hsa -from tinygrad.runtime.support.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL - -def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])] - -class VirtAQLQueue(AQLQueue): - def __init__(self, device, sz): - self.device = device - self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)() - self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue) - self.packets_count = 0 - self.available_packet_slots = sz - def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!" - def _submit_packet(self): - self.write_addr += AQL_PACKET_SIZE - self.packets_count += 1 - self.available_packet_slots -= 1 - -class HSAGraph(MultiGraphRunner): - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]): - super().__init__(jit_cache, input_rawbuffers, var_vals) - - # Check all jit items are compatible. - compiled_devices = set() - for ji in self.jit_cache: - if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.dev) - elif isinstance(ji.prg, BufferXfer): - for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device]) - else: raise GraphException - if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException - - self.devices: List[HSADevice] = list(compiled_devices) #type:ignore - - # Allocate kernel args. - kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) - for ji in self.jit_cache: - if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16) - kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferSpec()) for dev,sz in kernargs_size.items()} - - # Fill initial arguments. - self.ji_kargs_structs: Dict[int, ctypes.Structure] = {} - for j,ji in enumerate(self.jit_cache): - if not isinstance(ji.prg, CompiledRunner): continue - self.ji_kargs_structs[j] = ji.prg._prg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev]) - kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16) - for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf) - for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i].expr]) - - # Build queues. - self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices} - self.packets = {} - self.transfers = [] - self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table. - self.signals_to_reset: List[hsa.hsa_signal_t] = [] - self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {} - self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list) - - # Special packet to wait for the world. - self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices} - for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev]) - - for j,ji in enumerate(self.jit_cache): - if isinstance(ji.prg, CompiledRunner): - wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False) - for i in range(0, len(wait_signals), 5): - self.virt_aql_queues[ji.prg.dev].submit_barrier(wait_signals[i:i+5]) - self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.dev].write_addr) - - sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None - self.virt_aql_queues[ji.prg.dev].submit_kernel(ji.prg._prg, *ji.prg.p.launch_dims(var_vals), #type:ignore - ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal) - if PROFILE: self.profile_info[ji.prg.dev].append((sync_signal, ji.prg._prg.name, False)) - elif isinstance(ji.prg, BufferXfer): - dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] - dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device]) - sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev]) - - wait_signals = self.access_resources([dest, src], write=[0], new_dependency=sync_signal, sync_with_aql_packets=True) - self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), - (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True]) - self.ji_to_transfer[j] = len(self.transfers) - 1 - if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True)) - - # Wait for all active signals to finish the graph - wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) - for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))): - for dev in self.signals_to_devices[v.handle]: - wait_signals_to_finish[dev].append(v) - - self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) - for dev in self.devices: - wait_signals = wait_signals_to_finish[dev] - for i in range(0, max(1, len(wait_signals)), 5): - self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None) - - # Zero signals to allow graph to start and execute. - for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0) - hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0) - - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[str, int], wait=False) -> Optional[float]: - # Wait and restore signals - hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1) - hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices)) - - # Update rawbuffers - for (j,i),input_idx in self.input_replace.items(): - if j in self.ji_kargs_structs: - self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf) - else: - if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest - elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src - - # Update var_vals - for j in self.jc_idx_with_updatable_var_vals: - for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): - self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v.expr]) - - # Update launch dims - for j in self.jc_idx_with_updatable_launch_dims: - gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals) - self.packets[j].workgroup_size_x = lc[0] - self.packets[j].workgroup_size_y = lc[1] - self.packets[j].workgroup_size_z = lc[2] - self.packets[j].grid_size_x = gl[0] * lc[0] - self.packets[j].grid_size_y = gl[1] * lc[1] - self.packets[j].grid_size_z = gl[2] * lc[2] - - for dev in self.devices: - dev.flush_hdp() - dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count) - - for transfer_data in self.transfers: - check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data)) - - et = None - if wait: - st = time.perf_counter() - hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - et = time.perf_counter() - st - - for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata - return et - - def alloc_signal(self, reset_on_start=False, wait_on=None): - sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) - if reset_on_start: self.signals_to_reset.append(sync_signal) - if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on - return sync_signal - - def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]: - if isinstance(dep, hsa.hsa_signal_t): return dep - elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t): - if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True) - return packet.completion_signal - return None - - def access_resources(self, rawbufs, write, new_dependency, sync_with_aql_packets=False): - rdeps = self._access_resources(rawbufs, write, new_dependency) - wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps] - if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in rawbufs] - return dedup_signals(wait_signals) diff --git a/extra/backends/ops_hsa.py b/extra/backends/ops_hsa.py deleted file mode 100644 index 3b2fcc9ac8..0000000000 --- a/extra/backends/ops_hsa.py +++ /dev/null @@ -1,275 +0,0 @@ -from __future__ import annotations -import ctypes, functools, subprocess, io, atexit, collections, json -from typing import Tuple, TypeVar, List, Dict, Any -import tinygrad.runtime.autogen.hsa as hsa -from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv, PROFILE -from tinygrad.device import Compiled, Compiler, CompileError, BufferSpec, LRUAllocator -from tinygrad.renderer.cstyle import HIPRenderer -from tinygrad.runtime.support.hsa import check, scan_agents, find_memory_pool, AQLQueue -from tinygrad.runtime.support.hip_comgr import compile_hip -if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 - -class HSAProfiler: - def __init__(self): - self.tracked_signals = collections.defaultdict(list) - self.collected_events: List[Tuple[Any, ...]] = [] - self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t() - self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t() - - def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy)) - def process(self, device): - # Process all tracked signals, should be called before any of tracked signals are reused. - for sig,name,is_copy in self.tracked_signals[device]: - if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings := self.copy_timings))) - else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore - self.collected_events.append((device.device_id, 1 if is_copy else 0, name, timings.start, timings.end)) - self.tracked_signals.pop(device) - - def save(self, path): - mjson = [] - for i in range(len(HSADevice.devices)): - mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}}) - mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}}) - mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}}) - - for dev_id,queue_id,name,st,et in self.collected_events: - mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3}) - mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3}) - with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson})) - print(f"Saved HSA profile to {path}") -Profiler = HSAProfiler() - -class HSACompiler(Compiler): - def __init__(self, arch:str): - self.arch = arch - super().__init__(f"compile_hip_{self.arch}") - def compile(self, src:str) -> bytes: - try: return compile_hip(src, self.arch) - except RuntimeError as e: raise CompileError(e) - -class HSAProgram: - def __init__(self, device:HSADevice, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib - - if DEBUG >= 6: - asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib) - print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) - - self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501 - self.code_reader = init_c_var(hsa.hsa_code_object_reader_t(), - lambda x: check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(x)))) - check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, self.code_reader, None, None)) - check(hsa.hsa_executable_freeze(self.exec, None)) - - self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501 - self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501 - self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501 - self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501 - self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501 - - def __del__(self): - self.device.synchronize() - if hasattr(self, 'code_reader'): check(hsa.hsa_code_object_reader_destroy(self.code_reader)) - if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec)) - - def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): - if not hasattr(self, "args_struct_t"): - self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] + - [(f'v{i}', ctypes.c_int) for i in range(len(vals))])) - if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size: - raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}") - - kernargs = None - if self.kernargs_segment_size > 0: - kernargs = self.device.alloc_kernargs(self.kernargs_segment_size) - args_st = self.args_struct_t.from_address(kernargs) - for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i]) - for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i]) - self.device.flush_hdp() - - signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None - self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal) - if PROFILE: Profiler.track(signal, self.device, self.name) - if wait: - hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t()))) - return (timings.end - timings.start) * self.device.clocks_to_time - -T = TypeVar("T") -CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000 -class HSAAllocator(LRUAllocator): - def __init__(self, device:HSADevice): - self.device = device - super().__init__() - - def _alloc(self, size:int, options:BufferSpec): - if options.host: - check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p()))) - check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem)) - return mem.value - c_agents = (hsa.hsa_agent_t * len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]))(*HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]) - check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p()))) - check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf)) - return buf.value - - def _free(self, opaque:T, options:BufferSpec): - HSADevice.synchronize_system() - check(hsa.hsa_amd_memory_pool_free(opaque)) - - def _copyin(self, dest:T, src: memoryview): - # Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets. - self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True)) - mem = self._alloc(src.nbytes, BufferSpec(host=True)) - ctypes.memmove(mem, from_mv(src), src.nbytes) - check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal), - copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True)) - self.device.hw_queue.submit_barrier([copy_signal]) - self.device.delayed_free.append(mem) - if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True) - - def copy_from_fd(self, dest, fd, offset, size): - self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True)) - - if not hasattr(self, 'hb'): - self.hb = [self._alloc(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)] - self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)] - self.hb_polarity = 0 - self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1] - for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0) - - fo = io.FileIO(fd, "a+b", closefd=False) - fo.seek(offset - (minor_offset:=offset % PAGE_SIZE)) - - copies_called = 0 - copied_in = 0 - for local_offset in range(0, size+minor_offset, CHUNK_SIZE): - local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE) - copy_size = min(local_size-minor_offset, size-copied_in) - if copy_size == 0: break - - hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused - self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False) - - fo.readinto(to_mv(self.hb[self.hb_polarity], local_size)) - check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent, - copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity], - self.sdma[self.hb_polarity], True)) - copied_in += copy_size - self.hb_polarity = (self.hb_polarity + 1) % len(self.hb) - minor_offset = 0 # only on the first - copies_called += 1 - - wait_signals = [self.hb_signals[self.hb_polarity - 1]] - if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity]) - self.device.hw_queue.submit_barrier(wait_signals) - - def _copyout(self, dest:memoryview, src:T): - HSADevice.synchronize_system() - copy_signal = self.device.alloc_signal(reusable=True) - c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent) - check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p()))) - check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal)) - hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - check(hsa.hsa_amd_memory_unlock(from_mv(dest))) - if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True) - - def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None): - src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True)) - dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True)) - c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2) - check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, - copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True)) - src_dev.hw_queue.submit_barrier([copy_signal]) - dest_dev.hw_queue.submit_barrier([copy_signal]) - if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True) - -class HSADevice(Compiled): - devices: List[HSADevice] = [] - agents: Dict[int, List[hsa.hsa_agent_t]] = {} - cpu_agent: hsa.hsa_agent_t - cpu_mempool: hsa.hsa_amd_memory_pool_t - def __init__(self, device:str=""): - if not HSADevice.agents: - check(hsa.hsa_init()) - atexit.register(hsa_terminate) - HSADevice.agents = scan_agents() - HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0] - HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU) - if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1)) - - self.device_id = int(device.split(":")[1]) if ":" in device else 0 - self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id] - self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU) - self.hw_queue = AQLQueue(self) - HSADevice.devices.append(self) - - check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256)))) - self.arch = ctypes.string_at(agent_name_buf).decode() - - check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64()))) - self.clocks_to_time: float = 1 / gpu_freq.value - - check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AMD_AGENT_INFO_HDP_FLUSH, ctypes.byref(hdp_flush := hsa.hsa_amd_hdp_flush_t()))) - self.hdp_flush = hdp_flush - - self.delayed_free: List[int] = [] - self.reusable_signals: List[hsa.hsa_signal_t] = [] - - from tinygrad.runtime.graph.hsa import HSAGraph - super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph) - - # Finish init: preallocate some signals + space for kernargs - self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)] - self._new_kernargs_region(16 << 20) # initial region size is 16mb - - def synchronize(self): - self.hw_queue.wait() - - for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1) - self.signal_pool.extend(self.reusable_signals) - self.reusable_signals.clear() - - for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free)) - self.delayed_free.clear() - - self.kernarg_next_addr = self.kernarg_start_addr - Profiler.process(self) - - @staticmethod - def synchronize_system(): - for d in HSADevice.devices: d.synchronize() - - def alloc_signal(self, reusable=False): - if len(self.signal_pool): signal = self.signal_pool.pop() - else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t()))) - - # reusable means a signal could be reused after synchronize for the device it's allocated from is called. - if reusable: self.reusable_signals.append(signal) - return signal - - def alloc_kernargs(self, sz): - if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2)) - result = self.kernarg_next_addr - self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16) - return result - - def _new_kernargs_region(self, sz:int): - if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr) - self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferSpec()) - self.kernarg_next_addr = self.kernarg_start_addr - self.kernarg_pool_sz: int = sz - - def flush_hdp(self): self.hdp_flush.HDP_MEM_FLUSH_CNTL[0] = 1 - -def hsa_terminate(): - # Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs. - for dev in HSADevice.devices: - Profiler.process(dev) - del dev.hw_queue - - # hsa_shut_down cleans up all hsa-related resources. - hsa.hsa_shut_down() - HSADevice.synchronize = lambda: None #type:ignore - HSAProgram.__del__ = lambda _: None #type:ignore - if Profiler.collected_events: Profiler.save("/tmp/profile.json") diff --git a/extra/backends/rdna.py b/extra/backends/rdna.py deleted file mode 100644 index a5b775b734..0000000000 --- a/extra/backends/rdna.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Dict, Set -import yaml -from tinygrad.codegen.uops import UOpGraph, UOps, UOp -from tinygrad.uop.ops import BinaryOps -from tinygrad.dtype import dtypes - -def uops_to_rdna(function_name:str, uops:UOpGraph) -> str: - replace: Dict[UOp, UOp] = {} - seen: Set[UOp] = set() - for u in uops: - if u in seen: continue - seen.add(u) - for o,n in replace.items(): - if o in u.vin and u is not n: - u.vin = tuple(n if x == o else x for x in u.vin) - # pointer indexing - if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1: - val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_at=uops.uops.index(u)) - ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_at=uops.uops.index(u)) - u.vin = (u.vin[0], ptr) + u.vin[2:] - #uops.print() - - args = [] - ins = [] - - v_cnt = 3 # v[0:2] is local_xyz - s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz - - r: Dict[UOp, str] = {} - for u in uops: - if u.uop == UOps.SPECIAL: - if u.arg.startswith("lidx"): - r[u] = f'v{u.src[0].arg}' - elif u.arg.startswith("gidx"): - r[u] = f's{2+u.src[0].arg}' - else: - raise NotImplementedError - elif u.uop == UOps.CONST: - #r[u] = u.arg - - # TODO: sometimes we can use s - #r[u] = f"s{s_cnt}" - #s_cnt += 1 - #ins.append(f"s_mov_b32 {r[u]}, {u.arg}") - - r[u] = f"v{v_cnt}" - v_cnt += 1 - ins.append(f"v_mov_b32 {r[u]}, {u.arg}") - elif u.uop == UOps.ALU: - if u.arg == BinaryOps.ADD: - r[u] = f"v{v_cnt}" - v_cnt += 1 - ins.append(f"v_add_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}") - elif u.arg == BinaryOps.MUL: - r[u] = f"v{v_cnt}" - v_cnt += 1 - if dtypes.is_float(u.dtype): - ins.append(f"v_mul_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}") - else: - ins.append(f"v_mul_u32_u24 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}") - else: - raise NotImplementedError - elif u.uop == UOps.LOAD: - r[u] = f"v{v_cnt}" - v_cnt += 1 - ins.append(f"global_load_b32 {r[u]}, {r[u.vin[1]]}, {r[u.vin[0]]}") - ins.append("s_waitcnt vmcnt(0)") - elif u.uop == UOps.STORE: - ins.append(f"global_store_b32 {r[u.vin[1]]}, {r[u.vin[2]]}, {r[u.vin[0]]}") - elif u.uop == UOps.DEFINE_GLOBAL: - i = u.arg[0] - args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, - '.type_name': u.dtype.name+"*", '.value_kind': 'global_buffer'}) - s_cnt += s_cnt%2 # skip - r[u] = f"s[{s_cnt}:{s_cnt+1}]" - s_cnt += 2 - ins.append(f"s_load_b64 {r[u]}, s[0:1], {i*8}") - ins.append("s_waitcnt lgkmcnt(0)") - else: - raise NotImplementedError(f"can't render {u.uop}") - - # *** boilerplate rendering *** - - metadata = { - 'amdhsa.kernels': [{'.args': args, - '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"], - '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256, - '.name': function_name, '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0, - '.symbol': f'{function_name}.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0, - '.wavefront_size': 32}], - 'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]} - - boilerplate_start = f""" -.rodata -.global {function_name}.kd -.type {function_name}.kd,STT_OBJECT -.align 0x10 -.amdhsa_kernel {function_name}""" - - kernel_desc = { - '.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0, - '.amdhsa_next_free_vgpr': v_cnt, # this matters! - '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0, - '.amdhsa_next_free_sgpr': s_cnt, - '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, - '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, '.amdhsa_fp16_overflow': 0, - '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0, - '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1, - '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real? - '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, - '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0, - '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, - '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1, - '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0} - - code_start = f""".end_amdhsa_kernel -.text -.global {function_name} -.type {function_name},@function -.p2align 8 -{function_name}: -""" - - ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end'] - return ".amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + \ - boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + \ - '\n'.join(ins) + f"\n.size {function_name}, .-{function_name}" diff --git a/extra/backends/triton.py b/extra/backends/triton.py deleted file mode 100644 index 646c19d60d..0000000000 --- a/extra/backends/triton.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Dict, List, Final, Callable, DefaultDict -from collections import defaultdict -from tinygrad.uop.ops import UnaryOps, BinaryOps, TernaryOps, Op -from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv -from tinygrad.codegen.opt.kernel import UOp, Ops -from triton.compiler import compile as triton_compile -import linecache -import math -import re - -triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"} -signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"} - -def next_power_of_2(x): - return 1 << (x - 1).bit_length() - -def render_valid(valid): - return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True' - -#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile -def fill_dims_for_idx(idx, dims): - return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx - -def get_max(var): - if isinstance(var, int): return var - return re.sub(r'\[(.*?)\]', '', str(var))[1:-1] - -#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved -def remove_single_scalar_curly_braces(ptx_code): - return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')]) - -def render_const(args,dtype:DType): - return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args)) - -def render_cast(x:str, dtype:DType, bitcast=False): - return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})" - -def define_scalar(local_size, dtype, args): - if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})" - return render_const(args,dtype) - -def uops_to_triton(function_name:str, uops:List[UOp]): - local_size: List[int] = [] - depth = 1 - signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore - - c: DefaultDict[str, int] = defaultdict(int) - r: Dict[UOp, str] = {} - def ssa(u, prefix="t"): - nonlocal c, r - c[prefix] += 1 - r[u]=f"{prefix}{c[prefix]-1}" - return r[u] - - child_count: DefaultDict[UOp, int] = defaultdict(int) - for ru in uops: - for v in ru.vin: - child_count[v] += 1 - - def kk(s): kernel.append(" "*depth+s) - code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})", - UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})", - UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})", - UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})", - UnaryOps.NEG: lambda x,dtype: f"-{x}", - BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})", - BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))", - BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})", - BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})", - BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)", - TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})", - TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})", - } - def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))" - for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg - if uop == Ops.LOOP: - kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):") - depth += 1 - elif uop == Ops.END: depth -= 1 - elif uop == Ops.ALU: - assert dtype is not None - val = code_for_op[args](*[r[x] for x in vin]) - if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val - else: kk(f"{ssa(u, 'alu')} = ({val})") - elif uop == Ops.LOAD: - assert dtype is not None - if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}") - else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}") - elif uop == Ops.DEFINE_REG: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}") - elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args) - elif uop == Ops.ASSIGN: - kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}") - r[u] = r[vin[0]] - elif uop == Ops.STORE: - assert not isinstance(dtype, ImageDType), "unimplemented: image store" - kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ") - elif uop == Ops.DEFINE_GLOBAL: - bufs.append(args) - signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype]) - r[u] = args - elif uop == Ops.SPECIAL: - dims.append(args[1]) - valid.append(f"{args[1]}<{get_max(args[2])}") - if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}") - elif args[1].startswith("l"): - kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})") - local_size.append(args[2]) - r[u] = args[1] - elif uop == Ops.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1]) - else: raise NotImplementedError(f"unimplemented: {uop}") - - prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n" - for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]" - prg += "\n".join(kernel) - - acc_local_size = 1 - for x in local_size: acc_local_size *= next_power_of_2(x) - local_size = [acc_local_size] + [1] * (len(local_size) - 1) - - if DEBUG >= 4: print(prg) - getlines = linecache.getlines - linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "" == filename else getlines(filename, module_globals) - exec(compile(prg, "", "exec"), globals()) # pylint: disable=W0122\ - compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None)) - prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0]) - max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")] - for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i]) - - return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))} diff --git a/extra/disassemblers/adreno/__init__.py b/extra/disassemblers/adreno/__init__.py deleted file mode 100644 index 71974b72c9..0000000000 --- a/extra/disassemblers/adreno/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import ctypes -import os -import pathlib -import struct -from hexdump import hexdump - -fxn = None -def disasm_raw(buf): - global fxn - if fxn is None: - shared = pathlib.Path(__file__).parent / "disasm.so" - if not shared.is_file(): - os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so') - fxn = ctypes.CDLL(shared.as_posix())['disasm'] - fxn(buf, len(buf)) - -def disasm(buf): - def _read_lib(off): return struct.unpack("I", buf[off:off+4])[0] - - image_offset = _read_lib(0xc0) - image_size = _read_lib(0x100) - disasm_raw(buf[image_offset:image_offset+image_size]) diff --git a/extra/disk_read_speed.py b/extra/disk_read_speed.py deleted file mode 100644 index 6d7303c6a2..0000000000 --- a/extra/disk_read_speed.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python3 -import os, ctypes, ctypes.util, io, mmap, pathlib -from tinygrad import Tensor, dtypes, Device -from tinygrad.helpers import Timing, from_mv -libc = ctypes.CDLL(ctypes.util.find_library("c")) - -#from extra.hip_gpu_driver import hip_ioctl - -# sudo su -c "echo 3 > /proc/sys/vm/drop_caches" - -# sudo su -c 'echo 8 > /proc/sys/kernel/printk' -# sudo su -c "echo 'module amdgpu +p' > /sys/kernel/debug/dynamic_debug/control" - -libc.memcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] - -libc.read.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t] -libc.read.restype = ctypes.c_size_t - -libc.malloc.argtypes = [ctypes.c_size_t] -libc.malloc.restype = ctypes.c_void_p - -def read_direct(fd, sz): - with Timing("mmap: ", lambda x: f", {sz/x:.2f} GB/s"): - buf = mmap.mmap(-1, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) - with Timing("read: ", lambda x: f", {sz/x:.2f} GB/s"): - ret = libc.read(fd, from_mv(buf), sz) - assert ret == sz - -def read_mmap(fd, sz): - with Timing("mmfd: ", lambda x: f", {sz/x:.2f} GB/s"): - buf = mmap.mmap(fd, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) #|MAP_LOCKED) - t = 0 - for i in range(0, sz, 0x1000): t += buf[i] - -# def _copyin_async(self, dest:T, src:T, size:int): check(hip.hipMemcpyAsync(dest, src, size, hip.hipMemcpyHostToDevice, None)) - -def read_to_gpu_mmap(fd, sz, gpubuf): - with Timing("gpu copyin: ", lambda x: f", {sz/x:.2f} GB/s"): - with Timing("mmfd: ", lambda x: f", {sz/x:.2f} GB/s"): - buf = mmap.mmap(fd, sz, flags=mmap.MAP_SHARED|mmap.MAP_POPULATE) #|MAP_LOCKED) - dev.allocator._copyin_async(gpubuf, from_mv(buf), sz) - dev.synchronize() - -def read_to_gpu_single(fd, sz, gpubuf): - os.lseek(fd, 0, os.SEEK_SET) - with Timing("total: ", lambda x: f", {sz/x:.2f} GB/s"): - with Timing("gpu host alloc: ", lambda x: f", {sz/x:.2f} GB/s"): - hst = dev.allocator._hostalloc(sz) - with Timing("read to host: ", lambda x: f", {sz/x:.2f} GB/s"): - ret = libc.read(fd, hst, sz) - with Timing("gpu host copy: ", lambda x: f", {sz/x:.2f} GB/s"): - dev.allocator._copyin_async(gpubuf, hst, sz) - dev.synchronize() - -def read_to_gpu_pingpong(fd, sz, gpubuf): - psz = 256*1024*1024 - print(f"piece size {psz/(1024*1024):.2f} MB") - with Timing("gpu host alloc: ", lambda x: f", {sz/x:.2f} GB/s"): - hst1 = dev.allocator._hostalloc(psz) - hst2 = dev.allocator._hostalloc(psz) - - os.lseek(fd, 0, os.SEEK_SET) - with Timing("total: ", lambda x: f", {sz/x:.2f} GB/s"): - for i in range(sz//(psz*2)): - with Timing("tfer(0): ", lambda x: f", {psz/x:.2f} GB/s"): - ret = libc.read(fd, hst1, psz) - dev.synchronize() - dev.allocator._copyin_async(gpubuf, hst1, psz) - with Timing("tfer(1): ", lambda x: f", {psz/x:.2f} GB/s"): - ret = libc.read(fd, hst2, psz) - dev.synchronize() - dev.allocator._copyin_async(gpubuf, hst2, psz) - dev.synchronize() - -MAP_LOCKED = 0x2000 -MAP_HUGETLB = 0x40000 - -if __name__ == "__main__": - dev = Device[Device.DEFAULT] - - warm = (Tensor.ones(1024, device=Device.DEFAULT).contiguous() + Tensor.ones(1024, device=Device.DEFAULT).contiguous()).realize() - #fn = "/home/tiny/tinygrad/weights/rng" - fn = pathlib.Path(__file__).parents[1] / "weights/LLaMA-2/70B/consolidated.00.pth" - sz = os.stat(fn).st_size - t = Tensor.empty(sz, dtype=dtypes.uint8, device=f"disk:{fn}") - with Timing("copy: ", lambda x: f", {sz/x:.2f} GB/s"): - on_dev = t.to(Device.DEFAULT).realize() - - exit(0) - - # 4GB of random numbers - #fd = os.open("/home/tiny/tinygrad/weights/rng", os.O_RDWR|os.O_DIRECT) - #sz = os.fstat(fd).st_size // 4 - fd = os.open("/home/tiny/tinygrad/weights/LLaMA/7B/consolidated.00.pth", os.O_RDWR|os.O_DIRECT) - sz = os.fstat(fd).st_size - print(f"read {sz} from {fd}") - - with Timing("gpu alloc: ", lambda x: f", {sz/x:.2f} GB/s"): - gpubuf = dev.allocator._alloc(sz) - # warmup - dev.allocator._copyin_async(gpubuf, from_mv(bytearray(b"\x00\x00\x00\x00"*0x1000)), 0x4000) - print("copying, is warm") - - print("****** read to gpu pingpong") - read_to_gpu_pingpong(fd, sz, gpubuf) - exit(0) - - print("****** read direct") - read_direct(fd, sz) - - print("****** read mmap") - read_mmap(fd, sz) - - print("****** read to gpu single") - read_to_gpu_single(fd, sz, gpubuf) - - print("****** read to gpu mmap") - read_to_gpu_mmap(fd, sz, gpubuf) - - os._exit(0) diff --git a/extra/dump_cache.py b/extra/dump_cache.py deleted file mode 100644 index 325d2bd227..0000000000 --- a/extra/dump_cache.py +++ /dev/null @@ -1,21 +0,0 @@ -import sys, sqlite3, pickle -from tinygrad.helpers import CACHEDB - -if __name__ == "__main__": - fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB - conn = sqlite3.connect(fn) - cur = conn.cursor() - cur.execute("SELECT name FROM sqlite_master WHERE type='table'") - for f in cur.fetchall(): - table = f[0] - cur2 = conn.cursor() - cur2.execute(f"SELECT COUNT(*) FROM {table}") - cnt = cur2.fetchone()[0] - print(f"{table:20s} : {cnt}") - - cur3 = conn.cursor() - cur3.execute(f"SELECT * FROM {table} LIMIT 10") - for f in cur3.fetchall(): - v = pickle.loads(f[-1]) - print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50]) - #print(f"{len(k):10d}, {sk} -> {v}") diff --git a/extra/gemm/jax_pmatmul.py b/extra/gemm/jax_pmatmul.py deleted file mode 100755 index b69a2b9b47..0000000000 --- a/extra/gemm/jax_pmatmul.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 -import time -import jax -import jax.numpy as jnp - -print(jax.devices()) -DEVICES = len(jax.devices()) -BS = 32 -N = 4096 -dtype = jnp.float16 -A = jnp.zeros((DEVICES, BS, N, N), dtype) -B = jnp.zeros((1, 1, N, N), dtype) -A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices()) -B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices()) - -OPS = DEVICES*BS*N*N*N*2 -def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32) -pmatmul = jax.pmap(matmul) - -MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX) -for i in range(10): - st = time.perf_counter() - C = pmatmul(A,B).block_until_ready() - et = time.perf_counter()-st - tflops = (OPS*1e-12)/et - print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}") - diff --git a/extra/gemm/mlx_matmul.py b/extra/gemm/mlx_matmul.py deleted file mode 100644 index 821c7fff29..0000000000 --- a/extra/gemm/mlx_matmul.py +++ /dev/null @@ -1,10 +0,0 @@ -import mlx.core as mx -from tinygrad.helpers import Timing -N = 4096 -x = mx.random.normal((N,N)) -w = mx.random.normal((N,N)) - -FLOPS = N*N*N*2 -for i in range(10): - with Timing("", lambda x: f" {FLOPS/x:.2f} GFLOPS"): - mx.eval(x@w) diff --git a/extra/gemm/tf_gemm.py b/extra/gemm/tf_gemm.py deleted file mode 100644 index 802b344358..0000000000 --- a/extra/gemm/tf_gemm.py +++ /dev/null @@ -1,33 +0,0 @@ -import time -import tensorflow as tf - -gpus = tf.config.list_physical_devices('GPU') -if gpus: - try: - # Currently, memory growth needs to be the same across GPUs - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - logical_gpus = tf.config.list_logical_devices('GPU') - print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") - except RuntimeError as e: - # Memory growth must be set before GPUs have been initialized - print(e) - -for dtype in [tf.float16, tf.float32]: - for N in [256, 512, 1024, 2048, 4096, 8192]: - FLOPS = N*N*N*2 - - b = tf.random.uniform((N, N), dtype=dtype) - c = tf.random.uniform((N, N), dtype=dtype) - - b = tf.Variable(b) - c = tf.Variable(c) - - def tf_prog(b, c): - st = time.perf_counter() - a = tf.matmul(b, c) - tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done. - return time.perf_counter() - st - - tm = min([tf_prog(b, c) for _ in range(20)]) - print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}") \ No newline at end of file diff --git a/extra/hip_events.py b/extra/hip_events.py deleted file mode 100644 index 5719a18ce9..0000000000 --- a/extra/hip_events.py +++ /dev/null @@ -1,12 +0,0 @@ -import ctypes -import tinygrad.runtime.autogen.hip as hip -from tinygrad.runtime.ops_hip import check -from tinygrad.helpers import init_c_var - -if __name__ == "__main__": - check(hip.hipSetDevice(0)) - evt = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x)))) - check(hip.hipSetDevice(1)) - check(hip.hipStreamWaitEvent(None, evt, 0)) - check(hip.hipSetDevice(0)) - check(hip.hipEventRecord(evt, None)) \ No newline at end of file diff --git a/extra/junk/sentencepiece_model_pb2.py b/extra/junk/sentencepiece_model_pb2.py deleted file mode 100644 index 5de978fad0..0000000000 --- a/extra/junk/sentencepiece_model_pb2.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: sentencepiece_model.proto -# Protobuf Python Version: 4.25.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - _globals['DESCRIPTOR']._options = None - _globals['DESCRIPTOR']._serialized_options = b'H\003' - _globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None - _globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001' - _globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None - _globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001' - _globals['_TRAINERSPEC']._serialized_start=45 - _globals['_TRAINERSPEC']._serialized_end=1581 - _globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517 - _globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570 - _globals['_NORMALIZERSPEC']._serialized_start=1584 - _globals['_NORMALIZERSPEC']._serialized_end=1793 - _globals['_SELFTESTDATA']._serialized_start=1795 - _globals['_SELFTESTDATA']._serialized_end=1916 - _globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864 - _globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905 - _globals['_MODELPROTO']._serialized_start=1919 - _globals['_MODELPROTO']._serialized_end=2429 - _globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208 - _globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418 - _globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323 - _globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407 -# @@protoc_insertion_point(module_scope) diff --git a/extra/mcts_search.py b/extra/mcts_search.py deleted file mode 100644 index 825d97c3bf..0000000000 --- a/extra/mcts_search.py +++ /dev/null @@ -1,176 +0,0 @@ -from __future__ import annotations -from typing import List, Optional, Dict, cast -import numpy as np -np.set_printoptions(suppress=True) -import math, functools, time, random, statistics -from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling -from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.device import Buffer, Device, CompileError -from tinygrad.codegen.opt.search import _ensure_buffer_alloc, get_kernel_actions, _time_program -from tinygrad.engine.realize import get_program - -class MCTSNode: - def __init__(self, kernel:Kernel, parent=None): - self.kernel:Kernel = kernel - self.t = math.inf - self.n = 0 - self.tm = math.inf - self.i = -1 - self.parents: List[MCTSNode] = [parent] if parent is not None else [] - self.children: Optional[List[MCTSNode]] = None - self.removed_children: List[MCTSNode] = [] - -def expand_node(node:MCTSNode): - assert node.children is None - node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()] - -def remove_node(node:MCTSNode): - for parent in node.parents: - assert parent.children is not None - parent.children.remove(node) - parent.removed_children.append(node) - -C = math.sqrt(2) -TEMP = 0.5 -def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode: - if node.children is None or len(node.children) == 0: return node - unexplored_children = [] - explored_children = [] - ucb_explored_children: List[float] = [] - for child in node.children: - if child.n == 0: unexplored_children.append(child) - else: - ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n) - if not math.isinf(ucb): - explored_children.append(child) - ucb_explored_children.append(ucb) - if len(unexplored_children): return random.choice(unexplored_children) - if not len(explored_children): return node - # safe softmax - ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP) - return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm) - -# this will expand/remove sometimes -def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]: - if root.children is None: expand_node(root) - while root.children: - # tree traversal - node = _sample_tree(root, best_tm) - - if node.children is not None and len(node.children) == 0: - remove_node(node) - continue - - # node expansion - if node.n != 0: - if node.children is None: expand_node(node) - assert node.children is not None - if len(node.children) == 0: - remove_node(node) - continue - node = random.choice(node.children) - return node - return None - -def backprop(bnode:MCTSNode, tm, strength=1.0): - if bnode.t > tm: bnode.t = tm - bnode.n += strength - for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents)) - -graph_mcts_cnt = 0 -def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel: - global graph_mcts_cnt - # TODO: copied from BEAM - key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix} - if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None: - ret = lin.copy() - for o in val[len(lin.applied_opts):]: ret.apply_opt(o) - return ret - - rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()} - dev = Device[lin.opts.device] - root = MCTSNode(lin) - - st = time.perf_counter() - best, best_idx, best_tm = lin, 0, math.inf - seen_libs: Dict[bytes, MCTSNode] = {} - seen_asts: Dict[bytes, MCTSNode] = {} - compile_time, runtime_time = 0.0, 0.0 - for i in range(amt): - node = sample_tree(root, best_tm) # sample and expand - if node is None: break # finished the whole tree - node.i = i # when was node explored - - opt_ast = node.kernel.get_optimized_ast() - if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None: - # early check for same optimized AST hit - remove_node(node) - tm = sibling_node.t - else: - seen_asts[opt_ast.key] = node - - # lowering (50% of the time) - p = get_program(node.kernel.get_optimized_ast(name_override="test"), node.kernel.opts) - - # rollout - tm1 = time.perf_counter() - try: - lib = dev.compiler.compile(p.src) - except CompileError: - # NOTE: many of these "compiler errors" are caused by bad code output from the lowerer - lib = None - tm2 = time.perf_counter() - if lib is None: - tm = math.inf - else: - if (sibling_node:=seen_libs.get(lib, None)) is not None: - # NOTE: these should all be caught by the AST check, need to canonicalize - # remove this node, it's a duplicate - remove_node(node) - tm = sibling_node.t - else: - seen_libs[lib] = node - try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6 - except RuntimeError: tm = math.inf - node.tm = tm - tm3 = time.perf_counter() - compile_time += tm2-tm1 - runtime_time += tm3-tm2 - - # mock rollout - #node.tm = tm = random.random() + 0.1 - - if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm - et = time.perf_counter() - st - if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501 - - # backprop - backprop(node, tm) - if DEBUG>=2: print() - - if getenv("MCTSGRAPH"): - import networkx as nx - import os - GRAPHPATH = "/tmp/net" - def save_graph(G, fn, opt=""): - print("saving", G, f"to {fn}.svg") - nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot') - os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg') - - G = nx.DiGraph() - def add_node(node:MCTSNode): - if node.n == 0: return - for parent in node.parents: G.add_edge(parent, node) - gopts = node.kernel.applied_opts - edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT" - G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}", - fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '') - if node.children is not None: - for child in node.children+node.removed_children: add_node(child) - add_node(root) - save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR') - graph_mcts_cnt += 1 - - if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts) - return best diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py deleted file mode 100644 index e4cb5ed543..0000000000 --- a/extra/replay_pkl.py +++ /dev/null @@ -1,75 +0,0 @@ -import pickle, sys -from dataclasses import replace -from tinygrad import Device, Context, Tensor, GlobalCounters -from tinygrad.device import Buffer -from tinygrad.helpers import getenv, BEAM -from tinygrad.engine.jit import TinyJit -from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item, get_program -from tinygrad.renderer import ProgramSpec -from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps -from tinygrad.codegen.opt.heuristic import hand_coded_optimizations -import numpy as np - -def move_jit_captured_to_dev(captured, device="DSP"): - captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device] - - assign = {} - def move_buffer(b): - if b in assign: return assign[b] - - if b._base is not None: - newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset) - else: - newbuf = Buffer(device, b.size, b.dtype) - if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer()) - assign[b] = newbuf - return assign[b] - - for item in captured.jit_cache: - for b in item.bufs: - if b is not None: move_buffer(b) - captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache] - return captured - -if __name__ == "__main__": - with Context(DEBUG=0): - with open(sys.argv[1], "rb") as f: - fxn: TinyJit = pickle.load(f) - print(f"{f.tell()/1e6:.2f}M loaded") - print(type(fxn)) - - # Move all buffers to DSP device. - fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP") - new_jit = [] - - knum = 1 - for ei in fxn.captured.jit_cache: - # skip the copy and the first kernel - if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs): - if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0: - p: ProgramSpec = ei.prg.p - k = Kernel(p.ast, Device["DSP"].renderer) - - if getenv("VALIDATE"): - with Context(NOOPT=1): - lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run() - correct = ei.bufs[0].numpy() - ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes))) - GlobalCounters.kernel_count -= 1 - - if not getenv("NOOPT"): k.apply_opts(hand_coded_optimizations(k)) - p2 = get_program(k.ast, k.opts, k.applied_opts) - new_ei = replace(ei, prg=CompiledRunner(p2)) - new_ei.run() - new_jit.append(new_ei) - test = ei.bufs[0].numpy() - - if getenv("VALIDATE"): - import numpy as np - np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3) - knum += 1 - - if getenv("RUN_JIT", 0): - fxn.captured.free_intermediates() - fxn.captured.jit_cache = new_jit - fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP")) diff --git a/extra/resnet18/resnet_mlx.py b/extra/resnet18/resnet_mlx.py deleted file mode 100644 index b59477cd69..0000000000 --- a/extra/resnet18/resnet_mlx.py +++ /dev/null @@ -1,114 +0,0 @@ -# code from https://x.com/awnihannun/status/1832511021602500796 -from huggingface_hub import snapshot_download -import mlx.core as mx -import mlx.nn as nn -import time - - -class Block(nn.Module): - def __init__(self, in_dims, dims, stride=1): - super().__init__() - - self.conv1 = nn.Conv2d( - in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm(dims) - - self.conv2 = nn.Conv2d( - dims, dims, kernel_size=3, stride=1, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm(dims) - - self.downsample = [] - if stride != 1: - self.downsample = [ - nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm(dims) - ] - - def __call__(self, x): - out = nn.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - for l in self.downsample: - x = l(x) - out += x - out = nn.relu(out) - return out - - -class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=10): - super().__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm(64) - - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2) - - self.fc = nn.Linear(512, num_classes) - - def _make_layer(self, block, in_dims, dims, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(in_dims, dims, stride)) - in_dims = dims - return layers - - def __call__(self, x): - x = nn.relu(self.bn1(self.conv1(x))) - x = self.maxpool(x) - for l in self.layer1 + self.layer2 + self.layer3 + self.layer4: - x = l(x) - x = mx.mean(x, axis=[1, 2]) - x = self.fc(x) - return x - - - -def load(): - model = ResNet(Block, [2, 2, 2, 2], num_classes=1000) - file = "model.safetensors" - model_path = snapshot_download( - repo_id="awni/resnet18-mlx", - allow_patterns=[file], - ) - model.load_weights(model_path + "/" + file) - model.eval() - mx.eval(model) - return model - -if __name__ == "__main__": - - resnet18 = load() - - @mx.compile - def forward(im): - return resnet18(im) - - batch_sizes = [1, 2, 4, 8, 16, 32, 64] - #its = 200 - #batch_sizes = [64] - its = 20 - print(f"Batch Size | Images-per-second | Milliseconds-per-image") - print(f"---- | ---- | ---- ") - for N in batch_sizes: - image = mx.random.uniform(shape=(N, 288, 288, 3)) - - # Warmup - for _ in range(5): - output = forward(image) - mx.eval(output) - - tic = time.time() - for _ in range(its): - output = forward(image) - mx.async_eval(output) - mx.eval(output) - toc = time.time() - ims_per_sec = N * its / (toc - tic) - ms_per_im = 1e3 / ims_per_sec - print(f"{N} | {ims_per_sec:.3f} | {ms_per_im:.3f}") \ No newline at end of file diff --git a/extra/resnet18/resnet_tinygrad.py b/extra/resnet18/resnet_tinygrad.py deleted file mode 100644 index a34a27b2bc..0000000000 --- a/extra/resnet18/resnet_tinygrad.py +++ /dev/null @@ -1,109 +0,0 @@ -from huggingface_hub import snapshot_download -from tinygrad import nn, Tensor, TinyJit, Device, GlobalCounters, Context -import time - -class Block: - def __init__(self, in_dims, dims, stride=1): - super().__init__() - - self.conv1 = nn.Conv2d( - in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm(dims) - - self.conv2 = nn.Conv2d( - dims, dims, kernel_size=3, stride=1, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm(dims) - - self.downsample = [] - if stride != 1: - self.downsample = [ - nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm(dims) - ] - - def __call__(self, x): - out = self.bn1(self.conv1(x)).relu() - out = self.bn2(self.conv2(out)) - for l in self.downsample: - x = l(x) - out += x - return out.relu() - - -class ResNet: - def __init__(self, block, num_blocks, num_classes=10): - super().__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm(64) - - self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2) - - self.fc = nn.Linear(512, num_classes) - - def _make_layer(self, block, in_dims, dims, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(in_dims, dims, stride)) - in_dims = dims - return layers - - def __call__(self, x:Tensor): - x = self.bn1(self.conv1(x)).relu().max_pool2d() - x = x.sequential(self.layer1) - with Context(WINO=1): x = x.sequential(self.layer2 + self.layer3 + self.layer4) - x = x.mean([2, 3]) - x = self.fc(x) - return x - - - -def load(): - model = ResNet(Block, [2, 2, 2, 2], num_classes=1000) - file = "model.safetensors" - model_path = snapshot_download( - repo_id="awni/resnet18-mlx", - allow_patterns=[file], - ) - state = nn.state.safe_load(model_path + "/" + file) - # mlx is NHWC, tinygrad is NCHW - nn.state.load_state_dict(model, {k:v if len(v.shape) != 4 else v.to(None).permute(0,3,1,2).contiguous() for k,v in state.items()}, strict=False) - return model - -if __name__ == "__main__": - - resnet18 = load() - - def _forward(im): return resnet18(im) - forward = TinyJit(_forward, prune=True) - - batch_sizes = [1, 2, 4, 8, 16, 32, 64] - #its = 200 - #batch_sizes = [64] - its = 20 - print(f"Batch Size | Images-per-second | Milliseconds-per-image") - print(f"---- | ---- | ---- ") - for N in batch_sizes: - forward.reset() # reset the JIT for a new batch size (could make automatic) - image = Tensor.uniform(N, 3, 288, 288) - - # Warmup - for _ in range(5): - GlobalCounters.reset() - output = forward(image) - Device.default.synchronize() - - tic = time.time() - for _ in range(its): - GlobalCounters.reset() - output = forward(image) - Device.default.synchronize() - toc = time.time() - ims_per_sec = N * its / (toc - tic) - ms_per_im = 1e3 / ims_per_sec - print(f"{N} | {ims_per_sec:.3f} | {ms_per_im:.3f}") \ No newline at end of file diff --git a/extra/ring_copy.py b/extra/ring_copy.py deleted file mode 100644 index 1e3863b89a..0000000000 --- a/extra/ring_copy.py +++ /dev/null @@ -1,15 +0,0 @@ -from tinygrad import Tensor, Device, GlobalCounters -from tinygrad.helpers import Timing - -N = 512 -GPUS = 5 -ds = tuple([f"{Device.DEFAULT}:{i+1}" for i in range(GPUS)]) -t = [Tensor.ones(N, N, N, device=d).contiguous().realize() for d in ds] - -for _ in range(10): - GlobalCounters.reset() - with Timing(): - for ti in t: - ti.to_(ds[(ds.index(ti.device)+1+len(ds))%len(ds)]) - # ti.to_(ds[(ds.index(ti.device)-1+len(ds))%len(ds)]) # reversed order - ti.realize() \ No newline at end of file diff --git a/extra/self_tokenize.py b/extra/self_tokenize.py deleted file mode 100644 index a90e8ca3e5..0000000000 --- a/extra/self_tokenize.py +++ /dev/null @@ -1,47 +0,0 @@ -import os, pathlib, argparse -from examples.llama3 import Tokenizer -from tabulate import tabulate -from tinygrad import fetch -from tinygrad.helpers import flatten, getenv -from sz import NONCORE_DIRS - -# llama 3 tokenizer -tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model").as_posix()) - -def read_code(base_path, full=False): - ret = [] - for path, _, files in os.walk(os.path.join(base_path, "tinygrad")): - if not full and any(path.split("./")[1].startswith(x) for x in NONCORE_DIRS): continue - for name in files: - if not name.endswith(".py"): continue - if 'tinygrad/runtime/autogen' in path.replace('\\', '/'): continue - fullpath = os.path.join(path, name) - code = pathlib.Path(fullpath).read_text() - ret.append((fullpath.split("tinygrad/", 1)[1], code)) - return ret - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Analyze and optionally save tinygrad code.") - parser.add_argument("--output", help="Output file to write the combined code to.") - parser.add_argument("--full", action="store_true", help="All directories") - args = parser.parse_args() - - ret = read_code(".", args.full) - - table = [] - for name,code in ret: - table.append([name, len(tokenizer.encode(code))]) - print(tabulate([["name", "llm tokens"]]+sorted(table, key=lambda x: -x[1]), headers="firstrow")) - - banner = "#"*40 - code_str = ''.join([f"{banner}\n# {name}\n{banner}\n\n{code}\n" for name,code in ret]) - print(f"code has {len(code_str)} chars") - newline_count = code_str.count('\n') - print(f"code has {newline_count} newlines") - - encoded = tokenizer.encode(code_str) - print(f"code has {len(encoded)} tokens") - - if args.output: - with open(args.output, 'w') as f: f.write(code_str) - print(f"Combined code written to {args.output}") diff --git a/extra/threefry.py b/extra/threefry.py deleted file mode 100644 index 6de61734f2..0000000000 --- a/extra/threefry.py +++ /dev/null @@ -1,16 +0,0 @@ -if __name__ == "__main__": - import os - if "DEBUG" not in os.environ: os.environ["DEBUG"] = "2" - - from tinygrad import Tensor, GlobalCounters - from tinygrad.helpers import getenv - - if (seed := getenv("SEED", 0)) != 0: - Tensor.manual_seed(seed) - print(f"using seed {Tensor._seed}") - - for N in [10_000_000, 100_000_000, 1_000_000_000]: - GlobalCounters.reset() - t = Tensor.rand(N) - t.realize() - print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}") diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py deleted file mode 100644 index 3170cd8c61..0000000000 --- a/extra/to_movement_ops.py +++ /dev/null @@ -1,154 +0,0 @@ -import itertools -from enum import Enum, auto -from collections import defaultdict -from typing import List, Tuple, DefaultDict -from tinygrad.helpers import prod, tqdm -from tinygrad.uop.ops import UOp, Ops -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.uop.ops import sym_infer -from tinygrad.tensor import Tensor - -class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 - -def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker: - mop, arg = mop_arg - if mop == MovementOps.RESHAPE: - # shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE - if arg == (-1,): return st.reshape((prod(st.shape),)) - return st.reshape(arg) - if mop == MovementOps.PERMUTE: return st.permute(arg) - if mop == MovementOps.EXPAND: - if len(arg) != len(st.shape): st = st.reshape((1,*st.shape)) - return st.expand(arg) - if mop == MovementOps.PAD: return st.pad(arg) - if mop == MovementOps.SHRINK: return st.shrink(arg) - if mop == MovementOps.STRIDE: - assert all(x in [-1, 1] for x in arg) - return st.flip(tuple(i for i,x in enumerate(arg) if x == -1)) - raise ValueError("invalid mop") - -def make_scratch_st(st: ShapeTracker) -> ShapeTracker: - return ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),)) - -# ShapeTracker to an equivalent series of MovementOps (https://github.com/tinygrad/tinygrad/pull/2216) -def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]: - to_apply:List[Tuple[MovementOps, Tuple]] = [] - for i, v in enumerate(st.views): - real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape - offset = (v.offset or 0) + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) - real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) - real_real_shape = [s for s,st in zip(real_shape, v.strides) if st] - strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] - buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1 - if i: buffer_size = prod(st.views[i-1].shape) - real_offset if real_shape else 1 - def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True) - ordered_shape_strides, order = sort_by_strides(real_real_shape, strides) - to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))]) - if strides: - if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),))) - for i, shape_stride in enumerate(ordered_shape_strides): - if i0 else buffer_size - to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer))) - to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1))) - to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer))) - to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1])))) - to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1]))) - ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1]) - else: - to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1])))) - to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1]))) - to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))]) - if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides))))) - to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides)))) - if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides))) - # then, we apply pre expand pads - if v.mask is not None: - pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - if any(x != (0,0) for x in pre_expand_pads): - to_apply.append((MovementOps.PAD, pre_expand_pads)) - real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads)) - # then, we do any expands - if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape)) - # lastly, we apply post expand pads - if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads)) - - scratch_st = make_scratch_st(st) - ret = [] - seen = {} # {shapetracker: list of mops to generate that shapetracker} - for mop_arg in to_apply: - scratch_st = apply_mop(scratch_st, mop_arg) - if scratch_st in seen: - ret = seen[scratch_st][:] - else: - if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE: - ret[-1] = mop_arg - else: - if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),)) - ret.append(mop_arg) - seen[scratch_st] = ret[:] - return ret - -def get_real_view(shape, strides, offset, mask): - real_shape = tuple(y-x for x,y in mask) if mask else shape - offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0) - real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0) - real_real_shape = [s for s,st in zip(real_shape, strides) if st] - strides = [abs(st) if isinstance(st,int) else st for st in strides if st] - return real_real_shape, strides, real_offset - -def get_buffer_size(shape, strides, offset, mask): - real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask) - return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1 - -def st_equivalent(st1: ShapeTracker, st2: ShapeTracker): - if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True - idx1, valid1 = idxs1 - idx2, valid2 = idxs2 - # always invalid - if valid1 == 0 and valid2 == 0: return True - - var1 = idx1.vars() | valid1.vars() - var2 = idx2.vars() | valid2.vars() - # Maybe there are cases that vars are different yet the sts are the same? - if var1 != var2: return False - - # brute force over the vars range - vs = list(var1) - for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])): - if i > 1000: - print("WARNING: did not search all possible combinations") - break - var_vals = {k.expr:v for k,v in zip(vs, ranges)} - r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0 - r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0 - if r1 != r2: return False - - return True - -c: DefaultDict[int,int] = defaultdict(int) -def test_rebuild(st: ShapeTracker): - rebuilt_st = make_scratch_st(st) - mops = to_movement_ops(st) - c[len(mops)] += 1 - for mop_arg in mops: rebuilt_st = apply_mop(rebuilt_st, mop_arg) - rebuilt_st = rebuilt_st.simplify() - # why is the "all(x == 0 for x in rebuilt_st.views[-1].strides)" hack needed? - assert st_equivalent(st, rebuilt_st) or all(x == 0 for x in rebuilt_st.views[-1].strides), f"mismatch {st} {rebuilt_st}" - last_v1 = st.views[-1] - last_v2 = rebuilt_st.views[-1] - assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" - -def test_rebuild_bufferop_st(ast:UOp): - if ast.op is Ops.SHAPETRACKER: - test_rebuild(ast.arg) - for src in ast.src: test_rebuild_bufferop_st(src) - -if __name__ == "__main__": - from extra.optimization.helpers import load_worlds, ast_str_to_ast - ast_strs = load_worlds(False, False, True)[:2000] - for ast_str in tqdm(ast_strs): - test_rebuild_bufferop_st(ast_str_to_ast(ast_str)) - - print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}") diff --git a/extra/transfer_speed.py b/extra/transfer_speed.py deleted file mode 100644 index 9abe475a48..0000000000 --- a/extra/transfer_speed.py +++ /dev/null @@ -1,18 +0,0 @@ -from tinygrad import Tensor, Device - -#N = 1024 -N = 32 -t = Tensor.rand(N, N, N, device="CPU").realize() -d1 = Device.DEFAULT + ":1" -d2 = Device.DEFAULT + ":2" -d3 = Device.DEFAULT + ":3" - -for i in range(3): - t.to_(d1) - t.realize() - # t.to_("CPU") - # t.realize() - t.to_(d2) - t.realize() - t.to_(d3) - t.realize()