Merge remote-tracking branch 'upstream/master' into triton

This commit is contained in:
Szymon Ożóg
2023-08-18 16:12:45 +02:00
36 changed files with 954 additions and 257 deletions

View File

@@ -40,6 +40,11 @@ def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1: return x
return x[:, :, :, None, :].expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class RMSNorm:
def __init__(self, dim, eps=1e-6):
self.eps = eps
@@ -50,14 +55,22 @@ class RMSNorm:
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
class Attention:
def __init__(self, dim, n_heads, linear=Linear):
self.wq, self.wk, self.wv, self.wo = [linear(dim, dim, bias=False) for _ in range(4)]
def __init__(self, dim, n_heads, n_kv_heads, linear=Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def prepare_attention(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)]
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
return xq, xk, xv
@@ -74,6 +87,8 @@ class Attention:
# save the cache
self.cache_k, self.cache_v = keys.realize(), values.realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
return Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
# NOTE: this is not called
@@ -98,8 +113,8 @@ class FeedForward:
return self.w2(self.w1(x).silu() * self.w3(x))
class TransformerBlock:
def __init__(self, dim, multiple_of, n_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None):
self.attention = Attention(dim, n_heads, linear)
def __init__(self, dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None):
self.attention = Attention(dim, n_heads, n_kv_heads, linear)
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
@@ -125,8 +140,8 @@ class TransformerBlock:
return self._post(x, output) if mask is None else self.post(x, output)
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
@@ -173,11 +188,10 @@ MODEL_PARAMS = {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"files": 2,
},
# # 70B is disabled because we do not yet implement n_kv_heads argument
# "70B": {
# "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
# "files": 8,
# },
"70B": {
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
"files": 8,
},
},
}
@@ -277,7 +291,7 @@ if __name__ == "__main__":
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2")
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2")
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")

View File

@@ -1,61 +0,0 @@
import struct
from tinygrad.codegen.assembly import AssemblyCodegen
from tinygrad.codegen.linearizer import UOps
from tinygrad.helpers import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "u16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32"}
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXCodegen(AssemblyCodegen):
#supports_constant_folding: bool = True
def specialize(self, asm):
ins = [".version 8.2", ".target " + arch(), ".address_size 64",
f".visible .entry test({', '.join(f'.param .u64 buf{i}' for i in range(len(self.bufs)))}) {{"]
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", BinaryOps.CMPEQ: "setp.eq", UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn"}
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
ins.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
elif uop == UOps.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == UOps.SPECIAL:
if arg.startswith('buf'):
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: is this needed?
#ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith('gid'):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith('lid'):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == UOps.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
otype = vin[0].dtype if arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT] else out.dtype
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
elif uop == UOps.LOAD:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[out.dtype]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
elif uop == UOps.STORE:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[vin[1].dtype]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
elif uop == UOps.CAST:
if vin[0].dtype == dtypes.bool:
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, 0f3F800000, 0f00000000, {vin[0]};")
else:
ins.append(f"cvt.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[vin[0].dtype]} {out}, {vin[0]};")
elif uop == UOps.CONST:
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else arg};")
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins += ["ret;", "}"]
return "test", '\n'.join(ins)

View File

@@ -21,7 +21,9 @@ def safe_numpy(t) -> np.ndarray:
if t not in numpy_cache:
if DEBUG >= 1:
print("numpy cache miss", t)
numpy_cache[t] = t.numpy()
tmp = t.numpy()
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
assert len(numpy_cache[t].shape) > 0
return numpy_cache[t]
onnx_ops = importlib.import_module('extra.onnx_ops')
@@ -92,7 +94,7 @@ def get_run_onnx(onnx_model: ModelProto):
if inp.name in tensors: continue
tmp=inp.type.optional_type.elem_type.tensor_type if inp.type.HasField("optional_type") else (inp.type.sequence_type.elem_type.tensor_type if inp.type.HasField("sequence_type") else inp.type.tensor_type)
shape = shape_to_tuple(tmp.shape)
if len(shape) >= 1 and shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
if len(shape) >= 1: shape = tuple([x if x != 0 else 1 for x in shape]) # replace all dynamic dims with 1 for now
if inp.name in inputs:
if isinstance(inputs[inp.name], Tensor):
input_tensors[inp.name] = inputs[inp.name]
@@ -183,6 +185,8 @@ def get_run_onnx(onnx_model: ModelProto):
steps = safe_numpy(inp[4])[0] if len(inp) > 4 else 1
starts, ends = safe_numpy(starts.cast(dtypes.int32)).tolist(), safe_numpy(ends.cast(dtypes.int32)).tolist() # TODO: when indexing is added use that
for i,axis in enumerate(axes.tolist()):
assert axis % 1 == 0
axis = int(axis)
arg[axis] = (starts[i] if starts[i] >= 0 else inp[0].shape[axis]+starts[i], ends[i] if ends[i] >= 0 else inp[0].shape[axis]+ends[i])
ret = inp[0].slice(arg=arg)
elif n.op_type == "Shrink":

View File

@@ -1,9 +1,10 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
import numpy as np
import functools
from typing import Union, Tuple
from typing import Union, Tuple, Optional
import math
def Unsqueeze(data, axes):
@@ -201,7 +202,7 @@ def Tile(input, repeats):
final_shape = [r*s for r,s in zip(repeats_, input.shape)]
return input.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def Range(start, limit, delta): return Tensor.arange(safe_numpy(start)[0], safe_numpy(limit)[0], step=safe_numpy(delta)[0])
def Range(start, limit, delta): return Tensor.arange(*[safe_numpy(x)[0].item() for x in (start, limit, delta)])
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
@@ -218,10 +219,7 @@ def ConstantOfShape(input, value:Tensor=None):
shape = [int(x) for x in safe_numpy(input)]
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
# this is obviously wrong, but since we don't have types, it's better than nothing
def Cast(input, to):
print(f"WARNING: attempting to cast to {to}")
return input
def Cast(input, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
# NOTE: since we only have one type, this is valid!
def CastLike(input, target_type):
@@ -263,3 +261,70 @@ def OneHot(indices, depth, values, axis=-1):
def Floor(x:Tensor): return x.floor()
def Ceil(x:Tensor): return x.ceil()
def EmbedLayerNormalization(input_ids, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert (mask is None) is (mask_index_type is None)
assert mask is None, "functionality not supported yet" # TODO
input_shape = input_ids.shape
bsz, seq_length = input_shape[0], input_shape[1]
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
# bert embedding layer
if epsilon is None: epsilon = 1e-12
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = input.linear(weights, bias)
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
if epsilon is None: epsilon=1e-12
x = input + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
x = x + bias
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())

View File

@@ -1,13 +1,14 @@
import csv
import pathlib
import time
import onnx
import csv, pathlib, time, numpy as np
from os import getenv
import torch
torch.set_num_threads(1)
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
import onnxruntime as ort
from onnx2torch import convert
from extra.utils import download_file
from extra.onnx import get_run_onnx
from tinygrad.helpers import OSX
from tinygrad.helpers import OSX, DEBUG
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
@@ -16,6 +17,7 @@ MODELS = {
"openpilot": "https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx",
"efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx",
"shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx",
"commavq": "https://github.com/commaai/commavq/raw/master/models/gpt2m.onnx",
# broken in torch MPS
#"zfnet": "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx",
@@ -29,6 +31,7 @@ MODELS = {
CSV = {}
open_csv = None
torch.manual_seed(1)
def benchmark(mnm, nm, fxn):
tms = []
@@ -36,26 +39,31 @@ def benchmark(mnm, nm, fxn):
st = time.perf_counter_ns()
ret = fxn()
tms.append(time.perf_counter_ns() - st)
print(f"{m:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
CSV[nm] = min(tms)*1e-6
return min(tms), ret
#BASE = pathlib.Path(__file__).parent.parent.parent / "weights" / "onnx"
BASE = pathlib.Path("/tmp/onnx")
def benchmark_model(m):
def benchmark_model(m, validate_outs=False):
global open_csv, CSV
CSV = {"model": m}
fn = BASE / MODELS[m].split("/")[-1]
download_file(MODELS[m], fn)
onnx_model = onnx.load(fn)
output_names = [out.name for out in onnx_model.graph.output]
excluded = {inp.name for inp in onnx_model.graph.initializer}
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded}
np_inputs = {k:torch.randn(shp).numpy() for k,shp in input_shapes.items()}
assert len(input_shapes) < 20
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded}
#input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast
np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()}
assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}"
for device in ["METAL" if OSX else "GPU", "CLANG"]:
# print input names
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
for device in ["METAL" if OSX else "GPU", "CLANG"]: # + (["CUDA"] if torch.cuda.is_available() else []):
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
@@ -67,19 +75,55 @@ def benchmark_model(m):
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()})
del inputs, tinygrad_model, tinygrad_jitted_model
torch_model = convert(onnx_model)
torch_inputs = [torch.tensor(x) for x in np_inputs.values()]
benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs))
try:
torch_model = convert(onnx_model)
torch_inputs = [torch.tensor(x) for x in np_inputs.values()]
benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs))
torch_device = "mps" if OSX else "cuda"
torch_mps_model = torch_model.to(torch_device)
torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
torch_device = "mps" if OSX else "cuda"
torch_mps_model = torch_model.to(torch_device)
torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
except NotImplementedError:
print(f"{m:16s}onnx2torch doesn't support this model")
# bench onnxruntime
ort_options = ort.SessionOptions()
ort_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_options.log_severity_level = 3 # no warnings
for backend in ["CPU", "CUDA" if not OSX else "CoreML"]: # https://onnxruntime.ai/docs/execution-providers/
provider = backend+"ExecutionProvider"
if provider not in ort.get_available_providers(): continue
ort_sess = ort.InferenceSession(str(fn), ort_options, [provider])
benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs))
del ort_sess
if validate_outs:
rtol, atol = 8e-4, 8e-4 # tolerance for fp16 models
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
tinygrad_out = tinygrad_model(inputs)
ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"])
onnx_out = ort_sess.run(output_names, np_inputs)
onnx_out = dict([*[(name,x) for name, x in zip(output_names, onnx_out)]])
assert_allclose(tinygrad_out, onnx_out, rtol=rtol, atol=atol)
print(f"{m:16s}outputs validated with rtol={rtol:.1e}, atol={atol:.1e}")
if open_csv is None:
open_csv = csv.DictWriter(open('onnx_inference_speed.csv', 'w', newline=''), fieldnames=list(CSV.keys()))
open_csv.writeheader()
open_csv.writerow(CSV)
def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5):
assert len(tiny_out) == len(onnx_out) and tiny_out.keys() == onnx_out.keys()
for k in tiny_out.keys():
tiny_v, onnx_v = tiny_out[k], onnx_out[k]
if tiny_v is None: assert tiny_v == onnx_v
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}")
if __name__ == "__main__":
for m in MODELS: benchmark_model(m)
if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), True)
else:
for m in MODELS: benchmark_model(m, True)

View File

@@ -0,0 +1,135 @@
#!/usr/bin/env python
import unittest, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import get_parameters, get_state_dict
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.lazy import Device
from examples.llama import Transformer
ALLOCATED_DEV_BUFS = 0
class FakeDeviceBuffer():
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device
global ALLOCATED_DEV_BUFS
ALLOCATED_DEV_BUFS += 1
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"
FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass
def helper_test_correctness(gen, train):
from tinygrad.runtime.ops_gpu import CL, CLAllocator
old_alloc = CL.cl_allocator
CL.cl_allocator = CLAllocator(0)
no_alloc_result = train(*gen()).numpy()
Device[Device.DEFAULT].synchronize()
CL.cl_allocator = CLAllocator(512<<30) # Test cache correctness, so cache as much as possible, 512gb
for _ in range(4):
GlobalCounters.reset()
np.testing.assert_allclose(train(*gen()).numpy(), no_alloc_result, rtol=1e-3, atol=1e-5)
Device[Device.DEFAULT].synchronize()
assert len(CL.cl_allocator.cached_buffers) != 0, "Cache must be used"
CL.cl_allocator = old_alloc
def __helper_test_alloc_count(gen, train):
was_alloc = ALLOCATED_DEV_BUFS
for _ in range(2):
train(*gen())
return ALLOCATED_DEV_BUFS - was_alloc
def helper_test_alloc_count(mm, gen, train):
global FAKE_GLOBAL_ALLOCATOR
backup_program = Device[Device.DEFAULT].runtime
backup_buffer = Device[Device.DEFAULT].buffer
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].buffer = FakeBuffer
Device[Device.DEFAULT].method_cache.clear()
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(16<<30)
new_allocs = __helper_test_alloc_count(gen, train)
Device[Device.DEFAULT].method_cache.clear()
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(0)
old_allocs = __helper_test_alloc_count(gen, train)
print(f"{mm}: llama: old allocs count {old_allocs}, new allocs count {new_allocs}")
assert new_allocs < old_allocs, f"Hmm, doesn't cache work any more?"
Device[Device.DEFAULT].runtime = backup_program
Device[Device.DEFAULT].buffer = backup_buffer
FAKE_GLOBAL_ALLOCATOR = None
def check_gc():
if Device.DEFAULT == "GPU":
gc.collect() # Need to collect Tensors.
from extra.introspection import print_objects
assert print_objects() == 0
# for speed
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = [derandomize(s) for s in x.src]
else:
x.op = derandomize(x.op)
return x
def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = derandomize(p.lazydata)
p.realize()
class TestAllocators(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama(self):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def __test():
model = Transformer(**args_tiny)
derandomize_model(model)
def test(t): return model(t, 0).realize()
helper_test_correctness(lambda: (Tensor([[1,]]),), test)
__test()
Tensor.default_type = old_type
check_gc()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama_alloc_counts(self):
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def test_alloc_count(t):
model = Transformer(**args_tiny)
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
return model(t, 0).realize()
helper_test_alloc_count("llama", lambda: (Tensor([[2,]]),), test_alloc_count)
check_gc()
@unittest.skip("huge for CI")
def test_stable_diffusion(self):
from examples.stable_diffusion import UNetModel
model = UNetModel()
derandomize_model(model)
def test(t, t2): return model(t, 801, t2).realize()
helper_test_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test)
if __name__ == "__main__":
unittest.main()

106
test/test_allocators.py Normal file
View File

@@ -0,0 +1,106 @@
#!/usr/bin/env python
import unittest
import numpy as np
from weakref import ref
from tinygrad.ops import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.lazy import Device
def check_gc():
if Device.DEFAULT == "GPU":
from extra.introspection import print_objects
assert print_objects() == 0
class FakeDeviceBuffer():
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device
def __del__(self):
assert self.id == 0, "Should called _do_free() before"
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"
FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
def alloc(allocator, size, dtype, **kwargs):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = allocator
buf = FakeBuffer(size, dtype, **kwargs)
assert buf.dtype == dtype and buf.size == size
FAKE_GLOBAL_ALLOCATOR = None
return buf
def alloc_free_trace(allocator, size, dtype, **kwargs):
buf = alloc(allocator, size, dtype, **kwargs)
return ref(buf._buf)
def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf
class TestAllocators(unittest.TestCase):
def test_lru_allocator_reusage(self):
def test():
lru_allocator = FakeAllocator(2048)
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached"
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused"
__test()
usedbuf = alloc(lru_allocator, 16, dtypes.float32)
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert usedbuf != buf, "Nobody should get used buffer"
__test()
assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
test()
check_gc()
def test_lru_allocator_cache_free(self):
def test():
lru_allocator = FakeAllocator(128)
refs = []
for _ in range(32):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
for sz in range(32):
alloc_free_trace(lru_allocator, sz, dtypes.float32)
assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)"
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
test()
check_gc()
def test_lru_allocator_multidevice(self):
def test():
lru_allocator = FakeAllocator(256)
refs=[]
for i in range(8):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i)))
for i in range(64):
def __test():
dev = str(i % 8)
buf = alloc(lru_allocator, 16, dtypes.float32, device=dev)
assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused"
__test()
for r in refs: assert r() is not None, "All refs should be cached"
test()
check_gc()
if __name__ == "__main__":
unittest.main()

View File

@@ -91,6 +91,12 @@ class TestHalfDtype(unittest.TestCase):
def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16)
@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"], "float64 is not supported by some backends")
class TestDoubleDtype(unittest.TestCase):
def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4])
def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64)
def test_upcast_to_float64_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float64, target_dtype=dtypes.float64)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase):
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
@@ -107,8 +113,10 @@ class TestInt8Dtype(unittest.TestCase):
def test_int8_upcast_int64(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
@unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
@unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap")
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import Context, ContextVar, merge_dicts
VARIABLE = ContextVar("VARIABLE", 0)
@@ -106,5 +106,17 @@ with Context(VARIABLE=1):
...
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
class TestMergeDicts(unittest.TestCase):
def test_merge_dicts(self):
a = {"a": 1, "b": 2}
b = {"a": 1, "c": 3}
c = {}
d = {"a": 2, "b": 2}
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
assert merge_dicts([a, c]) == a
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
with self.assertRaises(AssertionError):
merge_dicts([a, d])
if __name__ == '__main__':
unittest.main()

View File

@@ -4,7 +4,7 @@ import math
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
from tinygrad.lazy import Device
if CI:
@@ -1127,21 +1127,37 @@ class TestOps(unittest.TestCase):
n = (x < 0).where(x, 1).numpy()
assert np.all(n == 1.)
def test_slice_fancy_indexing(self):
# indices cannot have gradient
a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False)
b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False)
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False)
e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False)
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]]
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,b,c,d,e], lambda x: x[i,j,k,o,p])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[:,b,c,d,e], lambda x: x[:,j,k,o,p])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[:,b,c,d,:], lambda x: x[:,j,k,o,:])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,b,...], lambda x: x[i,j,...])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,...,e], lambda x: x[i,...,p])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[...,c,:,e], lambda x: x[...,k,:,p])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[1,:,10:11,d,0:2], lambda x: x[1,:,10:11,o,0:2])
helper_test_op([(2,5,15,5,3,4)], lambda x: x[1,4,c,d,2], lambda x: x[1,4,k,o,2])
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)])
helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])])
def test_gather(self):
nda = np.random.randn(4,5,6,9,5).astype(np.float32)
ten = Tensor(nda, requires_grad=True)
tor = torch.tensor(nda, requires_grad=True)
c = np.random.randint(low=-4, high=4, size=[3,4,5]).astype(np.int32)
a = Tensor(c, requires_grad=False)
b = torch.tensor(c, requires_grad=False)
helper_test_op([], lambda: tor[b,:,:,:,:], lambda: ten.gather(a, dim=0))
helper_test_op([], lambda: tor[:,b,:,:,:], lambda: ten.gather(a, dim=1))
helper_test_op([], lambda: tor[:,:,b,:,:], lambda: ten.gather(a, dim=2))
helper_test_op([], lambda: tor[:,:,:,b,:], lambda: ten.gather(a, dim=3))
helper_test_op([], lambda: tor[:,:,:,:,b], lambda: ten.gather(a, dim=4))
ta = Tensor(c, requires_grad=True)
tb = torch.tensor(c, requires_grad=True, dtype=torch.float32)
self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError
# indices cannot have gradient
# indices cannot be negative (torch gather)
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1))
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2))
helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
def test_scaled_product_attention(self):
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))

113
test/test_symbolic_ops.py Normal file
View File

@@ -0,0 +1,113 @@
import unittest
from tinygrad.shape.symbolic import Variable
from tinygrad.helpers import getenv, CI
from tinygrad.tensor import Tensor, Device
import numpy as np
@unittest.skipIf(getenv("ARM64"), "ARM64 is not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG"], f"{Device.DEFAULT} is not supported")
class TestSymbolicOps(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_add(self):
def f(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_matmul(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_matmul_same_var_different_val(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4)
b = Tensor.rand(7, 5)
with self.assertRaises(AssertionError):
f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
@unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI")
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
expected = f(q, k, v).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@@ -26,17 +26,18 @@ class TestSymbolic(unittest.TestCase):
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t1.lazydata.st
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
assert st.shape == (i+j+k, 4)
assert st.real_strides() == (4, 1)
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t1.lazydata.st
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
assert st.real_strides() == (3, 1)
class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_simple(self):

View File

@@ -80,14 +80,14 @@ class TestFloatUOps(TestUOps):
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
# TODO: fix this on all the backends
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some")
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
class TestNonFloatUOps(TestUOps):
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32)
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken in CLANG")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
if __name__ == '__main__':

View File

@@ -283,6 +283,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert i // i == 1
def test_node_mod_node(self):
i = Variable("i", 1, 10)

View File

@@ -7,8 +7,8 @@ import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
class Register(NamedTuple):
nm:str
@@ -37,9 +37,10 @@ class AssemblyLanguage:
tor: Dict[Any, Register] = {}
ins: List[AssemblyInstruction] = []
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
def newreg(self, tok, dtype=dtypes.float32, scalar=False):
if isinstance(tok, Token): dtype = tok.dtype # this
self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
@@ -96,8 +97,8 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
#TODO: Do not use clear()
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
@@ -117,7 +118,7 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here?
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
@@ -129,7 +130,9 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
elif args[1] == 'local':
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(newvar)
@@ -155,10 +158,11 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
elif uop == UOps.LOAD and newvar is not None:
if isinstance(args, ConstOp):
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value))
reg = lang.newreg(newvar, dtype=newvar.dtype)
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args.invalid_value))
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value))
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args.value))
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
else:
@@ -173,16 +177,16 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None)))
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None)))
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
# define registers
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, lang.type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
if DEBUG >= 4:
for tins in lang.ins: print(tins)
return global_size, local_size
return global_size, local_size

View File

@@ -22,7 +22,7 @@ def specialize_to_arm64(fn_nm, asm):
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
@@ -137,12 +137,12 @@ def specialize_to_arm64(fn_nm, asm):
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:

View File

@@ -0,0 +1,98 @@
from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
def render_cast(ins, inp, out):
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
elif out.dtype == dtypes.bool:
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
else:
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXLanguage(AssemblyLanguage):
supports_constant_folding: bool = True
def specialize_to_ptx(lang, function_name):
param_cnt = 0
ins = []
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
for uop, out, vin, arg in lang.ins:
if uop == UOps.ENDLOOP:
ins.append("bar.sync 0;")
elif uop == UOps.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == UOps.SPECIAL:
if arg.startswith('data'):
param_cnt += 1
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
# ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith('gid'):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith('lid'):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == UOps.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
if arg == TernaryOps.WHERE:
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
vin = vin[1:] + [reg]
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
reg = lang.newreg((out, dt[0]), dtype=dt[1])
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
render_cast(ins, reg, out)
else:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
elif uop == UOps.STORE:
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
if arg[2] == dtypes.bool != vin[1].dtype:
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
render_cast(ins, vin[1], prereg)
else: prereg = vin[1]
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
render_cast(ins, prereg, reg)
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
else:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
elif uop == UOps.CAST:
render_cast(ins, vin[0], out)
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
ins = ins_prefix + ins
ins += ["ret;", "}"]
return '\n'.join(ins)
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
lang = PTXLanguage()
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True

View File

@@ -9,7 +9,7 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape, View
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
VariableOrNum = Union[Variable, NumNode, Node]
# bottom ones are asm only
@@ -301,6 +301,9 @@ class Linearizer:
# add global buffers
for buf,name in self.arg_bufs.items():
self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
@@ -317,7 +320,7 @@ class Linearizer:
if DEBUG >= 3: self.printbufs()
# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# parse AST
@@ -548,7 +551,7 @@ class Linearizer:
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)

View File

@@ -162,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer):
# early exit
return
if k.opts.has_local:
if k.opts.has_local and all(isinstance(s, int) for s in k.sts[0].shape[:k.first_reduce]):
# are we grouping? (requires local shape support)
if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
@@ -204,8 +204,8 @@ def hand_coded_optimizations(k:Linearizer):
while prod(k.sts[0].shape[:k.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(range(k.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if len(xb_choices):
xb_choices = sorted(xb_choices)
@@ -219,7 +219,7 @@ def hand_coded_optimizations(k:Linearizer):
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if k.first_reduce < (k.shape_len-k.upcasted) and (len(list(k.shape_offsets(k.full_buf_index))) <= 4 or not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))):
if (s:=k.full_unupcasted_shape[-1]) <= 32:
if (s:=k.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
k.upcast()
# if it's small, upcast a second reduce dimension too
if k.first_reduce < (k.shape_len-k.upcasted) and s <= 3 and k.full_unupcasted_shape[-1] <= 3: k.upcast()

View File

@@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib
from weakref import KeyedRef, ref
from _weakref import _remove_dead_weakref # type: ignore
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
from math import prod # noqa: F401 # pylint:disable=unused-import
ShapeType = Tuple[int, ...]
@@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def merge_dicts(ds:Iterable[Dict]) -> Dict:
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for k,v in kvs}
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
@@ -83,11 +87,11 @@ class ImageDType(DType):
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes._half4, dtypes._float4)
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
@@ -115,6 +119,7 @@ class dtypes:
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
@@ -125,6 +130,7 @@ class GlobalCounters:
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None

View File

@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType
@@ -12,7 +12,7 @@ class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
self.cnt: int = 0
self.jit_cache: List[Tuple[Callable, Any]] = [] # TODO: Any should be List[RawBuffer], but this fails
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]]]] = []
self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], int, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_size, expected_type)
@@ -29,7 +29,8 @@ class TinyJit:
for (j,i),(input_name, expected_size, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name].size == expected_size and input_rawbuffers[input_name].dtype == expected_type, f"size or type mismatch in JIT, {input_rawbuffers[input_name]} != <{expected_size}, {expected_type}>"
self.jit_cache[j][1][i] = input_rawbuffers[input_name]
for prg, args in self.jit_cache: prg(args, jit=True)
for prg, pargs in self.jit_cache: # type: Callable, List[Optional[RawBuffer]]
prg(pargs, jit=True)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 1:
GlobalCounters.cache = []
@@ -40,10 +41,10 @@ class TinyJit:
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# get the inputs for replacement
for j,(prg,args) in enumerate(self.jit_cache): # pylint: disable=E1133
for i,a in enumerate(args):
for j_,(_,pargs_) in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]]]]
for i,a in enumerate(pargs_):
if a in input_rawbuffers.values():
self.input_replace[(j,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0]
self.input_replace[(j_,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0]
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None

View File

@@ -2,8 +2,9 @@ from __future__ import annotations
import functools, time
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
@@ -131,20 +132,24 @@ class ASTRunner:
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
return self
def exec(self, bufs, force_wait=False, optimizing=False) -> Optional[float]:
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs, force_wait=force_wait)
return self(rawbufs, var_vals, force_wait=force_wait)
def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
if et := self.clprg((self.global_size + [1]*(3-len(self.global_size))) if self.global_size is not None else None,
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None,
(local_size + [1]*(3-len(local_size))) if local_size is not None else None,
*rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += self.mem_estimate
if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0)
return et
@@ -178,9 +183,11 @@ class Compiled:
output.realized = None
break
# we don't have an output buffer, we have to create it
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized:
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
# update the output var_vals from src
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, output, self.linearizer_opts)
@@ -200,5 +207,5 @@ class Compiled:
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(k.bufs)
prg.exec(k.bufs, var_vals=output.st.var_vals)
return output.realized

View File

@@ -3,7 +3,7 @@ import math
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
# div is different in cl than python
render_cl = render_python.copy()
@@ -17,6 +17,7 @@ class CStyleLanguage(NamedTuple):
buffer_prefix: str = ""
buffer_suffix: str = ""
smem_prefix: str = ""
arg_int_prefix: str = ""
barrier: str = ""
gid: List[str] = []
lid: List[str] = []
@@ -52,7 +53,7 @@ class CStyleLanguage(NamedTuple):
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("f" if isinstance(x, float) else "")
else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
# returns a str expression of the loaded value with the output type
@@ -69,7 +70,7 @@ class CStyleLanguage(NamedTuple):
def render_local(self, name:str, size:int):
return self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:int, _max:int) -> str:
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
@@ -78,6 +79,7 @@ class CStyleLanguage(NamedTuple):
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
@@ -128,7 +130,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)

View File

@@ -22,17 +22,18 @@ code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
BinaryOps.CMPLT: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()),
BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=('fast',)),
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=('fast',)),
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=('fast',)),
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y),
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
}
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
@@ -44,6 +45,8 @@ def cast(bb, val, input_type, output_type):
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
elif input_type == dtypes.float64:
val = bb[-1].fptrunc(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
return val
@@ -55,6 +58,8 @@ def cast(bb, val, input_type, output_type):
val = bb[-1].bitcast(val, ir.IntType(32))
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].trunc(val, ir.IntType(16))
elif output_type == dtypes.float64:
val = bb[-1].fpext(val, ir.DoubleType())
else:
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
return val
@@ -114,11 +119,11 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
valid = args.valid.render(render_llvm, bb[-1])
if isinstance(args, ConstOp):
assert newvar.dtype == dtypes.float, "newvar must be float"
value, invalid_value = [int(args.value), int(args.invalid_value)] if dtypes.is_int(newvar.dtype) else ([bool(args.value), bool(args.invalid_value)] if newvar.dtype == dtypes.bool else [args.value, args.invalid_value]) # type: ignore
if args.valid.min == 0 and args.valid.max == 1:
val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value))
val = bb[-1].select(valid, ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value), ir.Constant(dtype_to_llvm_dtype[newvar.dtype], invalid_value))
else:
val = ir.Constant(ir.FloatType(), args.value if args.valid.min == 1 else args.invalid_value)
val = ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value if args.valid.min == 1 else invalid_value)
# TODO: this is a hack. it shouldn't be const that signals this
reduce_phis.append(newvar)
else:

View File

@@ -38,7 +38,7 @@ class WGSLLanguage(CStyleLanguage):
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
return prg, global_size[::-1] if len(global_size) else [1], local_size
def render_for(self, expr:str, _min:int, _max:int) -> str:
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:

View File

@@ -1,18 +1,21 @@
import ctypes
import numpy as np
from typing import TypeVar, Type, Any
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters
from collections import defaultdict, deque
from typing import TypeVar, Type, Any, Dict, Deque, Tuple
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
_T = TypeVar("_T")
class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType, buf:Any=None):
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
self.size: int = size
self.dtype: DType = dtype
self._buf = buf
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._memsz: int = size*dtype.itemsize
self._allocator = allocator
GlobalCounters.mem_used += self._memsz
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
@property
def key(self): return (self.size, self.dtype.key)
@@ -39,7 +42,7 @@ class RawBufferMapped(RawBufferCopyIn):
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):
@@ -66,3 +69,43 @@ class RawConst(RawBuffer): # pylint: disable=abstract-method
def buf_is_kernel_arg(x) -> bool:
return x.realized is not None and x.realized.__class__ is not RawConst
class LRUAllocator:
def __init__(self, dev_memsz=(4<<30)):
self.epoch = 0
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
def __del__(self):
for v in self.cached_buffers.values():
for buf, _ in v: self._free_buffer(buf)
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
return rawbufs.popleft()[0]
def _alloc_buffer(self, size, dtype, device, **kwargs):
self.free_space[device] -= size*dtype.itemsize
while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
bucket, epoch = self.aging_order[device].popleft()
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
newbuf = self._do_alloc(size, dtype, device, **kwargs)
self.buffer_info[newbuf] = (size, dtype, device)
return newbuf
def _free_buffer(self, buf_to_free):
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
self.buffer_info.pop(buf_to_free)
self._do_free(buf_to_free)
def alloc(self, size, dtype, device='0', **kwargs):
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
self.epoch += 1
size, dtype, device = self.buffer_info[buf]
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
def _do_free(self, buf): pass

View File

@@ -74,8 +74,8 @@ class ClangProgram:
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf for x in args])
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.monotonic()-st
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)

View File

@@ -2,9 +2,9 @@ import subprocess, time, re, hashlib, tempfile, os, functools
from typing import Optional
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -47,8 +47,12 @@ if getenv("CUDACPU", 0) == 1:
else:
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
class CUDAAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory())
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc)
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore
@@ -91,5 +95,5 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
__device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {}
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
};
"""))
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)
""")) if not getenv("PTX") else fromimport("tinygrad.codegen.assembly_ptx", "uops_to_ptx_asm")
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)

View File

@@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore
from typing import Optional, List
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -17,30 +17,36 @@ ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
if DEBUG >= 5:
early_exec = fromimport("extra.helpers", "enable_early_exec")()
class CLAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
if isinstance(dtype, ImageDType):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
else:
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
return buf
class _CL:
def __init__(self):
cl_platforms = cl.get_platforms()
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if len(y)]
self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
self.cl_platform = self.devices[0].platform
def post_init(self, device=None):
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
self.cl_platform = cl.get_platforms()[getenv('CL_PLATFORM', 0)]
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])]
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
def synchronize(self):
for q in self.cl_queue: q.finish()
CL = _CL()
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device='0'):
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
else:
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
super().__init__(size, dtype, buf)
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
@@ -80,7 +86,7 @@ class CLProgram:
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
if wait:
e.wait()
@@ -91,9 +97,8 @@ class CLProgram:
return None
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), renderer, CLProgram, CL.synchronize)

View File

@@ -3,7 +3,7 @@ import ctypes, functools
import extra.hip_wrapper as hip
from tinygrad.helpers import DEBUG
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -14,9 +14,14 @@ if DEBUG >= 5:
# The default HIP stream is used for everything.
class HIPAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return hip.hipMalloc(size * dtype.itemsize)
def _do_free(self, buf): hip.hipFree(buf)
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
HIPAlloc = HIPAllocator(hip.hipGetDeviceProperties(hip.hipGetDevice()).totalGlobalMem)
class RawHIPBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype): super().__init__(size, dtype, hip.hipMalloc(size * dtype.itemsize))
def __del__(self): hip.hipFree(self._buf)
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=HIPAlloc)
def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, 0)
def _copyout(self, x:np.ndarray): hip.hipMemcpy_dtoh(x.ctypes.data, self._buf, self.size * self.dtype.itemsize)

View File

@@ -1,20 +1,26 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, functools
import os, subprocess, pathlib, functools, ctypes
import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
from tinygrad.helpers import prod, getenv, DEBUG, DType
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
METAL_XCODE = getenv("METAL_XCODE")
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
def _do_free(self, buf): buf.release()
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
class _METAL:
def __init__(self):
self.mtl_buffers_in_flight: List[Any] = []
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueue()
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
# TODO: is there a better way to do this?
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
@@ -23,10 +29,8 @@ METAL = _METAL()
class RawMetalBuffer(RawBufferMapped):
def __init__(self, size:int, dtype:DType):
super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared))
def __del__(self):
self._buf.release()
super().__del__()
assert dtype != dtypes.double, f"METAL does not support {dtype.name}"
super().__init__(size, dtype, allocator=METAL.allocator)
def _buffer(self):
METAL.synchronize()
return self._buf.contents().as_buffer(self._buf.length())
@@ -64,7 +68,10 @@ class MetalProgram:
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
for i,a in enumerate(bufs):
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
@@ -74,7 +81,7 @@ class MetalProgram:
METAL.mtl_buffers_in_flight.append(command_buffer)
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))

View File

@@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
inverse_type_map = {v:k for k,v in type_map.items()}
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{

View File

@@ -1,7 +1,7 @@
import numpy as np
import functools
from wgpu.utils._device import get_default_device # type: ignore
from tinygrad.runtime.lib import RawBufferCopyIn
from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import Compiled
from tinygrad.codegen.linearizer import LinearizerOptions
@@ -9,32 +9,37 @@ from tinygrad.renderer.cstyle import uops_to_cstyle
from tinygrad.renderer.wgsl import WGSLLanguage
import wgpu # type: ignore
device = get_default_device()
wgpu_device = get_default_device()
class WebGPUProgram:
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg)
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
def __call__(self, global_size, local_size, *bufs, wait=False):
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
command_encoder = device.create_command_encoder()
bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts)
pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings)
compute_pipeline = wgpu_device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
command_encoder = wgpu_device.create_command_encoder()
compute_pass = command_encoder.begin_compute_pass()
compute_pass.set_pipeline(compute_pipeline)
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
compute_pass.dispatch_workgroups(*global_size) # x y z
compute_pass.end()
device.queue.submit([command_encoder.finish()])
wgpu_device.queue.submit([command_encoder.finish()])
class RawWebGPUAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return wgpu_device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
WebGPUAlloc = RawWebGPUAllocator(wgpu_device.limits['max_buffer_size'])
class RawWebGPUBuffer(RawBufferCopyIn):
def __init__(self, size:int, dtype:DType):
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU"
super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC))
def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64,dtypes.double], f"dtype {dtype} not supported on WEBGPU"
super().__init__(size, dtype, allocator=WebGPUAlloc)
def _copyin(self, x:np.ndarray): wgpu_device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram)

View File

@@ -59,7 +59,7 @@ class View(ViewInternal):
# generate an expression if you have a single idx variable
def expr_node(self, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(self.shape)-1)
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else []
acc = 1
for d,s in reversed(self.shape_strides):
ret.append(((idx//acc)%d)*s)
@@ -69,7 +69,7 @@ class View(ViewInternal):
# generate an expression if you have a variable or expression for each index
def expr_idxs(self, idxs) -> Node:
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
@@ -162,7 +162,7 @@ class ShapeTracker:
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable):
ret[idxs.index(this_dim)] = 1

View File

@@ -65,6 +65,7 @@ class Node:
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if self == b: return NumNode(1)
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
@@ -262,6 +263,14 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
if isinstance(n, (int, NumNode)): return int(n)
if isinstance(n, Variable): return var_vals[n]
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
raise NotImplementedError(n)
@functools.lru_cache(maxsize=None)
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
render_python: Dict[Type, Callable] = {

View File

@@ -274,12 +274,20 @@ class Tensor:
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
# is possible.
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
# - Fancy indexing and combined indexing is supported
# - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
# - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
# - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
# - The dims are reduced at sum_dim for each Tensor passed in
# - There's a special case where a permute is needed at the end:
# - if first Tensor passed in (expand dims) is not at dim 0
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
def __getitem__(self, val):
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
val = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape):
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
orig_slices = list(val)
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
@@ -290,6 +298,8 @@ class Tensor:
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
else:
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)]
orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices]
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
@@ -315,21 +325,46 @@ class Tensor:
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
sliced_tensor = reshaped_tensor.shrink(final_slice)
final_shape = []
sub = [0] * len(tensor_found)
it_shape = iter(new_shape)
for i in orig_slices:
if isinstance(i, (int, slice)):
for i,s in enumerate(orig_slices):
if isinstance(s, (int, slice)):
dim_shape = next(it_shape)
if isinstance(i, slice): final_shape.append(dim_shape)
else: # i is None
if isinstance(s, slice): final_shape.append(dim_shape)
elif tensor_found:
for i_ in range(len(tensor_found)):
if tensor_found[i_][0] > i: sub[i_] -= 1
else: # s is None
final_shape.append(1)
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
ret = sliced_tensor.reshape(tuple(final_shape)) # Reshape
if tensor_found: # Fancy/tensor indexing
for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1])
dim = [i[0] for i in tensor_found]
idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
max_dim = max(idx, key=lambda i: i.ndim).ndim
idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))]
new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1))
arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1))
ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0])
for idx_,d in zip(idx[1:],sum_dim[1:]):
new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim))
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
ret = ((new_idx == arange) * ret).sum(d)
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
order = list(range(ret.ndim))
order = order[dim[0]:dim[0]+idx[0].ndim] + order[:dim[0]] + order[dim[0]+idx[0].ndim:]
ret = ret.permute(order=order)
return ret
def gather(self, idx, dim):
idx = (idx < 0).where(idx+self.shape[dim], idx) # Turn neg idx pos
new_self = self.reshape(*self.shape[:dim+1], *[1]*idx.ndim, *self.shape[dim+1:])
arange = Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim, self.shape[dim], *[1]*idx.ndim, *[1]*(self.ndim-dim-1))
new_idx = idx.reshape(*[1]*dim, 1, *idx.shape, *[1]*(self.ndim-dim-1))
return (new_self * (arange == new_idx)).sum(dim)
def gather(self: Tensor, idx: Tensor, dim: int):
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
if dim < 0: dim += self.ndim
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
return ((idx == Tensor.arange(self.shape[dim])) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim